mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add max_pool3d backward pass for MPS (#157498)
Note on backward precision over fp16: A float16 number has 10 bits of mantissa, 5 bits of exponent, and 1 bit for the sign. If the sign bit is positive, then with a mantissa $m$ and exponent $e$ represented in base 10, the number that the float16 format represents is $(1 + m / 1024) \exp2(e)$. ([source](https://en.wikipedia.org/wiki/Half-precision_floating-point_format)) Consider adding two numbers $a$ and $b$ which have arbitrary mantissas, and say their exponents are $e_a = 1$ (so $2 \le a \lt 4$) and $e_b=-3$ (so $0.175 \le b \lt 0.25$). Assume that the result has the same exponent as $a$. Since the exponents differ by 4, we'll effectively need to truncate the 4 rightmost bits of $b$'s mantissa, which would introduce a maximum error on the order of $(2^4 / 1024) \exp2(-3) \approx 0.002$. The error is nearly the same if $e_b = -2$ (so $0.25 \le b \lt 0.5$), where the 3 rightmost bits are truncated, giving a maximum error on the order of $(2^3 / 1024) \exp2(-2) \approx 0.002$. Same for $e_b=-1$. So if we're adding up nine different numbers that all have exponents -3, -2, or -1, and they sum to a number with exponent 1, then we would expect a maximum error of several times greater than 0.002. In my comments above, summing those particular nine numbers in different ways gave results that ranged between 3.1816 and 3.1758, a difference of $0.0058 \approx 2.9 * 0.002$. That's within the acceptable bounds, and we can safely just increase the error tolerance used in test_output_grad_match for the case of max_pool3d_backward with float16. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157498 Approved by: https://github.com/malfet
This commit is contained in:
parent
63a96eaeb8
commit
510c398a4f
|
|
@ -27,3 +27,14 @@ struct PoolingParams {
|
|||
_ARRAY_NS::array<int64_t, N - 2> padding;
|
||||
_ARRAY_NS::array<int64_t, N - 2> dilation;
|
||||
};
|
||||
|
||||
template <unsigned N = 5>
|
||||
struct PoolingBackwardParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
_ARRAY_NS::array<int64_t, N> grad_input_sizes;
|
||||
_ARRAY_NS::array<int64_t, N> grad_input_strides;
|
||||
_ARRAY_NS::array<int64_t, N> grad_output_sizes;
|
||||
_ARRAY_NS::array<int64_t, N> grad_output_strides;
|
||||
_ARRAY_NS::array<int64_t, N> indices_strides;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
#include <ATen/native/mps/kernels/Pooling.h>
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <metal_array>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
// Iterates through all the input elements that this kernel needs to
|
||||
// apply max to. Specialized for 3 pooling dimensions.
|
||||
|
|
@ -83,6 +86,50 @@ void max_pool_3d_input_iter(
|
|||
*indices = max_index;
|
||||
}
|
||||
|
||||
struct PoolOffsets {
|
||||
int64_t output;
|
||||
int64_t indices;
|
||||
int64_t input_leading;
|
||||
|
||||
PoolOffsets() : output(0), indices(0), input_leading(0) {}
|
||||
};
|
||||
|
||||
// Finds the offset of the output element that a forward pass thread will
|
||||
// calculate, `output[N, C, d, h, w]`. Also, find the offset of the input for
|
||||
// the leading dim indices, `input[N, C]`. Optionally, keep track of the output
|
||||
// pooling dimension indices, `[d, h , w]`.
|
||||
PoolOffsets find_pool_offsets(
|
||||
constant int64_t* output_sizes,
|
||||
constant int64_t* output_strides,
|
||||
constant int64_t* indices_strides,
|
||||
constant int64_t* input_strides,
|
||||
device int64_t* work_pooling_dim_indices,
|
||||
int32_t dims,
|
||||
int32_t leading_dims,
|
||||
uint tid) {
|
||||
int64_t output_idx = static_cast<int64_t>(tid);
|
||||
PoolOffsets offsets;
|
||||
|
||||
for (int64_t dim = dims - 1; dim >= 0; dim--) {
|
||||
int64_t dim_idx = output_idx % (output_sizes[dim]);
|
||||
offsets.output += output_strides[dim] * dim_idx;
|
||||
offsets.indices += indices_strides[dim] * dim_idx;
|
||||
|
||||
if (dim < leading_dims) {
|
||||
offsets.input_leading += input_strides[dim] * dim_idx;
|
||||
} else {
|
||||
// Keep track of pooling dimension indices of the output element, so we
|
||||
// can use them in the input iteration later on.
|
||||
if (work_pooling_dim_indices != nullptr) {
|
||||
work_pooling_dim_indices[dim - leading_dims] = dim_idx;
|
||||
}
|
||||
}
|
||||
output_idx = output_idx / output_sizes[dim];
|
||||
}
|
||||
|
||||
return offsets;
|
||||
}
|
||||
|
||||
// Kernel computes one element of the output per kernel call.
|
||||
template <typename T>
|
||||
kernel void max_pool(
|
||||
|
|
@ -113,32 +160,20 @@ kernel void max_pool(
|
|||
// element of the output. We need to fill it with the proper values below.
|
||||
device int64_t* work_pooling_dim_indices =
|
||||
work_pooling_dim_indices_ + tid * pooling_dims;
|
||||
int64_t output_idx = static_cast<int64_t>(tid);
|
||||
int64_t output_offset = 0;
|
||||
int64_t indices_offset = 0;
|
||||
int64_t input_leading_offset = 0;
|
||||
|
||||
// First, find the offset of the output element this thread will calculate,
|
||||
// `output[N, C, d, h, w]`. Also, find the offset of the input for the leading
|
||||
// dim indices, `input[N, C]` and keep track of the pooling dimension indices,
|
||||
// `[d, h , w]`.
|
||||
for (int64_t dim = dims - 1; dim >= 0; dim--) {
|
||||
int64_t dim_idx = output_idx % (output_sizes[dim]);
|
||||
output_offset += output_strides[dim] * dim_idx;
|
||||
indices_offset += indices_strides[dim] * dim_idx;
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
indices_strides,
|
||||
input_strides,
|
||||
work_pooling_dim_indices,
|
||||
dims,
|
||||
leading_dims,
|
||||
tid);
|
||||
|
||||
if (dim < leading_dims) {
|
||||
input_leading_offset += input_strides[dim] * dim_idx;
|
||||
} else {
|
||||
// Keep track of pooling dimension indices of the output element, so we
|
||||
// can use them in the input iteration later on.
|
||||
work_pooling_dim_indices[dim - leading_dims] = dim_idx;
|
||||
}
|
||||
output_idx = output_idx / output_sizes[dim];
|
||||
}
|
||||
output += output_offset;
|
||||
indices += indices_offset;
|
||||
input += input_leading_offset;
|
||||
output += offsets.output;
|
||||
indices += offsets.indices;
|
||||
input += offsets.input_leading;
|
||||
|
||||
max_pool_3d_input_iter<T>(
|
||||
input,
|
||||
|
|
@ -153,6 +188,69 @@ kernel void max_pool(
|
|||
dilation);
|
||||
}
|
||||
|
||||
// Finds the element in the grad input which corresponds to the index into the
|
||||
// pool, and then adds the grad output element to it.
|
||||
template <typename T>
|
||||
void max_pool_backward_impl(
|
||||
device AtomicType_t<T>* grad_input,
|
||||
T grad_output_element,
|
||||
int32_t input_index,
|
||||
constant int64_t* grad_input_sizes,
|
||||
constant int64_t* grad_input_strides,
|
||||
int32_t grad_input_leading_offset,
|
||||
int32_t pooling_dims) {
|
||||
int32_t size_prod = 1;
|
||||
int32_t pool_offset = 0;
|
||||
|
||||
for (int32_t dim = pooling_dims - 1; dim >= 0; dim--) {
|
||||
int32_t next_size_prod = grad_input_sizes[dim] * size_prod;
|
||||
pool_offset +=
|
||||
grad_input_strides[dim] * ((input_index % next_size_prod) / size_prod);
|
||||
size_prod *= grad_input_sizes[dim];
|
||||
}
|
||||
|
||||
AtomicType<T>::atomic_add(
|
||||
grad_input, grad_input_leading_offset + pool_offset, grad_output_element);
|
||||
}
|
||||
|
||||
// Kernel computes one element of the grad input per kernel call.
|
||||
template <typename T>
|
||||
kernel void max_pool_backward(
|
||||
device AtomicType_t<T>* grad_input [[buffer(0)]],
|
||||
constant T* grad_output [[buffer(1)]],
|
||||
constant int64_t* indices [[buffer(2)]],
|
||||
constant PoolingBackwardParams<5>& params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
int32_t pooling_dims = params.pooling_dims;
|
||||
int32_t dims = params.dims;
|
||||
constant int64_t* grad_input_sizes = params.grad_input_sizes.data();
|
||||
constant int64_t* grad_input_strides = params.grad_input_strides.data();
|
||||
constant int64_t* grad_output_sizes = params.grad_output_sizes.data();
|
||||
constant int64_t* grad_output_strides = params.grad_output_strides.data();
|
||||
constant int64_t* indices_strides = params.indices_strides.data();
|
||||
|
||||
int32_t leading_dims = dims - pooling_dims;
|
||||
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
grad_output_sizes,
|
||||
grad_output_strides,
|
||||
indices_strides,
|
||||
grad_input_strides,
|
||||
nullptr,
|
||||
dims,
|
||||
leading_dims,
|
||||
tid);
|
||||
|
||||
max_pool_backward_impl<T>(
|
||||
grad_input,
|
||||
grad_output[offsets.output],
|
||||
indices[offsets.indices],
|
||||
grad_input_sizes + leading_dims,
|
||||
grad_input_strides + leading_dims,
|
||||
offsets.input_leading,
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
#define REGISTER_MAX_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant void* input_ [[buffer(0)]], \
|
||||
|
|
@ -162,6 +260,15 @@ kernel void max_pool(
|
|||
constant PoolingParams<5>& params [[buffer(4)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
|
||||
template [[host_name("max_pool_backward_" #DTYPE)]] \
|
||||
kernel void max_pool_backward<DTYPE>( \
|
||||
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
|
||||
constant DTYPE * grad_output_ [[buffer(1)]], \
|
||||
constant int64_t* grad_indices_ [[buffer(2)]], \
|
||||
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_MAX_POOL_OP(float);
|
||||
REGISTER_MAX_POOL_OP(half);
|
||||
REGISTER_MAX_POOL_OP(int);
|
||||
|
|
@ -170,6 +277,11 @@ REGISTER_MAX_POOL_OP(short);
|
|||
REGISTER_MAX_POOL_OP(char);
|
||||
REGISTER_MAX_POOL_OP(uchar);
|
||||
REGISTER_MAX_POOL_OP(bool);
|
||||
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(float);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_MAX_POOL_OP(bfloat);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#include <ATen/ops/max_pool2d_native.h>
|
||||
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
|
||||
#include <ATen/ops/max_pool2d_with_indices_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_native.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -270,16 +271,16 @@ static IntArrayRef tensor_to_intarrayref(const Tensor& tensor) {
|
|||
return IntArrayRef(data_ptr, length);
|
||||
}
|
||||
|
||||
static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
||||
const Tensor& indices,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
using PoolSizes = std::tuple<int32_t, Tensor, Tensor, Tensor, Tensor, Tensor>;
|
||||
|
||||
static PoolSizes process_pool_sizes(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
TORCH_INTERNAL_ASSERT(pooling_dims == 1 || pooling_dims == 2 || pooling_dims == 3);
|
||||
|
||||
const int32_t dims = input.dim();
|
||||
|
|
@ -387,9 +388,27 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
|||
|
||||
t_output_size.slice(0, leading_dims) = t_output_pooling_size;
|
||||
|
||||
return std::tuple<int32_t, Tensor, Tensor, Tensor, Tensor, Tensor>(
|
||||
dims, t_output_size, t_kernel_size, t_stride, t_padding, t_dilation);
|
||||
}
|
||||
|
||||
static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
||||
const Tensor& indices,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto [dims, t_output_size, t_kernel_size, t_stride, t_padding, t_dilation] =
|
||||
process_pool_sizes(input, kernel_size, stride, padding, dilation, ceil_mode, pooling_dims, op_name);
|
||||
|
||||
IntArrayRef output_size = tensor_to_intarrayref(t_output_size);
|
||||
output.resize_(output_size);
|
||||
indices.resize_(output_size);
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
output.resize_(output_size, memory_format);
|
||||
indices.resize_(output_size, memory_format);
|
||||
|
||||
auto iter = TensorIteratorConfig().add_output(output).resize_outputs(false).check_all_same_dtype(false).build();
|
||||
|
||||
|
|
@ -436,6 +455,52 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
|||
});
|
||||
}
|
||||
|
||||
static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
|
||||
const Tensor& indices,
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto [dims, t_output_size, t_kernel_size, t_stride, t_padding, t_dilation] =
|
||||
process_pool_sizes(input, kernel_size, stride, padding, dilation, ceil_mode, pooling_dims, op_name);
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
grad_input.resize_(input.sizes(), memory_format);
|
||||
grad_input.fill_(0);
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = grad_output.numel();
|
||||
PoolingBackwardParams<5> params;
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
memcpy(params.grad_input_sizes.data(), grad_input.sizes().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.grad_input_strides.data(), grad_input.strides().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.grad_output_strides.data(), grad_output.strides().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.grad_output_sizes.data(), grad_output.sizes().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t));
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto maxPoolPSO = lib.getPipelineStateForFunc("max_pool_backward_" + scalarToMetalTypeString(input));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(maxPoolPSO, op_name, {input});
|
||||
[computeEncoder setComputePipelineState:maxPoolPSO];
|
||||
mtl_setArgs(computeEncoder, grad_input, grad_output, indices, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, maxPoolPSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(maxPoolPSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void avg_pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const std::optional<Tensor>& grad_output_opt,
|
||||
|
|
@ -738,6 +803,52 @@ std::tuple<Tensor, Tensor> max_pool3d_with_indices_mps(const Tensor& input,
|
|||
return std::tuple<Tensor, Tensor>(output, indices);
|
||||
}
|
||||
|
||||
Tensor& max_pool3d_with_indices_backward_out_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const Tensor& indices,
|
||||
Tensor& grad_input) {
|
||||
mps::max_pool_with_indices_backward_out_mps_template(grad_input,
|
||||
indices,
|
||||
input,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
/*pooling_dims=*/3,
|
||||
"max_pool3d_backward");
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode,
|
||||
const Tensor& indices) {
|
||||
auto grad_input = at::empty({0}, input.options());
|
||||
mps::max_pool_with_indices_backward_out_mps_template(grad_input,
|
||||
indices,
|
||||
input,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
ceil_mode,
|
||||
/*pooling_dims=*/3,
|
||||
"max_pool3d_backward");
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
int64_t kH,
|
||||
|
|
|
|||
|
|
@ -12442,12 +12442,14 @@
|
|||
dispatch:
|
||||
CPU: max_pool3d_with_indices_backward_out_cpu
|
||||
CUDA: max_pool3d_with_indices_backward_out_cuda
|
||||
MPS: max_pool3d_with_indices_backward_out_mps
|
||||
|
||||
- func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_pool3d_with_indices_backward_cpu
|
||||
CUDA: max_pool3d_with_indices_backward_cuda
|
||||
MPS: max_pool3d_with_indices_backward_mps
|
||||
|
||||
- func: max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
|
|
|||
|
|
@ -1649,7 +1649,6 @@ torch.cuda.synchronize()
|
|||
def test_MaxPool2d_indices(self, device, dtype):
|
||||
self._test_maxpool_indices(2, device=device, dtype=dtype)
|
||||
|
||||
@expectedFailureMPS
|
||||
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(torch.float)
|
||||
def test_MaxPool3d_indices(self, device, dtype):
|
||||
|
|
@ -2001,7 +2000,6 @@ torch.cuda.synchronize()
|
|||
prec=0.05,
|
||||
)
|
||||
|
||||
@expectedFailureMPS # max_pool3d_with_indices not supported on MPS device
|
||||
def test_maxpool3d_non_square_backward(self, device):
|
||||
# previous CUDA routine of this backward calculates kernel launch grid size
|
||||
# with last two dimensions interchanged, so the tailing along the longer dim
|
||||
|
|
|
|||
|
|
@ -12244,6 +12244,11 @@ class TestConsistency(TestCaseMPS):
|
|||
atol, rtol = 3e-3, 3e-3
|
||||
if op.name == "logcumsumexp":
|
||||
atol, rtol = 4e-3, 1e-3
|
||||
if op.name == "nn.functional.max_pool3d" and dtype == torch.float16:
|
||||
# In a few cases where stride is smaller than kernel size,
|
||||
# several output grad elements of similar magnitudes get summed
|
||||
# together, introducing significant error for float16.
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fmax_mixed_dtypes(self, device):
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_select(AtenTensorHandle s
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_median(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
|
||||
|
|
|
|||
|
|
@ -3864,9 +3864,6 @@ module_db: list[ModuleInfo] = [
|
|||
ModuleInfo(torch.nn.MaxPool3d,
|
||||
module_inputs_func=module_inputs_torch_nn_MaxPool3d,
|
||||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
skips=(
|
||||
# not supported on MPS backend
|
||||
DecorateInfo(skipIfMPS, device_type='mps'),)
|
||||
),
|
||||
ModuleInfo(torch.nn.KLDivLoss,
|
||||
module_inputs_func=module_inputs_torch_nn_KLDivLoss,
|
||||
|
|
|
|||
|
|
@ -849,7 +849,6 @@ if torch.backends.mps.is_available():
|
|||
"floor_divide": [torch.float16, torch.float32],
|
||||
# derivative for aten::narrow_copy is not implemented on CPU
|
||||
"narrow_copy": [torch.float16, torch.float32],
|
||||
"nn.functional.max_pool3d": [torch.float16, torch.float32],
|
||||
# derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
|
||||
"histogramdd": [torch.float16, torch.float32],
|
||||
# derivative for aten::histogram is not implemented
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user