Skip to content

Commit

Permalink
fix matmul specail case (PaddlePaddle#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 authored Jun 30, 2022
1 parent 10c7406 commit d2da113
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,29 @@ Node *matmul_handler(Graph *graph, Node *node) {
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y"));
auto alpha = BOOST_GET_CONST(float, op->GetAttr("alpha"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape();
Node *x_node = GetInputVarNode("X", node);
Node *y_node = GetInputVarNode("Y", node);
int x_rank = x_node->Var()->GetShape().size();
int y_rank = y_node->Var()->GetShape().size();

auto gen_perm = [](const int rank) -> std::vector<int64_t> {
std::vector<int64_t> perm;
if (rank == 1) {
perm = std::vector<int64_t>{0};
} else if (rank == 2) {
perm = std::vector<int64_t>{1, 0};
} else if (rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", rank));
}
return perm;
};

int x_rank = x_shape.size();
std::vector<int64_t> perm;
if (x_rank == 1) {
perm = std::vector<int64_t>{0};
} else if (x_rank == 2) {
if (x_rank == 2) {
if (!transpose_x && !transpose_y && is_float_equal(alpha, 1.0f)) {
return CreateBaseOp(
graph,
Expand All @@ -137,18 +152,10 @@ Node *matmul_handler(Graph *graph, Node *node) {
transpose_x,
transpose_y,
alpha);
} else if (x_rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (x_rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", x_rank));
}

Node *x_node = GetInputVarNode("X", node);
Node *y_node = GetInputVarNode("Y", node);
if (transpose_x) {
auto perm = gen_perm(x_rank);
x_node = CreateBaseOp(graph,
node,
"popart_transpose",
Expand All @@ -158,6 +165,7 @@ Node *matmul_handler(Graph *graph, Node *node) {
x_node = x_node->outputs[0];
}
if (transpose_y) {
auto perm = gen_perm(y_rank);
y_node = CreateBaseOp(graph,
node,
"popart_transpose",
Expand Down Expand Up @@ -368,28 +376,30 @@ Node *matmul_v2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("trans_x"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("trans_y"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape();

std::vector<int64_t> perm;
int x_rank = x_shape.size();
if (x_rank == 1) {
perm = std::vector<int64_t>{0};
} else if (x_rank == 2) {
perm = std::vector<int64_t>{1, 0};
} else if (x_rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (x_rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", x_rank));
}

Node *x_node = GetInputVarNode("X", node);
Node *y_node = GetInputVarNode("Y", node);
int x_rank = x_node->Var()->GetShape().size();
int y_rank = y_node->Var()->GetShape().size();

auto gen_perm = [](const int rank) -> std::vector<int64_t> {
std::vector<int64_t> perm;
if (rank == 1) {
perm = std::vector<int64_t>{0};
} else if (rank == 2) {
perm = std::vector<int64_t>{1, 0};
} else if (rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", rank));
}
return perm;
};

if (transpose_x) {
auto perm = gen_perm(x_rank);
x_node = CreateBaseOp(graph,
node,
"popart_transpose",
Expand All @@ -399,6 +409,7 @@ Node *matmul_v2_handler(Graph *graph, Node *node) {
x_node = x_node->outputs[0];
}
if (transpose_y) {
auto perm = gen_perm(y_rank);
y_node = CreateBaseOp(graph,
node,
"popart_transpose",
Expand Down
34 changes: 32 additions & 2 deletions python/paddle/fluid/tests/unittests/ipu/test_matmul_op_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def set_op_attrs(self):
class TestCase7(TestBase):

def set_data_feed(self):
x = np.random.uniform(size=[1, 12, 128, 64])
y = np.random.uniform(size=[1, 12, 128, 64])
x = np.random.uniform(size=[1, 3, 4, 5])
y = np.random.uniform(size=[1, 3, 4, 5])

self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}
Expand Down Expand Up @@ -205,5 +205,35 @@ def set_data_feed(self):
self.feed_fp16 = {"x": x.astype(np.float16), "y": x.astype(np.float16)}


class TestCase10(TestBase):

def set_op_attrs(self):
self.attrs = {
"transpose_y": True,
}

def set_data_feed(self):
x = np.random.uniform(size=[4, 2, 3])
y = np.random.uniform(size=[2, 3])

self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}


class TestCase11(TestBase):

def set_op_attrs(self):
self.attrs = {
"transpose_x": True,
}

def set_data_feed(self):
x = np.random.uniform(size=[4, 3, 2])
y = np.random.uniform(size=[3, 2])

self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}


if __name__ == "__main__":
unittest.main()
30 changes: 30 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_matmul_v2_op_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,35 @@ def set_data_feed(self):
}


class TestCase9(TestBase):

def set_op_attrs(self):
self.attrs = {
"transpose_y": True,
}

def set_data_feed(self):
x = np.random.uniform(size=[4, 2, 3])
y = np.random.uniform(size=[2, 3])

self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}


class TestCase10(TestBase):

def set_op_attrs(self):
self.attrs = {
"transpose_x": True,
}

def set_data_feed(self):
x = np.random.uniform(size=[4, 3, 2])
y = np.random.uniform(size=[3, 2])

self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}


if __name__ == "__main__":
unittest.main()

0 comments on commit d2da113

Please sign in to comment.