Skip to content

Commit

Permalink
Remove "static inputs" for reduction ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek Kulkarni committed Jul 21, 2020
1 parent a493480 commit 9383b70
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 60 deletions.
73 changes: 21 additions & 52 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2400,45 +2400,8 @@ static Status TranslateNonMaxSuppressionV4Op(
return Status::OK();
}

static Status TranslateReduceOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map,
std::function<std::shared_ptr<ng::Node>(
std::shared_ptr<ng::Node>, std::shared_ptr<ng::Node>, const bool)>
create_ng_node) {
shared_ptr<ng::Node> ng_input;
TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, &ng_input));
bool tf_keep_dims;
if (GetNodeAttr(op->attrs(), "keep_dims", &tf_keep_dims) != Status::OK()) {
tf_keep_dims = false;
}

std::vector<int64> axes;
TF_RETURN_IF_ERROR(GetStaticInputVector(op, 1, static_input_map, &axes));

ng::Shape input_shape = ng_input->get_shape();
size_t input_rank = input_shape.size();

TF_RETURN_IF_ERROR(CheckAxisDimInRange(axes, input_rank));

std::vector<size_t> ng_reduction_axes_vect(axes.size());
std::transform(
axes.begin(), axes.end(), ng_reduction_axes_vect.begin(),
[input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); });
auto ng_reduction_axes = ConstructNgNode<ng::opset3::Constant>(
op->name(), ng::element::i64, ng::Shape{ng_reduction_axes_vect.size()},
ng_reduction_axes_vect);

std::shared_ptr<ng::Node> ng_node =
create_ng_node(ng_input, ng_reduction_axes, tf_keep_dims);
Builder::SetTracingInfo(op->name(), ng_node);

SaveNgOp(ng_op_map, op->name(), ng_node);
return Status::OK();
}

template <typename T>
static Status TranslateDirectReduceOp(
static Status TranslateReduceOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
// ensure its either an arithmetic or a logical reduction
Expand All @@ -2448,13 +2411,19 @@ static Status TranslateDirectReduceOp(
"Expected node to be either a valid logical or arithmetic reduction "
"type");
}
return TranslateReduceOp(
op, static_input_map, ng_op_map,
[&op](std::shared_ptr<ng::Node> ng_input,
std::shared_ptr<ng::Node> ng_reduction_axes, const bool keep_dims) {
return ConstructNgNode<T>(op->name(), ng_input, ng_reduction_axes,
keep_dims);
});

shared_ptr<ng::Node> ng_input, ng_reduction_indices;
TF_RETURN_IF_ERROR(
GetInputNodes(ng_op_map, op, &ng_input, &ng_reduction_indices));
bool keep_dims;
if (GetNodeAttr(op->attrs(), "keep_dims", &keep_dims) != Status::OK()) {
keep_dims = false;
}

std::shared_ptr<ng::Node> ng_node =
ConstructNgNode<T>(op->name(), ng_input, ng_reduction_indices, keep_dims);
SaveNgOp(ng_op_map, op->name(), ng_node);
return Status::OK();
}

