Skip to content

Commit

Permalink
Refactor compiler specializations to consider backend (#4734)
Browse files Browse the repository at this point in the history
In this PR I am trying to refactor the specializations that we apply to
the signature of a given function in Triton.

Basically, given a kernel there are some argument properties that can
help compilation. E.g., divisibility by 16 and the fact that an integer
is equal to 1.

In a previous PR: #4716, I
needed other specializations to add buffer support in the AMD backend
(and get back some performance when we were using unaligned masked
loads).

So this is my attempt to redesign the specialization support to
introduce per-backend specializations. The idea is that
`AttrsDescriptor` is now the class that is taking care of doing the
analysis of the parameters and adding the specialization. It also has a
function table where more specializations can be added per-backend.
  • Loading branch information
giuseros authored Oct 2, 2024
1 parent 057a9d3 commit cd1cc2d
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 92 deletions.
6 changes: 3 additions & 3 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def walk_fn(op):
torch.empty((32, 32), device=device), # out_ptr
16, # BLOCK_SIZE
]
target = triton.runtime.driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={
Expand All @@ -69,12 +71,10 @@ def walk_fn(op):
constants={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=kernel._get_config(*args, ),
attrs=backend.get_attrs_descriptor(args, kernel.params),
)

context = triton._C.libtriton.ir.context()
target = triton.runtime.driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())
codegen_fns = dict()
module_map = backend.get_module_map()
Expand Down
7 changes: 4 additions & 3 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import triton
import triton.language as tl
from triton.backends.compiler import AttrsDescriptor
from triton.compiler import ASTSource

target = triton.runtime.driver.active.get_current_target()
Expand All @@ -25,7 +26,7 @@ def kernel_sub(a, b, o, N: tl.constexpr):


def test_compile_in_subproc() -> None:
config = triton.compiler.AttrsDescriptor(tuple(range(4)), ())
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
proc.start()
Expand All @@ -47,7 +48,7 @@ def kernel_dot(Z):


def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
config = AttrsDescriptor.from_hints({0: 16})
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
proc.start()
Expand Down Expand Up @@ -86,7 +87,7 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
gc.disable()

# stage 1.p
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
config = AttrsDescriptor.from_hints({0: 16})
compile_empty_kernel_with_gc(config)

# stage 2.p
Expand Down
195 changes: 194 additions & 1 deletion python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
import hashlib
import subprocess

from abc import ABCMeta, abstractmethod, abstractclassmethod
Expand All @@ -8,6 +9,184 @@
from types import ModuleType


class AttrsDescriptor:
"""
This class handles compile-time properties for specific function parameters.
Different backends can add more properties to the common ones. The class
contains two fields:
`arg_properties`: a dictionary containing the different compile-time properties for different
parameters. I.e., the dictionary is a map from property names to parameter indices
{
"prop0": (0, 2, 3)
"prop1": (0, 4, 5)
}
Different backends might need different properties on those paraemters to enable
specific optimizations. The common compile time properties contained in this class
are :
- "tt.divisibility", i.e., is the given parameter divisible by 16
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1
`property_values`: a dictionary containing the value of the different compile-time properties, like:
{
"prop0": val0
"prop1": val1
}
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
"""
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')

def __init__(self, params=None, values=None):
"""
Initialize the compile-time properties
We can initialize the AttrsDescriptor class by passing the list of params
of the function and their `values`. The function will try to apply the properties
to the values and save the parameters in the `arg_properties` list. If we don't pass
either the `params` or the `values` we should initialize the class via an alternative method
(see `from_dict` or `from_hints`)
"""
# Default initialization
self.arg_properties = {}
self.property_values = {}
self.constant_properties = set()

self._add_common_properties(params, values)
self._add_backend_properties(params, values)
self._init_slots()

def _add_common_properties(self, params, values):
""" Add common compile-time properties """
self.property_values["tt.divisibility"] = 16
self.property_values["tt.equal_to"] = 1
self.constant_properties.add("tt.equal_to")

if (params is None) or (values is None):
return

# Compile properties deduction
assert (len(params) == len(values))

# Divisibility property
self.arg_properties["tt.divisibility"] = [
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
]

# Equal to 1 property
self.arg_properties["tt.equal_to"] = [
param.num
for param, arg in zip(params, values)
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
]

def _add_backend_properties(self, params=None, values=None):
""" This method is for different subclasses to implement their own compile-time properties """
pass

def _init_slots(self):
""" Initialize the slots of this class """
for name, val in self.arg_properties.items():
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)

def get_fn_attrs(self) -> Dict:
"""
Get the function attributes as a dictionary.
The returned dictionary will look like :
{
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
}
"""
attrs = {}
for prop_name, arg_set in self.arg_properties.items():
prop_val = self.property_values[prop_name]
for arg in arg_set:
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
return attrs

def get_constants(self) -> Dict:
""" Return a mapping of constant parameters to their values """
constants = {}
for prop_name in self.constant_properties:
for p in self.arg_properties.get(prop_name, []):
constants[p] = self.property_values[prop_name]
return constants

