From 410aae07cb6497be69e9ce1b9253cae73a1a1386 Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Wed, 22 Jan 2025 16:39:47 -0800 Subject: [PATCH] add support to metrics calculation (#196) Summary: Add support to metrics calculation. 1. Iteration E2E time 2. bandwidth This is the copy of https://github.com/facebookresearch/param/pull/195 for importing it into Meta. Differential Revision: D68038126 Pulled By: shengfukevin --- .../comm/backend/pytorch_dist_backend.py | 9 +- et_replay/comm/comms_utils.py | 73 +++- et_replay/comm/profiler_trace_analysis.py | 336 ++++++++++++++++++ et_replay/pyproject.toml | 4 + et_replay/tools/comm_replay.py | 139 +++++--- 5 files changed, 490 insertions(+), 71 deletions(-) create mode 100644 et_replay/comm/profiler_trace_analysis.py diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 1ac42bf8..6847005a 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -623,6 +623,13 @@ def barrier(self, collectiveArgs, name="dummy", retFlag=False): if retFlag: return retObj + def barrier_all_ranks(self): + dist.barrier( + device_ids=[self.get_device().index] + if dist.get_backend() == "nccl" + else None + ) + def sync_barrier(self, collectiveArgs, desc="dummy"): # ensure all streams have finished outstanding events before calling barrier self.complete_accel_ops(collectiveArgs) @@ -1031,7 +1038,7 @@ def initialize_groups(self, backend="gloo"): # even if they are not going to be members of the group. sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store) sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks)) - torch.distributed.barrier() + self.barrier_all_ranks() idxed_group_ranks_to_pgId: dict[tuple[int], list[int]] = defaultdict(list) for i in range(self.get_world_size()): diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 38c82533..2681a06b 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -19,11 +19,12 @@ from io import StringIO from typing import Any +from torch._C._profiler import ProfilerActivity + try: - from param_bench.train.comms.pt.fb.internals import ( - fbInitProfiler, - fbSampleProfiler, - fbStartProfiler, + from et_replay.fb.internals import ( + get_fb_profiler_activities, + get_fb_profiler_trace_handler, initialize_collectiveArgs_internal, remove_quantization_handlers, ) @@ -390,45 +391,79 @@ def ensureTensorFlush(tensors: list[torch.Tensor] | torch.Tensor) -> Any: return x -def startProfiler(rank: int, device: str, numWarmupIters: int, numIters: int) -> bool: +_torch_profiler = None + + +def startProfiler( + rank: int, device: str, numWarmupIters: int, numIters: int, output_path: str +): """ - Starts internal profiler with given parameters. + Starts pytorch profiler with given parameters. Args: rank: Global rank. device: Type of device "cuda", "cpu", etc. numWarmupIters: Number of warmup iterations. numIters: Number of real iterations. - Returns: - bool: Returns if internal profile was able to start or not. + output_path: Path to save profiler trace. """ + global _torch_profiler + if has_internal_libs: - fbInitProfiler( - rank=rank, - device=device, + activities = get_fb_profiler_activities() + trace_handler = get_fb_profiler_trace_handler() + else: + activities = ([ProfilerActivity.CPU, ProfilerActivity.CUDA],) + + def trace_handler(p): + import pathlib + + folder_path = os.path.join(output_path, "profiler_trace") + try: + pathlib.Path(folder_path).mkdir(parents=True, exist_ok=True) + except PermissionError: + logger.error(f"Permission denied to create directory {folder_path}") + p.export_chrome_trace(os.path.join(folder_path, f"rank-{rank}.json")) + + logger.debug("GPU Trace Collection: Enabled") + _torch_profiler = torch.profiler.profile( + schedule=torch.profiler.schedule( + wait=1, warmup=numWarmupIters, - iters=numIters, - ) - fbStartProfiler() + active=numIters, + repeat=1, + ), + on_trace_ready=trace_handler, + activities=activities, + ) + + if _torch_profiler: + logger.debug("GPU Trace Profiler: Start") + _torch_profiler.start() return True else: - logger.debug("Internal profiler is not available, skip...") + logger.debug("GPU Trace Profiler: Fail to start") return False def sampleProfiler(stop: bool = False) -> None: """ - Starts internal sample profiler. + Starts sample profiler. Args: stop: Bool to be passed into sample profiler. Returns: None """ - if has_internal_libs: - fbSampleProfiler(stop) + global _torch_profiler + if _torch_profiler: + _torch_profiler.step() + if stop: + _torch_profiler.stop() + _torch_profiler = None + logger.debug("GPU Trace Profiler: Stop") else: - logger.debug("Internal profiler is not available, skip...") + logger.debug("GPU Trace Profiler: not enabled") class commsArgs: diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py new file mode 100644 index 00000000..7d0ad2f3 --- /dev/null +++ b/et_replay/comm/profiler_trace_analysis.py @@ -0,0 +1,336 @@ +import ast +import json +import logging +import os +import pathlib +from collections import defaultdict +from typing import Any, Callable, Dict + +import numpy as np +from intervaltree import Interval, IntervalTree + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# refer to: +# https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/c10/core/ScalarType.h#L61 +_dtype_size_map: Dict[str, int] = { + "Byte": 1, + "Char": 1, + "Short": 2, + "Int": 4, + "Long": 8, + "Half": 2, + "Float": 4, + "Double": 8, + "ComplexHalf": 4, + "ComplexFloat": 8, + "ComplexDouble": 16, + "Bool": 1, + "QInt8": 1, + "QUInt8": 1, + "QInt32": 4, + "BFloat16": 2, + "QUInt4x2": 1, + "QUInt2x4": 1, + "Bits1x8": 1, + "Bits2x4": 1, + "Bits4x2": 1, + "Bits8": 1, + "Bits16": 2, + "Float8_e5m2": 1, + "Float8_e4m3fn": 1, + "Float8_e5m2fnuz": 1, + "Float8_e4m3fnuz": 1, + "UInt16": 2, + "UInt32": 4, + "UInt64": 8, +} + +# refer to: https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md +_busbw_correction_factors_tbl: Dict[str, Callable[[int], float]] = { + "all_reduce": (lambda n: 2 * (n - 1) / n), + "all_gather": (lambda n: (n - 1) / n), + "all_to_all": (lambda n: (n - 1) / n), + "reduce_scatter": (lambda n: (n - 1) / n), + "reduce": (lambda n: 1), + "scatter": (lambda n: (n - 1) / n), + "gather": (lambda n: (n - 1) / n), + "broadcast": (lambda n: 1), + "send": (lambda n: 1), + "recv": (lambda n: 1), +} + +# map collective name of event to key string for bw calculation +_collname_to_busbw_corr_factor_func: Dict[str, Callable[[int], float]] = { + "allreduce": _busbw_correction_factors_tbl["all_reduce"], + "all_gather": _busbw_correction_factors_tbl["all_gather"], + "_allgather_base": _busbw_correction_factors_tbl["all_gather"], + "reduce_scatter": _busbw_correction_factors_tbl["reduce_scatter"], + "_reduce_scatter_base": _busbw_correction_factors_tbl["reduce_scatter"], + "all_to_all": _busbw_correction_factors_tbl["all_to_all"], + "all_to_allv": _busbw_correction_factors_tbl["all_to_all"], + "broadcast": _busbw_correction_factors_tbl["broadcast"], + "reduce": _busbw_correction_factors_tbl["reduce"], + "gather": _busbw_correction_factors_tbl["gather"], + "scatter": _busbw_correction_factors_tbl["scatter"], + "send": _busbw_correction_factors_tbl["send"], + "recv": _busbw_correction_factors_tbl["recv"], +} + + +def _get_dict_value(d, k, err_msg): + if k not in d: + raise ValueError(err_msg) + return d.get(k) + + +def _calculate_event_data_size(evt): + return ( + max(evt["args"]["In msg nelems"], evt["args"]["Out msg nelems"]) + * _dtype_size_map[evt["args"]["dtype"]] + ) + + +def _calculate_algbw(evt: Dict[str, Any]) -> float: + duration_us = _get_dict_value(evt, "dur", f'Missing "dur" in event: {evt}') + total_bytes = _calculate_event_data_size(evt) + + # NCCL tests use 1024^3 to convert B to GB (but not 1e9) + # https://github.com/NVIDIA/nccl-tests/blob/8dfeab9eb9bdfdf13503e71e1f33e7f8a208b540/src/common.cu#L102 + # but it uses 1e9 to convert bw from B/s to GB/s + # https://github.com/NVIDIA/nccl-tests/blob/8dfeab9eb9bdfdf13503e71e1f33e7f8a208b540/src/all_gather.cu#L41 + return round((total_bytes / duration_us) / 1e3, 2) + + +def _get_event_busbw_factor(evt): + coll_name = _get_dict_value( + evt["args"], "Collective name", f'Missing "Collective name" in event: {evt}' + ) + + # barrier is implemented using AllReduce + if coll_name in [ + "barrier", + ]: + return 0 + + group_size = _get_dict_value( + evt["args"], "Group size", f'Missing "Group size" in event: {evt}' + ) + correction_factor_func = _get_dict_value( + _collname_to_busbw_corr_factor_func, + coll_name, + f"Unsupported collective op for busbw calculation: {coll_name}", + ) + + return correction_factor_func(group_size) + + +def calculate_bw_(trace_data): + nccl_events = [ + i + for i in trace_data["traceEvents"] + if i.get("cat", "") == "kernel" and i["name"].startswith("ncclDevKernel_") + ] + for evt in nccl_events: + try: + coll_name = _get_dict_value( + evt["args"], + "Collective name", + f'Missing "Collective name" in event: {evt}', + ) + + # barrier is implemented using AllReduce + if coll_name in [ + "barrier", + ]: + continue + + algbw = _calculate_algbw(evt) + busbw_factor = _get_event_busbw_factor(evt) + busbw = round(algbw * busbw_factor, 2) + + evt["args"]["algbw (GB/sec)"] = algbw + evt["args"]["busbw (GB/sec)"] = busbw + evt["args"]["busbw_factor"] = busbw_factor + except ValueError as e: + logger.error("Error processing event: %s", e) + + +def calculate_sbw(trace_data): + # calculate shared bw per rank + nccl_events = [ + i + for i in trace_data["traceEvents"] + if i.get("cat", "") == "kernel" + and i["name"].startswith("ncclDevKernel_") + and "busbw_factor" in i["args"] + ] + + if not len(nccl_events): + return 0 + + total_data_size = sum( + [ + _calculate_event_data_size(evt) * _get_event_busbw_factor(evt) + for evt in nccl_events + ] + ) + + time_range_tree = IntervalTree( + [Interval(evt["ts"], evt["ts"] + evt["dur"]) for evt in nccl_events] + ) + time_range_tree.merge_overlaps() + + begin_time_point = min([i.begin for i in time_range_tree]) + end_time_point = max([i.end for i in time_range_tree]) + + sorted_tr = sorted(time_range_tree) + total_idle_time = ( + sum( + [ + sorted_tr[i + 1].begin - sorted_tr[i].end + for i in range(len(sorted_tr) - 1) + ] + ) + if len(sorted_tr) > 1 + else 0 + ) + + return total_data_size / (end_time_point - begin_time_point - total_idle_time) / 1e3 + + +def pick_iter_e2e_time_(trace_data, tl): + tl.extend( + [ + evt["dur"] + for evt in trace_data["traceEvents"] + if evt.get("cat", "") == "user_annotation" + and evt["name"].startswith("ProfilerStep#") + ] + ) + + +def pick_comm_bw_(trace_data, comm_bw_data): + rank = trace_data["distributedInfo"]["rank"] + nccl_events = [ + i + for i in trace_data["traceEvents"] + if i.get("cat", "") == "kernel" + and i["name"].startswith("ncclDevKernel_") + and "algbw (GB/sec)" in i["args"] + ] + for evt in nccl_events: + knl_name = evt["name"][: evt["name"].index("(")] + data_size = _calculate_event_data_size(evt) + ranks_count = evt["args"]["Group size"] + + ranks = ast.literal_eval(evt["args"]["Process Group Ranks"]) + pg_id = int(evt["args"]["Process Group Name"]) + pg = tuple([*ranks, pg_id]) if rank == min(ranks) else None + + comm_bw_data[(knl_name, data_size, ranks_count)].append( + [ + evt["dur"], + evt["args"]["algbw (GB/sec)"], + evt["args"]["busbw (GB/sec)"], + pg, + ] + ) + + +def analyze_profiler_trace(trace_dir: str, report_dir: str): + """ + Analyse input PyTorch profiler trace (i.e. Kineto trace) and generate report. + + Args: + trace_dir (str): dir path of input traces, where trace name should be in "rank-n.json" format. + report_dir (str): dir path for generated reports + """ + logger.info( + f'Parse profiler trace from "{trace_dir}" and generate reports to "{report_dir}"' + ) + + processed_trace_dir = os.path.join(report_dir, "profiler_trace_processed") + pathlib.Path(processed_trace_dir).mkdir(parents=True, exist_ok=True) + + # list of iteration time in all ranks + iter_e2e_time = [] + + # list of shared bw + sbw_lst = [] + + # key is (kernel_name, data size, ranks number) + # value is list of [dur, algbw, busbw, pg] + comm_bw_data = defaultdict(list) + + for fpath in os.scandir(trace_dir): + if not fpath.is_file(): + continue + + with open(fpath.path, "r", encoding="utf-8") as f: + trace = json.load(f) + + calculate_bw_(trace) + with open( + os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8" + ) as f: + json.dump(trace, f) + + sbw_lst.append(calculate_sbw(trace)) + + pick_iter_e2e_time_(trace, iter_e2e_time) + pick_comm_bw_(trace, comm_bw_data) + + comm_bw_summary = {} + for k, v in comm_bw_data.items(): + t_lst = [i[0] for i in v] + busbw_lst = [i[2] for i in v] + pg_set = set([i[3] for i in v if i[3]]) + comm_bw_summary[k] = [ + len(pg_set), + np.average(t_lst), + np.average(busbw_lst), + np.percentile(busbw_lst, 1), + np.percentile(busbw_lst, 50), + np.percentile(busbw_lst, 90), + np.percentile(busbw_lst, 99), + ] + comm_bw_summary = dict(sorted(comm_bw_summary.items())) + + # dump summary report + with open( + os.path.join(report_dir, "profiler_trace_summary_report.txt"), + "w", + encoding="utf-8", + ) as f: + f.write( + f"avg. E2ETime of iters among all ranks: {sum(iter_e2e_time) / len(iter_e2e_time) / 1e3 :.3f} ms\n" + ) + f.write( + f"avg. SharedBW (i.e. sum(data_size * busbw_factor) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n" + ) + + f.write( + f'\n{" ":>70s}|{" ":>5s}|{"AVG.":^19s}|{"p01":^8s}|{"p50":^8s}|{"p90":^8s}|{"p99":^8s}|\n' + ) + + f.write( + f'{"kernel":>50s} {"size":>12s} {"#rks":>6s}|{"#pgs":>5s}|{" dur":>10s} ' + ) + for i in range(5): # average, p01, p50, p90, p99 + f.write(f'{" busbw":>8s}|') + f.write("\n") + + f.write( + f'{" ":>50s} {" (B)":>12s} {" ":>6s}|{" ":>5s}|{" (ms)":>10s} ' + ) + for i in range(5): # average, p50, p90, p99 + f.write(f'{"(GB/s)":>8s}|') + f.write("\n") + + for k, v in comm_bw_summary.items(): + f.write(f"{k[0]:>50s} {k[1]:>12d} {k[2]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} ") + for i in range(2, len(v)): + f.write(f"{v[i]:>8.2f}|") + f.write("\n") diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index e5ee8cac..19dbf0ea 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -5,6 +5,10 @@ build-backend = "setuptools.build_meta" [project] name = "et_replay" version = "0.5.0" +dependencies = [ + "numpy", + "intervaltree", +] [tool.setuptools.package-dir] "et_replay" = "." diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 5f258fa9..ba474343 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -15,7 +15,7 @@ import numpy as np import torch -from et_replay.comm import comms_utils, commsTraceParser +from et_replay.comm import comms_utils, commsTraceParser, profiler_trace_analysis from et_replay.comm.backend.base_backend import supportedP2pOps from et_replay.comm.comms_utils import ( bootstrap_info_holder, @@ -26,12 +26,7 @@ paramToCommName, ) from et_replay.comm.param_profile import paramProfile, paramTimer - -try: - # pyre-ignore[21]: - from trainer_iteration_wrapper import setTrainingIteration -except ImportError: - pass +from et_replay.vendor_internal import fb_internal logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -77,14 +72,12 @@ def writeCommDetails(commsTracePerf: list, rank: int, folder: str = "./") -> Non if saveToLocal: try: - import subprocess + import pathlib + + pathlib.Path(folder).mkdir(parents=True, exist_ok=True) + except PermissionError: + logger.error(f"Permission denied to create directory {folder}") - subprocess.check_output(["mkdir", "-p", str(folder)], text=True) - except Exception as err: - logger.error( - "\t Error: {} while creating directory: {} ".format(err, folder) - ) - pass with open(comms_file, "w") as write_file: json.dump(commsTracePerf, write_file, indent=2) @@ -113,7 +106,7 @@ def __init__(self): self.max_msg_cnt = 0 # 0 means no limit self.num_msg = 0 self.is_blocking = False - self.warmup_iter = 5 + self.do_warm_up = False self.reuse_tensors = False self.allowList = "" @@ -214,10 +207,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: help="Only replay first N operations (0 means no limit)", ) parser.add_argument( - "--warmup-iter", - type=int, - default=self.warmup_iter, - help="Number of warmup iterations", + "--do-warm-up", + action="store_true", + default=self.do_warm_up, + help="Toggle to perform extra replaying for warm-up", ) parser.add_argument( "--reuse-tensors", @@ -232,20 +225,6 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: default="all", help="List of desired collectives (separate by comma) to be replayed, e.g., `--allow-ops all_reduce,all_to_allv,wait`, typo or not supported collectives will be ignored.", ) - parser.add_argument( - "--output-path", - type=str, - default=self.out_path, - nargs="?", - const="", - help='Output path to write the replayed trace for post performance analysis. Set as empty string, i.e., "", to skip output', - ) - parser.add_argument( - "--output-ranks", - type=str, - default="all", - help="List of ranks separated by comma or a range specified by start:end to generate replayed trace for post performance analysis. Default including all ranks.", - ) parser.add_argument( "--colls-per-batch", type=int, @@ -270,18 +249,36 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: default=self.num_replays, help="Number of times to replay the given trace, used to get more accurate replay for small traces.", ) + + parser.add_argument( + "--output-path", + type=str, + default=self.out_path, + nargs="?", + const="", + help="Path to store generated results (e.g., replayed trace, profiler trace) for post performance analysis. (Default: %(default)s)", + ) + + parser.add_argument( + "--output-ranks", + type=str, + default=None, + help="List of ranks separated by comma (e.g. 1,2,3) OR a range specified by start:end (e.g., 1:3) to enable replayed trace dumping for post performance analysis. (Default: %(default)s)", + ) + parser.add_argument( "--profiler-num-replays-start", type=int, default=self.profiler_num_replays_start, - help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True", + help="Index of replay iteration to start collecting profiler trace after warmup in all ranks. (Default: %(default)s)", ) parser.add_argument( "--profiler-num-replays", type=int, default=self.profiler_num_replays, - help=f"Number of replay iterations to collect profiler. Default profile {self.profiler_num_replays} replays if --enables-profiler is True.", + help="Number of replay iterations to collect profiler trace in all ranks. (Default: %(default)s)", ) + args, _ = parser.parse_known_args() return args @@ -305,17 +302,30 @@ def checkArgs(self, args: argparse.Namespace) -> None: f"The specified trace path '{self.trace_file}' is neither a " "file nor a directory. Please provide a valid path." ) - comms_utils.gracefulExit() + if args.disable_parallel_read and not args.use_one_trace: raise ValueError( "--disable-parallel-read is valid only when --use-one-trace is used." ) - comms_utils.gracefulExit() + if args.trace_type not in VALID_TRACE_TYPES: raise ValueError( f"Trace type {self.trace_type} is not valid! Please specify one supported trace type from {str(VALID_TRACE_TYPES)} by using --trace-type." ) - comms_utils.gracefulExit() + """ + if ( + args.output_ranks is not None + and len(args.output_ranks) > 0 + and not len(args.output_path) + ): + raise ValueError('"--output-path" is not set for replay trace dumping') + + if ( + args.enable_profiler + and not len(args.output_path) + ): + raise ValueError('"--output-path" is not set for profiler trace dumping') + """ def reportBenchTime(self): """ @@ -392,9 +402,11 @@ def reportBenchTime(self): if not self.is_dry_run: print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20)) print( - "{}\n Total latency (us) of comms in trace: {}. \n{}".format( + "{}\nE2E latency (us): {} for {} iters, {:10.2f} per iter in avg\n{}".format( "-" * 50, self.totalTraceLatency, + self.num_replays, + self.totalTraceLatency / self.num_replays, "-" * 50, ) ) @@ -632,9 +644,8 @@ def generate_io_tensors( return super().prepComm(curComm, commsParams) else: commsOpHash = self.hashEtCommsOp(curComm) + # Allocate input/output tensors if first time replay, otherwise reuse the previous ones. if commsOpHash in self.et_to_tensors: - # Allocate input/output tensors if first time replay, otherwise the previous ones. - super().prepComm(curComm, commsParams, False) (ipTensor, opTensor) = self.et_to_tensors[commsOpHash] else: (ipTensor, opTensor) = super().prepComm(curComm, commsParams, True) @@ -1217,8 +1228,13 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: device=self.collectiveArgs.device, numWarmupIters=numWarmupIters, numIters=numProfileIters, + output_path=self.out_path, ) + # sync everything before starting real runs + with paramProfile(description="# PARAM replay warmup post-replay global sync"): + self.backendFuncs.sync_barrier(self.collectiveArgs) + if self.backendFuncs.get_global_rank() == 0: logger.info( f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}" @@ -1226,6 +1242,22 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: for coll, sizes in self.collInMsgBytes.items(): logger.info(f"\t{coll}: {len(sizes)}") + traceStartTime = time.monotonic_ns() + + def trace_handler(p): + import pathlib + + folder_path = os.path.join(self.out_path, "profiler_trace") + try: + pathlib.Path(folder_path).mkdir(parents=True, exist_ok=True) + except PermissionError: + logger.error(f"Permission denied to create directory {folder_path}") + p.export_chrome_trace( + os.path.join( + folder_path, f"rank-{self.backendFuncs.get_global_rank()}.json" + ) + ) + traceStartTime = 0 for i in range(self.warmup_iter + self.num_replays): if i == self.warmup_iter: @@ -1234,12 +1266,6 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: if self.collectiveArgs.enable_profiler: comms_utils.sampleProfiler() - # set training iteration number in NCCL - try: - setTrainingIteration(i + 1) - except NameError: - pass - if self.backendFuncs.get_global_rank() == 0: s = time.monotonic_ns() @@ -1270,6 +1296,8 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: # cleanup any memory left in use self.backendFuncs.clear_memory(self.collectiveArgs) + self.backendFuncs.barrier_all_ranks() + def runBench( self, commsParams: commsParamsHolderBase, @@ -1322,12 +1350,21 @@ def runBench( if self.backendFuncs.get_global_rank() in self.outputRanks: writeCommDetails( self.traceWithPerf, - folder=self.out_path, + folder=os.path.join(self.out_path, "replayed_trace"), rank=global_rank, ) # TODO: collect perf. from all ranks to rank 0 and detect any imbalanced perf? - self.backendFuncs.barrier(self.collectiveArgs) - self.backendFuncs.complete_accel_ops(self.collectiveArgs) + + if ( + commsParams.enable_profiler + and not fb_internal.has_fb_internal_libs + and self.backendFuncs.get_global_rank() == 0 + ): + profiler_trace_analysis.analyze_profiler_trace( + os.path.join(self.out_path, "profiler_trace"), self.out_path + ) + + self.backendFuncs.barrier_all_ranks() def replayInit( self, @@ -1456,7 +1493,7 @@ def initBench( self.shrink = args.auto_shrink self.max_msg_cnt = args.max_msg_cnt self.is_blocking = args.blocking - self.warmup_iter = args.warmup_iter + self.do_warm_up = args.do_warm_up self.reuse_tensors = args.reuse_tensors self.allowList = args.allow_ops if args.output_ranks == "all":