diff --git a/src/model.rs b/src/model.rs index 76da1a41..73599d5e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -573,11 +573,15 @@ mod tests { let kernel_val = from_data(vec![1, 1, 1, 1], vec![0.5]); let kernel = builder.add_float_constant(&kernel_val); - let add_operator = + // Names of all operator output nodes. + let mut op_outputs = Vec::new(); + + let mut add_operator = |builder: &mut ModelBuilder, name: &str, op: OpType, input_nodes: &[Option]| { let output_name = format!("{}_out", name); let op_output_node = builder.add_value(&output_name); builder.add_operator(name, op, input_nodes, &[op_output_node]); + op_outputs.push(output_name); op_output_node }; @@ -739,51 +743,25 @@ mod tests { let model = Model::load(&buffer).unwrap(); - // Outputs of ops tested with a 4D input (eg. NCHW image). - let outputs = vec![ - "Add_out", - "AveragePool_out", - "BatchNormalization_out", - "Cast_out", - "Clip_out", - "Concat_out", - "ConstantOfShape_out", - "Conv_out", - "ConvTranspose_out", - "Cos_out", - "Div_out", - "Equal_out", - "Erf_out", - "Expand_out", - "Identity_out", - "Gather_out", - "GlobalAveragePool_out", - "LeakyRelu_out", - "Less_out", - "MaxPool_out", - "Mul_out", - "Pad_out", - "Pow_out", - "ReduceMean_out", - "Relu_out", - "Reshape_out", - "Resize_out", - "Shape_out", - "Sigmoid_out", - "Slice_out", - "Softmax_out", - "Sqrt_out", - "Squeeze_out", - "Sin_out", - "Sub_out", - "Tanh_out", - "Transpose_out", - "Unsqueeze_out", - ]; + // Most ops are tested with a 4D input (eg. NCHW image). A few require + // different shapes are tested separately. let input = from_data(vec![1, 1, 3, 3], vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]); + for output in op_outputs { + if [ + "Gemm_out", + "MatMul_out", + "Split_out_1", + "Split_out_2", + "Range_out", + "Where_out", + ] + .contains(&output.as_str()) + { + // This op requires special handling. See below. + continue; + } - for output in outputs { - let output_id = model.find_node(output).unwrap(); + let output_id = model.find_node(&output).unwrap(); let result = model .run( &[(input_node as usize, (&input).into())],