Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes aggressive SliceView removal and removes unnecessary squeezing in MatMult library nodes #1884

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions dace/libraries/blas/nodes/batched_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ExpandBatchedMatMulPure(ExpandTransformation):
@staticmethod
def make_sdfg(node, parent_state, parent_sdfg):
# Get metadata from parent SDFG
((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b, shape_b, strides_b),
((edge_a, outer_array_a, shape_a, strides_a, _, _), (edge_b, outer_array_b, shape_b, strides_b, _, _),
cdata) = _get_matmul_operands(node, parent_state, parent_sdfg)
outedge = parent_state.out_edges(node)[0]
cdesc = parent_sdfg.arrays[outedge.data.data]
Expand Down Expand Up @@ -52,7 +52,7 @@ def make_sdfg(node, parent_state, parent_sdfg):

_, array_a = sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=storage)
_, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=storage)
_, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-1], storage=storage)
_, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-3], storage=storage)

# Add an initialization state
init_state = sdfg.add_state()
Expand Down Expand Up @@ -91,7 +91,7 @@ class ExpandBatchedMatMulMKL(ExpandTransformation):
@staticmethod
def expansion(node, state, sdfg):
node.validate(sdfg, state)
(_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg)
(_, adesc, ashape, astrides, _, _), (_, bdesc, bshape, bstrides, _, _), _ = _get_matmul_operands(node, state, sdfg)
cdesc: dt.Array = sdfg.arrays[state.out_edges(node)[0].data.data]
check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)
dtype = cdesc.dtype.base_type
Expand Down Expand Up @@ -160,7 +160,7 @@ class ExpandBatchedMatMulOpenBLAS(ExpandTransformation):
@staticmethod
def expansion(node, state, sdfg):
node.validate(sdfg, state)
(_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg)
(_, adesc, ashape, astrides, _, _), (_, bdesc, bshape, bstrides, _, _), _ = _get_matmul_operands(node, state, sdfg)
cdesc = sdfg.arrays[state.out_edges(node)[0].data.data]
check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)
dtype = cdesc.dtype.base_type
Expand Down Expand Up @@ -446,10 +446,7 @@ def validate(self, sdfg, state):
f'may not match', UserWarning)
elif not res:
raise ValueError("Inputs to matrix-matrix product must agree in the k-dimension")
out_subset = dc(out_memlet.subset)
out_subset.squeeze()
size2 = out_subset.size()
if len(size2) != 3:
if len(out_memlet.subset) != 3:
raise ValueError("batched matrix-matrix product only supported on matrices")


Expand Down
33 changes: 12 additions & 21 deletions dace/libraries/blas/nodes/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from dace.frontend.common import op_repository as oprepo
import dace.sdfg.nodes
from dace.transformation.transformation import ExpandTransformation
from dace.libraries.blas.blas_helpers import (to_blastype, get_gemm_opts, check_access, dtype_to_cudadatatype,
to_cublas_computetype)
from dace.libraries.blas.blas_helpers import to_blastype, check_access, dtype_to_cudadatatype, to_cublas_computetype
from dace.libraries.blas.nodes.matmul import (_get_matmul_operands, _get_codegen_gemm_opts)
from .. import environments
import numpy as np
Expand Down Expand Up @@ -47,7 +46,7 @@ class ExpandGemmPure(ExpandTransformation):
def make_sdfg(node, parent_state, parent_sdfg):
sdfg = dace.SDFG(node.label + "_sdfg")

((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b, shape_b, strides_b),
((edge_a, outer_array_a, shape_a, strides_a, _, _), (edge_b, outer_array_b, shape_b, strides_b, _, _),
cdata) = _get_matmul_operands(node, parent_state, parent_sdfg)

dtype_a = outer_array_a.dtype.type
Expand Down Expand Up @@ -79,7 +78,7 @@ def make_sdfg(node, parent_state, parent_sdfg):

_, array_a = sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=outer_array_a.storage)
_, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=outer_array_b.storage)
_, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-1], storage=cdata[1].storage)
_, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-3], storage=cdata[1].storage)

