Skip to content

Commit

Permalink
Create the list of output nodes to test automatically
Browse files Browse the repository at this point in the history
Populate the list of output node names for each operator as part of the
`add_operator!` calls, instead of separately. This will avoid mistakes where a
new op is added in this test but is never actually tested.
  • Loading branch information
robertknight committed Jan 3, 2023
1 parent 9ce9b95 commit 0c86683
Showing 1 changed file with 22 additions and 44 deletions.
66 changes: 22 additions & 44 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,15 @@ mod tests {
let kernel_val = from_data(vec![1, 1, 1, 1], vec![0.5]);
let kernel = builder.add_float_constant(&kernel_val);

let add_operator =
// Names of all operator output nodes.
let mut op_outputs = Vec::new();

let mut add_operator =
|builder: &mut ModelBuilder, name: &str, op: OpType, input_nodes: &[Option<u32>]| {
let output_name = format!("{}_out", name);
let op_output_node = builder.add_value(&output_name);
builder.add_operator(name, op, input_nodes, &[op_output_node]);
op_outputs.push(output_name);
op_output_node
};

Expand Down Expand Up @@ -739,51 +743,25 @@ mod tests {

let model = Model::load(&buffer).unwrap();

// Outputs of ops tested with a 4D input (eg. NCHW image).
let outputs = vec![
"Add_out",
"AveragePool_out",
"BatchNormalization_out",
"Cast_out",
"Clip_out",
"Concat_out",
"ConstantOfShape_out",
"Conv_out",
"ConvTranspose_out",
"Cos_out",
"Div_out",
"Equal_out",
"Erf_out",
"Expand_out",
"Identity_out",
"Gather_out",
"GlobalAveragePool_out",
"LeakyRelu_out",
"Less_out",
"MaxPool_out",
"Mul_out",
"Pad_out",
"Pow_out",
"ReduceMean_out",
"Relu_out",
"Reshape_out",
"Resize_out",
"Shape_out",
"Sigmoid_out",
"Slice_out",
"Softmax_out",
"Sqrt_out",
"Squeeze_out",
"Sin_out",
"Sub_out",
"Tanh_out",
"Transpose_out",
"Unsqueeze_out",
];
// Most ops are tested with a 4D input (eg. NCHW image). A few require
// different shapes are tested separately.
let input = from_data(vec![1, 1, 3, 3], vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
for output in op_outputs {
if [
"Gemm_out",
"MatMul_out",
"Split_out_1",
"Split_out_2",
"Range_out",
"Where_out",
]
.contains(&output.as_str())
{
// This op requires special handling. See below.
continue;
}

for output in outputs {
let output_id = model.find_node(output).unwrap();
let output_id = model.find_node(&output).unwrap();
let result = model
.run(
&[(input_node as usize, (&input).into())],
Expand Down

0 comments on commit 0c86683

Please sign in to comment.