Skip to content

Commit

Permalink
[luci/service] Add test cases for reshape
Browse files Browse the repository at this point in the history
This commit adds test cases for reshape operation.

ONE-DCO-1.0-Signed-off-by: Jongwon Yang <[email protected]>
  • Loading branch information
jongwonyang committed Sep 9, 2024
1 parent 8a7d433 commit 436119f
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,119 @@ TEST(CloneNodeTest, clone_Reshape)
ASSERT_EQ(node_reshape->newShape()->dim(0), cloned_reshape->newShape()->dim(0));
ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1));
}

TEST(ShapeRuleTest, reshape_by_input_const_static)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = 6;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_TRUE(output_shape.dim(0).known());
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(6, output_shape.dim(0).value());
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_by_input_const_dynamic)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = -1;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_TRUE(output_shape.dim(0).known());
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(6, output_shape.dim(0).value());
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_input_tensor_undefined_NEG)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::UNDEFINED);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = 6;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape));
}

TEST(ShapeRuleTest, reshape_input_shape_undefined_NEG)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = 6;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::UNDEFINED);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_FALSE(shape_inf_rule.infer(node_reshape, output_shape));
}

0 comments on commit 436119f

Please sign in to comment.