From ac6288a6331df2a5369a6343d88bc696abe9f417 Mon Sep 17 00:00:00 2001 From: maekawatoshiki Date: Sat, 18 Nov 2023 18:34:46 +0900 Subject: [PATCH] Add test --- crates/core/src/onnx/load.rs | 6 +- crates/session-cpu/Cargo.toml | 1 + crates/session-cpu/tests/binop.rs | 95 +++++++++++++++++++++++++++++++ crates/session/tests/ort.rs | 4 +- 4 files changed, 102 insertions(+), 4 deletions(-) create mode 100644 crates/session-cpu/tests/binop.rs diff --git a/crates/core/src/onnx/load.rs b/crates/core/src/onnx/load.rs index 98479e8..e768699 100644 --- a/crates/core/src/onnx/load.rs +++ b/crates/core/src/onnx/load.rs @@ -70,8 +70,10 @@ pub fn load_onnx_from_model_proto(model_proto: ModelProto) -> Result opset_version = Some(opset_import.version()), - "" => return Err(ModelLoadError::DuplicateOpset), + "" | "ai.onnx" if opset_version.is_none() => { + opset_version = Some(opset_import.version()) + } + "" | "ai.onnx" => return Err(ModelLoadError::DuplicateOpset), domain => { return Err(ModelLoadError::Todo( format!("Custom domain ('{domain}') not supported yet").into(), diff --git a/crates/session-cpu/Cargo.toml b/crates/session-cpu/Cargo.toml index 4e7ff78..a40ed39 100644 --- a/crates/session-cpu/Cargo.toml +++ b/crates/session-cpu/Cargo.toml @@ -30,3 +30,4 @@ env_logger = "0.9.0" image = "0.24.2" structopt = "0.3.26" criterion = "0.4.0" +ort = "1.16.2" diff --git a/crates/session-cpu/tests/binop.rs b/crates/session-cpu/tests/binop.rs new file mode 100644 index 0000000..25e1597 --- /dev/null +++ b/crates/session-cpu/tests/binop.rs @@ -0,0 +1,95 @@ +use altius_core::{ + graph::Graph, + model::Model, + node::Node, + onnx::{load_onnx, save::save_onnx}, + op::Op, + tensor::{Tensor, TensorElemType, TypedFixedShape}, +}; +use altius_session_cpu::CPUSessionBuilder; +use ndarray::CowArray; +use ort::{Environment, ExecutionProvider, SessionBuilder, Value}; + +#[test] +fn cpu_add() { + env_logger::init(); + + let path = tempfile::NamedTempFile::new().unwrap(); + let path = path.path(); + export_onnx(path.to_str().unwrap()); + + let x_ = vec![1.0f32, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0]; + let y_ = vec![2.0f32, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0]; + + let env = Environment::builder() + .with_execution_providers(&[ExecutionProvider::CPU(Default::default())]) + .build() + .unwrap() + .into_arc(); + let sess = SessionBuilder::new(&env) + .unwrap() + .with_model_from_file(path) + .unwrap(); + let x = CowArray::from(&x_) + .into_shape((4, 2)) + .unwrap() + .into_dimensionality() + .unwrap(); + let y = CowArray::from(&y_) + .into_shape((4, 2)) + .unwrap() + .into_dimensionality() + .unwrap(); + let x = Value::from_array(sess.allocator(), &x).unwrap(); + let y = Value::from_array(sess.allocator(), &y).unwrap(); + let z = &sess.run(vec![x, y]).unwrap()[0]; + let z = z.try_extract::().unwrap(); + let ort_z = z.view(); + assert!(ort_z.shape() == &[4, 2]); + + let sess = CPUSessionBuilder::new(load_onnx(path).unwrap()) + .build() + .unwrap(); + let x = Tensor::new(vec![4, 2].into(), x_); + let y = Tensor::new(vec![4, 2].into(), y_); + let altius_z = &sess.run(vec![x, y]).unwrap()[0]; + assert!(altius_z.dims().as_slice() == &[4, 2]); + + ort_z + .as_slice() + .unwrap() + .iter() + .zip(altius_z.data::()) + .for_each(|(ort, altius)| { + assert!((ort - altius).abs() < 1e-6); + }); +} + +#[cfg(test)] +fn export_onnx(path: &str) { + // TODO: We need a better interface for building models. + let mut model = Model { + graph: Graph::default(), + opset_version: 12, + }; + let x = model.graph.values.new_val_named_and_shaped( + "x", + TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), + ); + let y = model.graph.values.new_val_named_and_shaped( + "y", + TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), + ); + let z = model.graph.values.new_val_named_and_shaped( + "z", + TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32), + ); + model + .graph + .nodes + .alloc(Node::new(Op::Add).with_ins(vec![x, y]).with_outs(vec![z])); + model.graph.inputs.push(x); + model.graph.inputs.push(y); + model.graph.outputs.push(z); + save_onnx(&model, path).unwrap(); +} diff --git a/crates/session/tests/ort.rs b/crates/session/tests/ort.rs index 61d021f..d80eaca 100644 --- a/crates/session/tests/ort.rs +++ b/crates/session/tests/ort.rs @@ -11,8 +11,8 @@ use ort::{Environment, ExecutionProvider, SessionBuilder, Value}; #[test] fn ort_add() { - let path = "/tmp/add.onnx"; - export_onnx(path); + let path = tempfile::NamedTempFile::new().unwrap(); + export_onnx(path.path().to_str().unwrap()); let env = Environment::builder() .with_execution_providers(&[ExecutionProvider::CPU(Default::default())])