Add avg_pool3d backward pass for MPS (#159089)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159089
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler 2025-08-04 18:48:29 -05:00 committed by PyTorch MergeBot
parent 57ab39f7e4
commit b59b61a099
7 changed files with 204 additions and 14 deletions

View File

@ -448,6 +448,65 @@ void avg_pool_3d_input_iter(
*output = value_sum / static_cast<T>(divisor);
}
template <typename T>
void avg_pool_backward_3d_input_iter(
device AtomicType_t<T>* grad_input,
constant T* grad_output,
constant int32_t* grad_input_sizes,
constant int32_t* grad_input_strides,
int32_t grad_input_leading_offset,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
bool count_include_pad,
bool has_divisor_override,
int32_t divisor_override) {
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count) * (bounds2.count);
auto grad_val = *grad_output / static_cast<T>(divisor);
auto size12 = grad_input_sizes[1] * grad_input_sizes[2];
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = grad_input_strides[0] * i0;
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
auto offset1 = grad_input_strides[1] * i1;
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
auto offset2 = grad_input_strides[2] * i2;
auto pool_offset = offset0 + offset1 + offset2;
AtomicType<T>::atomic_add(
grad_input, grad_input_leading_offset + pool_offset, grad_val);
}
}
}
}
// Kernel computes one element of the output per kernel call.
template <typename T>
kernel void avg_pool(
@ -500,6 +559,57 @@ kernel void avg_pool(
params.divisor_override);
}
template <typename T>
kernel void avg_pool_backward(
device AtomicType_t<T>* grad_input [[buffer(0)]],
constant T* grad_output [[buffer(1)]],
constant AvgPoolingParams<5>& params [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto grad_input_sizes = params.input_sizes.data();
auto grad_input_strides = params.input_strides.data();
auto grad_output_sizes = params.output_sizes.data();
auto grad_output_strides = params.output_strides.data();
auto kernel_size = params.kernel_size.data();
auto stride = params.stride.data();
auto padding = params.padding.data();
auto leading_dims = dims - pooling_dims;
// This buffer keeps track of the pooling dimension indices of this thread's
// element of the output. We need to fill it with the proper values below.
int32_t pooling_dim_indices[3];
PoolOffsets offsets = find_pool_offsets(
grad_output_sizes,
grad_output_strides,
/*indices_strides=*/nullptr,
grad_input_strides,
pooling_dim_indices,
dims,
leading_dims,
/*return_indices=*/false,
tid);
grad_output += offsets.output;
grad_input_sizes += leading_dims;
grad_input_strides += leading_dims;
avg_pool_backward_3d_input_iter<T>(
grad_input,
grad_output,
grad_input_sizes,
grad_input_strides,
offsets.input_leading,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
}
#define REGISTER_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
@ -521,13 +631,20 @@ kernel void avg_pool(
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
#define REGISTER_POOL_BACKWARD_OP(DTYPE) \
template [[host_name("max_pool_backward_" #DTYPE)]] \
kernel void max_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output_ [[buffer(1)]], \
constant int64_t* grad_indices_ [[buffer(2)]], \
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_backward_" #DTYPE)]] \
kernel void avg_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_POOL_OP(float);
@ -540,6 +657,6 @@ REGISTER_POOL_OP(char);
REGISTER_POOL_OP(uchar);
REGISTER_POOL_OP(bool);
REGISTER_MAX_POOL_BACKWARD_OP(float);
REGISTER_MAX_POOL_BACKWARD_OP(half);
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
REGISTER_POOL_BACKWARD_OP(float);
REGISTER_POOL_BACKWARD_OP(half);
REGISTER_POOL_BACKWARD_OP(bfloat);

View File

@ -14,6 +14,7 @@
#include <ATen/ops/avg_pool2d_backward.h>
#include <ATen/ops/avg_pool2d_backward_native.h>
#include <ATen/ops/avg_pool2d_native.h>
#include <ATen/ops/avg_pool3d_backward_native.h>
#include <ATen/ops/avg_pool3d_native.h>
#include <ATen/ops/max_pool2d_backward_native.h>
#include <ATen/ops/max_pool2d_native.h>
@ -725,6 +726,64 @@ static void avg_pool_out_mps_template(const Tensor& output,
});
}
static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
const Tensor& input,
const Tensor& grad_output,
IntArrayRef _kernel_size,
IntArrayRef _stride,
IntArrayRef _padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, _, kernel_size, stride, padding, __] =
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);
const auto memory_format = input.suggest_memory_format();
grad_input.resize_(input.sizes(), memory_format);
grad_input.fill_(0);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = grad_output.numel();
AvgPoolingParams<5> params;
params.dims = dims;
params.pooling_dims = pooling_dims;
params.count_include_pad = count_include_pad;
params.has_divisor_override = divisor_override.has_value();
if (divisor_override.has_value()) {
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
}
for (const auto dim : c10::irange(dims)) {
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
}
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto PSO = lib.getPipelineStateForFunc("avg_pool_backward_" + scalarToMetalTypeString(input));
getMPSProfiler().beginProfileKernel(PSO, op_name, {grad_output});
[computeEncoder setComputePipelineState:PSO];
mtl_setArgs(computeEncoder, grad_input, grad_output, params);
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
getMPSProfiler().endProfileKernel(PSO);
}
});
}
} // namespace mps
Tensor mps_max_pool2d(const Tensor& input,
@ -1083,4 +1142,26 @@ TORCH_IMPL_FUNC(avg_pool3d_out_mps)
"avg_pool3d");
}
TORCH_IMPL_FUNC(avg_pool3d_backward_out_mps)(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const Tensor& grad_input) {
mps::avg_pool_backward_out_mps_template(grad_input,
input,
grad_output,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
/*pooling_dims=*/3,
"avg_pool3d_backward");
}
} // namespace at::native

View File

@ -12378,6 +12378,7 @@
dispatch:
CPU: avg_pool3d_backward_out_cpu
CUDA: avg_pool3d_backward_out_cuda
MPS: avg_pool3d_backward_out_mps
MkldnnCPU: mkldnn_avg_pool3d_backward_out
- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor

View File

@ -9570,7 +9570,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
)
assertGeneratedKernelCountEqual(self, 0)
@xfail_if_mps_unimplemented
def test_avg_pool3d_backward(self):
def fn(a, b):
return aten.avg_pool3d_backward(
@ -9592,7 +9591,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
],
)
@xfail_if_mps_unimplemented
@skip_if_halide # compiles for 5+ minutes
def test_avg_pool3d_backward2(self):
def fn(a, b):
@ -9615,7 +9613,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
],
)
@xfail_if_mps_unimplemented
def test_avg_pool3d_backward3(self):
def fn(a, b):
return aten.avg_pool3d_backward(
@ -9639,7 +9636,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
)
assertGeneratedKernelCountEqual(self, 1)
@xfail_if_mps_unimplemented
def test_avg_pool3d_backward4(self):
def fn(a, b):
return aten.avg_pool3d_backward(

View File

@ -39,6 +39,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_angle(AtenTensorHandle self, Ate
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);

View File

@ -4064,11 +4064,6 @@ module_db: list[ModuleInfo] = [
),
ModuleInfo(torch.nn.LocalResponseNorm,
module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
skips=(
# uses avg_pool3d which is not supported on MPS backend
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'),
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous_tensors'),
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous'),)
),
ModuleInfo(torch.nn.LayerNorm,
module_inputs_func=module_inputs_torch_nn_LayerNorm,

View File

@ -817,7 +817,6 @@ if torch.backends.mps.is_available():
"round": [torch.float16],
# topk fails with duplicate indices
"topk": [torch.float16],
"nn.functional.avg_pool3d": [torch.float32],
}
SKIPLIST_GRAD = {