Skip to content

Commit

Permalink
Add support for dynamic dims (iree-org#178)
Browse files Browse the repository at this point in the history
This PR adds support for dynamic dimensions in the
kernels. The user specifies the dynamic dimensions
by
- Not adding them to the hyperparameter dictionary
- Explicitly specifying them with the dynamic_symbols kwarg
  and the dynamic_symbols_mapping kwarg to specify which
  values to use for the dynamic dims at runtime

This PR does not modify the codegen and so incorrect or
unsupported values for the dynamic dims will result
in incorrect results. (garbage in -> garbage out)

---------

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored Oct 1, 2024
1 parent 84320ea commit 553e929
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 67 deletions.
94 changes: 83 additions & 11 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


def codegen_test_context(canonicalize: bool = False):
def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]):
bindings = {
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}

# Remove dynamic symbols from the bindings.
for sym in dynamic_symbols:
if sym in bindings:
del bindings[sym]

return tk.gen.TestLaunchContext(
{
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
},
canonicalize=canonicalize,
bindings, canonicalize=canonicalize, dynamic_symbols=dynamic_symbols
)


Expand Down Expand Up @@ -328,6 +334,72 @@ def test(
# CHECK-SAME: strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16>


@run_test
def test_dynamic_copy():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
b = tkw.read(a, elements_per_thread=16)
tkw.write(b, a, elements_per_thread=16)

with codegen_test_context(canonicalize=True, dynamic_symbols=[M, N]):
a = torch.randn(16, 16, dtype=torch.float16)
print(test(a).module_op)

# CHECK: stream.executable.export public @test workgroups(%[[ARG0:.*]]: index, %[[ARG1:.*]]:
# CHECK-SAME: index) -> (index, index, index) {
# CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
# CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
# CHECK: %[[D0:.+]] = arith.ceildivsi %[[ARG0]], %[[C16]] : index
# CHECK: %[[D1:.+]] = arith.ceildivsi %[[ARG1]], %[[C16]] : index
# CHECK: stream.return %[[D0]], %[[D1]], %[[C1]] : index, index, index
# CHECK: }
# CHECK: func.func @test(%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
# CHECK-SAME: attributes {translation_info = #[[TRANSLATION:.+]]} {
# CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<16xf16>
# CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> :
# CHECK-SAME: vector<16xindex>
# CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
# CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
# CHECK-DAG: %[[C16]] = arith.constant 16 : index
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index
# CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index
# CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x
# CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y
# CHECK: %[[D0]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x?xf16>{%[[ARG1]],
# CHECK-SAME: %[[ARG2]]}
# CHECK: %[[D1]] = arith.muli %[[WORKGROUP_ID_0]], %[[C16]] : index
# CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index
# CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index
# CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D1]] : index
# CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index
# CHECK: %[[D6:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C16]] : index
# CHECK: %[[D7:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C32]] : index
# CHECK: %[[D8:.+]] = arith.addi %[[D7]], %[[D6]] : index
# CHECK: %[[D9:.+]] = vector.splat %[[D8]] : vector<16xindex>
# CHECK: %[[D10:.+]] = arith.addi %[[D9]], %[[CST_0]] : vector<16xindex>
# CHECK: %[[D11:.+]] = vector.splat %[[ARG2]] : vector<16xindex>
# CHECK: %[[D12:.+]] = arith.cmpi slt, %[[D10]], %[[D11]] : vector<16xindex>
# CHECK: %[[D13:.+]] = arith.cmpi slt, %[[D5]], %[[ARG1]] : index
# CHECK: %[[D14:.+]] = vector.splat %[[D13]] : vector<16xi1>
# CHECK: %[[D15:.+]] = arith.andi %[[D12]], %[[D14]] : vector<16xi1>
# CHECK: %[[D16:.+]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D15]], %[[CST]] : memref<?x?xf16>,
# CHECK-SAME: vector<16xi1>, vector<16xf16> into vector<16xf16>
# CHECK: vector.maskedstore %[[D0]][%[[D5]], %[[D8]]], %[[D15]], %[[D16]] : memref<?x?xf16>, vector<16xi1>,
# CHECK-SAME: vector<16xf16>
# CHECK: return


@run_test
def test_mma():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
Expand Down
75 changes: 54 additions & 21 deletions shark_turbine/kernel/compiler/dispatch_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from typing import Any, Callable, Optional, Type

from .._support.indexing import (
IndexingContext,
)
from .._support.indexing import IndexingContext, IndexSymbol, IndexExpr

from .base import (
CodegenError,
Expand Down Expand Up @@ -99,6 +97,7 @@ def define_entrypoint(
grid: Grid,
workgroup_size: list[int] = None,
subgroup_size: int = None,
dynamic_symbols: list[IndexSymbol] = [],
) -> "DispatchEntrypoint":
"""Defines a dispatch function with a signature like:
Expand All @@ -119,26 +118,24 @@ def define_entrypoint(
The given name is not uniqued (must be unique as given by the caller).
"""
kb_input_bindings = sig.kernel_buffer_input_bindings
kb_temp_bindings = sig.kernel_buffer_temporary_bindings
kb_output_bindings = sig.kernel_buffer_output_bindings
# TODO: The way we are doing grid bindings is wrong. The Grid type
# should be paramerized with special grid axis symbols which are
# algebraically related to concrete shape dim symbols. For now, we are
# just assuming that the grid dims can be resolved to constants , when
# in reality, we should pass the workload and parameterize the grid
# dims on the workloads.
workload_axis_bindings = []
dynamic_dim_bindings = sig.dynamic_dim_bindings

# Input bindings are always user specified.
# Grid/workgroup bindings are in the inputs section but are implied.
# Temp bindings are a special kind of output bindings.
# Output bindings are the real outputs.
linear_bindings = (
kb_input_bindings
+ workload_axis_bindings
+ kb_temp_bindings
+ kb_output_bindings
)
# Dynamic dim bindings are the dynamic dims of the input and output tensors.
linear_bindings = kb_input_bindings + dynamic_dim_bindings + kb_output_bindings

dynamic_dim_indices = {
"begin": len(kb_input_bindings),
"end": len(linear_bindings) - len(kb_output_bindings),
}

# TODO: This is sloppy. This assert will hit on some user errors for
# unsupported type combinations and is just a last resort right now.
Expand Down Expand Up @@ -177,28 +174,59 @@ def abi_type(binding: BindingDesc):
with InsertionPoint.at_block_begin(self._exe_block):
export_op = stream_d.ExecutableExportOp(name, name)
export_block = export_op.workgroup_count.blocks.append(
*([b.as_mlir_type() for b in workload_axis_bindings])
*([b.as_mlir_type() for b in dynamic_dim_bindings])
)

workgroup_builder = WorkgroupBuilder(
export_block, lambda vs: stream_d.ReturnOp(vs)
)

# TODO: Support passing workload to the dispatch function.
from ..wave.codegen import gen_sympy_index

# Map dynamic symbols to block arguments.
dynamic_symbols_mapping = {
k: v
for k, v in zip(
dynamic_symbols, workgroup_builder.entry_block.arguments
)
}

with InsertionPoint(workgroup_builder.entry_block):
result_type = IndexType.get()
workgroup_values = [
arith_d.constant(result_type, IntegerAttr.get(result_type, dim))
for dim in grid.dims
]
workgroup_values = []
for dim in grid.dims:
if isinstance(dim, IndexExpr):
workgroup_values.append(
gen_sympy_index(dynamic_symbols_mapping, dim)
)
else:
workgroup_values.append(
arith_d.constant(
result_type, IntegerAttr.get(result_type, dim)
)
)

while len(workgroup_values) < 3:
workgroup_values.append(
arith_d.constant(result_type, IntegerAttr.get(result_type, 1))
)
workgroup_builder.terminate(workgroup_values)

return DispatchEntrypoint(sig, def_func_block, linear_bindings)
# Map dynamic symbols to func arguments for dispatch entrypoint.
dynamic_symbols_mapping = {
k: v
for k, v in zip(
dynamic_symbols,
def_func_args[
dynamic_dim_indices["begin"] : dynamic_dim_indices["end"]
],
)
}

return DispatchEntrypoint(
sig, def_func_block, linear_bindings, dynamic_symbols_mapping
)


class WorkgroupBuilder:
Expand Down Expand Up @@ -231,8 +259,10 @@ def __init__(
sig: KernelSignature,
entry_block: Block,
linear_bindings: list[BindingDesc],
dynamic_symbols_mapping: dict[IndexSymbol, Value],
):
super().__init__(sig, entry_block)
self.dynamic_symbols_mapping = dynamic_symbols_mapping
self._abi_value_by_reference: dict[tuple[str, Any], Value] = {
b.reference: value
for value, b in zip(entry_block.arguments, linear_bindings)
Expand All @@ -250,12 +280,15 @@ def resolve(self, binding: BindingDesc) -> Value:
result_type = IndexType.get()
zero_value = arith_d.constant(result_type, IntegerAttr.get(result_type, 0))
linear_arg_value = self._abi_value_by_reference[binding.reference]
# TODO: Need to also look up dynamic symbol values.
return stream_d.binding_subspan(
binding.as_mlir_type(),
linear_arg_value,
byte_offset=zero_value,
dynamic_dims=[],
dynamic_dims=[
self.dynamic_symbols_mapping[dim]
for dim in binding.kernel_buffer_type.symbolic_shape
if dim in self.dynamic_symbols_mapping
],
)

raise ValidationError(f"Unhandled binding type: {binding}")
40 changes: 37 additions & 3 deletions shark_turbine/kernel/compiler/host_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ir import (
Block,
FunctionType,
IndexType,
InsertionPoint,
IrType,
Location,
Expand All @@ -19,6 +20,9 @@
func_d,
)

from .._support.indexing import IndexSymbol
from .kernel_codegen import BindingDesc


def memref_to_tensor(memrefs: list[IrType]):
tensors = []
Expand All @@ -29,30 +33,60 @@ def memref_to_tensor(memrefs: list[IrType]):
return tensors


def get_dynamic_dims(bindings: list[BindingDesc], dynamic_symbols: list[IndexSymbol]):
dynamic_dims: list[IndexSymbol] = []
for b in bindings:
for dim in b.kernel_buffer_type.symbolic_shape:
if dim in dynamic_symbols:
dynamic_dims.append(dim)
return dynamic_dims


def isolated_test_call(
mb: ModuleBuilder, exe: StreamExecutable, sig: KernelSignature, entrypoint: str
mb: ModuleBuilder,
exe: StreamExecutable,
sig: KernelSignature,
entrypoint: str,
dynamic_symbols: list[IndexSymbol] = [],
):
with InsertionPoint(mb.body_block), Location.unknown():
input_types = [b.as_mlir_type() for b in sig.kernel_buffer_input_bindings]
input_tensors = memref_to_tensor(input_types)
argument_dims = get_dynamic_dims(
sig.kernel_buffer_input_bindings, dynamic_symbols
)
input_tensors += [IndexType.get() for _ in argument_dims]

output_types = [b.as_mlir_type() for b in sig.kernel_buffer_output_bindings]
output_tensors = memref_to_tensor(output_types)
result_dims = get_dynamic_dims(
sig.kernel_buffer_output_bindings, dynamic_symbols
)

ftype = FunctionType.get(input_tensors, output_tensors)
func_op = func_d.FuncOp("isolated_benchmark", ftype)
arg_locs = [
(Location.name(b.name) if b.name is not None else Location.unknown())
for b in sig.kernel_buffer_input_bindings
for b in sig.kernel_buffer_input_bindings + sig.dynamic_dim_bindings
]
entry_block = func_op.add_entry_block(arg_locs)
offset = len(sig.kernel_buffer_input_bindings)
dynamic_argument_map = {
k: v for k, v in zip(dynamic_symbols, entry_block.arguments[offset:])
}
with InsertionPoint(entry_block):
assert isinstance(entry_block, Block)
# Create a flow.dispatch op to the kernel
dispatch = SymbolRefAttr.get([exe.sym_name.value, entrypoint])
entrypoints = ArrayAttr.get([dispatch])

out = flow_d.DispatchOp(
output_tensors, [], entrypoints, entry_block.arguments, [], []
output_tensors,
[dynamic_argument_map[dim] for dim in dynamic_symbols],
entrypoints,
entry_block.arguments,
[dynamic_argument_map[dim] for dim in argument_dims],
[dynamic_argument_map[dim] for dim in result_dims],
)

func_d.ReturnOp(out)
16 changes: 16 additions & 0 deletions shark_turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@ def kernel_buffer_temporary_bindings(self) -> list[BindingDesc]:
and b.kernel_buffer_type.usage == KernelBufferUsage.TEMPORARY
]

@property
def dynamic_dim_bindings(self) -> list[BindingDesc]:
"""Gets all dynamic dimension bindings."""
return [b for b in self.bindings if b.binding_type == BindingType.SYMBOL_VALUE]

def add_from_dynamic_symbols(self, dynamic_symbols: list[IndexSymbol]):
for symbol in dynamic_symbols:
self.bindings.append(
BindingDesc(
("symbol", symbol),
BindingType.SYMBOL_VALUE,
name=symbol.name,
symbol_type=symbol,
)
)

def add_from_graph_placeholders(self, graph: fx.Graph):
# Extract all placeholder nodes.
placeholder_nodes = filter_fx_graph(graph, is_placeholder)
Expand Down
Loading

0 comments on commit 553e929

Please sign in to comment.