Skip to content

Commit

Permalink
Auto merge of autumnai#40 - autumnai:fix/0_allocation, r=hobofan
Browse files Browse the repository at this point in the history
fix/convolution: workaround for 0 memory allocation
  • Loading branch information
homu committed Mar 3, 2016
2 parents e782a8a + e30b59d commit 4094b48
Showing 1 changed file with 13 additions and 42 deletions.
55 changes: 13 additions & 42 deletions src/frameworks/cuda/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,6 @@ impl ConvForwardAlgo {
};
Ok(ConvForwardAlgo::from_cudnn(&algo))
}

/// Check if the algo needs a cudnn workspace.
fn needs_cudnn_workspace(&self) -> Result<bool, ::co::error::Error> {
Ok(match *self {
ConvForwardAlgo::Auto => return Err(::co::error::Error::Plugin(::co::plugin::Error::Plugin("Can't check necessary workspace size for ConvForwardAlgo::Auto. Use `find_cudnn_algo` to find an algorithm."))),
ConvForwardAlgo::GEMM => true,
ConvForwardAlgo::ImplicitGEMM => false,
ConvForwardAlgo::ImplicitPrecompiledGEMM => true,
ConvForwardAlgo::FFT => true,
ConvForwardAlgo::FFTTiling => true,
ConvForwardAlgo::Direct => true,
})
}
}

impl ConvBackwardFilterAlgo {
Expand Down Expand Up @@ -209,17 +196,6 @@ impl ConvBackwardFilterAlgo {
};
Ok(ConvBackwardFilterAlgo::from_cudnn(&algo))
}

/// Check if the algo needs a cudnn workspace.
fn needs_cudnn_workspace(&self) -> Result<bool, ::co::error::Error> {
Ok(match *self {
ConvBackwardFilterAlgo::Auto => return Err(::co::error::Error::Plugin(::co::plugin::Error::Plugin("Can't check necessary workspace size for ConvBackwardFilterAlgo::Auto. Use `find_cudnn_algo` to find an algorithm."))),
ConvBackwardFilterAlgo::ImplicitGEMM => false,
ConvBackwardFilterAlgo::ImplicitGEMMSum => true,
ConvBackwardFilterAlgo::ImplicitPrecompiledGEMMSum => true,
ConvBackwardFilterAlgo::FFT => true,
})
}
}

impl ConvBackwardDataAlgo {
Expand Down Expand Up @@ -262,17 +238,6 @@ impl ConvBackwardDataAlgo {
};
Ok(ConvBackwardDataAlgo::from_cudnn(&algo))
}

/// Check if the algo needs a cudnn workspace.
fn needs_cudnn_workspace(&self) -> Result<bool, ::co::error::Error> {
Ok(match *self {
ConvBackwardDataAlgo::Auto => return Err(::co::error::Error::Plugin(::co::plugin::Error::Plugin("Can't check necessary workspace size for ConvBackwardDataAlgo::Auto. Use `find_cudnn_algo` to find an algorithm."))),
ConvBackwardDataAlgo::ImplicitGEMM => false,
ConvBackwardDataAlgo::ImplicitGEMMSum => false,
ConvBackwardDataAlgo::FFT => true,
ConvBackwardDataAlgo::FFTTiling => true,
})
}
}

macro_rules! impl_convolution_for_cuda_backend {
Expand Down Expand Up @@ -304,13 +269,19 @@ macro_rules! impl_convolution_for_cuda_backend {
let useable_algo_bwd_filter = try!(algo_bwd_filter.find_cudnn_algo(&filter_desc, &conv_desc, &src_desc, &dest_desc));
let useable_algo_bwd_data = try!(algo_bwd_data.find_cudnn_algo(&filter_desc, &conv_desc, &src_desc, &dest_desc));

let workspace_size_fwd = API::get_convolution_forward_workspace_size(*CUDNN.id_c(), useable_algo_fwd.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
let workspace_size_bwd_filter = API::get_convolution_backward_filter_workspace_size(*CUDNN.id_c(), useable_algo_bwd_filter.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
// let workspace_size_bwd_data = API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
let workspace_size_bwd_data = match try!(useable_algo_bwd_data.needs_cudnn_workspace()) {
false => 1,
true => API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap(),
};
let mut workspace_size_fwd = API::get_convolution_forward_workspace_size(*CUDNN.id_c(), useable_algo_fwd.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
let mut workspace_size_bwd_filter = API::get_convolution_backward_filter_workspace_size(*CUDNN.id_c(), useable_algo_bwd_filter.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();
let mut workspace_size_bwd_data = API::get_convolution_backward_data_workspace_size(*CUDNN.id_c(), useable_algo_bwd_data.as_cudnn().unwrap(), *filter_desc.id_c(), *conv_desc.id_c(), *src_desc.id_c(), *dest_desc.id_c()).unwrap();

if workspace_size_fwd == 0 {
workspace_size_fwd = 8;
}
if workspace_size_bwd_filter == 0 {
workspace_size_bwd_filter = 8;
}
if workspace_size_bwd_data == 0 {
workspace_size_bwd_data = 8;
}

Ok(
::cudnn::utils::ConvolutionConfig::new(
Expand Down

0 comments on commit 4094b48

Please sign in to comment.