Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
maekawatoshiki committed Nov 18, 2023
1 parent 0c03bd5 commit ac6288a
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 4 deletions.
6 changes: 4 additions & 2 deletions crates/core/src/onnx/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ pub fn load_onnx_from_model_proto(model_proto: ModelProto) -> Result<Model, Mode
let mut opset_version = None;
for opset_import in &model_proto.opset_import {
match opset_import.domain() {
"" if opset_version.is_none() => opset_version = Some(opset_import.version()),
"" => return Err(ModelLoadError::DuplicateOpset),
"" | "ai.onnx" if opset_version.is_none() => {
opset_version = Some(opset_import.version())
}
"" | "ai.onnx" => return Err(ModelLoadError::DuplicateOpset),
domain => {
return Err(ModelLoadError::Todo(
format!("Custom domain ('{domain}') not supported yet").into(),
Expand Down
1 change: 1 addition & 0 deletions crates/session-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ env_logger = "0.9.0"
image = "0.24.2"
structopt = "0.3.26"
criterion = "0.4.0"
ort = "1.16.2"
95 changes: 95 additions & 0 deletions crates/session-cpu/tests/binop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use altius_core::{
graph::Graph,
model::Model,
node::Node,
onnx::{load_onnx, save::save_onnx},
op::Op,
tensor::{Tensor, TensorElemType, TypedFixedShape},
};
use altius_session_cpu::CPUSessionBuilder;
use ndarray::CowArray;
use ort::{Environment, ExecutionProvider, SessionBuilder, Value};

#[test]
fn cpu_add() {
env_logger::init();

let path = tempfile::NamedTempFile::new().unwrap();
let path = path.path();
export_onnx(path.to_str().unwrap());

let x_ = vec![1.0f32, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0];
let y_ = vec![2.0f32, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];

let env = Environment::builder()
.with_execution_providers(&[ExecutionProvider::CPU(Default::default())])
.build()
.unwrap()
.into_arc();
let sess = SessionBuilder::new(&env)
.unwrap()
.with_model_from_file(path)
.unwrap();
let x = CowArray::from(&x_)
.into_shape((4, 2))
.unwrap()
.into_dimensionality()
.unwrap();
let y = CowArray::from(&y_)
.into_shape((4, 2))
.unwrap()
.into_dimensionality()
.unwrap();
let x = Value::from_array(sess.allocator(), &x).unwrap();
let y = Value::from_array(sess.allocator(), &y).unwrap();
let z = &sess.run(vec![x, y]).unwrap()[0];
let z = z.try_extract::<f32>().unwrap();
let ort_z = z.view();
assert!(ort_z.shape() == &[4, 2]);

let sess = CPUSessionBuilder::new(load_onnx(path).unwrap())
.build()
.unwrap();
let x = Tensor::new(vec![4, 2].into(), x_);
let y = Tensor::new(vec![4, 2].into(), y_);
let altius_z = &sess.run(vec![x, y]).unwrap()[0];
assert!(altius_z.dims().as_slice() == &[4, 2]);

ort_z
.as_slice()
.unwrap()
.iter()
.zip(altius_z.data::<f32>())
.for_each(|(ort, altius)| {
assert!((ort - altius).abs() < 1e-6);
});
}

#[cfg(test)]
fn export_onnx(path: &str) {
// TODO: We need a better interface for building models.
let mut model = Model {
graph: Graph::default(),
opset_version: 12,
};
let x = model.graph.values.new_val_named_and_shaped(
"x",
TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32),
);
let y = model.graph.values.new_val_named_and_shaped(
"y",
TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32),
);
let z = model.graph.values.new_val_named_and_shaped(
"z",
TypedFixedShape::new(vec![4, 2].into(), TensorElemType::F32),
);
model
.graph
.nodes
.alloc(Node::new(Op::Add).with_ins(vec![x, y]).with_outs(vec![z]));
model.graph.inputs.push(x);
model.graph.inputs.push(y);
model.graph.outputs.push(z);
save_onnx(&model, path).unwrap();
}
4 changes: 2 additions & 2 deletions crates/session/tests/ort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use ort::{Environment, ExecutionProvider, SessionBuilder, Value};

#[test]
fn ort_add() {
let path = "/tmp/add.onnx";
export_onnx(path);
let path = tempfile::NamedTempFile::new().unwrap();
export_onnx(path.path().to_str().unwrap());

let env = Environment::builder()
.with_execution_providers(&[ExecutionProvider::CPU(Default::default())])
Expand Down

0 comments on commit ac6288a

Please sign in to comment.