[MPS] Move max_pool2d to mps dispatch key (#90772)

Related issue: #77394

This PR also modifies some assertions in the codegen, an explanatory comment for it has been added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90772
Approved by: https://github.com/albanD
This commit is contained in:
Li-Huai (Allan) Lin 2023-02-16 01:13:08 +00:00 committed by PyTorch MergeBot
parent 250c054bdd
commit e8dc34eaeb
7 changed files with 23 additions and 24 deletions

View File

@ -9,7 +9,6 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_mps_max_pool2d.h>
#include <ATen/ops/adaptive_avg_pool1d_native.h>
#include <ATen/ops/adaptive_avg_pool2d.h>
#include <ATen/ops/adaptive_max_pool1d_native.h>
@ -141,12 +140,6 @@ Tensor max_pool2d(
return at::mkldnn_max_pool2d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
#ifdef USE_MPS
if (self.is_mps()) {
return at::_mps_max_pool2d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
#endif
#if defined(C10_MOBILE)
if(xnnpack::use_max_pool2d(self, kernel_size, padding, stride,
dilation, ceil_mode)) {

View File

@ -308,7 +308,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output,
} // namespace mps
Tensor _mps_max_pool2d(
Tensor mps_max_pool2d(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,

View File

@ -3567,19 +3567,14 @@
- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
# native_functions.yaml
# https://github.com/pytorch/pytorch/issues/77394
- func: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
dispatch:
MPS: _mps_max_pool2d
autogen: _mps_max_pool2d.out
CompositeImplicitAutograd: max_pool2d
MPS: mps_max_pool2d
- func: mps_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
dispatch:
MPS: mps_max_pool2d_backward
autogen: mps_max_pool2d_backward.out
autogen: max_pool2d_backward.out
- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
dispatch:

View File

@ -377,8 +377,6 @@ aten::_mps_convolution
aten::_mps_convolution.out
aten::_mps_convolution_transpose
aten::_mps_convolution_transpose.out
aten::_mps_max_pool2d
aten::_mps_max_pool2d.out
aten::_native_batch_norm_legit.no_stats_out
aten::_native_batch_norm_legit.out
aten::_native_decoder_only_multi_head_attention
@ -857,6 +855,8 @@ aten::max
aten::max.dim
aten::max.dim_max
aten::max.unary_out
aten::max_pool2d_backward
aten::max_pool2d_backward.out
aten::max_pool2d_with_indices
aten::max_pool2d_with_indices.out
aten::max_pool2d_with_indices_backward
@ -930,8 +930,6 @@ aten::mps_convolution_backward
aten::mps_convolution_backward.out
aten::mps_convolution_transpose_backward
aten::mps_convolution_transpose_backward.out
aten::mps_max_pool2d_backward
aten::mps_max_pool2d_backward.out
aten::multi_margin_loss
aten::multi_margin_loss.out
aten::multi_margin_loss_backward

View File

@ -150,6 +150,10 @@ ALLOW_LIST = [
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
("aten::mps_linear", datetime.date(9999, 1, 1)),
("aten::_mps_linear", datetime.date(9999, 1, 1)),
("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)),
("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)),
("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)),
("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
("aten::view_copy.SymInt", datetime.date(2022, 11, 30)),
("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)),
("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)),

View File

@ -2170,8 +2170,8 @@
input, weight, bias: linear_backward(input, grad, weight, grad_input_mask)
#mps
- name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"

View File

@ -638,6 +638,7 @@ class NativeFunction:
raw_dispatch = e.pop("dispatch", None)
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
dispatch: Dict[DispatchKey, BackendMetadata] = {}
num_dispatch_keys: int = 0
if raw_dispatch is not None:
assert not manual_kernel_registration, (
"cannot specify both manual_kernel_registration and dispatch; with "
@ -650,6 +651,8 @@ class NativeFunction:
assert isinstance(ks, str), e
for k in ks.split(","):
dispatch_key = DispatchKey.parse(k.strip())
num_dispatch_keys += 1
if ignore_keys and dispatch_key in ignore_keys:
continue
assert dispatch_key in dispatch_keys, (
@ -677,7 +680,12 @@ class NativeFunction:
):
redundant_composite_implicit_autograd = True
assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), (
# We count the number of dispatch keys which have not been ignored to prevent a dispatch table
# in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
# from being treated as redundant.
assert not (
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
), (
"unnecessary dispatch table for this function; just delete the dispatch "
"key entirely"
)
@ -687,6 +695,7 @@ class NativeFunction:
structured_delegate
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
or num_dispatch_keys != 1
), (
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "