diff --git a/benches/network_benches.rs b/benches/network_benches.rs index e27253d7..fe6909a0 100644 --- a/benches/network_benches.rs +++ b/benches/network_benches.rs @@ -69,7 +69,7 @@ mod cuda { #[bench] #[ignore] #[cfg(feature = "cuda")] - fn bench_mnsit_forward_1(b: &mut Bencher) { + fn bench_mnsit_forward_1(_b: &mut Bencher) { let mut cfg = SequentialConfig::default(); // set up input cfg.add_input("in", &vec![1, 30, 30]); @@ -96,7 +96,7 @@ mod cuda { backend.clone(), &LayerConfig::new("network", LayerType::Sequential(cfg))); let _ = timeit_loops!(10, { - let inp = SharedTensor::::new(backend.device(), &vec![1, 30, 30]).unwrap(); + let inp = SharedTensor::::new(&[1, 30, 30]); let inp_lock = Arc::new(RwLock::new(inp)); network.forward(&[inp_lock]); @@ -260,7 +260,7 @@ mod cuda { let func = || { let forward_time = timeit_loops!(1, { - let inp = SharedTensor::::new(backend.device(), &vec![128, 3, 112, 112]).unwrap(); + let inp = SharedTensor::new(&[128, 3, 112, 112]); let inp_lock = Arc::new(RwLock::new(inp)); network.forward(&[inp_lock]); @@ -416,7 +416,7 @@ mod cuda { backend.clone(), &LayerConfig::new("network", LayerType::Sequential(cfg))); let mut func = || { - let inp = SharedTensor::::new(backend.device(), &vec![128, 3, 112, 112]).unwrap(); + let inp = SharedTensor::::new(&[128, 3, 112, 112]); let inp_lock = Arc::new(RwLock::new(inp)); network.forward(&[inp_lock]); diff --git a/examples/benchmarks.rs b/examples/benchmarks.rs index 5075a63f..41e83b62 100644 --- a/examples/benchmarks.rs +++ b/examples/benchmarks.rs @@ -160,8 +160,7 @@ fn bench_alexnet() { let func = || { let forward_time = timeit_loops!(1, { { - let inp = SharedTensor::::new(backend.device(), &vec![128, 3, 224, 224]).unwrap(); - + let inp = SharedTensor::::new(&[128, 3, 224, 224]); let inp_lock = Arc::new(RwLock::new(inp)); network.forward(&[inp_lock.clone()]); } @@ -242,8 +241,7 @@ fn bench_overfeat() { let func = || { let forward_time = timeit_loops!(1, { { - let inp = SharedTensor::::new(backend.device(), &vec![128, 3, 231, 231]).unwrap(); - + let inp = SharedTensor::new(&[128, 3, 231, 231]); let inp_lock = Arc::new(RwLock::new(inp)); network.forward(&[inp_lock.clone()]); } @@ -339,7 +337,7 @@ fn bench_vgg_a() { let func = || { let forward_time = timeit_loops!(1, { { - let inp = SharedTensor::::new(backend.device(), &vec![64, 3, 224, 224]).unwrap(); + let inp = SharedTensor::new(&[64, 3, 224, 224]); let inp_lock = Arc::new(RwLock::new(inp)); network.forward(&[inp_lock.clone()]); diff --git a/tests/layer_specs.rs b/tests/layer_specs.rs index d368a172..94b4bb5e 100644 --- a/tests/layer_specs.rs +++ b/tests/layer_specs.rs @@ -66,8 +66,10 @@ mod layer_spec { let loaded_weights = loaded_layer.learnable_weights_data(); let loaded_weight_lock = loaded_weights[0].read().unwrap(); - let original_weight = original_weight_lock.get(native_backend().device()).unwrap().as_native().unwrap().as_slice::(); - let loaded_weight = loaded_weight_lock.get(native_backend().device()).unwrap().as_native().unwrap().as_slice::(); + let original_weight = original_weight_lock.read(native_backend().device()) + .unwrap().as_native().unwrap().as_slice::(); + let loaded_weight = loaded_weight_lock.read(native_backend().device()) + .unwrap().as_native().unwrap().as_slice::(); assert_eq!(original_weight, loaded_weight); } @@ -131,27 +133,28 @@ mod layer_spec { let mut reshape_network = Layer::from_config(cuda_backend.clone(), &LayerConfig::new("reshape_model", LayerType::Sequential(reshape_model))); let input = vec![1f32, 1f32, 2f32]; - let mut normal_tensor = SharedTensor::::new(native_backend.device(), &(3)).unwrap(); + let mut normal_tensor = SharedTensor::::new(&[3]); // let mut normal_tensor_output = SharedTensor::::new(native_backend.device(), &(3)).unwrap(); - let mut reshape_tensor = SharedTensor::::new(native_backend.device(), &(3)).unwrap(); + let mut reshape_tensor = SharedTensor::::new(&[3]); // let mut reshape_tensor_output = SharedTensor::::new(native_backend.device(), &(3)).unwrap(); - write_to_memory(normal_tensor.get_mut(native_backend.device()).unwrap(), &input); - write_to_memory(reshape_tensor.get_mut(native_backend.device()).unwrap(), &input); + write_to_memory(normal_tensor.write_only(native_backend.device()).unwrap(), &input); + write_to_memory(reshape_tensor.write_only(native_backend.device()).unwrap(), &input); let normal_tensor_output = normal_network.forward(&[Arc::new(RwLock::new(normal_tensor))])[0].clone(); - let _ = normal_tensor_output.write().unwrap().add_device(native_backend.device()); - normal_tensor_output.write().unwrap().sync(native_backend.device()).unwrap(); let normal_tensor_output_native_ = normal_tensor_output.read().unwrap(); - let normal_tensor_output_native = normal_tensor_output_native_.get(native_backend.device()).unwrap().as_native().unwrap(); - assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32], normal_tensor_output_native.as_slice::()); + let normal_tensor_output_native = normal_tensor_output_native_ + .read(native_backend.device()).unwrap().as_native().unwrap(); + assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32], + normal_tensor_output_native.as_slice::()); let reshape_tensor_output = reshape_network.forward(&[Arc::new(RwLock::new(reshape_tensor))])[0].clone(); - let _ = reshape_tensor_output.write().unwrap().add_device(native_backend.device()); - reshape_tensor_output.write().unwrap().sync(native_backend.device()).unwrap(); let reshape_tensor_output_native_ = reshape_tensor_output.read().unwrap(); - let reshape_tensor_output_native = reshape_tensor_output_native_.get(native_backend.device()).unwrap().as_native().unwrap(); - assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32], reshape_tensor_output_native.as_slice::()); - assert_eq!(normal_tensor_output_native.as_slice::(), reshape_tensor_output_native.as_slice::()); + let reshape_tensor_output_native = reshape_tensor_output_native_ + .read(native_backend.device()).unwrap().as_native().unwrap(); + assert_eq!(&[0.7310585786f32, 0.7310586f32, 0.880797f32], + reshape_tensor_output_native.as_slice::()); + assert_eq!(normal_tensor_output_native.as_slice::(), + reshape_tensor_output_native.as_slice::()); } }