Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compile_trace] Add compile time Kineto trace #148

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tritonbench/components/compile_time/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .trace import do_compile_time_in_task, fbcode_do_compile_time_in_task # noqa F401
from .trace import ( # noqa F401
do_compile_kineto_trace_in_task,
do_compile_time_in_task,
fbcode_do_compile_time_in_task,
)
55 changes: 54 additions & 1 deletion tritonbench/components/compile_time/trace.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from typing import Callable, Dict
import random
import string
from datetime import datetime
from functools import partial
from typing import Callable, Dict, Optional

import torch
import torch.profiler as profiler
from tritonbench.utils.env_utils import fresh_triton_cache, is_fbcode

if is_fbcode():
from triton.fb.triton_util import triton_add_listener, TritonHook

from .fb.run_utils import trace_handler

DEFAULT_PROFILE_OPTS = {
"record_shapes": True,
"profile_memory": True,
"with_stack": True,
"with_flops": True,
"with_modules": True,
}


def fbcode_do_compile_time_in_task(fn: Callable) -> Dict[str, float]:
# not yet getting results that make sense to me
Expand Down Expand Up @@ -37,3 +52,41 @@ def do_compile_time_in_task(fn: Callable) -> float:
torch.cuda.synchronize() # Wait for the events to be recorded!
latency_with_compile = start_event.elapsed_time(end_event)
return latency_with_compile


def do_compile_kineto_trace_in_task(
fn: Callable,
profile_opts: Optional[Dict[str, bool]] = None,
output_dir: Optional[str] = None,
) -> Optional[str]:
"""Profile compilation stage using Kineto."""
activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
if not profile_opts:
profile_opts = DEFAULT_PROFILE_OPTS
prefix = f"tritonbench_{fn._name}"
name = f"{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{''.join(random.choices(string.digits, k=10))}.json"
with fresh_triton_cache():
with profiler.profile(
schedule=profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
activities=activity_groups,
record_shapes=profile_opts["record_shapes"],
profile_memory=profile_opts["profile_memory"],
with_stack=profile_opts["with_stack"],
with_flops=profile_opts["with_flops"],
with_modules=profile_opts["with_modules"],
on_trace_ready=(
partial(trace_handler, name)
if is_fbcode()
else profiler.tensorboard_trace_handler(output_dir)
),
) as prof:
fn()
prof.step()
print(f"output dir: {output_dir}")
if not hasattr(torch.version, "git_version"):
return f"https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/{name}.gz&bucket=pyper_traces"
else:
return output_dir
47 changes: 44 additions & 3 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ class BenchmarkOperatorMetrics:
compile_time: Optional[float] = None
# stage breakdown of compile times
compile_time_by_stage: Optional[Dict[str, float]] = None
# compile time with kineto trace
compile_trace: Optional[str] = None
# ncu trace file
ncu_trace: Optional[str] = None
# ncu replay file
Expand Down Expand Up @@ -1145,6 +1147,10 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.compile_time = compile_time
if compile_time_by_stage:
metrics.compile_time_by_stage = compile_time_by_stage
if "compile_trace" in self.required_metrics:
metrics.compile_trace = self.compile_time(
input_id, fn_name, metrics, kineto_trace=True
)
if "ncu_trace" in self.required_metrics:
metrics.ncu_trace = self.ncu_trace(input_id, fn_name)
# Collect NCU metrics if any required metrics match the ncu analyzer
Expand Down Expand Up @@ -1236,6 +1242,29 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.all_configs = self.all_configs(fn)
if "kernel_source_hash" in self.required_metrics:
metrics.kernel_source_hash = self.kernel_hash(fn)
if "_compile_time_kineto_trace_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_compile_time_kineto_trace_in_task"]
and len(self._only) == 1
and (self._input_id is not None)
), (
"_compile_time_kineto_trace_in_task must be measured by itself. "
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}"
)
from tritonbench.components.compile_time import (
do_compile_kineto_trace_in_task,
)

kineto_trace_output_dir = self.get_temp_path("kineto_trace")
kineto_trace_output_dir.mkdir(parents=True, exist_ok=True)
metrics.extra_metrics["_compile_time_kineto_trace"] = (
do_compile_kineto_trace_in_task(
fn, output_dir=str(kineto_trace_output_dir)
)
)
self._compile_time_kineto_trace = metrics.extra_metrics[
"_compile_time_kineto_trace"
]
if "_compile_time_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_compile_time_in_task"]
Expand Down Expand Up @@ -1591,8 +1620,12 @@ def kineto_trace(self, input_id: int, fn: Callable) -> str:
)

def compile_time(
self, input_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics
) -> float:
self,
input_id: int,
fn_name: str,
metrics: BenchmarkOperatorMetrics,
kineto_trace: bool = False,
) -> Union[float, str]:
# We need to spawn a subprocess when user wants to measure the compile time
# of multiple sample inputs and backends.
from tritonbench.operators.op_task import OpTask
Expand All @@ -1611,12 +1644,20 @@ def compile_time(
"--input-id",
str(input_id),
"--metrics",
"_compile_time_in_task",
(
"_compile_time_in_task"
if not kineto_trace
else "_compile_time_kineto_trace_in_task"
),
]
)
op_task = OpTask(name=self.name)
op_task.make_operator_instance(args=op_task_args)
op_task.run()
if kineto_trace:
kineto_trace_loc = op_task.get_attribute("_compile_time_kineto_trace")
del op_task
return kineto_trace_loc
if op_task.get_attribute("triton_hook_latency") is not None:
compiled_time = op_task.get_attribute("triton_hook_latency")
compile_time_by_stage = op_task.get_attribute("compile_time_by_stage")
Expand Down