Skip to content

Commit

Permalink
Fix Cuda convolution tests. Fixes autumnai#45.
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondLovesYou committed Apr 9, 2016
1 parent 080ffe6 commit 4049aff
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tests/convolution_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ mod convolution_spec_cuda {
let d1 = 3;
let k = 6;
let f = 3;
let w2 = (w1 - f + 0) / 1;
let h2 = (h1 - f + 0) / 1;
let w2 = (w1 - f + 0) / 1 + 1;
let h2 = (h1 - f + 0) / 1 + 1;
let mut x = SharedTensor::<T>::new(backend.device(), &(batch, d1, h1, w1)).unwrap();
let mut payload: &mut [T] = &mut ::std::iter::repeat(val).take(x.capacity()).collect::<Vec<T>>();
payload[0] = val2;
Expand Down Expand Up @@ -148,8 +148,15 @@ mod convolution_spec_cuda {
Ok(_) => {
result.sync(native.device()).unwrap();
if let Some(mem) = result.get(native.device()).unwrap().as_native() {
let mut payload: &mut [f64] = &mut ::std::iter::repeat(27f64).take(result.capacity()).collect::<Vec<f64>>();
payload[0] = 28f64;
let mut payload: &mut [f64] = &mut ::std::iter::repeat(27f64)
.take(result.capacity())
.collect::<Vec<f64>>();

let data = mem.as_slice::<f64>();

for i in 0..19 {
payload[i * result.desc()[result.desc().len() - 1] * result.desc()[result.desc().len() - 2]] = 28.0;
}
assert_eq!(payload, mem.as_slice::<f64>());
}
},
Expand Down

0 comments on commit 4049aff

Please sign in to comment.