Skip to content

Commit

Permalink
refactor/tensor: convert to the new memory API
Browse files Browse the repository at this point in the history
Fix SharedTensor::new() usage.

REFERENCE: autumnai/collenchyma#37, autumnai/collenchyma#62
  • Loading branch information
alexandermorozov committed Apr 30, 2016
1 parent 1d1b854 commit 704e45e
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ fn run_mnist(model_name: Option<String>, batch_size: Option<usize>, learning_rat

// set up backends
let backend = ::std::rc::Rc::new(Backend::<Cuda>::default().unwrap());
let native_backend = ::std::rc::Rc::new(Backend::<Native>::default().unwrap());

// set up solver
let mut solver_cfg = SolverConfig { minibatch_size: batch_size, base_lr: learning_rate, momentum: momentum, .. SolverConfig::default() };
Expand All @@ -168,9 +167,8 @@ fn run_mnist(model_name: Option<String>, batch_size: Option<usize>, learning_rat
let mut confusion = ::leaf::solver::ConfusionMatrix::new(10);
confusion.set_capacity(Some(1000));

let mut inp = SharedTensor::<f32>::new(backend.device(), &vec![batch_size, 1, 28, 28]).unwrap();
let label = SharedTensor::<f32>::new(native_backend.device(), &vec![batch_size, 1]).unwrap();
inp.add_device(native_backend.device()).unwrap();
let inp = SharedTensor::<f32>::new(&[batch_size, 1, 28, 28]);
let label = SharedTensor::<f32>::new(&[batch_size, 1]);

let inp_lock = Arc::new(RwLock::new(inp));
let label_lock = Arc::new(RwLock::new(label));
Expand Down

0 comments on commit 704e45e

Please sign in to comment.