From 3e0e6eebe6ae6a8861dc5d653b52b2ac89c2cbc5 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 30 Jan 2025 05:51:58 -0800 Subject: [PATCH 1/2] Add compile time Kineto trace --- .../components/compile_time/__init__.py | 6 +- tritonbench/components/compile_time/trace.py | 55 ++++++++++++++++++- tritonbench/utils/triton_op.py | 48 +++++++++++++++- 3 files changed, 104 insertions(+), 5 deletions(-) diff --git a/tritonbench/components/compile_time/__init__.py b/tritonbench/components/compile_time/__init__.py index 50c8f45f..51e7dfdc 100644 --- a/tritonbench/components/compile_time/__init__.py +++ b/tritonbench/components/compile_time/__init__.py @@ -1 +1,5 @@ -from .trace import do_compile_time_in_task, fbcode_do_compile_time_in_task # noqa F401 +from .trace import ( # noqa F401 + do_compile_kineto_trace_in_task, + do_compile_time_in_task, + fbcode_do_compile_time_in_task, +) diff --git a/tritonbench/components/compile_time/trace.py b/tritonbench/components/compile_time/trace.py index fbb1d133..2a4110b2 100644 --- a/tritonbench/components/compile_time/trace.py +++ b/tritonbench/components/compile_time/trace.py @@ -1,11 +1,26 @@ -from typing import Callable, Dict +import random +import string +from datetime import datetime +from functools import partial +from typing import Callable, Dict, Optional import torch +import torch.profiler as profiler from tritonbench.utils.env_utils import fresh_triton_cache, is_fbcode if is_fbcode(): from triton.fb.triton_util import triton_add_listener, TritonHook + from .fb.run_utils import trace_handler + +DEFAULT_PROFILE_OPTS = { + "record_shapes": True, + "profile_memory": True, + "with_stack": True, + "with_flops": True, + "with_modules": True, +} + def fbcode_do_compile_time_in_task(fn: Callable) -> Dict[str, float]: # not yet getting results that make sense to me @@ -37,3 +52,41 @@ def do_compile_time_in_task(fn: Callable) -> float: torch.cuda.synchronize() # Wait for the events to be recorded! latency_with_compile = start_event.elapsed_time(end_event) return latency_with_compile + + +def do_compile_kineto_trace_in_task( + fn: Callable, + profile_opts: Optional[Dict[str, bool]] = None, + output_dir: Optional[str] = None, +) -> Optional[str]: + """Profile compilation stage using Kineto.""" + activity_groups = [ + profiler.ProfilerActivity.CUDA, + profiler.ProfilerActivity.CPU, + ] + if not profile_opts: + profile_opts = DEFAULT_PROFILE_OPTS + prefix = f"tritonbench_{fn._name}" + name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json" + with fresh_triton_cache(): + with profiler.profile( + schedule=profiler.schedule(wait=0, warmup=0, active=1, repeat=1), + activities=activity_groups, + record_shapes=profile_opts["record_shapes"], + profile_memory=profile_opts["profile_memory"], + with_stack=profile_opts["with_stack"], + with_flops=profile_opts["with_flops"], + with_modules=profile_opts["with_modules"], + on_trace_ready=( + partial(trace_handler, name) + if is_fbcode() + else profiler.tensorboard_trace_handler(output_dir) + ), + ) as prof: + fn() + prof.step() + print(f"output dir: {output_dir}") + if not hasattr(torch.version, "git_version"): + return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces" + else: + return output_dir diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 1c1abd8d..5dc6d045 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -199,6 +199,8 @@ class BenchmarkOperatorMetrics: compile_time: Optional[float] = None # stage breakdown of compile times compile_time_by_stage: Optional[Dict[str, float]] = None + # compile time with kineto trace + compile_trace: Optional[str] = None # ncu trace file ncu_trace: Optional[str] = None # ncu replay file @@ -1145,6 +1147,10 @@ def _init_extra_metrics() -> Dict[str, Any]: metrics.compile_time = compile_time if compile_time_by_stage: metrics.compile_time_by_stage = compile_time_by_stage + if "compile_trace" in self.required_metrics: + metrics.compile_trace = self.compile_time( + input_id, fn_name, metrics, kineto_trace=True + ) if "ncu_trace" in self.required_metrics: metrics.ncu_trace = self.ncu_trace(input_id, fn_name) # Collect NCU metrics if any required metrics match the ncu analyzer @@ -1236,6 +1242,29 @@ def _init_extra_metrics() -> Dict[str, Any]: metrics.all_configs = self.all_configs(fn) if "kernel_source_hash" in self.required_metrics: metrics.kernel_source_hash = self.kernel_hash(fn) + if "_compile_time_kineto_trace_in_task" in self.required_metrics: + assert ( + self.required_metrics == ["_compile_time_kineto_trace_in_task"] + and len(self._only) == 1 + and (self._input_id is not None) + ), ( + "_compile_time_kineto_trace_in_task must be measured by itself. " + f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}" + ) + from tritonbench.components.compile_time import ( + do_compile_kineto_trace_in_task, + ) + + kineto_trace_output_dir = self.get_temp_path("kineto_trace") + kineto_trace_output_dir.mkdir(parents=True, exist_ok=True) + metrics.extra_metrics["_compile_time_kineto_trace"] = ( + do_compile_kineto_trace_in_task( + fn, output_dir=str(kineto_trace_output_dir) + ) + ) + self._compile_time_kineto_trace = metrics.extra_metrics[ + "_compile_time_kineto_trace" + ] if "_compile_time_in_task" in self.required_metrics: assert ( self.required_metrics == ["_compile_time_in_task"] @@ -1591,8 +1620,12 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str: ) def compile_time( - self, input_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics - ) -> float: + self, + input_id: int, + fn_name: str, + metrics: BenchmarkOperatorMetrics, + kineto_trace: bool = False, + ) -> Union[float, str]: # We need to spawn a subprocess when user wants to measure the compile time # of multiple sample inputs and backends. from tritonbench.operators.op_task import OpTask @@ -1611,12 +1644,21 @@ def compile_time( "--input-id", str(input_id), "--metrics", - "_compile_time_in_task", + ( + "_compile_time_in_task" + if not kineto_trace + else "_compile_time_kineto_trace_in_task" + ), ] ) op_task = OpTask(name=self.name) op_task.make_operator_instance(args=op_task_args) op_task.run() + if kineto_trace: + kineto_trace_loc = op_task.get_attribute("_compile_time_kineto_trace") + print(f"kineto_trace_loc: {kineto_trace_loc}") + del op_task + return kineto_trace_loc if op_task.get_attribute("triton_hook_latency") is not None: compiled_time = op_task.get_attribute("triton_hook_latency") compile_time_by_stage = op_task.get_attribute("compile_time_by_stage") From 30124b0014ab63c828822fa54b0ad3af03404443 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 30 Jan 2025 05:56:26 -0800 Subject: [PATCH 2/2] Add compile time Kineto trace --- tritonbench/utils/triton_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 5dc6d045..e259ee41 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -1656,7 +1656,6 @@ def compile_time( op_task.run() if kineto_trace: kineto_trace_loc = op_task.get_attribute("_compile_time_kineto_trace") - print(f"kineto_trace_loc: {kineto_trace_loc}") del op_task return kineto_trace_loc if op_task.get_attribute("triton_hook_latency") is not None: