mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add avg_pool3d backward pass for MPS (#159089)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159089 Approved by: https://github.com/malfet
This commit is contained in:
parent
57ab39f7e4
commit
b59b61a099
|
|
@ -448,6 +448,65 @@ void avg_pool_3d_input_iter(
|
|||
*output = value_sum / static_cast<T>(divisor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void avg_pool_backward_3d_input_iter(
|
||||
device AtomicType_t<T>* grad_input,
|
||||
constant T* grad_output,
|
||||
constant int32_t* grad_input_sizes,
|
||||
constant int32_t* grad_input_strides,
|
||||
int32_t grad_input_leading_offset,
|
||||
thread int32_t (&pooling_dim_indices)[3],
|
||||
constant int32_t* kernel_size,
|
||||
constant int32_t* stride,
|
||||
constant int32_t* padding,
|
||||
bool count_include_pad,
|
||||
bool has_divisor_override,
|
||||
int32_t divisor_override) {
|
||||
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
|
||||
grad_input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
|
||||
grad_input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
|
||||
grad_input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
|
||||
auto divisor = has_divisor_override
|
||||
? divisor_override
|
||||
: (bounds0.count) * (bounds1.count) * (bounds2.count);
|
||||
auto grad_val = *grad_output / static_cast<T>(divisor);
|
||||
auto size12 = grad_input_sizes[1] * grad_input_sizes[2];
|
||||
|
||||
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
|
||||
auto offset0 = grad_input_strides[0] * i0;
|
||||
|
||||
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
|
||||
auto offset1 = grad_input_strides[1] * i1;
|
||||
|
||||
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
|
||||
auto offset2 = grad_input_strides[2] * i2;
|
||||
auto pool_offset = offset0 + offset1 + offset2;
|
||||
|
||||
AtomicType<T>::atomic_add(
|
||||
grad_input, grad_input_leading_offset + pool_offset, grad_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel computes one element of the output per kernel call.
|
||||
template <typename T>
|
||||
kernel void avg_pool(
|
||||
|
|
@ -500,6 +559,57 @@ kernel void avg_pool(
|
|||
params.divisor_override);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void avg_pool_backward(
|
||||
device AtomicType_t<T>* grad_input [[buffer(0)]],
|
||||
constant T* grad_output [[buffer(1)]],
|
||||
constant AvgPoolingParams<5>& params [[buffer(2)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto grad_input_sizes = params.input_sizes.data();
|
||||
auto grad_input_strides = params.input_strides.data();
|
||||
auto grad_output_sizes = params.output_sizes.data();
|
||||
auto grad_output_strides = params.output_strides.data();
|
||||
auto kernel_size = params.kernel_size.data();
|
||||
auto stride = params.stride.data();
|
||||
auto padding = params.padding.data();
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
// This buffer keeps track of the pooling dimension indices of this thread's
|
||||
// element of the output. We need to fill it with the proper values below.
|
||||
int32_t pooling_dim_indices[3];
|
||||
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
grad_output_sizes,
|
||||
grad_output_strides,
|
||||
/*indices_strides=*/nullptr,
|
||||
grad_input_strides,
|
||||
pooling_dim_indices,
|
||||
dims,
|
||||
leading_dims,
|
||||
/*return_indices=*/false,
|
||||
tid);
|
||||
|
||||
grad_output += offsets.output;
|
||||
grad_input_sizes += leading_dims;
|
||||
grad_input_strides += leading_dims;
|
||||
|
||||
avg_pool_backward_3d_input_iter<T>(
|
||||
grad_input,
|
||||
grad_output,
|
||||
grad_input_sizes,
|
||||
grad_input_strides,
|
||||
offsets.input_leading,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
params.count_include_pad,
|
||||
params.has_divisor_override,
|
||||
params.divisor_override);
|
||||
}
|
||||
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
|
|
@ -521,13 +631,20 @@ kernel void avg_pool(
|
|||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
|
||||
#define REGISTER_POOL_BACKWARD_OP(DTYPE) \
|
||||
template [[host_name("max_pool_backward_" #DTYPE)]] \
|
||||
kernel void max_pool_backward<DTYPE>( \
|
||||
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
|
||||
constant DTYPE * grad_output_ [[buffer(1)]], \
|
||||
constant int64_t* grad_indices_ [[buffer(2)]], \
|
||||
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_backward_" #DTYPE)]] \
|
||||
kernel void avg_pool_backward<DTYPE>( \
|
||||
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
|
||||
constant DTYPE * grad_output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_POOL_OP(float);
|
||||
|
|
@ -540,6 +657,6 @@ REGISTER_POOL_OP(char);
|
|||
REGISTER_POOL_OP(uchar);
|
||||
REGISTER_POOL_OP(bool);
|
||||
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(float);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(half);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
|
||||
REGISTER_POOL_BACKWARD_OP(float);
|
||||
REGISTER_POOL_BACKWARD_OP(half);
|
||||
REGISTER_POOL_BACKWARD_OP(bfloat);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include <ATen/ops/avg_pool2d_backward.h>
|
||||
#include <ATen/ops/avg_pool2d_backward_native.h>
|
||||
#include <ATen/ops/avg_pool2d_native.h>
|
||||
#include <ATen/ops/avg_pool3d_backward_native.h>
|
||||
#include <ATen/ops/avg_pool3d_native.h>
|
||||
#include <ATen/ops/max_pool2d_backward_native.h>
|
||||
#include <ATen/ops/max_pool2d_native.h>
|
||||
|
|
@ -725,6 +726,64 @@ static void avg_pool_out_mps_template(const Tensor& output,
|
|||
});
|
||||
}
|
||||
|
||||
static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef _kernel_size,
|
||||
IntArrayRef _stride,
|
||||
IntArrayRef _padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
std::optional<int64_t> divisor_override,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto [dims, _, kernel_size, stride, padding, __] =
|
||||
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
grad_input.resize_(input.sizes(), memory_format);
|
||||
grad_input.fill_(0);
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = grad_output.numel();
|
||||
|
||||
AvgPoolingParams<5> params;
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
params.count_include_pad = count_include_pad;
|
||||
params.has_divisor_override = divisor_override.has_value();
|
||||
if (divisor_override.has_value()) {
|
||||
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
|
||||
}
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
|
||||
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
|
||||
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
|
||||
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
|
||||
}
|
||||
|
||||
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto PSO = lib.getPipelineStateForFunc("avg_pool_backward_" + scalarToMetalTypeString(input));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(PSO, op_name, {grad_output});
|
||||
[computeEncoder setComputePipelineState:PSO];
|
||||
mtl_setArgs(computeEncoder, grad_input, grad_output, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(PSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor mps_max_pool2d(const Tensor& input,
|
||||
|
|
@ -1083,4 +1142,26 @@ TORCH_IMPL_FUNC(avg_pool3d_out_mps)
|
|||
"avg_pool3d");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool3d_backward_out_mps)(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
std::optional<int64_t> divisor_override,
|
||||
const Tensor& grad_input) {
|
||||
mps::avg_pool_backward_out_mps_template(grad_input,
|
||||
input,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
/*pooling_dims=*/3,
|
||||
"avg_pool3d_backward");
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -12378,6 +12378,7 @@
|
|||
dispatch:
|
||||
CPU: avg_pool3d_backward_out_cpu
|
||||
CUDA: avg_pool3d_backward_out_cuda
|
||||
MPS: avg_pool3d_backward_out_mps
|
||||
MkldnnCPU: mkldnn_avg_pool3d_backward_out
|
||||
|
||||
- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
|
||||
|
|
|
|||
|
|
@ -9570,7 +9570,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
def test_avg_pool3d_backward(self):
|
||||
def fn(a, b):
|
||||
return aten.avg_pool3d_backward(
|
||||
|
|
@ -9592,7 +9591,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
],
|
||||
)
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
@skip_if_halide # compiles for 5+ minutes
|
||||
def test_avg_pool3d_backward2(self):
|
||||
def fn(a, b):
|
||||
|
|
@ -9615,7 +9613,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
],
|
||||
)
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
def test_avg_pool3d_backward3(self):
|
||||
def fn(a, b):
|
||||
return aten.avg_pool3d_backward(
|
||||
|
|
@ -9639,7 +9636,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
def test_avg_pool3d_backward4(self):
|
||||
def fn(a, b):
|
||||
return aten.avg_pool3d_backward(
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_angle(AtenTensorHandle self, Ate
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
|
||||
|
|
|
|||
|
|
@ -4064,11 +4064,6 @@ module_db: list[ModuleInfo] = [
|
|||
),
|
||||
ModuleInfo(torch.nn.LocalResponseNorm,
|
||||
module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
|
||||
skips=(
|
||||
# uses avg_pool3d which is not supported on MPS backend
|
||||
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'),
|
||||
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous_tensors'),
|
||||
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous'),)
|
||||
),
|
||||
ModuleInfo(torch.nn.LayerNorm,
|
||||
module_inputs_func=module_inputs_torch_nn_LayerNorm,
|
||||
|
|
|
|||
|
|
@ -817,7 +817,6 @@ if torch.backends.mps.is_available():
|
|||
"round": [torch.float16],
|
||||
# topk fails with duplicate indices
|
||||
"topk": [torch.float16],
|
||||
"nn.functional.avg_pool3d": [torch.float32],
|
||||
}
|
||||
|
||||
SKIPLIST_GRAD = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user