Deprecate _preserve_ops and consolidate with decomp_table (#135080)

In this PR, we deprecate _preserve_ops feature in run_decomposition API. We can't kill this API completely because Executorch team depends on it. As the syncing between two repos is non-trivial, I just leave this argument as deprecated for now. In the next PR, i will immediately remove it.

After this PR, run_decompositions will only decompose what's inside the decomp table and preserve the rest by default. Note that this feature is only rolled out to OSS for now. Old code path is protected under IS_FBCODE flag.

Differential Revision: [D62163161](https://our.internmc.facebook.com/intern/diff/D62163161/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135080
Approved by: https://github.com/justinchuby, https://github.com/avikchaudhuri, https://github.com/bdhirsh
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2024-09-13 22:09:00 -07:00 committed by PyTorch MergeBot
parent 357b7fb579
commit 382fad58b3
8 changed files with 596 additions and 160 deletions

View File

@ -18,7 +18,11 @@ import torch._dynamo as torchdynamo
import torch.nn.functional as F import torch.nn.functional as F
from functorch.experimental.control_flow import cond, map from functorch.experimental.control_flow import cond, map
from torch import Tensor from torch import Tensor
from torch._decomp import get_decompositions from torch._decomp import (
_decomp_table_to_post_autograd_aten,
core_aten_decompositions,
get_decompositions,
)
from torch._dynamo.test_case import TestCase from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import normalize_gm from torch._dynamo.testing import normalize_gm
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
@ -1068,14 +1072,17 @@ graph():
x = self.linear(x) x = self.linear(x)
return torch.ops.aten.chunk.default(x, 3, 0) return torch.ops.aten.chunk.default(x, 3, 0)
gm = ( ep = torch.export.export(Foo(), (torch.randn(3, 3),))
torch.export.export( if IS_FBCODE:
Foo(), ep = ep.run_decompositions(
(torch.randn(3, 3),), {}, _preserve_ops=(torch.ops.aten.linear.default,)
)
.run_decompositions({}, _preserve_ops=(torch.ops.aten.linear.default,))
.graph_module
) )
else:
decomp_table = _decomp_table_to_post_autograd_aten()
del decomp_table[torch.ops.aten.linear.default]
ep = ep.run_decompositions(decomp_table)
gm = ep.graph_module
# linear is CompositeImplicitAutograd functional op so we should preserve it # linear is CompositeImplicitAutograd functional op so we should preserve it
# chunk is CompositeImplicitAutograd non-functional op we decompose. # chunk is CompositeImplicitAutograd non-functional op we decompose.
self.assertExpectedInline( self.assertExpectedInline(
@ -1436,6 +1443,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
"dy - 6 = 6" not in exc.args[0] "dy - 6 = 6" not in exc.args[0]
) # don't suggest fix for non-root dim ) # don't suggest fix for non-root dim
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
def test_keep_composite_ops_invalid(self): def test_keep_composite_ops_invalid(self):
class Foo(torch.nn.Module): class Foo(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -1446,32 +1454,29 @@ def forward(self, p_linear_weight, p_linear_bias, x):
x = self.linear(x) x = self.linear(x)
return torch.ops.aten.chunk.default(x, 3, 0) return torch.ops.aten.chunk.default(x, 3, 0)
with self.assertRaisesRegex( def _(*args, **kwargs):
RuntimeError, "aten.chunk.default is a mutating/aliasing op" return NotImplemented
):
_ = torch.export.export(
Foo(),
(torch.randn(3, 3),),
).run_decompositions({}, _preserve_ops=(torch.ops.aten.chunk.default,))
with self.assertRaisesRegex( with self.assertWarnsRegex(UserWarning, "The op aten.chunk.default"):
RuntimeError, "aten.sym_size.default is a metadata query function"
):
_ = torch.export.export( _ = torch.export.export(
Foo(), Foo(),
(torch.randn(3, 3),), (torch.randn(3, 3),),
).run_decompositions({}, _preserve_ops=(torch.ops.aten.sym_size.default,)) ).run_decompositions({torch.ops.aten.chunk.default: _})
with self.assertRaisesRegex( with self.assertWarnsRegex(UserWarning, "The op aten.sym_size.default"):
RuntimeError, _ = torch.export.export(
"We can't detect aten.native_batch_norm.default as a functional op statically", Foo(),
(torch.randn(3, 3),),
).run_decompositions({torch.ops.aten.sym_size.default: _})
with self.assertWarnsRegex(
UserWarning,
"The op aten.native_batch_norm.default",
): ):
_ = torch.export.export( _ = torch.export.export(
Foo(), Foo(),
(torch.randn(3, 3),), (torch.randn(3, 3),),
).run_decompositions( ).run_decompositions({torch.ops.aten.native_batch_norm.default: _})
{}, _preserve_ops=(torch.ops.aten.native_batch_norm.default,)
)
def test_keep_composite_ops_linear_convd(self): def test_keep_composite_ops_linear_convd(self):
class MyLinear(torch.nn.Module): class MyLinear(torch.nn.Module):
@ -1499,10 +1504,14 @@ def forward(self, p_linear_weight, p_linear_bias, x):
ep = torch.export.export( ep = torch.export.export(
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
) )
if IS_FBCODE:
ep_has_linear_convd = ep.run_decompositions( ep_has_linear_convd = ep.run_decompositions(
decomp_table={}, {},
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
) )
else:
ep_has_linear_convd = ep.run_decompositions({})
self.assertExpectedInline( self.assertExpectedInline(
str(ep_has_linear_convd.graph_module.code).strip(), str(ep_has_linear_convd.graph_module.code).strip(),
"""\ """\
@ -1516,13 +1525,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
return (add,)""", return (add,)""",
) )
if IS_FBCODE:
ep_has_convd = ep.run_decompositions( ep_has_convd = ep.run_decompositions(
decomp_table=None, _preserve_ops=(
_preserve_ops=[
torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.default,
torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.default,
],
) )
)
else:
decomp_table = core_aten_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
del decomp_table[torch.ops.aten.conv1d.default]
ep_has_convd = ep.run_decompositions(decomp_table=decomp_table)
self.assertExpectedInline( self.assertExpectedInline(
str(ep_has_convd.graph_module.code).strip(), str(ep_has_convd.graph_module.code).strip(),
"""\ """\
@ -1538,10 +1553,15 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add,)""", return (add,)""",
) )
if IS_FBCODE:
ep_has_convd = ep_has_convd.run_decompositions( ep_has_convd = ep_has_convd.run_decompositions(
decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] _preserve_ops=(torch.ops.aten.conv2d.default,)
) )
else:
decomp_table = core_aten_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table)
self.assertExpectedInline( self.assertExpectedInline(
str(ep_has_convd.graph_module.code).strip(), str(ep_has_convd.graph_module.code).strip(),
"""\ """\
@ -1584,9 +1604,15 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
ep = torch.export.export_for_training( ep = torch.export.export_for_training(
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
) )
if IS_FBCODE:
ep_has_linear_convd = ep.run_decompositions(
{},
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
)
else:
ep_has_linear_convd = ep.run_decompositions( ep_has_linear_convd = ep.run_decompositions(
decomp_table={}, decomp_table={},
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
) )
self.assertExpectedInline( self.assertExpectedInline(
@ -1602,13 +1628,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
return (add,)""", return (add,)""",
) )
if IS_FBCODE:
ep_has_convd = ep.run_decompositions( ep_has_convd = ep.run_decompositions(
decomp_table=None, _preserve_ops=(
_preserve_ops=[
torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.default,
torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.default,
],
) )
)
else:
decomp_table = core_aten_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
del decomp_table[torch.ops.aten.conv1d.default]
ep_has_convd = ep.run_decompositions(decomp_table=decomp_table)
self.assertExpectedInline( self.assertExpectedInline(
str(ep_has_convd.graph_module.code).strip(), str(ep_has_convd.graph_module.code).strip(),
@ -1626,9 +1658,14 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
return (add,)""", return (add,)""",
) )
if IS_FBCODE:
ep_has_convd = ep_has_convd.run_decompositions( ep_has_convd = ep_has_convd.run_decompositions(
decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] _preserve_ops=(torch.ops.aten.conv2d.default,)
) )
else:
decomp_table = core_aten_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table)
self.assertExpectedInline( self.assertExpectedInline(
str(ep_has_convd.graph_module.code).strip(), str(ep_has_convd.graph_module.code).strip(),
@ -1646,6 +1683,57 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
return (add,)""", return (add,)""",
) )
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
def test_error_when_passing_mutating_primitive_op(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x.sin()
ep = export(Foo(), (torch.ones(3, 3),))
with self.assertWarnsRegex(
UserWarning,
"The op aten.index_put_.default",
):
ep.run_decompositions({torch.ops.aten.index_put_.default: None})
def test_if_post_autograd_op_preserved(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x.sin() + x.sum()
ep = export(Foo(), (torch.ones(3, 3),))
if IS_FBCODE:
ep_preserve_sum = ep.run_decompositions(
_preserve_ops=(torch.ops.aten.sum.default,)
)
else:
decomp_table = core_aten_decompositions()
del decomp_table[torch.ops.aten.sum.default]
ep_preserve_sum = ep.run_decompositions(decomp_table)
# Even though we are decomposing to core aten which should make
# sum into sum.dim_IntList, we explicitly marked it to not do that.
self.assertExpectedInline(
str(ep_preserve_sum.graph_module.code).strip(),
"""\
def forward(self, x):
sin = torch.ops.aten.sin.default(x)
sum_1 = torch.ops.aten.sum.default(x); x = None
add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None
return (add,)""",
)
ep_no_preserve_sum = ep.run_decompositions()
self.assertExpectedInline(
str(ep_no_preserve_sum.graph_module.code).strip(),
"""\
def forward(self, x):
sin = torch.ops.aten.sin.default(x)
sum_1 = torch.ops.aten.sum.dim_IntList(x, []); x = None
add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None
return (add,)""",
)
def test_set_grad_empty(self): def test_set_grad_empty(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -4674,6 +4762,84 @@ def forward(self, b_a_buffer, x):
self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp))) self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
self.assertEqual(id(state_dict), id(ep.state_dict)) self.assertEqual(id(state_dict), id(ep.state_dict))
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
def test_export_decomp_torture_case_1(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin = torch.nn.Linear(10, 1)
def forward(self, x):
return self.lin(x)
inp = (torch.randn(5, 10),)
m = M()
ep = export(m, inp)
def custom_decomp_callable(x, weight, bias):
return x + bias
decomp_table = core_aten_decompositions()
decomp_table[torch.ops.aten.linear.default] = custom_decomp_callable
core_aten_ep = ep.run_decompositions(decomp_table)
self.assertExpectedInline(
str(core_aten_ep.graph_module.code).strip(),
"""\
def forward(self, p_lin_weight, p_lin_bias, x):
add = torch.ops.aten.add.Tensor(x, p_lin_bias); x = p_lin_bias = None
return (add,)""",
)
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
def test_export_decomp_torture_case_2(self):
class MyLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.conv1d = torch.nn.Conv1d(16, 33, 3)
self.linear = MyLinear()
def forward(self, x, y):
x_conv = self.conv(x)
y_conv_1d = self.conv1d(y)
x_linear = self.linear(x_conv)
return x_linear.cos() + y_conv_1d.sum()
ep = export(Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)))
ep_has_linear_convd = ep.run_decompositions(decomp_table={})
def _decompose_linear_custom(x, weight, bias):
return torch.matmul(x, weight.T) + 2 * bias
ep_decompose_linear = ep_has_linear_convd.run_decompositions(
decomp_table={torch.ops.aten.linear.default: _decompose_linear_custom}
)
self.assertExpectedInline(
str(ep_decompose_linear.graph_module.code).strip(),
"""\
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None
matmul = torch.ops.aten.matmul.default(conv2d, permute); conv2d = permute = None
mul = torch.ops.aten.mul.Tensor(c_linear_bias, 2); c_linear_bias = None
add = torch.ops.aten.add.Tensor(matmul, mul); matmul = mul = None
cos = torch.ops.aten.cos.default(add); add = None
sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None
add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add_1,)""",
)
def test_export_decomps_dynamic(self): def test_export_decomps_dynamic(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
@ -5244,7 +5410,7 @@ graph():
eps, eps,
), ),
) )
ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) ep.run_decompositions()
self.assertEqual( self.assertEqual(
ep.module()( ep.module()(
input, weight, bias, running_mean, running_var, training, momentum, eps input, weight, bias, running_mean, running_var, training, momentum, eps
@ -5474,7 +5640,7 @@ graph():
output = model(t, dim, index, src) output = model(t, dim, index, src)
ep = torch.export.export(model, args=(t, dim, index, src)) ep = torch.export.export(model, args=(t, dim, index, src))
ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) ep = ep.run_decompositions()
self.assertEqual(ep.module()(t, dim, index, src), output) self.assertEqual(ep.module()(t, dim, index, src), output)
def test_fqn(self): def test_fqn(self):
@ -8041,9 +8207,14 @@ class TestExportCustomClass(TorchTestCase):
ep.graph_module.code ep.graph_module.code
) )
if IS_FBCODE:
ep = ep.run_decompositions(_preserve_ops=(torch.ops.aten.elu.default,))
else:
decomp_table = core_aten_decompositions()
del decomp_table[torch.ops.aten.elu.default]
ep = ep.run_decompositions( ep = ep.run_decompositions(
decomp_table=get_decompositions([torch.ops.aten.elu.default]), decomp_table=decomp_table,
_preserve_ops=[torch.ops.aten.elu.default],
) )
FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run( FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run(
ep.graph_module.code ep.graph_module.code
@ -8066,12 +8237,17 @@ class TestExportCustomClass(TorchTestCase):
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
).run(ep.graph_module.code) ).run(ep.graph_module.code)
decomp_table = get_decompositions([torch.ops.aten.upsample_bilinear2d.vec]) if IS_FBCODE:
ep = ep.run_decompositions(
_preserve_ops=(torch.ops.aten.upsample_bilinear2d.vec,)
)
else:
decomp_table = core_aten_decompositions()
del decomp_table[torch.ops.aten.upsample_bilinear2d.vec]
ep = ep.run_decompositions( ep = ep.run_decompositions(
decomp_table=decomp_table, decomp_table=decomp_table,
_preserve_ops=[torch.ops.aten.upsample_bilinear2d.vec],
) )
assert torch.ops.aten.upsample_bilinear2d.vec in decomp_table
FileCheck().check_count( FileCheck().check_count(
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
).run(ep.graph_module.code) ).run(ep.graph_module.code)

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: export"] # Owner(s): ["oncall: export"]
import torch import torch
from torch.testing._internal.common_utils import IS_FBCODE
try: try:
@ -15,9 +16,11 @@ test_classes = {}
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs): def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
ep = torch.export.export_for_training(*args, **kwargs) ep = torch.export.export_for_training(*args, **kwargs)
if IS_FBCODE:
return ep.run_decompositions( return ep.run_decompositions(
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
) )
return ep.run_decompositions({})
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs): def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
@ -25,9 +28,12 @@ def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
ep = torch.export.export_for_training(*args, **kwargs) ep = torch.export.export_for_training(*args, **kwargs)
else: else:
ep = torch.export.export_for_training(*args, **kwargs, strict=False) ep = torch.export.export_for_training(*args, **kwargs, strict=False)
if IS_FBCODE:
return ep.run_decompositions( return ep.run_decompositions(
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
) )
return ep.run_decompositions({})
def make_dynamic_cls(cls, strict): def make_dynamic_cls(cls, strict):

