forked from malinjawi/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DilatedMaxPool3d.cpp
299 lines (262 loc) · 9.99 KB
/
DilatedMaxPool3d.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/Pool.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
#include <ATen/ops/max_pool3d_with_indices_native.h>
#endif
namespace at::native {
namespace {
void max_pool3d_with_indices_out_cpu_template(
Tensor& output,
Tensor& indices,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
"max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
"max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
const int dH = stride.empty() ? kH :
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
"max_pool3d: padding must either be a single int, or a tuple of three ints");
const int pT = safe_downcast<int, int64_t>(padding[0]);
const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
"max_pool3d: dilation must be either a single int, or a tuple of three ints");
const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
const auto memory_format = input.suggest_memory_format();
if (memory_format == at::MemoryFormat::ChannelsLast3d) {
TORCH_CHECK(input.ndimension() == 5,
"non-empty 5D (batch mode) tensor expected for input with channels_last_3d layout");
} else if (memory_format == at::MemoryFormat::Contiguous) {
TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input");
} else {
TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous");
}
const int64_t nslices = input.size(-4);
const int64_t itime = input.size(-3);
const int64_t iheight = input.size(-2);
const int64_t iwidth = input.size(-1);
const int64_t otime = pooling_output_shape<int64_t>(itime, kT, pT, dT, dilationT, ceil_mode);
const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
pool3d_shape_check(
input,
nslices,
kT, kH, kW,
dT, dH, dW,
pT, pH, pW,
dilationT, dilationH, dilationW,
itime, iheight, iwidth,
otime, oheight, owidth,
"max_pool3d_with_indices_out_cpu_template()");
if (input.dim() == 4) { /* non-batch mode */
/* resize output */
output.resize_({nslices, otime, oheight, owidth});
/* indices will contain ti,i,j locations for each output point */
indices.resize_({nslices, otime, oheight, owidth});
}
else { /* batch mode */
const int64_t nbatch = input.size(0);
/* resize output */
output.resize_({nbatch, nslices, otime, oheight, owidth}, memory_format);
/* indices will contain ti,i,j locations for each output point */
indices.resize_({nbatch, nslices, otime, oheight, owidth}, memory_format);
}
max_pool3d_kernel(
kCPU, output, indices, input,
kW, kH, kT,
dW, dH, dT,
pW, pH, pT,
dilationW, dilationH, dilationT);
}
Tensor& max_pool3d_with_indices_backward_out_cpu_template(
Tensor& gradInput,
const Tensor& gradOutput,
const Tensor& input,
const Tensor& indices,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3,
"max_pool3d: kernel_size must either be a single int, or a tuple of three ints")
const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
"max_pool3d: stride must either be omitted, a single int, or a tuple of three ints")
const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
const int dH = stride.empty() ? kH :
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[1]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dT : safe_downcast<int, int64_t>(stride[2]);
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
"max_pool3d: padding must either be a single int, or a tuple of three ints");
const int pT = safe_downcast<int, int64_t>(padding[0]);
const int pH = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[1]);
const int pW = padding.size() == 1 ? pT : safe_downcast<int, int64_t>(padding[2]);
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3,
"max_pool3d: dilation must be either a single int, or a tuple of three ints");
const int dilationT = safe_downcast<int, int64_t>(dilation[0]);
const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[1]);
const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast<int, int64_t>(dilation[2]);
TORCH_CHECK(input.dtype() == gradOutput.dtype(),
"expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
const auto memory_format = input.suggest_memory_format();
if (memory_format == at::MemoryFormat::ChannelsLast3d) {
TORCH_CHECK(input.ndimension() == 5,
"non-empty 5D (batch mode) tensor expected for input with channels_last_3d layout");
} else if (memory_format == at::MemoryFormat::Contiguous) {
TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input");
} else {
TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous");
}
const int64_t nslices = input.size(-4);
const int64_t itime = input.size(-3);
const int64_t iheight = input.size(-2);
const int64_t iwidth = input.size(-1);
/* resize */
gradInput.resize_(input.sizes(), memory_format);
gradInput.zero_();
const int64_t otime = gradOutput.size(-3);
const int64_t oheight = gradOutput.size(-2);
const int64_t owidth = gradOutput.size(-1);
max_pool3d_backward_shape_check(
input,
gradOutput,
indices,
nslices,
kT, kH, kW,
dT, dH, dW,
pT, pH, pW,
dilationT, dilationH, dilationW,
itime, iheight, iwidth,
otime, oheight, owidth,
"max_pool3d_with_indices_backward_out_cpu_template()");
max_pool3d_backward_kernel(
kCPU, gradInput,
gradOutput, indices);
return gradInput;
}
} // namespace
std::tuple<Tensor&, Tensor&> max_pool3d_with_indices_out_cpu(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
Tensor& output,
Tensor& indices)
{
max_pool3d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
std::tuple<Tensor, Tensor> max_pool3d_with_indices_cpu(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
NoNamesGuard guard;
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
max_pool3d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
guard.reset();
namedinference::propagate_names(output, input);
namedinference::propagate_names(indices, input);
return std::tuple<Tensor, Tensor>(output, indices);
}
Tensor& max_pool3d_with_indices_backward_out_cpu(const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices,
Tensor& gradInput)
{
max_pool3d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
Tensor max_pool3d_with_indices_backward_cpu(
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
auto gradInput = at::empty({0}, input.options());
max_pool3d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
DEFINE_DISPATCH(max_pool3d_kernel);
DEFINE_DISPATCH(max_pool3d_backward_kernel);
} // namespace at::native