if equal_valued(1, node.alpha):
mul_program = "__out = __a * __b"
Expand All @@ -93,7 +92,7 @@ def make_sdfg(node, parent_state, parent_sdfg):
state = sdfg.add_state_after(init_state, node.label + "_state")

if '_cin' in node.in_connectors:
sdfg.add_array("_cin", shape_c, dtype_c, strides=cdata[-1], storage=cdata[1].storage)
sdfg.add_array("_cin", shape_c, dtype_c, strides=cdata[-3], storage=cdata[1].storage)

mul_out, mul_out_array = "_c", array_c
output_nodes = None
Expand Down Expand Up @@ -159,7 +158,7 @@ class ExpandGemmOpenBLAS(ExpandTransformation):
@staticmethod
def expansion(node, state, sdfg):
node.validate(sdfg, state)
(_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg)
(_, adesc, _, _, _, _), (_, bdesc, _, _, _, _), _ = _get_matmul_operands(node, state, sdfg)
dtype = adesc.dtype.base_type
func = to_blastype(dtype.type).lower() + 'gemm'
alpha = f'{dtype.ctype}({node.alpha})'
Expand Down Expand Up @@ -458,7 +457,7 @@ class ExpandGemmPBLAS(ExpandTransformation):
@staticmethod
def expansion(node, state, sdfg):
node.validate(sdfg, state)
(_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg)
(_, adesc, ashape, _, _, _), (_, bdesc, bshape, _, _, _), _ = _get_matmul_operands(node, state, sdfg)
dtype = adesc.dtype.base_type

