Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jul 15, 2024
1 parent a6590fe commit 725d591
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 42 deletions.
77 changes: 41 additions & 36 deletions src/converter/converter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import argparse
import logging
import sys
import traceback

from .pytorch_converter import PyTorchConverter
from .text_converter import TextConverter


def setup_logging(log_filename: str) -> None:
"""Set up logging to file and stream handlers."""
formatter = logging.Formatter("%(levelname)s [%(asctime)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")

file_handler = logging.FileHandler(log_filename, mode="w")
Expand All @@ -21,46 +20,52 @@ def setup_logging(log_filename: str) -> None:
logging.basicConfig(level=logging.DEBUG, handlers=[file_handler, stream_handler])


def convert_text(args: argparse.Namespace) -> None:
"""Convert text input trace to Chakra execution trace."""
converter = TextConverter(args.input_trace, args.output_trace, args.num_npus, args.num_passes)
converter.convert()


def convert_pytorch(args: argparse.Namespace) -> None:
"""Convert PyTorch input trace to Chakra execution trace."""
converter = PyTorchConverter(args.input_trace, args.output_trace, simulate=args.simulate)
converter.convert()


def main() -> None:
parser = argparse.ArgumentParser(description="Execution Trace Converter")
parser.add_argument("--input_type", type=str, default=None, required=True, help="Input execution trace type")
parser.add_argument(
"--input_filename", type=str, default=None, required=True, help="Input execution trace filename"
)
parser.add_argument(
"--output_filename", type=str, default=None, required=True, help="Output Chakra execution trace filename"
)
parser.add_argument(
"--num_npus", type=int, default=None, required="Text" in sys.argv, help="Number of NPUs in a system"
"""Convert to Chakra execution trace in the protobuf format."""
parser = argparse.ArgumentParser(
description=(
"Chakra execution trace converter. This converter takes an input file in another format and generates "
"a Chakra execution trace output in the protobuf format. This converter is designed for any downstream "
"simulators that take Chakra execution traces in the protobuf format."
)
)
parser.add_argument(
"--num_passes", type=int, default=None, required="Text" in sys.argv, help="Number of training passes"

subparsers = parser.add_subparsers(title="subcommands", description="Valid subcommands", help="Input type")

text_parser = subparsers.add_parser("Text", help="Convert Text trace")
text_parser.add_argument("--input-trace", type=str, required=True, help="Input execution trace filename")
text_parser.add_argument("--output-trace", type=str, required=True, help="Output Chakra execution trace filename")
text_parser.add_argument("--num-npus", type=int, required=True, help="Number of NPUs in a system")
text_parser.add_argument("--num-passes", type=int, required=True, help="Number of training passes")
text_parser.add_argument("--log-filename", type=str, default="debug.log", help="Log filename")
text_parser.set_defaults(func=convert_text)

pytorch_parser = subparsers.add_parser("PyTorch", help="Convert PyTorch trace")
pytorch_parser.add_argument("--input-trace", type=str, required=True, help="Input execution trace filename")
pytorch_parser.add_argument(
"--output-trace", type=str, required=True, help="Output Chakra execution trace filename"
)
parser.add_argument("--simulate", action="store_true", help="Run simulate_execution if set")
parser.add_argument("--log_filename", type=str, default="debug.log", help="Log filename")
pytorch_parser.add_argument("--simulate", action="store_true", help="Run simulate_execution if set")
pytorch_parser.add_argument("--log-filename", type=str, default="debug.log", help="Log filename")
pytorch_parser.set_defaults(func=convert_pytorch)

args = parser.parse_args()

setup_logging(args.log_filename)
logging.debug(" ".join(sys.argv))

try:
if args.input_type == "Text":
converter = TextConverter(args.input_filename, args.output_filename, args.num_npus, args.num_passes)
converter.convert()
elif args.input_type == "PyTorch":
converter = PyTorchConverter(args.input_filename, args.output_filename, simulate=args.simulate)
converter.convert()
else:
supported_types = ["Text", "PyTorch"]
logging.error(
f"The input type '{args.input_type}' is not supported. "
f"Supported types are: {', '.join(supported_types)}."
)
sys.exit(1)
except Exception:
traceback.print_exc()
logging.debug(traceback.format_exc())
sys.exit(1)

args.func(args)


if __name__ == "__main__":
Expand Down
8 changes: 2 additions & 6 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,8 @@ def load_pytorch_execution_traces(self) -> Dict:
Dict: The loaded PyTorch execution trace data.
"""
logging.info("Loading PyTorch execution traces from file.")
try:
with open(self.input_filename, "r") as pytorch_et:
return json.load(pytorch_et)
except IOError as e:
logging.error(f"Error opening file {self.input_filename}: {e}")
raise Exception(f"Could not open file {self.input_filename}") from e
with open(self.input_filename, "r") as pytorch_et:
return json.load(pytorch_et)

def _parse_and_instantiate_nodes(
self, pytorch_et_data: Dict
Expand Down

0 comments on commit 725d591

Please sign in to comment.