Skip to content

Commit

Permalink
Merge pull request #20 from mlcommons/lint
Browse files Browse the repository at this point in the history
Integrate flake8 and pyre for automated linting and type checking
  • Loading branch information
srinivas212 authored Feb 6, 2024
2 parents 5925827 + adc5f9a commit 120ce3d
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 335 deletions.
39 changes: 39 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
[flake8]
enable-extensions = G
select = B,C,E,F,G,P,SIM1,T4,W,B9,TOR0,TOR1,TOR2
max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# fix these lints in the future
E275,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907
# these ignores are from flake8-comprehensions; please fix!
C407,
# these ignores are from flake8-logging-format; please fix!
G100,G101,G200,G201,G202
# these ignores are from flake8-simplify. please fix or ignore with commented reason
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
# flake8-simplify code styles
SIM102,SIM103,SIM106,SIM112,
# TorchFix codes that don't make sense for PyTorch itself:
# removed and deprecated PyTorch functions.
TOR001,TOR101,
# TODO(kit1980): fix all TOR102 issues
# `torch.load` without `weights_only` parameter is unsafe
TOR102,
P201,
per-file-ignores =
__init__.py: F401
optional-ascii-coding = True
exclude =
./.git,
./build,
./et_def/et_def_pb2.py,
./et_def/et_def_pb2_grpc.py,
./third_party/utils/protolib.py,
27 changes: 27 additions & 0 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Python Linting

on: [push, pull_request]

jobs:
lint-and-format:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'

- name: Install dependencies
run: |
pip install flake8
pip install pyre-check
pip install .
- name: Run Flake8
run: flake8 .

- name: Run Pyre Check
run: pyre check
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ build/
__pycache__/
*.egg
*.et
*.dot
*.dot
.pyre
7 changes: 7 additions & 0 deletions .pyre_configuration
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"source_directories": [
"timeline_visualizer",
"et_converter"
],
"search_path": ["/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages"]
}
123 changes: 57 additions & 66 deletions et_converter/et_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

def get_logger(log_filename: str) -> logging.Logger:
formatter = logging.Formatter(
"%(levelname)s [%(asctime)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p")
"%(levelname)s [%(asctime)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p")

file_handler = FileHandler(log_filename, mode="w")
file_handler.setLevel(logging.DEBUG)
Expand All @@ -32,63 +32,54 @@ def get_logger(log_filename: str) -> logging.Logger:

def main() -> None:
parser = argparse.ArgumentParser(
description="Execution Trace Converter"
)
description="Execution Trace Converter")
parser.add_argument(
"--input_type",
type=str,
default=None,
required=True,
help="Input execution trace type"
)
"--input_type",
type=str,
default=None,
required=True,
help="Input execution trace type")
parser.add_argument(
"--input_filename",
type=str,
default=None,
required=True,
help="Input execution trace filename"
)
"--input_filename",
type=str,
default=None,
required=True,
help="Input execution trace filename")
parser.add_argument(
"--output_filename",
type=str,
default=None,
required=True,
help="Output Chakra execution trace filename"
)
"--output_filename",
type=str,
default=None,
required=True,
help="Output Chakra execution trace filename")
parser.add_argument(
"--num_dims",
type=int,
default=None,
required=True,
help="Number of dimensions in the network topology"
)
"--num_dims",
type=int,
default=None,
required=True,
help="Number of dimensions in the network topology")
parser.add_argument(
"--num_npus",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of NPUs in a system"
)
"--num_npus",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of NPUs in a system")
parser.add_argument(
"--num_passes",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of training passes"
)
"--num_passes",
type=int,
default=None,
required="Text" in sys.argv,
help="Number of training passes")
parser.add_argument(
"--npu_frequency",
type=int,
default=None,
required="FlexFlow" in sys.argv,
help="NPU frequency in MHz"
)
"--npu_frequency",
type=int,
default=None,
required="FlexFlow" in sys.argv,
help="NPU frequency in MHz")
parser.add_argument(
"--log_filename",
type=str,
default="debug.log",
help="Log filename"
)
"--log_filename",
type=str,
default="debug.log",
help="Log filename")
args = parser.parse_args()