def filter_out_constants(self):
""" Return the same object, without properties marked as constants"""
import copy
c = copy.deepcopy(self)
for prop_name in c.constant_properties:
c.arg_properties.pop(prop_name, None)
c.property_values.pop(prop_name, None)
c.constant_properties = {}
return c

def hash(self):
values = [sorted(self.arg_properties.values())]
values += [sorted(self.property_values.values())]
values += [sorted(self.constant_properties)]
key = str(values)
return hashlib.sha256(key.encode("utf-8")).hexdigest()

def to_dict(self):
return self.arg_properties

@staticmethod
def from_dict(data):
attrsDescriptor = AttrsDescriptor()
for prop_name, param_ids in data.items():
attrsDescriptor.arg_properties[prop_name] = param_ids
attrsDescriptor._init_slots()
return attrsDescriptor

@staticmethod
def from_hints(hints: list[tuple[int, int]]):
"""
Create the class from a set of hints that are passed in.
Instead of deducing the properties from a list of paramaters and values,
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
matches one of the values of the properties (e.g., `prop_val[prop0]`),
then we insert `param_index` into the correct list (e.g., in
`arg_properties[prop0]`)
"""
attrsDescriptor = AttrsDescriptor()
for prop_name, prop_val in attrsDescriptor.property_values.items():
attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
attrsDescriptor._init_slots()
return attrsDescriptor

@staticmethod
def is_divisible_by_16(x):
""" Return if the argument is a multiple of 16"""
if hasattr(x, "data_ptr"):
return x.data_ptr() % 16 == 0
elif isinstance(x, int):
return x % 16 == 0
if x is None:
return True
return False

@staticmethod
def is_equal_to_1(x):
""" Return if the argument is a constant 1"""
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False

@staticmethod
def get_property_key(val, align):
if align and AttrsDescriptor.is_divisible_by_16(val):
return "D"
if AttrsDescriptor.is_equal_to_1(val):
return "1"
return "N"


@dataclass(frozen=True)
class GPUTarget(object):
# Target backend, e.g., cuda, hip
Expand Down Expand Up @@ -79,6 +258,20 @@ def load_dialects(self, context):
@abstractmethod
def get_module_map(self) -> Dict[str, ModuleType]:
"""
Return a map of interface modules to their device-specific implementations.
Return a map of interface modules to their device-specific implementations
"""
raise NotImplementedError

def get_attrs_descriptor(self, params, args):
"""
Return an attribute descriptor: given a set of parameters and arguments
the descriptor stores a set of compile time properties that can improve code
generation. Different backends might benefit from different properties
"""
return AttrsDescriptor(params, args)

def compute_spec_key(self, arg, align):
"""
Return the ascii key for a given argument with a given set of properties
"""
return AttrsDescriptor.get_property_key(arg, align)
2 changes: 1 addition & 1 deletion python/triton/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
from .errors import CompilationError

__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]
12 changes: 8 additions & 4 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ def kernel_suffix(signature, specialization):
suffix += str(i)
if i in specialization.equal_to_1:
suffix += 'c'
if i in specialization.divisible_by_16:
if i in specialization.divisibility_16:
suffix += 'd'
return suffix

Expand All @@ -1284,17 +1284,21 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
gscope = fn.__globals__.copy()
function_name = fn.repr(specialization)
tys = list(specialization.signature.values())
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1}
new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16}
new_constants = attrs.get_constants()
for k in new_constants:
if k in tys and tys[k] == "i1" and new_constants[k] == 1:
new_constants[k] = True

new_attrs = attrs.filter_out_constants()
fn_attrs = new_attrs.get_fn_attrs()
all_constants = constants.copy()
all_constants.update(new_constants)
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
file_name, begin_line = get_jit_fn_file_line(fn)

prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name,
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
generator.visit(fn.parse())

Expand Down
28 changes: 1 addition & 27 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,19 @@
import json
from .._C.libtriton import get_cache_invalidating_env_vars, ir
from ..backends import backends
from ..backends.compiler import GPUTarget
from ..backends.compiler import GPUTarget, AttrsDescriptor
from .. import __version__
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..tools.disasm import get_sass
# TODO: this shouldn't be here
from dataclasses import dataclass
from .code_generator import ast_to_ttir
from pathlib import Path
import re
import functools
import os


@dataclass
class AttrsDescriptor:
divisible_by_16: set = None
equal_to_1: set = None

def __post_init__(self):
if self.divisible_by_16 is None:
self.divisible_by_16 = set()
if self.equal_to_1 is None:
self.equal_to_1 = set()

def to_dict(self):
return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)}

@staticmethod
def from_dict(data):
return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])),
equal_to_1=set(data.get('equal_to_1', [])))

def hash(self):
key = str([sorted(x) for x in self.__dict__.values()])
return hashlib.sha256(key.encode("utf-8")).hexdigest()


# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
Expand Down
Loading

0 comments on commit cd1cc2d

Please sign in to comment.