static Status TranslateOneHotOp(
Expand Down Expand Up @@ -3908,8 +3877,8 @@ const static std::map<
{"Add", TranslateBinaryOp<ngraph::opset3::Add>},
{"AddN", TranslateAddNOp},
{"AddV2", TranslateBinaryOp<ngraph::opset3::Add>},
{"Any", TranslateDirectReduceOp<ng::opset3::ReduceLogicalOr>},
{"All", TranslateDirectReduceOp<ng::opset3::ReduceLogicalAnd>},
{"Any", TranslateReduceOp<ng::opset3::ReduceLogicalOr>},
{"All", TranslateReduceOp<ng::opset3::ReduceLogicalAnd>},
{"ArgMax", TranslateArgMinMaxOp<ng::op::ArgMax>},
{"ArgMin", TranslateArgMinMaxOp<ng::op::ArgMin>},
{"Asin", TranslateUnaryOp<ngraph::opset3::Asin>},
Expand Down Expand Up @@ -3961,13 +3930,13 @@ const static std::map<
{"LogicalNot", TranslateUnaryOp<ngraph::opset3::LogicalNot>},
{"LogicalOr", TranslateBinaryOp<ngraph::opset3::LogicalOr>},
{"MatMul", TranslateMatMulOp},
{"Max", TranslateDirectReduceOp<ng::opset3::ReduceMax>},
{"Max", TranslateReduceOp<ng::opset3::ReduceMax>},
{"Maximum", TranslateBinaryOp<ngraph::opset3::Maximum>},
{"MaxPool", TranslateMaxPoolOp},
{"MaxPool3D", TranslateMaxPool3DOp},
{"NonMaxSuppressionV4", TranslateNonMaxSuppressionV4Op},
{"Mean", TranslateDirectReduceOp<ng::opset3::ReduceMean>},
{"Min", TranslateDirectReduceOp<ng::opset3::ReduceMin>},
{"Mean", TranslateReduceOp<ng::opset3::ReduceMean>},
{"Min", TranslateReduceOp<ng::opset3::ReduceMin>},
{"Minimum", TranslateBinaryOp<ngraph::opset3::Minimum>},
{"MirrorPad", TranslatePadOp},
{"Mul", TranslateBinaryOp<ngraph::opset3::Multiply>},
Expand All @@ -3985,7 +3954,7 @@ const static std::map<
{"Pow", TranslateBinaryOp<ngraph::opset3::Power>},
// PreventGradient is just Identity in data-flow terms, so reuse that.
{"PreventGradient", TranslateIdentityOp},
{"Prod", TranslateDirectReduceOp<ng::opset3::ReduceProd>},
{"Prod", TranslateReduceOp<ng::opset3::ReduceProd>},
{"QuantizeAndDequantizeV2", TranslateQuantizeAndDequantizeV2Op},
{"QuantizedAvgPool", TranslateQuantizedAvgPoolOp},
{"QuantizedConcat", TranslateQuantizedConcatOp},
Expand Down Expand Up @@ -4029,7 +3998,7 @@ const static std::map<
{"Squeeze", TranslateSqueezeOp},
{"StridedSlice", TranslateStridedSliceOp},
{"Sub", TranslateBinaryOp<ngraph::opset3::Subtract>},
{"Sum", TranslateDirectReduceOp<ng::opset3::ReduceSum>},
{"Sum", TranslateReduceOp<ng::opset3::ReduceSum>},
{"Tan", TranslateUnaryOp<ngraph::opset3::Tan>},
{"Tanh", TranslateUnaryOp<ngraph::opset3::Tanh>},
{"Tile", TranslateTileOp},
Expand Down
8 changes: 0 additions & 8 deletions ngraph_bridge/ngraph_mark_for_clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,18 @@ const std::map<std::string, SetAttributesFunction>& GetAttributeSetters() {

if (!initialized) {
// Set Additional Attributes (if any)
set_attributes_map["Any"] = SetStaticInputs({1});
set_attributes_map["All"] = SetStaticInputs({1});
set_attributes_map["ArgMax"] = SetStaticInputs({1});
set_attributes_map["ArgMin"] = SetStaticInputs({1});
set_attributes_map["ConcatV2"] = SetStaticInputs({-1});
set_attributes_map["Conv2DBackpropInput"] = SetStaticInputs({0});
set_attributes_map["ExpandDims"] = SetStaticInputs({1});
set_attributes_map["Fill"] = SetStaticInputs({0});
set_attributes_map["GatherV2"] = SetStaticInputs({2});
set_attributes_map["Max"] = SetStaticInputs({1});
set_attributes_map["Mean"] = SetStaticInputs({1});
set_attributes_map["Min"] = SetStaticInputs({1});
set_attributes_map["MirrorPad"] = SetStaticInputs({1});
set_attributes_map["NonMaxSuppressionV4"] = SetStaticInputs({2, 3, 4});
set_attributes_map["OneHot"] = SetStaticInputs({1});
set_attributes_map["Pad"] = SetStaticInputs({1});
set_attributes_map["PadV2"] = SetStaticInputs({1, 2});
set_attributes_map["Prod"] = SetStaticInputs({1});

set_attributes_map["QuantizeAndDequantizeV2"] = SetStaticInputs({1, 2});
set_attributes_map["QuantizedConcat"] = [](Node* n) {
SetStaticInputs(n, {0}); // the axis
Expand Down Expand Up @@ -242,7 +235,6 @@ const std::map<std::string, SetAttributesFunction>& GetAttributeSetters() {
set_attributes_map["Split"] = SetStaticInputs({0});
set_attributes_map["SplitV"] = SetStaticInputs({1, 2});
set_attributes_map["StridedSlice"] = SetStaticInputs({1, 2, 3});
set_attributes_map["Sum"] = SetStaticInputs({1});
set_attributes_map["TopKV2"] = SetStaticInputs({1});
set_attributes_map["Tile"] = SetStaticInputs({1});
set_attributes_map["Transpose"] = SetStaticInputs({1});
Expand Down

0 comments on commit 9383b70

Please sign in to comment.