logger = get_logger(args.log_filename)
Expand All @@ -97,27 +88,27 @@ def main() -> None:
try:
if args.input_type == "Text":
converter = Text2ChakraConverter(
args.input_filename,
args.output_filename,
args.num_dims,
args.num_npus,
args.num_passes,
logger)
args.input_filename,
args.output_filename,
args.num_dims,
args.num_npus,
args.num_passes,
logger)
converter.convert()
elif args.input_type == "FlexFlow":
converter = FlexFlow2ChakraConverter(
args.input_filename,
args.output_filename,
args.num_dims,
args.npu_frequency,
logger)
args.input_filename,
args.output_filename,
args.num_dims,
args.npu_frequency,
logger)
converter.convert()
elif args.input_type == "PyTorch":
converter = PyTorch2ChakraConverter(
args.input_filename,
args.output_filename,
args.num_dims,
logger)
args.input_filename,
args.output_filename,
args.num_dims,
logger)
converter.convert()
else:
logger.error(f"{args.input_type} unsupported")
Expand Down
49 changes: 22 additions & 27 deletions et_converter/flexflow2chakra_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chakra.third_party.utils.protolib import encodeMessage as encode_message
from chakra.et_def.et_def_pb2 import (
NodeType as ChakraNodeType,
Node as ChakraNode,
AttributeProto as ChakraAttr,
COMP_NODE,
Expand Down Expand Up @@ -36,31 +37,31 @@ def get_label(self, ff_node: Any) -> str:
try:
label = ff_node.get_attributes()["label"]
return label.replace("\"", "")[1:-1]
except:
raise ValueError(f"Cannot retrieve label from a FlexFlow node")
except Exception:
raise ValueError("Cannot retrieve label from a FlexFlow node")

def get_id(self, ff_node: Any) -> int:
ff_node_name = ff_node.get_name()
try:
return int(ff_node_name.replace("node", ""))
except:
except Exception:
raise ValueError(f"Cannot retrieve id from \"{ff_node_name}\"")

def get_npu_id(self, ff_node: Any) -> int:
label = self.get_label(ff_node)
try:
return int(label.split("|")[0].strip().split("=")[1])
except:
except Exception:
raise ValueError(f"Cannot retrieve npu_id from \"{label}\"")

def get_name(self, ff_node: Any) -> str:
label = self.get_label(ff_node)
try:
return label.split("|")[1].strip()
except:
except Exception:
raise ValueError(f"Cannot retrieve name from \"{label}\"")

def get_node_type(self, ff_node: Any) -> int:
def get_node_type(self, ff_node: Any) -> ChakraNodeType:
label = self.get_label(ff_node)
try:
node_type = label.split("|")[3].strip()
Expand All @@ -70,36 +71,36 @@ def get_node_type(self, ff_node: Any) -> int:
return COMM_SEND_NODE
else:
raise ValueError(f"Unsupported node_type, \"{node_type}\"")
except:
except Exception:
raise ValueError(f"Cannot retrieve node_type from \"{label}\"")

def get_runtime(self, ff_node: Any) -> int:
label = self.get_label(ff_node)
try:
wall_clock_time = float(label.split("|")[4].strip().split("=")[1])
return int(round(wall_clock_time * self.num_cycles_per_sec))
except:
except Exception:
raise ValueError(f"Cannot retrieve runtime from \"{label}\"")

def get_comm_src(self, ff_node: Any) -> int:
label = self.get_label(ff_node)
try:
return int(label.split("|")[4].strip().split("=")[1])
except:
except Exception:
raise ValueError(f"Cannot retrieve comm_src from \"{label}\"")

def get_comm_dst(self, ff_node: Any) -> int:
label = self.get_label(ff_node)
try:
return int(label.split("|")[5].strip().split("=")[1])
except:
except Exception:
raise ValueError(f"Cannot retrieve comm_dst from \"{label}\"")

def get_comm_size(self, ff_node: Any) -> int:
label = self.get_label(ff_node)
try:
return int(label.split("|")[6].strip().split("=")[1])
except:
except Exception:
raise ValueError(f"Cannot retrieve comm_size from \"{label}\"")

def convert_FF_node_to_CK_node(self, ff_node: Any) -> Any:
Expand Down Expand Up @@ -137,7 +138,7 @@ def convert(self) -> None:
src_id = int(edge.get_source().replace("node", ""))
dst_id = int(edge.get_destination().replace("node", ""))
ck_node = self.node_id_node_dict[dst_id]
ck_node.parent.append(src_id)
ck_node.data_deps.append(src_id)
num_ff_edges += 1
self.logger.info(f"Converted {num_ff_nodes} nodes and {num_ff_edges} edges")

Expand Down Expand Up @@ -165,7 +166,7 @@ def convert(self) -> None:
# communication nodes
elif (ck_node.type == COMM_SEND_NODE):
if (self.node_id_comm_info_dict[ck_node.id]["comm_src"] == npu_id)\
or (self.node_id_comm_info_dict[ck_node.id]["comm_dst"] == npu_id):
or (self.node_id_comm_info_dict[ck_node.id]["comm_dst"] == npu_id):
comm_src = self.node_id_comm_info_dict[ck_node.id]["comm_src"]
comm_dst = self.node_id_comm_info_dict[ck_node.id]["comm_dst"]
comm_key = f"{ck_node.id}_{comm_src}_{comm_dst}"
Expand All @@ -187,27 +188,21 @@ def convert(self) -> None:
ck_comm_node.type = COMM_RECV_NODE
ck_comm_node.name += f"_{ck_node.name}"

ck_comm_node.attr.append(
ChakraAttr(name="comm_src",
int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_src"]))
ck_comm_node.attr.append(
ChakraAttr(name="comm_dst",
int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_dst"]))
ck_comm_node.attr.append(
ChakraAttr(name="comm_size",
int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_size"]))
ck_comm_node.attr.append(
ChakraAttr(name="comm_tag",
int64_val=comm_tag))
ck_comm_node.attr.extend([
ChakraAttr(name="comm_src", int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_src"]),
ChakraAttr(name="comm_dst", int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_dst"]),
ChakraAttr(name="comm_size", int64_val=self.node_id_comm_info_dict[ck_node.id]["comm_size"]),
ChakraAttr(name="comm_tag", int64_val=comm_tag)
])

per_npu_comm_nodes += 1
total_comm_nodes += 1

# transfer dependencies
for parent_node_id in ck_node.parent:
for parent_node_id in ck_node.data_deps:
parent_node = self.node_id_node_dict[parent_node_id]
if self.node_id_npu_id_dict[parent_node.id] == npu_id:
ck_comm_node.parent.append(parent_node_id)
ck_comm_node.data_deps.append(parent_node_id)

npu_id_node_id_node_dict[npu_id].update({node_id: ck_comm_node})
self.logger.info(f"NPU[{npu_id}]: {per_npu_comp_nodes} compute nodes and {per_npu_comm_nodes} communication nodes")
Expand Down
Loading

0 comments on commit 120ce3d

Please sign in to comment.