Skip to content

Commit

Permalink
add support to metrics calculation (#196)
Browse files Browse the repository at this point in the history
Summary:
Add support to metrics calculation.

1. Iteration E2E time
2. bandwidth

This is the copy of #195 for importing it into Meta.


Differential Revision: D68038126

Pulled By: shengfukevin
  • Loading branch information
sanshang-nv authored and facebook-github-bot committed Jan 23, 2025
1 parent c5f8d06 commit 410aae0
Show file tree
Hide file tree
Showing 5 changed files with 490 additions and 71 deletions.
9 changes: 8 additions & 1 deletion et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,13 @@ def barrier(self, collectiveArgs, name="dummy", retFlag=False):
if retFlag:
return retObj

def barrier_all_ranks(self):
dist.barrier(
device_ids=[self.get_device().index]
if dist.get_backend() == "nccl"
else None
)

def sync_barrier(self, collectiveArgs, desc="dummy"):
# ensure all streams have finished outstanding events before calling barrier
self.complete_accel_ops(collectiveArgs)
Expand Down Expand Up @@ -1031,7 +1038,7 @@ def initialize_groups(self, backend="gloo"):
# even if they are not going to be members of the group.
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()
self.barrier_all_ranks()

idxed_group_ranks_to_pgId: dict[tuple[int], list[int]] = defaultdict(list)
for i in range(self.get_world_size()):
Expand Down
73 changes: 54 additions & 19 deletions et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from io import StringIO
from typing import Any

from torch._C._profiler import ProfilerActivity

try:
from param_bench.train.comms.pt.fb.internals import (
fbInitProfiler,
fbSampleProfiler,
fbStartProfiler,
from et_replay.fb.internals import (
get_fb_profiler_activities,
get_fb_profiler_trace_handler,
initialize_collectiveArgs_internal,
remove_quantization_handlers,
)
Expand Down Expand Up @@ -390,45 +391,79 @@ def ensureTensorFlush(tensors: list[torch.Tensor] | torch.Tensor) -> Any:
return x


def startProfiler(rank: int, device: str, numWarmupIters: int, numIters: int) -> bool:
_torch_profiler = None


def startProfiler(
rank: int, device: str, numWarmupIters: int, numIters: int, output_path: str
):
"""
Starts internal profiler with given parameters.
Starts pytorch profiler with given parameters.
Args:
rank: Global rank.
device: Type of device "cuda", "cpu", etc.
numWarmupIters: Number of warmup iterations.
numIters: Number of real iterations.
Returns:
bool: Returns if internal profile was able to start or not.
output_path: Path to save profiler trace.
"""
global _torch_profiler

if has_internal_libs:
fbInitProfiler(
rank=rank,
device=device,
activities = get_fb_profiler_activities()
trace_handler = get_fb_profiler_trace_handler()
else:
activities = ([ProfilerActivity.CPU, ProfilerActivity.CUDA],)

def trace_handler(p):
import pathlib

folder_path = os.path.join(output_path, "profiler_trace")
try:
pathlib.Path(folder_path).mkdir(parents=True, exist_ok=True)
except PermissionError:
logger.error(f"Permission denied to create directory {folder_path}")
p.export_chrome_trace(os.path.join(folder_path, f"rank-{rank}.json"))

logger.debug("GPU Trace Collection: Enabled")
_torch_profiler = torch.profiler.profile(
schedule=torch.profiler.schedule(
wait=1,
warmup=numWarmupIters,
iters=numIters,
)
fbStartProfiler()
active=numIters,
repeat=1,
),
on_trace_ready=trace_handler,
activities=activities,
)

if _torch_profiler:
logger.debug("GPU Trace Profiler: Start")
_torch_profiler.start()
return True
else:
logger.debug("Internal profiler is not available, skip...")
logger.debug("GPU Trace Profiler: Fail to start")
return False


def sampleProfiler(stop: bool = False) -> None:
"""
Starts internal sample profiler.
Starts sample profiler.
Args:
stop: Bool to be passed into sample profiler.
Returns:
None
"""
if has_internal_libs:
fbSampleProfiler(stop)
global _torch_profiler
if _torch_profiler:
_torch_profiler.step()
if stop:
_torch_profiler.stop()
_torch_profiler = None
logger.debug("GPU Trace Profiler: Stop")
else:
logger.debug("Internal profiler is not available, skip...")
logger.debug("GPU Trace Profiler: not enabled")


class commsArgs:
Expand Down
Loading

0 comments on commit 410aae0

Please sign in to comment.