View File

@ -10,7 +10,7 @@ from functools import partial
import torch._inductor.decomposition import torch._inductor.decomposition
import torch.autograd import torch.autograd
from torch import Tensor from torch import Tensor
from torch._decomp import core_aten_decompositions, decomposition_table from torch._decomp import _is_cia_op, core_aten_decompositions, decomposition_table
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
from torch._ops import DispatchKey from torch._ops import DispatchKey
from torch.testing import make_tensor from torch.testing import make_tensor
@ -62,7 +62,7 @@ decomposition_names = {
core_decomposition_names = { core_decomposition_names = {
overload_to_aten_name(k) overload_to_aten_name(k)
for k in core_aten_decompositions() for k in core_aten_decompositions()
if isinstance(k, torch._ops.OpOverload) if isinstance(k, torch._ops.OpOverload) and not _is_cia_op(k)
} }
_decomp_test_ops = [ _decomp_test_ops = [
op op

View File

@ -1,15 +1,26 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import inspect import inspect
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import lru_cache, partial, wraps
from itertools import chain from itertools import chain
from typing import Callable, Dict, List, Sequence, TypeVar, Union from typing import (
Callable,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
)
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import torch import torch
import torch.library import torch.library
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket
from torch._prims_common import CustomOutParamAnnotation from torch._prims_common import CustomOutParamAnnotation
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
@ -20,6 +31,8 @@ __all__ = [
"register_decomposition", "register_decomposition",
"get_decompositions", "get_decompositions",
"core_aten_decompositions", "core_aten_decompositions",
"_decomp_table_to_post_autograd_aten",
"_special_op_to_preserve_cia",
] ]
_T = TypeVar("_T") _T = TypeVar("_T")
@ -250,13 +263,184 @@ import torch._decomp.decompositions
import torch._refs import torch._refs
# Our strategy for deciding if we can preserve a op is following:
# 1. The op should be known statically that it is functional
# 2. If it is maybe aliasing, we decompose because we must know if an op
# is mutating or aliasing.
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
def _check_valid_to_preserve(op_overload):
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
return False
if op_overload in FunctionalTensor.metadata_fns:
return False
alias_info = len(
[i for i in op_overload._schema.arguments if i.alias_info is not None]
)
is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable
if is_mutating_or_aliasing:
return False
if not torch._C._dispatch_has_kernel(op_overload.name()):
return False
return True
def _is_cia_op(op: "OpOverload") -> bool:
return (
torch._C._dispatch_has_kernel_for_dispatch_key(
op.name(), torch._C.DispatchKey.CompositeImplicitAutograd
)
or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels
)
@lru_cache(maxsize=1)
def _collect_all_valid_cia_ops() -> Set["OperatorBase"]:
"""
This is an util function that gets the all CIA functional ops.
The algorithm is in 2 steps:
1. We first query C++ dispatcher to get the list of CIA ops
and then we call getattr on torch.ops.aten to lazily populate
them.
2. Sometimes, handful of ops have CIA registered in python dispatcher
but not on the C++ side, these can't be caught at the first step.
So we walk again to get the final list.
Note that the output of this function should never be modified
"""
# First step to lazily populate torch.ops.aten
cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key(
"CompositeImplicitAutograd"
)
# Ignore quantized namespace ops
cia_ops = [name[6:] for name in cia_ops if name.startswith("aten::")]
# Materialize all CIA ops first
for op in cia_ops:
split_list = op.split(".")
# Sometime overload could be missing
assert len(split_list) == 1 or len(split_list) == 2
op_name = split_list[0]
op_overload_name = "default"
if len(split_list) == 2:
op_overload_name = split_list[1]
_ = getattr(getattr(torch.ops.aten, op_name), op_overload_name)
# Second step to finally compile the list of all valid ops
cia_ops = set()
for op in torch.ops.aten:
op_packet = getattr(torch.ops.aten, op)
for overload in op_packet.overloads():
op_overload = getattr(op_packet, overload)
if _check_valid_to_preserve(op_overload) and _is_cia_op(op_overload):
cia_ops.add(op_overload)
return cia_ops
def _get_decomp_for_cia(op):
# [NOTE] Seperating out func.decompose
# Ideally we should be able to just register func.decompose but
# we can't as this decomp is gonna be registered to the py_impl.
# As a result it will infinitely recurse. So we first check if the op
# has py_impl entry for CIA and if it is we use that first. If not,
# we register C++ query to py_impl.
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey):
return op.py_kernels[dk]
def _special_op_to_decompose_cia(*args, **kwargs):
kernel = kwargs["kernel"]
del kwargs["kernel"]
# Can't call kernel.decompose due to infinite recursion as
# we register this kernel to py_impl directly
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if torch._C._dispatch_has_kernel_for_dispatch_key(
kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd
):
return kernel._op_dk(dk, *args, **kwargs)
else:
raise AssertionError(
f"Expected {kernel} to have CompositeImplicitAutograd kernel"
)
return partial(_special_op_to_decompose_cia, kernel=op)
# See NOTE [Core ATen Ops] # See NOTE [Core ATen Ops]
# #
# list was copied from torch/_inductor/decomposition.py # list was copied from torch/_inductor/decomposition.py
# excluding decompositions that results in prim ops # excluding decompositions that results in prim ops
# Resulting opset of decomposition is core aten ops # Resulting opset of decomposition is core aten ops
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
decomp_table = _core_aten_decompositions_post_autograd()
# If it is fbcode change, we return the old decomposition list
from torch._inductor import config
if config.is_fbcode():
return decomp_table
aten = torch.ops.aten aten = torch.ops.aten
# We are deleting custom decomp in core_aten_decomp
# for CIA ops but it should be fine technically
# because this table is only meant to be used in export context
# in which we really carefully control the decomp behaviour
# In any case, C++ decomps should be preferred
cia_ops_that_should_be_removed = [
aten.all.dimname,
aten.index_add.dimname,
aten.index_copy.dimname,
aten.index_fill.Dimname_Scalar,
aten.index_fill.Dimname_Tensor,
aten.norm.names_ScalarOpt_dim_dtype,
aten.norm.names_ScalarOpt_dim,
aten.silu_backward.default,
aten.std.default,
aten.std.dim,
aten.std.names_dim,
aten.std.correction_names,
aten.std_mean.default,
aten.std_mean.dim,
aten.std_mean.names_dim,
aten.std_mean.correction_names,
aten.upsample_bilinear2d.vec,
aten.upsample_trilinear3d.vec,
]
for k in list(decomp_table.keys()):
if k in cia_ops_that_should_be_removed:
del decomp_table[k]
for op in _collect_all_valid_cia_ops():
decomp_table[op] = _get_decomp_for_cia(op)
return decomp_table
# This table is a stop-gap table which replicates
# the old behaviour of post-dispatch IR.
# This table contains all functional CIA ops mapping
# to their default decomp. In old export, this will
# be decomposed implicitly.
def _decomp_table_to_post_autograd_aten():
decomp_table = {}
for k in _collect_all_valid_cia_ops():
decomp_table[k] = _get_decomp_for_cia(k)
return decomp_table
def _core_aten_decompositions_post_autograd() -> (
Dict[torch._ops.OperatorBase, Callable]
):
aten = torch.ops.aten
# TODO Delete all mutating or CIA ops from this list
return get_decompositions( return get_decompositions(
[ [
aten.addcdiv, aten.addcdiv,

View File

@ -57,8 +57,8 @@ def _export_forward_backward(
ep = _decompose_exported_program( ep = _decompose_exported_program(
ep, ep,
decomp_table=core_aten_decompositions(), cia_to_decomp={},
_preserve_ops=(), # type: ignore[arg-type] python_decomp_table=core_aten_decompositions(),
joint_loss_index=joint_loss_index, joint_loss_index=joint_loss_index,
) )
gm, new_graph_signature = _copy_graph_module_and_signature(ep) gm, new_graph_signature = _copy_graph_module_and_signature(ep)

View File

@ -54,7 +54,6 @@ from torch._export.utils import (
from torch._export.verifier import Verifier from torch._export.verifier import Verifier
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import unset_fake_temporarily from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._tree_utils import is_equivalent, reorder_kwargs from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_base import PassResult
@ -174,7 +173,7 @@ _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [
@contextmanager @contextmanager
def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True): def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True):
# This function overrides CompositeImplicitAutograd decomp for # This function overrides CompositeImplicitAutograd decomp for
# functional composite ops that user specified. Ideally we want to not-decompose # functional composite ops that user specified. Ideally we want to not-decompose
# ALL composite ops but today's C++ functinalization relies on # ALL composite ops but today's C++ functinalization relies on
@ -192,48 +191,7 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
saved_tables = {} saved_tables = {}
patched_ops = set() patched_ops = set()
removed_decomps = {} for op_overload, decomp_callable in cia_ops_to_callable.items():
for op_overload in ops_to_preserve:
# Our strategy for deciding if we can preserve CIA is following:
# 1. The op should be known statically that it is functional
# 2. If it is maybe aliasing, we decompose because we must know if an op
# is mutating or aliasing.
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
def assert_valid_to_preserve(op_overload):
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
raise RuntimeError(
f"We can't detect {op_overload} as a functional op statically, so we can't preserve it"
)
if op_overload in FunctionalTensor.metadata_fns:
raise RuntimeError(
f"{op_overload} is a metadata query function, "
"it will be preserved implicitly in our tracing system. "
"Please file an issue on github if you see otherwise"
)
alias_info = len(
[i for i in op_overload._schema.arguments if i.alias_info is not None]
)
is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable
if is_mutating_or_aliasing:
raise RuntimeError(
f"{op_overload} is a mutating/aliasing op, we can't preserve it as is"
)
if not torch._C._dispatch_has_kernel(op_overload.name()):
raise RuntimeError(
f"{op_overload} is a TorchScript op, we can't preserve it as is"
)
return True
if safe:
# If we didn't error, it means we can go ahead
assert_valid_to_preserve(op_overload)
saved_tables[op_overload] = op_overload.py_kernels.copy() saved_tables[op_overload] = op_overload.py_kernels.copy()
patched_ops.add(op_overload) patched_ops.add(op_overload)
@ -247,11 +205,9 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]
if safe: if safe:
op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(
def _(*args, **kwargs): decomp_callable
return NotImplemented )
op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_)
# For fake tensor prop, we do want to register meta kernel directly # For fake tensor prop, we do want to register meta kernel directly
if torch._C.DispatchKey.Meta not in op_overload.py_kernels: if torch._C.DispatchKey.Meta not in op_overload.py_kernels:
@ -259,10 +215,6 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
functools.partial(_register_cia_to_meta, kernel=op_overload) functools.partial(_register_cia_to_meta, kernel=op_overload)
) )
if op_overload in decomp_table:
removed_decomps[op_overload] = decomp_table[op_overload]
del decomp_table[op_overload]
try: try:
yield yield
finally: finally:
@ -271,8 +223,10 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
op.py_kernels.update(saved_tables[op]) op.py_kernels.update(saved_tables[op])
op._dispatch_cache.clear() op._dispatch_cache.clear()
for op, decomp in removed_decomps.items():
decomp_table[op] = decomp def _special_op_to_preserve_cia(*args, **kwargs):
"This is an special marker that tells our infra that we shouldn't decompose this op"
return NotImplemented
@contextmanager @contextmanager
@ -281,18 +235,65 @@ def _override_decomp_aten_to_variants():
# and their CompositeImplicitAutograd kernels will not become NotImplemented. # and their CompositeImplicitAutograd kernels will not become NotImplemented.
# We will later replace them with aten._to_copy when functionalizing. # We will later replace them with aten._to_copy when functionalizing.
with _override_composite_implicit_decomp( with _override_composite_implicit_decomp(
(torch.ops.aten.to.dtype_layout, torch.ops.aten.to.dtype), {
{}, torch.ops.aten.to.dtype_layout: _special_op_to_preserve_cia,
torch.ops.aten.to.dtype: _special_op_to_preserve_cia,
},
safe=False, safe=False,
): ):
yield yield
def _split_decomp_table_to_cia_and_python_decomp(
decomp_table: Dict[torch._ops.OperatorBase, Callable]
) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]:
from torch._decomp import _collect_all_valid_cia_ops, _get_decomp_for_cia
all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
cia_ops_to_callable = {}
for op in list(decomp_table.keys()):
# TODO we are silently allowing non-safe(non-functional) ops through a crack
# due to core aten decomp table having non-functional entries. Once we have
# a tigher check around core aten decomp, we should warn users about them.
# Tracking issue: (https://github.com/pytorch/pytorch/issues/135759)
# if it is a valid CIA op we can mess with in export, we check if it is:
# 1. Has been marked as to be decomposed. Example:
# decomp_table = decomp_table_to_core_aten()
# del decomp_table[aten.linear]
# In this case, user says decompose everything except for aten.linear
# 2. Has been marked with custom decomp behavour. Example:
# decomp_table = {aten.linear: some_op}
# For (1), we want to remove all the CIA ops that weren't handled by user as
# it suggests they are safe to decompose, so we should remove from preservable_list.
# for (2), we just plumb the custom decomp to AOTDIspatcher.
# In both cases, we want to remove this CIA op from the decomp_table as it is special
# handled.
if op in all_preservable_cia_ops:
# TODO this is annpying case where aten.item has
# prim decomposition which later calls into aten.item
# and recurses infinitely. (https://github.com/pytorch/pytorch/issues/136050)
if op == torch.ops.aten.item.default:
cia_ops_to_callable[op] = _get_decomp_for_cia(op)
else:
cia_ops_to_callable[op] = decomp_table[op]
all_preservable_cia_ops.remove(op)
del decomp_table[op]
# If we reached here, it means user intentionally deleted these CIA ops from
# decomp table.
for k in all_preservable_cia_ops:
cia_ops_to_callable[k] = _special_op_to_preserve_cia
return cia_ops_to_callable, decomp_table
def _decompose_and_get_gm_with_new_signature_constants( def _decompose_and_get_gm_with_new_signature_constants(
ep, ep,
*, *,
decomp_table: Dict[torch._ops.OperatorBase, Callable], cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
_preserve_ops: Tuple[torch._ops.OpOverload], python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
joint_loss_index: Optional[int], joint_loss_index: Optional[int],
): ):
from torch._functorch.aot_autograd import aot_export_module from torch._functorch.aot_autograd import aot_export_module
@ -356,8 +357,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
with _ignore_backend_decomps(), ( with _ignore_backend_decomps(), (
fake_mode fake_mode
), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp(
_preserve_ops, cia_to_decomp,
decomp_table,
): ):
aten_export_artifact = _export_to_aten_ir( aten_export_artifact = _export_to_aten_ir(
mod, mod,
@ -369,7 +369,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
{}, {},
fake_params_buffers, fake_params_buffers,
constant_attrs, constant_attrs,
decomp_table=decomp_table, decomp_table=python_decomp_table,
_check_autograd_state=False, _check_autograd_state=False,
) )
@ -404,13 +404,12 @@ def _decompose_and_get_gm_with_new_signature_constants(
fake_mode = detect_fake_mode(fake_args) fake_mode = detect_fake_mode(fake_args)
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode
with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp(
_preserve_ops, cia_to_decomp
decomp_table,
): ):
gm, graph_signature = aot_export_module( gm, graph_signature = aot_export_module(
ep.graph_module, ep.graph_module,
fake_args, fake_args,
decompositions=decomp_table, decompositions=python_decomp_table,
trace_joint=True if joint_loss_index is not None else False, trace_joint=True if joint_loss_index is not None else False,
output_loss_index=joint_loss_index output_loss_index=joint_loss_index
if joint_loss_index is not None if joint_loss_index is not None
@ -610,14 +609,14 @@ def _common_getitem_elimination_pass(
def _decompose_exported_program( def _decompose_exported_program(
ep, ep,
*, *,
decomp_table: Dict[torch._ops.OperatorBase, Callable], cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
_preserve_ops: Tuple[torch._ops.OpOverload], python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
joint_loss_index: Optional[int], joint_loss_index: Optional[int],
): ):
gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants( gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
ep, ep,
decomp_table=decomp_table, cia_to_decomp=cia_to_decomp,
_preserve_ops=_preserve_ops, python_decomp_table=python_decomp_table,
joint_loss_index=joint_loss_index, joint_loss_index=joint_loss_index,
) )
@ -994,16 +993,83 @@ class ExportedProgram:
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
For now, we do not decompose joint graphs. For now, we do not decompose joint graphs.
"""
from torch._decomp import core_aten_decompositions
if decomp_table is None: Args:
decomp_table:
An optional argument that specifies decomp behaviour for Aten ops
(1) If None, we decompose to core aten decompositions
(2) If empty, we don't decompose any operator
Some examples:
If you don't want to decompose anything
.. code-block:: python
ep = torch.export.export(model, ...)
ep = ep.run_decompositions(decomp_table={})
If you want to get a core aten operator set except for certain operator, you can do following:
.. code-block:: python
ep = torch.export.export(model, ...)
from torch._decomp import core_aten_decompositions
decomp_table = core_aten_decompositions() decomp_table = core_aten_decompositions()
decomp_table[your_op] = your_custom_decomp
ep = ep.run_decompositions(decomp_table=decomp_table)
"""
from torch._decomp import (
_decomp_table_to_post_autograd_aten,
core_aten_decompositions,
)
from torch._inductor import config
# FIXME delete this option after PTC, Executorch syncing is
# bit annoying so can't get rid of it easily
if _preserve_ops != ():
warnings.warn(
"This API is deprecated and soon will be removed. "
"Please look at the docstring to see how to preserve "
"an operator."
)
_decomp_table = (
core_aten_decompositions() if decomp_table is None else dict(decomp_table)
)
if config.is_fbcode():
# This means the decomp_table would only be containing post-autograd ops
# We should manually add CIA decomps
for k, v in _decomp_table_to_post_autograd_aten().items():
_decomp_table[k] = v
for op in _preserve_ops:
if op in _decomp_table:
del _decomp_table[op]
# Note [Seperating decomp_table into CIA decomps and non-CIA decomps]
# At this point, we have a decomp_table that contains decomp behaviour for
# both CIA and post-autograd ops.
# We need to separate the op into two categories:
# 1. CIA op: These are the ops that we want to override
# CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp
# context manager to plumb it through AOTDispatcher
# 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just
# checking if they are statically functional is enough.
# For joint IR case tho, we need to use the old path because we can't register
# custom decomps this way because we can't use context manager as it installs
# autograd_error node.
(
cia_to_decomp,
python_decomp_table,
) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table)
return _decompose_exported_program( return _decompose_exported_program(
self, self,
decomp_table=decomp_table, cia_to_decomp=cia_to_decomp,
_preserve_ops=_preserve_ops, # type: ignore[arg-type] python_decomp_table=python_decomp_table,
joint_loss_index=None, joint_loss_index=None,
) )

View File

@ -1,5 +1,3 @@
"""Build decomp table from PyTorch."""
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations from __future__ import annotations
@ -83,7 +81,14 @@ def create_onnx_friendly_decomposition_table(
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
decomposition functions. decomposition functions.
""" """
decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} # This table contains all CIA decomps, so we should filter out ONNX supported CIAs from it
decomposition_table: dict[torch._ops.OperatorBase, Callable] = (
torch._decomp._decomp_table_to_post_autograd_aten()
)
can_preserve = get_preserve_ops().intersection(onnx_registered_ops)
for op in list(decomposition_table.keys()):
if op in can_preserve:
del decomposition_table[op]
# NOTE: If we import torch._decomp, we will get RuntimeError: Only a single # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single
# TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your
@ -95,6 +100,9 @@ def create_onnx_friendly_decomposition_table(
# not exportable anyways. # not exportable anyways.
if op_overload in onnx_registered_ops: if op_overload in onnx_registered_ops:
continue continue
# If it is HOP, we filter those out as well.
if not hasattr(op_overload, "_schema"):
continue
decomposition_table[op_overload] = decomp_fn decomposition_table[op_overload] = decomp_fn
return decomposition_table return decomposition_table

View File

@ -17,11 +17,7 @@ def decompose_with_registry(
""" """
onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry)) onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry))
decomp_table = _decomp.create_onnx_friendly_decomposition_table(onnx_registered_ops) decomp_table = _decomp.create_onnx_friendly_decomposition_table(onnx_registered_ops)
# Try to preserve some known CompositeImplicitAutograd ops return exported_program.run_decompositions(decomp_table)
to_preserve = _decomp.get_preserve_ops()
# We can only preserve implemented ops
can_preserve = tuple(to_preserve.intersection(onnx_registered_ops))
return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve)
def insert_type_promotion_nodes( def insert_type_promotion_nodes(