Allow registering decomps for HigherOrderOp; add decomp for out_dtype (#108080)

We allow registering decomps for HigherOrderOp via the existing decomp
mechanisms:
- I refactored those APIs to accept torch._ops.OperatorBase, which is the base
  class for torch.ops.HigherOrderOperator and torch.ops.OpOverload
- HigherOrderOps must directly call maybe_handle_decomp in their
  ProxyTorchDispatchMode handling in order to resolve decompositions. We
  can change this in the future so that they do not need to do this.

Next, we add an inductor decomp for out_dtype. This decomp shouldn't be
generally available because we want to preserve out_dtype to the backend
for other use cases (i.e. executorch).

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108080
Approved by: https://github.com/HDCharles
This commit is contained in:
rzou 2023-08-30 07:02:12 -07:00 committed by PyTorch MergeBot
parent 95e3126370
commit 0e4752bafc
10 changed files with 150 additions and 32 deletions

View File

@ -38,13 +38,19 @@ aten = torch.ops.aten
# TODO: this isn't going to work with non-aten namespaces
def overload_to_aten_name(overload):
return overload._schema.name.split("::")[1]
def overload_to_aten_name(op):
return op._schema.name.split("::")[1]
# All operators that can have decomp tests
decomposition_names = {overload_to_aten_name(k) for k in decomposition_table}
core_decomposition_names = {overload_to_aten_name(k) for k in core_aten_decompositions()}
decomposition_names = {
overload_to_aten_name(k) for k in decomposition_table
if isinstance(k, torch._ops.OpOverload)
}
core_decomposition_names = {
overload_to_aten_name(k) for k in core_aten_decompositions()
if isinstance(k, torch._ops.OpOverload)
}
_decomp_test_ops = [
op
for op in op_db
@ -908,7 +914,8 @@ class HasDecompTest(TestCase):
# Some decompositions are registered for CompositeImplicitAutograd
# operators, which never appear in AOTAutograd's graph so are never used.
useful_decomps = {op for op in decomposition_table.keys()
if self._can_appear_in_trace(op)}
if isinstance(op, torch._ops.OpOverload) and
self._can_appear_in_trace(op)}
core_decomps = torch._decomp.core_aten_decompositions().keys()
core_aten_ops = useful_decomps - core_decomps
self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))

View File

