mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow registering decomps for HigherOrderOp; add decomp for out_dtype (#108080)
We allow registering decomps for HigherOrderOp via the existing decomp mechanisms: - I refactored those APIs to accept torch._ops.OperatorBase, which is the base class for torch.ops.HigherOrderOperator and torch.ops.OpOverload - HigherOrderOps must directly call maybe_handle_decomp in their ProxyTorchDispatchMode handling in order to resolve decompositions. We can change this in the future so that they do not need to do this. Next, we add an inductor decomp for out_dtype. This decomp shouldn't be generally available because we want to preserve out_dtype to the backend for other use cases (i.e. executorch). Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/108080 Approved by: https://github.com/HDCharles
This commit is contained in:
parent
95e3126370
commit
0e4752bafc
|
|
@ -38,13 +38,19 @@ aten = torch.ops.aten
|
|||
|
||||
|
||||
# TODO: this isn't going to work with non-aten namespaces
|
||||
def overload_to_aten_name(overload):
|
||||
return overload._schema.name.split("::")[1]
|
||||
def overload_to_aten_name(op):
|
||||
return op._schema.name.split("::")[1]
|
||||
|
||||
|
||||
# All operators that can have decomp tests
|
||||
decomposition_names = {overload_to_aten_name(k) for k in decomposition_table}
|
||||
core_decomposition_names = {overload_to_aten_name(k) for k in core_aten_decompositions()}
|
||||
decomposition_names = {
|
||||
overload_to_aten_name(k) for k in decomposition_table
|
||||
if isinstance(k, torch._ops.OpOverload)
|
||||
}
|
||||
core_decomposition_names = {
|
||||
overload_to_aten_name(k) for k in core_aten_decompositions()
|
||||
if isinstance(k, torch._ops.OpOverload)
|
||||
}
|
||||
_decomp_test_ops = [
|
||||
op
|
||||
for op in op_db
|
||||
|
|
@ -908,7 +914,8 @@ class HasDecompTest(TestCase):
|
|||
# Some decompositions are registered for CompositeImplicitAutograd
|
||||
# operators, which never appear in AOTAutograd's graph so are never used.
|
||||
useful_decomps = {op for op in decomposition_table.keys()
|
||||
if self._can_appear_in_trace(op)}
|
||||
if isinstance(op, torch._ops.OpOverload) and
|
||||
self._can_appear_in_trace(op)}
|
||||
core_decomps = torch._decomp.core_aten_decompositions().keys()
|
||||
core_aten_ops = useful_decomps - core_decomps
|
||||
self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))
|
||||
|
|
|
|||
|
|
@ -3,11 +3,17 @@ import unittest
|
|||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
import torch._export
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests, TestCase, IS_WINDOWS, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, TEST_CUDA
|
||||
)
|
||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import SM80OrLater, _get_torch_cuda_version
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
|
|
@ -152,6 +158,58 @@ class TestOutDtypeOp(TestCase):
|
|||
loss = out - torch.ones(out.shape)
|
||||
loss.backward()
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "_int_mm unavailable")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "_int_mm unavailable")
|
||||
@unittest.skipIf(not SM80OrLater, "_int_mm unavailable")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "_int_mm unavailable")
|
||||
@unittest.skipIf(not TEST_CUDA, "_int_mm unavailable")
|
||||
@skipIfNoDynamoSupport
|
||||
def test_out_dtype_inductor_decomp(self) -> None:
|
||||
def func(x, w):
|
||||
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w)
|
||||
|
||||
w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
|
||||
x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
|
||||
|
||||
ref = torch._int_mm(x, w)
|
||||
test_out = func(x, w)
|
||||
func_comp = torch.compile(func, fullgraph=True, mode="max-autotune")
|
||||
test_out_c = func_comp(x, w)
|
||||
self.assertTrue(torch.allclose(ref, test_out))
|
||||
self.assertTrue(torch.allclose(ref, test_out_c))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "cuda only")
|
||||
def test_out_dtype_inductor_decomp_trace(self) -> None:
|
||||
def func(x, w):
|
||||
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w)
|
||||
|
||||
w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
|
||||
x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
|
||||
|
||||
# Check that make_fx with inductor decomps produces _int_mm
|
||||
decomp_table = torch._inductor.decomposition.select_decomp_table()
|
||||
gm = make_fx(func, decomp_table, tracing_mode="symbolic")(x, w)
|
||||
self.assertExpectedInline(gm.code.strip(), """\
|
||||
def forward(self, x_1, w_1):
|
||||
_int_mm = torch.ops.aten._int_mm.default(x_1, w_1); x_1 = w_1 = None
|
||||
return _int_mm""")
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "cuda only")
|
||||
def test_out_dtype_int_mm_default_trace(self) -> None:
|
||||
def func(x, w):
|
||||
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, w)
|
||||
|
||||
w = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
|
||||
x = torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda")
|
||||
|
||||
# By default, out_dtype is preserved in the trace
|
||||
gm = make_fx(func, tracing_mode="symbolic")(x, w)
|
||||
self.assertExpectedInline(gm.code.strip(), """\
|
||||
def forward(self, x_1, w_1):
|
||||
out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, x_1, w_1); x_1 = w_1 = None
|
||||
return out_dtype""")
|
||||
|
||||
def test_out_dtype_wrong_output(self) -> None:
|
||||
def multiple_out(x):
|
||||
return out_dtype(
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import inspect
|
|||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from itertools import chain
|
||||
from typing import Callable, Dict, Sequence, Union
|
||||
from typing import Callable, Dict, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -21,7 +21,9 @@ __all__ = [
|
|||
|
||||
# TODO: relax key type here; torch registrations should be possible to; but
|
||||
# right now this type is accurate
|
||||
global_decomposition_table: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
|
||||
global_decomposition_table: Dict[
|
||||
str, Dict[torch._ops.OperatorBase, Callable]
|
||||
] = defaultdict(dict)
|
||||
|
||||
decomposition_table = global_decomposition_table["post_autograd"]
|
||||
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
|
||||
|
|
@ -35,8 +37,12 @@ def _add_op_to_registry(registry, op, fn):
|
|||
If op is OpOverload, it will be added to the registry directly.
|
||||
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
|
||||
"""
|
||||
overloads = []
|
||||
if isinstance(op, OpOverload):
|
||||
overloads: List[Union[torch._ops.OperatorBase]] = []
|
||||
if isinstance(op, HigherOrderOperator):
|
||||
# There's no concept of overloads for HigherOrderOperator
|
||||
registry[op] = fn
|
||||
return
|
||||
elif isinstance(op, OpOverload):
|
||||
overloads.append(op)
|
||||
else:
|
||||
assert isinstance(op, OpOverloadPacket)
|
||||
|
|
@ -46,7 +52,6 @@ def _add_op_to_registry(registry, op, fn):
|
|||
for op_overload in overloads:
|
||||
if op_overload in registry:
|
||||
raise RuntimeError(f"duplicate registrations for {op_overload}")
|
||||
|
||||
# TorchScript dumps a bunch of extra nonsense overloads
|
||||
# which don't have corresponding dispatcher entries, we need
|
||||
# to filter those out, e.g aten.add.float_int
|
||||
|
|
@ -143,9 +148,9 @@ def register_decomposition(
|
|||
|
||||
|
||||
def get_decompositions(
|
||||
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
|
||||
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
|
||||
type: str = "post_autograd",
|
||||
) -> Dict[OpOverload, Callable]:
|
||||
) -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
"""
|
||||
Retrieve a dictionary of decompositions corresponding to the list of
|
||||
operator overloads and overload packets passed as input. Overload
|
||||
|
|
@ -162,13 +167,14 @@ def get_decompositions(
|
|||
registry = global_decomposition_table[type]
|
||||
packets_to_overloads = defaultdict(list)
|
||||
for opo in registry:
|
||||
packets_to_overloads[opo.overloadpacket].append(opo)
|
||||
decompositions = {}
|
||||
if isinstance(opo, (OpOverload, OpOverloadPacket)):
|
||||
packets_to_overloads[opo.overloadpacket].append(opo)
|
||||
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
|
||||
for op in aten_ops:
|
||||
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
|
||||
for op_overload in packets_to_overloads[op]:
|
||||
decompositions[op_overload] = registry[op_overload]
|
||||
elif isinstance(op, OpOverload) and op in registry:
|
||||
elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
|
||||
decompositions[op] = registry[op]
|
||||
return decompositions
|
||||
|
||||
|
|
@ -202,7 +208,7 @@ import torch._refs
|
|||
# list was copied from torch/_inductor/decomposition.py
|
||||
# excluding decompositions that results in prim ops
|
||||
# Resulting opset of decomposition is core aten ops
|
||||
def core_aten_decompositions() -> Dict[OpOverload, Callable]:
|
||||
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
aten = torch.ops.aten
|
||||
return get_decompositions(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import torch._prims_common as utils
|
|||
import torch.nn.functional as F
|
||||
from torch import sym_float, sym_int, Tensor
|
||||
from torch._decomp import register_decomposition
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType
|
||||
from torch._prims_common.wrappers import (
|
||||
_maybe_convert_to_dtype,
|
||||
|
|
@ -3828,6 +3829,13 @@ def arange_start(
|
|||
)
|
||||
|
||||
|
||||
@register_decomposition(out_dtype)
|
||||
def out_dtype_decomp(*args, **kwargs):
|
||||
from torch._higher_order_ops.out_dtype import out_dtype_dense
|
||||
|
||||
return out_dtype_dense(*args, **kwargs)
|
||||
|
||||
|
||||
@register_decomposition(aten.multi_margin_loss)
|
||||
@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd)
|
||||
@out_wrapper()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import torch._decomp
|
|||
from torch import Tensor
|
||||
|
||||
decomposition_table = torch._decomp.decomposition_table
|
||||
decomposition_table_for_jvp: Dict[torch._ops.OpOverload, Callable] = {}
|
||||
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
|
||||
register_decomposition = torch._decomp.register_decomposition
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
maybe_handle_decomp,
|
||||
)
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
|
|
@ -90,6 +91,13 @@ out_dtype.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
|
|||
|
||||
|
||||
def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
|
||||
# NB: Long-term we should put the decomposition logic into
|
||||
# ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp
|
||||
# in all HigherOrderOp proxy implementations.
|
||||
r = maybe_handle_decomp(proxy_mode, func_overload, (op, output_dtype, *args), {})
|
||||
if r is not NotImplemented:
|
||||
return r
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
# This is a simplified implementation of this operator just for tracing.
|
||||
# Actual implementation may also first promote the arguments
|
||||
|
|
@ -115,6 +123,24 @@ def out_dtype_dense(
|
|||
output_dtype: torch.dtype,
|
||||
*args
|
||||
):
|
||||
if is_int_mm(op, output_dtype, args):
|
||||
return torch._int_mm(*args)
|
||||
return out_dtype_fallback(op, output_dtype, *args)
|
||||
|
||||
|
||||
def is_int_mm(op, output_dtype, args):
|
||||
return (
|
||||
op == torch.ops.aten.mm.default and
|
||||
output_dtype == torch.int32 and
|
||||
len(args) == 2 and
|
||||
args[0].dtype == torch.int8 and
|
||||
args[1].dtype == torch.int8 and
|
||||
args[0].is_cuda and
|
||||
args[1].is_cuda
|
||||
)
|
||||
|
||||
|
||||
def out_dtype_fallback(op, output_dtype, *args):
|
||||
flat_inputs = pytree.tree_flatten(args)[0] + [torch.ones(1, dtype=output_dtype)]
|
||||
promote_dtype: torch.dtype = elementwise_dtypes(
|
||||
*flat_inputs,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from torch._decomp.decompositions import (
|
|||
pw_cast_for_opmath,
|
||||
)
|
||||
from torch._decomp.decompositions_for_rng import extra_random_decomps
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
|
||||
from . import config
|
||||
|
||||
|
|
@ -54,6 +55,7 @@ inductor_decompositions = get_decompositions(
|
|||
aten.sqrt_,
|
||||
aten.std,
|
||||
aten.std_mean,
|
||||
out_dtype,
|
||||
aten._to_copy,
|
||||
aten.tril_indices,
|
||||
aten.triu_indices,
|
||||
|
|
|
|||
|
|
@ -5668,6 +5668,12 @@ def activate_meta():
|
|||
activate_meta_table[opo] = registry[opo]
|
||||
|
||||
for op_overload, fn in activate_meta_table.items():
|
||||
# Don't register meta for HigherOrderOp's decomp.
|
||||
# We can reconsider this in the future, but in general,
|
||||
# the way you do a meta for a HigherOrderOp is different from
|
||||
# OpOverload.
|
||||
if isinstance(op_overload, torch._ops.HigherOrderOperator):
|
||||
continue
|
||||
assert isinstance(op_overload, OpOverload)
|
||||
|
||||
op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ prim = torch.ops.prim
|
|||
log = logging.getLogger(__name__)
|
||||
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
|
||||
|
||||
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
|
||||
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OperatorBase, Callable] = {}
|
||||
|
||||
CONSTANT_NUMEL_LIMIT = 1
|
||||
|
||||
|
|
@ -252,11 +252,9 @@ def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
|
|||
not_implemented_log.debug("ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", unrecognized_types)
|
||||
return NotImplemented
|
||||
|
||||
if func in CURRENT_DECOMPOSITION_TABLE:
|
||||
with proxy_mode:
|
||||
r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
return r
|
||||
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
|
||||
if r is not NotImplemented:
|
||||
return r
|
||||
|
||||
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
|
||||
if not pre_dispatch:
|
||||
|
|
@ -872,6 +870,13 @@ def disable_proxy_modes_tracing(enable_current=False):
|
|||
torch._C._set_dispatch_mode(maybe_old)
|
||||
|
||||
|
||||
def maybe_handle_decomp(proxy_mode, op, args, kwargs):
|
||||
if op in CURRENT_DECOMPOSITION_TABLE:
|
||||
with proxy_mode:
|
||||
return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"):
|
||||
"""A helper function used to get the GraphModule for the given func.
|
||||
|
||||
|
|
|
|||
|
|
@ -18,17 +18,17 @@ from torch.onnx._internal.fx import registration
|
|||
@_beartype.beartype
|
||||
def _create_onnx_supports_op_overload_table(
|
||||
registry,
|
||||
) -> Set[Union[torch._ops.OpOverload, Callable]]:
|
||||
) -> Set[Union[torch._ops.OperatorBase, Callable]]:
|
||||
"""
|
||||
Creates a set of OpOverload and Callable objects that represent ONNX-supported PyTorch operations.
|
||||
Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations.
|
||||
|
||||
Args:
|
||||
registry (OnnxRegistry): The ONNX registry for PyTorch.
|
||||
|
||||
Returns:
|
||||
A collection of OpOverload and Callable objects representing ONNX-supported PyTorch operations.
|
||||
A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations.
|
||||
"""
|
||||
table: Set[Union[torch._ops.OpOverload, Callable]] = set()
|
||||
table: Set[Union[torch._ops.OperatorBase, Callable]] = set()
|
||||
|
||||
# Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`,
|
||||
# but retrievable via explicit lookup.
|
||||
|
|
@ -80,7 +80,7 @@ def _create_onnx_supports_op_overload_table(
|
|||
@_beartype.beartype
|
||||
def create_onnx_friendly_decomposition_table(
|
||||
registry,
|
||||
) -> Dict[torch._ops.OpOverload, Callable]:
|
||||
) -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
"""
|
||||
This function creates a dictionary of op overloads and their decomposition functions
|
||||
for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function,
|
||||
|
|
@ -91,10 +91,10 @@ def create_onnx_friendly_decomposition_table(
|
|||
registry (torch.onnx.OnnxRegistry): The ONNX registry for PyTorch.
|
||||
|
||||
Returns:
|
||||
Dict[torch._ops.OpOverload, Callable]: A dictionary that maps op overloads to their corresponding
|
||||
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
|
||||
decomposition functions.
|
||||
"""
|
||||
decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
|
||||
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = {}
|
||||
# Dictionary that maps torch.ops.aten.* to exporter look up key; e.g.,
|
||||
# _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add".
|
||||
_ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user