forked from malinjawi/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AffineGridGenerator.cpp
126 lines (104 loc) · 3.05 KB
/
AffineGridGenerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAConfig.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cudnn_affine_grid_generator_backward_native.h>
#include <ATen/ops/cudnn_affine_grid_generator_native.h>
#include <ATen/ops/empty.h>
#endif
#if !AT_CUDNN_ENABLED()
namespace at {
namespace native {
// See Note [ATen preprocessor philosophy]
Tensor cudnn_affine_grid_generator_forward(
const Tensor& theta,
int64_t N,
int64_t C,
int64_t H,
int64_t W) {
AT_ERROR(
"cudnn_affine_grid_generator_forward: ATen not compiled with cuDNN support");
}
Tensor cudnn_affine_grid_generator_backward(
const Tensor& grad_theta,
int64_t N,
int64_t C,
int64_t H,
int64_t W) {
AT_ERROR(
"cudnn_affine_grid_generator_backward: ATen not compiled with cuDNN support");
}
} // namespace native
} // namespace at
#else // AT_CUDNN_ENABLED()
#include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/cudnn/cudnn-wrapper.h>
#include <ATen/TensorUtils.h>
namespace at {
namespace native {
namespace {
void setSamplerDescriptor(
SpatialTransformerDescriptor& desc,
cudnnDataType_t dataType,
int N,
int C,
int H,
int W) {
int inputSize[4] = {N, C, H, W};
desc.set(dataType, 4, inputSize);
}
} // namespace
Tensor cudnn_affine_grid_generator_forward(
const Tensor& theta_t,
int64_t N,
int64_t C,
int64_t H,
int64_t W) {
auto theta_t_contig = theta_t.contiguous();
TensorArg theta{theta_t_contig, "theta", 1};
CheckedFrom c = "cudnn_affine_grid_generator_forward";
checkContiguous(c, theta);
checkSize(c, theta, {N, 2, 3});
auto grid_t = at::empty({0}, theta->options());
grid_t.resize_({N, H, W, 2});
auto dataType = getCudnnDataType(*theta);
SpatialTransformerDescriptor desc;
setSamplerDescriptor(desc, dataType, N, C, H, W);
AT_CUDNN_CHECK(cudnnSpatialTfGridGeneratorForward(
getCudnnHandle(), desc.desc(), theta->data_ptr(), grid_t.data_ptr()));
return grid_t;
}
Tensor cudnn_affine_grid_generator_backward(
const Tensor& grad_grid_t,
int64_t N,
int64_t C,
int64_t H,
int64_t W) {
auto grad_grid_contig = grad_grid_t.contiguous();
TensorArg grad_grid{grad_grid_contig, "grad_grid", 1};
CheckedFrom c = "cudnn_affine_grid_generator_backward";
checkContiguous(c, grad_grid);
checkSize(c, grad_grid, {N, H, W, 2});
auto grad_theta_t = at::empty({0}, grad_grid->options());
grad_theta_t.resize_({N, 2, 3});
auto dataType = getCudnnDataType(grad_theta_t);
SpatialTransformerDescriptor desc;
setSamplerDescriptor(desc, dataType, N, C, H, W);
AT_CUDNN_CHECK(cudnnSpatialTfGridGeneratorBackward(
getCudnnHandle(),
desc.desc(),
grad_grid->data_ptr(),
grad_theta_t.data_ptr()));
return grad_theta_t;
}
} // namespace native
} // namespace at
#endif // AT_CUDNN_ENABLED()