@ -3,11 +3,17 @@ import unittest
import torch
import torch._dynamo
import torch._inductor
import torch._inductor.decomposition
import torch._export
from torch._higher_order_ops.out_dtype import out_dtype
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
run_tests, TestCase, IS_WINDOWS, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, TEST_CUDA
)
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater, _get_torch_cuda_version
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
@ -152,6 +158,58 @@ class TestOutDtypeOp(TestCase):
loss = out - torch.ones(out.shape)
loss.backward()
@unittest.skipIf(IS_WINDOWS, "_int_mm unavailable")
@unittest.skipIf(TEST_WITH_ROCM, "_int_mm unavailable")
@unittest.skipIf(not SM80OrLater, "_int_mm unavailable")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "_int_mm unavailable")
@unittest.skipIf(not TEST_CUDA, "_int_mm unavailable")
@skipIfNoDynamoSupport
def test_out_dtype_inductor_decomp(self) -> None:
def func(x, w):
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w)
w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
ref = torch._int_mm(x, w)
test_out = func(x, w)
func_comp = torch.compile(func, fullgraph=True, mode="max-autotune")
test_out_c = func_comp(x, w)
self.assertTrue(torch.allclose(ref, test_out))
self.assertTrue(torch.allclose(ref, test_out_c))
@unittest.skipIf(not TEST_CUDA, "cuda only")
def test_out_dtype_inductor_decomp_trace(self) -> None:
def func(x, w):
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w)
w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
# Check that make_fx with inductor decomps produces _int_mm
decomp_table = torch._inductor.decomposition.select_decomp_table()
gm = make_fx(func, decomp_table, tracing_mode="symbolic")(x, w)
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1, w_1):
_int_mm = torch.ops.aten._int_mm.default(x_1, w_1); x_1 = w_1 = None
return _int_mm""")
@unittest.skipIf(not TEST_CUDA, "cuda only")
def test_out_dtype_int_mm_default_trace(self) -> None:
def func(x, w):
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w)
w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
# By default, out_dtype is preserved in the trace
gm = make_fx(func, tracing_mode="symbolic")(x, w)
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1, w_1):
out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, x_1, w_1); x_1 = w_1 = None
return out_dtype""")
def test_out_dtype_wrong_output(self) -> None:
def multiple_out(x):
return out_dtype(

View File

@ -2,11 +2,11 @@ import inspect
from collections import defaultdict
from functools import wraps
from itertools import chain
from typing import Callable, Dict, Sequence, Union
from typing import Callable, Dict, List, Sequence, Union
import torch
import torch.library
from torch._ops import OpOverload, OpOverloadPacket
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
from torch.utils._pytree import tree_map
__all__ = [
@ -21,7 +21,9 @@ __all__ = [
# TODO: relax key type here; torch registrations should be possible to; but
# right now this type is accurate
global_decomposition_table: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
global_decomposition_table: Dict[
str, Dict[torch._ops.OperatorBase, Callable]
] = defaultdict(dict)
decomposition_table = global_decomposition_table["post_autograd"]
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
@ -35,8 +37,12 @@ def _add_op_to_registry(registry, op, fn):
If op is OpOverload, it will be added to the registry directly.
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
"""
overloads = []
if isinstance(op, OpOverload):
overloads: List[Union[torch._ops.OperatorBase]] = []
if isinstance(op, HigherOrderOperator):
# There's no concept of overloads for HigherOrderOperator
registry[op] = fn
return
elif isinstance(op, OpOverload):
overloads.append(op)
else:
assert isinstance(op, OpOverloadPacket)
@ -46,7 +52,6 @@ def _add_op_to_registry(registry, op, fn):
for op_overload in overloads:
if op_overload in registry:
raise RuntimeError(f"duplicate registrations for {op_overload}")
# TorchScript dumps a bunch of extra nonsense overloads
# which don't have corresponding dispatcher entries, we need
# to filter those out, e.g aten.add.float_int
@ -143,9 +148,9 @@ def register_decomposition(
def get_decompositions(
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
type: str = "post_autograd",
) -> Dict[OpOverload, Callable]:
) -> Dict[torch._ops.OperatorBase, Callable]:
"""
Retrieve a dictionary of decompositions corresponding to the list of
operator overloads and overload packets passed as input. Overload
@ -162,13 +167,14 @@ def get_decompositions(
registry = global_decomposition_table[type]
packets_to_overloads = defaultdict(list)
for opo in registry:
packets_to_overloads[opo.overloadpacket].append(opo)
decompositions = {}
if isinstance(opo, (OpOverload, OpOverloadPacket)):
packets_to_overloads[opo.overloadpacket].append(opo)
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
for op in aten_ops:
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
for op_overload in packets_to_overloads[op]:
decompositions[op_overload] = registry[op_overload]
elif isinstance(op, OpOverload) and op in registry:
elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
decompositions[op] = registry[op]
return decompositions
@ -202,7 +208,7 @@ import torch._refs
# list was copied from torch/_inductor/decomposition.py
# excluding decompositions that results in prim ops
# Resulting opset of decomposition is core aten ops
def core_aten_decompositions() -> Dict[OpOverload, Callable]:
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten = torch.ops.aten
return get_decompositions(
[

View File

@ -12,6 +12,7 @@ import torch._prims_common as utils
import torch.nn.functional as F
from torch import sym_float, sym_int, Tensor
from torch._decomp import register_decomposition
from torch._higher_order_ops.out_dtype import out_dtype
from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType
from torch._prims_common.wrappers import (
_maybe_convert_to_dtype,
@ -3828,6 +3829,13 @@ def arange_start(
)
@register_decomposition(out_dtype)
def out_dtype_decomp(*args, **kwargs):
from torch._higher_order_ops.out_dtype import out_dtype_dense
return out_dtype_dense(*args, **kwargs)
@register_decomposition(aten.multi_margin_loss)
@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd)
@out_wrapper()

View File

@ -6,7 +6,7 @@ import torch._decomp
from torch import Tensor
decomposition_table = torch._decomp.decomposition_table
decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
register_decomposition = torch._decomp.register_decomposition
aten = torch.ops.aten

View File

@ -5,6 +5,7 @@ from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
maybe_handle_decomp,
)
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
@ -90,6 +91,13 @@ out_dtype.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
# NB: Long-term we should put the decomposition logic into
# ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp
# in all HigherOrderOp proxy implementations.
r = maybe_handle_decomp(proxy_mode, func_overload, (op, output_dtype, *args), {})
if r is not NotImplemented:
return r
with disable_proxy_modes_tracing():
# This is a simplified implementation of this operator just for tracing.
# Actual implementation may also first promote the arguments
@ -115,6 +123,24 @@ def out_dtype_dense(
output_dtype: torch.dtype,
*args
):
if is_int_mm(op, output_dtype, args):
return torch._int_mm(*args)
return out_dtype_fallback(op, output_dtype, *args)
def is_int_mm(op, output_dtype, args):
return (
op == torch.ops.aten.mm.default and
output_dtype == torch.int32 and
len(args) == 2 and
args[0].dtype == torch.int8 and
args[1].dtype == torch.int8 and
args[0].is_cuda and
args[1].is_cuda
)
def out_dtype_fallback(op, output_dtype, *args):
flat_inputs = pytree.tree_flatten(args)[0] + [torch.ones(1, dtype=output_dtype)]
promote_dtype: torch.dtype = elementwise_dtypes(
*flat_inputs,

View File

@ -17,6 +17,7 @@ from torch._decomp.decompositions import (
pw_cast_for_opmath,
)
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch._higher_order_ops.out_dtype import out_dtype
from . import config
@ -54,6 +55,7 @@ inductor_decompositions = get_decompositions(
aten.sqrt_,
aten.std,
aten.std_mean,
out_dtype,
aten._to_copy,
aten.tril_indices,
aten.triu_indices,

View File

@ -5668,6 +5668,12 @@ def activate_meta():
activate_meta_table[opo] = registry[opo]
for op_overload, fn in activate_meta_table.items():
# Don't register meta for HigherOrderOp's decomp.
# We can reconsider this in the future, but in general,
# the way you do a meta for a HigherOrderOp is different from
# OpOverload.
if isinstance(op_overload, torch._ops.HigherOrderOperator):
continue
assert isinstance(op_overload, OpOverload)
op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)

View File

@ -42,7 +42,7 @@ prim = torch.ops.prim
log = logging.getLogger(__name__)
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OperatorBase, Callable] = {}
CONSTANT_NUMEL_LIMIT = 1
@ -252,11 +252,9 @@ def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
not_implemented_log.debug("ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", unrecognized_types)
return NotImplemented
if func in CURRENT_DECOMPOSITION_TABLE:
with proxy_mode:
r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
if r is not NotImplemented:
return r
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
if r is not NotImplemented:
return r
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if not pre_dispatch:
@ -872,6 +870,13 @@ def disable_proxy_modes_tracing(enable_current=False):
torch._C._set_dispatch_mode(maybe_old)
def maybe_handle_decomp(proxy_mode, op, args, kwargs):
if op in CURRENT_DECOMPOSITION_TABLE:
with proxy_mode:
return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
return NotImplemented
def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"):
"""A helper function used to get the GraphModule for the given func.

View File

@ -18,17 +18,17 @@ from torch.onnx._internal.fx import registration
@_beartype.beartype
def _create_onnx_supports_op_overload_table(
registry,
) -> Set[Union[torch._ops.OpOverload, Callable]]:
) -> Set[Union[torch._ops.OperatorBase, Callable]]:
"""
Creates a set of OpOverload and Callable objects that represent ONNX-supported PyTorch operations.
Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations.
Args:
registry (OnnxRegistry): The ONNX registry for PyTorch.
Returns:
A collection of OpOverload and Callable objects representing ONNX-supported PyTorch operations.
A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations.
"""
table: Set[Union[torch._ops.OpOverload, Callable]] = set()
table: Set[Union[torch._ops.OperatorBase, Callable]] = set()
# Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`,
# but retrievable via explicit lookup.
@ -80,7 +80,7 @@ def _create_onnx_supports_op_overload_table(
@_beartype.beartype
def create_onnx_friendly_decomposition_table(
registry,
) -> Dict[torch._ops.OpOverload, Callable]:
) -> Dict[torch._ops.OperatorBase, Callable]:
"""
This function creates a dictionary of op overloads and their decomposition functions
for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function,
@ -91,10 +91,10 @@ def create_onnx_friendly_decomposition_table(
registry (torch.onnx.OnnxRegistry): The ONNX registry for PyTorch.
Returns:
Dict[torch._ops.OpOverload, Callable]: A dictionary that maps op overloads to their corresponding
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
decomposition functions.
"""
decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = {}
# Dictionary that maps torch.ops.aten.* to exporter look up key; e.g.,
# _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add".
_ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry)