mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Move max_pool2d to mps dispatch key (#90772)
Related issue: #77394 This PR also modifies some assertions in the codegen, an explanatory comment for it has been added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90772 Approved by: https://github.com/albanD
This commit is contained in:
parent
250c054bdd
commit
e8dc34eaeb
|
|
@ -9,7 +9,6 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_mps_max_pool2d.h>
|
||||
#include <ATen/ops/adaptive_avg_pool1d_native.h>
|
||||
#include <ATen/ops/adaptive_avg_pool2d.h>
|
||||
#include <ATen/ops/adaptive_max_pool1d_native.h>
|
||||
|
|
@ -141,12 +140,6 @@ Tensor max_pool2d(
|
|||
return at::mkldnn_max_pool2d(
|
||||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
#ifdef USE_MPS
|
||||
if (self.is_mps()) {
|
||||
return at::_mps_max_pool2d(
|
||||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
#endif
|
||||
#if defined(C10_MOBILE)
|
||||
if(xnnpack::use_max_pool2d(self, kernel_size, padding, stride,
|
||||
dilation, ceil_mode)) {
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
|
|||
|
||||
} // namespace mps
|
||||
|
||||
Tensor _mps_max_pool2d(
|
||||
Tensor mps_max_pool2d(
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
|
|
|
|||
|
|
@ -3567,19 +3567,14 @@
|
|||
- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
|
||||
- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
|
||||
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
|
||||
# native_functions.yaml
|
||||
# https://github.com/pytorch/pytorch/issues/77394
|
||||
- func: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
dispatch:
|
||||
MPS: _mps_max_pool2d
|
||||
autogen: _mps_max_pool2d.out
|
||||
CompositeImplicitAutograd: max_pool2d
|
||||
MPS: mps_max_pool2d
|
||||
|
||||
- func: mps_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
dispatch:
|
||||
MPS: mps_max_pool2d_backward
|
||||
autogen: mps_max_pool2d_backward.out
|
||||
autogen: max_pool2d_backward.out
|
||||
|
||||
- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -377,8 +377,6 @@ aten::_mps_convolution
|
|||
aten::_mps_convolution.out
|
||||
aten::_mps_convolution_transpose
|
||||
aten::_mps_convolution_transpose.out
|
||||
aten::_mps_max_pool2d
|
||||
aten::_mps_max_pool2d.out
|
||||
aten::_native_batch_norm_legit.no_stats_out
|
||||
aten::_native_batch_norm_legit.out
|
||||
aten::_native_decoder_only_multi_head_attention
|
||||
|
|
@ -857,6 +855,8 @@ aten::max
|
|||
aten::max.dim
|
||||
aten::max.dim_max
|
||||
aten::max.unary_out
|
||||
aten::max_pool2d_backward
|
||||
aten::max_pool2d_backward.out
|
||||
aten::max_pool2d_with_indices
|
||||
aten::max_pool2d_with_indices.out
|
||||
aten::max_pool2d_with_indices_backward
|
||||
|
|
@ -930,8 +930,6 @@ aten::mps_convolution_backward
|
|||
aten::mps_convolution_backward.out
|
||||
aten::mps_convolution_transpose_backward
|
||||
aten::mps_convolution_transpose_backward.out
|
||||
aten::mps_max_pool2d_backward
|
||||
aten::mps_max_pool2d_backward.out
|
||||
aten::multi_margin_loss
|
||||
aten::multi_margin_loss.out
|
||||
aten::multi_margin_loss_backward
|
||||
|
|
|
|||
|
|
@ -150,6 +150,10 @@ ALLOW_LIST = [
|
|||
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::mps_linear", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_linear", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)),
|
||||
("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
|
||||
("aten::view_copy.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)),
|
||||
("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)),
|
||||
|
|
|
|||
|
|
@ -2170,8 +2170,8 @@
|
|||
input, weight, bias: linear_backward(input, grad, weight, grad_input_mask)
|
||||
|
||||
#mps
|
||||
- name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
|
||||
self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
|
||||
- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
|
|
|||
|
|
@ -638,6 +638,7 @@ class NativeFunction:
|
|||
raw_dispatch = e.pop("dispatch", None)
|
||||
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
||||
dispatch: Dict[DispatchKey, BackendMetadata] = {}
|
||||
num_dispatch_keys: int = 0
|
||||
if raw_dispatch is not None:
|
||||
assert not manual_kernel_registration, (
|
||||
"cannot specify both manual_kernel_registration and dispatch; with "
|
||||
|
|
@ -650,6 +651,8 @@ class NativeFunction:
|
|||
assert isinstance(ks, str), e
|
||||
for k in ks.split(","):
|
||||
dispatch_key = DispatchKey.parse(k.strip())
|
||||
num_dispatch_keys += 1
|
||||
|
||||
if ignore_keys and dispatch_key in ignore_keys:
|
||||
continue
|
||||
assert dispatch_key in dispatch_keys, (
|
||||
|
|
@ -677,7 +680,12 @@ class NativeFunction:
|
|||
):
|
||||
redundant_composite_implicit_autograd = True
|
||||
|
||||
assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), (
|
||||
# We count the number of dispatch keys which have not been ignored to prevent a dispatch table
|
||||
# in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
|
||||
# from being treated as redundant.
|
||||
assert not (
|
||||
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
|
||||
), (
|
||||
"unnecessary dispatch table for this function; just delete the dispatch "
|
||||
"key entirely"
|
||||
)
|
||||
|
|
@ -687,6 +695,7 @@ class NativeFunction:
|
|||
structured_delegate
|
||||
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
||||
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
|
||||
or num_dispatch_keys != 1
|
||||
), (
|
||||
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
|
||||
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user