if not equal_valued(0, node.beta):
Expand Down Expand Up @@ -513,8 +512,8 @@ def expansion(node, parent_state, parent_sdfg, num_pes=32, tile_size_m=None):
:return:
"""

((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b, shape_b, strides_b),
(edge_c, outer_array_c, shape_c, strides_c)) = _get_matmul_operands(node, parent_state, parent_sdfg)
((edge_a, outer_array_a, shape_a, strides_a, _, _), (edge_b, outer_array_b, shape_b, strides_b, _, _),
(edge_c, outer_array_c, shape_c, strides_c, _, _)) = _get_matmul_operands(node, parent_state, parent_sdfg)

dtype_a = outer_array_a.dtype.type
dtype_b = outer_array_b.dtype.type
Expand Down Expand Up @@ -1013,17 +1012,11 @@ def validate(self, sdfg, state):
size2 = None
for _, _, _, dst_conn, memlet in state.in_edges(self):
if dst_conn == '_a':
subset = dc(memlet.subset)
subset.squeeze()
size0 = subset.size()
size0 = memlet.subset.size()
if dst_conn == '_b':
subset = dc(memlet.subset)
subset.squeeze()
size1 = subset.size()
size1 = memlet.subset.size()
if dst_conn == '_c':
subset = dc(memlet.subset)
subset.squeeze()
size2 = subset.size()
size2 = memlet.subset.size()

if self.transA:
size0 = list(reversed(size0))
Expand All @@ -1043,9 +1036,7 @@ def validate(self, sdfg, state):
UserWarning)
elif not res:
raise ValueError("Inputs to matrix-matrix product must agree in the k-dimension")
out_subset = dc(out_memlet.subset)
out_subset.squeeze()
size3 = out_subset.size()
size3 = out_memlet.subset.size()
if size2 is not None:
res = [equal(s0, s1) for s0, s1 in zip(size2, size3)]
fail = any([r is False for r in res])
Expand Down
46 changes: 10 additions & 36 deletions dace/libraries/blas/nodes/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dace.libraries.blas import blas_helpers
from dace.frontend.common import op_repository as oprepo
from dace.libraries.blas import environments
from dace.sdfg import nodes, utils as sdutils
import numpy as np
import warnings

Expand All @@ -24,13 +23,8 @@ class ExpandGemvPure(ExpandTransformation):
def expansion(node, parent_state, parent_sdfg, **kwargs):
node.validate(parent_sdfg, parent_state)
sdfg = dace.SDFG(node.label + "_sdfg")
((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x),
(edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node,
parent_state,
parent_sdfg,
name_lhs="_A",
name_rhs="_x",
name_out="_y")
((edge_a, outer_array_a, _, _, shape_a, strides_a), (edge_x, outer_array_x, _, _, shape_x, strides_x),
(edge_y, outer_array_y, _, _, shape_y, strides_y)) = _get_matmul_operands(node, parent_state, parent_sdfg, name_lhs="_A", name_rhs="_x", name_out="_y")
dtype_a = outer_array_a.dtype.type
dtype_x = outer_array_x.dtype.type
dtype_y = outer_array_y.dtype.type
Expand Down Expand Up @@ -154,13 +148,8 @@ def expansion(node, parent_state, parent_sdfg, tile_size_x=None, tile_size_y=Non
beta = node.beta

# Get input/output data (the method considers also the presence of view nodes)
((edge_a, desc_a, shape_a, strides_a), (edge_x, desc_x, shape_x, strides_x),
(edge_y, desc_y, shape_y, strides_y)) = _get_matmul_operands(node,
parent_state,
parent_sdfg,
name_lhs="_A",
name_rhs="_x",
name_out="_y")
((edge_a, desc_a, _, _, shape_a, strides_a), (edge_x, desc_x, _, _, shape_x, strides_x),
(edge_y, desc_y, _, _, shape_y, strides_y)) = _get_matmul_operands(node, parent_state, parent_sdfg, name_lhs="_A", name_rhs="_x", name_out="_y")

# Create local versions of input/output data nodes
_, desc_a = sdfg.add_array("_A",
Expand Down Expand Up @@ -618,13 +607,8 @@ class ExpandGemvCuBLAS(ExpandTransformation):
def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
node.validate(sdfg, state)

((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x),
(edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node,
state,
sdfg,
name_lhs="_A",
name_rhs="_x",
name_out="_y")
((edge_a, outer_array_a, _, _, shape_a, strides_a), (edge_x, outer_array_x, _, _, shape_x, strides_x),
(edge_y, outer_array_y, _, _, shape_y, strides_y)) = _get_matmul_operands(node, state, sdfg, name_lhs="_A", name_rhs="_x", name_out="_y")
dtype_a = outer_array_a.dtype.type
dtype = outer_array_x.dtype.base_type
veclen = outer_array_x.dtype.veclen
Expand Down Expand Up @@ -720,13 +704,8 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):

node.validate(sdfg, state)

((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x),
(edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node,
state,
sdfg,
name_lhs="_A",
name_rhs="_x",
name_out="_y")
((edge_a, outer_array_a, _, _, shape_a, strides_a), (edge_x, outer_array_x, _, _, shape_x, strides_x),
(edge_y, outer_array_y, _, _, shape_y, strides_y)) = _get_matmul_operands(node, state, sdfg, name_lhs="_A", name_rhs="_x", name_out="_y")
dtype_a = outer_array_a.dtype.type
dtype = outer_array_x.dtype.base_type
veclen = outer_array_x.dtype.veclen
Expand Down Expand Up @@ -806,13 +785,8 @@ class ExpandGemvPBLAS(ExpandTransformation):
@staticmethod
def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
node.validate(sdfg, state)
((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x),
(edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node,
state,
sdfg,
name_lhs="_A",
name_rhs="_x",
name_out="_y")
((edge_a, outer_array_a, _, _, shape_a, strides_a), (edge_x, outer_array_x, _, _, shape_x, strides_x),
(edge_y, outer_array_y, _, _, shape_y, strides_y)) = _get_matmul_operands(node, state, sdfg, name_lhs="_A", name_rhs="_x", name_out="_y")
dtype_a = outer_array_a.dtype.type
dtype = outer_array_x.dtype.base_type
veclen = outer_array_x.dtype.veclen
Expand Down
7 changes: 1 addition & 6 deletions dace/libraries/blas/nodes/ger.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from dace.properties import Property, SymbolicProperty
from dace.properties import SymbolicProperty
from dace.transformation.transformation import ExpandTransformation
from dace.frontend.common import op_repository as oprepo
from dace.sdfg.nodes import LibraryNode
from dace.libraries.blas.nodes.matmul import _get_matmul_operands
import dace.library as library
from dace.sdfg import SDFG, SDFGState, nodes
from dace import data as dt, memlet as mm, subsets as sbs
import dace
import copy
import numpy as np

import dace.library
import dace.properties
import dace.sdfg.nodes

from dace import dtypes
from dace.memlet import Memlet


@library.expansion
class ExpandGerPure(ExpandTransformation):
Expand Down
30 changes: 17 additions & 13 deletions dace/libraries/blas/nodes/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,27 @@ def _get_matmul_operands(node, state, sdfg, name_lhs="_a", name_rhs="_b", name_o
res_rhs = None
for edge in state.all_edges(node):
if edge.dst_conn in [name_lhs, name_rhs]:
subset = dc(edge.data.subset)
squeezed = subset.squeeze()
size = subset.size()
size = edge.data.subset.size()
squeezed = dc(edge.data.subset)
squeezed_dims = squeezed.squeeze()
squeezed_size = squeezed.size()
outer_array = sdfg.data(dace.sdfg.find_input_arraynode(state, edge).data)
strides = [s for i, s in enumerate(outer_array.strides) if i in squeezed]
res = edge, outer_array, size, strides
strides = list(outer_array.strides)
squeezed_strides = [s for i, s in enumerate(outer_array.strides) if i in squeezed_dims]
res = edge, outer_array, size, strides, squeezed_size, squeezed_strides
if edge.dst_conn == name_lhs:
res_lhs = res
else:
res_rhs = res
elif edge.src_conn == name_out:
subset = dc(edge.data.subset)
squeezed = subset.squeeze()
size = subset.size()
size = edge.data.subset.size()
squeezed = dc(edge.data.subset)
squeezed_dims = squeezed.squeeze()
squeezed_size = squeezed.size()
outer_array = sdfg.data(dace.sdfg.find_output_arraynode(state, edge).data)
strides = [s for i, s in enumerate(outer_array.strides) if i in squeezed]
res_out = edge, outer_array, size, strides
strides = list(outer_array.strides)
squeezed_strides = [s for i, s in enumerate(outer_array.strides) if i in squeezed_dims]
res_out = edge, outer_array, size, strides, squeezed_size, squeezed_strides
for res, name in ((res_lhs, name_lhs), (res_rhs, name_rhs), (res_out, name_out)):
if res is None:
raise ValueError("Matrix multiplication connector " "\"{}\" not found.".format(name))
Expand Down Expand Up @@ -85,7 +89,7 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta,
from dace.codegen.common import sym2cpp
from dace.libraries.blas.blas_helpers import get_gemm_opts

(_, _, ashape, astride), (_, _, bshape, bstride), (_, _, cshape, cstride) = _get_matmul_operands(node, state, sdfg)
(_, _, ashape, astride, _, _), (_, _, bshape, bstride, _, _), (_, _, cshape, cstride, _, _) = _get_matmul_operands(node, state, sdfg)

if getattr(node, 'transA', False):
ashape = list(reversed(ashape))
Expand Down Expand Up @@ -141,8 +145,8 @@ class SpecializeMatMul(dace.transformation.transformation.ExpandTransformation):
@staticmethod
def expansion(node, state, sdfg):
a, b, c = _get_matmul_operands(node, state, sdfg)
size_a = a[2]
size_b = b[2]
size_a = a[4]
size_b = b[4]
if len(size_a) == 2 and len(size_b) == 2:
# Matrix and matrix -> GEMM
from dace.libraries.blas.nodes.gemm import Gemm
Expand Down
Loading
Loading