mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Deprecate _preserve_ops and consolidate with decomp_table (#135080)
In this PR, we deprecate _preserve_ops feature in run_decomposition API. We can't kill this API completely because Executorch team depends on it. As the syncing between two repos is non-trivial, I just leave this argument as deprecated for now. In the next PR, i will immediately remove it. After this PR, run_decompositions will only decompose what's inside the decomp table and preserve the rest by default. Note that this feature is only rolled out to OSS for now. Old code path is protected under IS_FBCODE flag. Differential Revision: [D62163161](https://our.internmc.facebook.com/intern/diff/D62163161/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135080 Approved by: https://github.com/justinchuby, https://github.com/avikchaudhuri, https://github.com/bdhirsh
This commit is contained in:
parent
357b7fb579
commit
382fad58b3
|
|
@ -18,7 +18,11 @@ import torch._dynamo as torchdynamo
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from functorch.experimental.control_flow import cond, map
|
from functorch.experimental.control_flow import cond, map
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch._decomp import get_decompositions
|
from torch._decomp import (
|
||||||
|
_decomp_table_to_post_autograd_aten,
|
||||||
|
core_aten_decompositions,
|
||||||
|
get_decompositions,
|
||||||
|
)
|
||||||
from torch._dynamo.test_case import TestCase
|
from torch._dynamo.test_case import TestCase
|
||||||
from torch._dynamo.testing import normalize_gm
|
from torch._dynamo.testing import normalize_gm
|
||||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||||
|
|
@ -1068,14 +1072,17 @@ graph():
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return torch.ops.aten.chunk.default(x, 3, 0)
|
return torch.ops.aten.chunk.default(x, 3, 0)
|
||||||
|
|
||||||
gm = (
|
ep = torch.export.export(Foo(), (torch.randn(3, 3),))
|
||||||
torch.export.export(
|
if IS_FBCODE:
|
||||||
Foo(),
|
ep = ep.run_decompositions(
|
||||||
(torch.randn(3, 3),),
|
{}, _preserve_ops=(torch.ops.aten.linear.default,)
|
||||||
)
|
)
|
||||||
.run_decompositions({}, _preserve_ops=(torch.ops.aten.linear.default,))
|
else:
|
||||||
.graph_module
|
decomp_table = _decomp_table_to_post_autograd_aten()
|
||||||
)
|
del decomp_table[torch.ops.aten.linear.default]
|
||||||
|
ep = ep.run_decompositions(decomp_table)
|
||||||
|
|
||||||
|
gm = ep.graph_module
|
||||||
# linear is CompositeImplicitAutograd functional op so we should preserve it
|
# linear is CompositeImplicitAutograd functional op so we should preserve it
|
||||||
# chunk is CompositeImplicitAutograd non-functional op we decompose.
|
# chunk is CompositeImplicitAutograd non-functional op we decompose.
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
|
|
@ -1436,6 +1443,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||||
"dy - 6 = 6" not in exc.args[0]
|
"dy - 6 = 6" not in exc.args[0]
|
||||||
) # don't suggest fix for non-root dim
|
) # don't suggest fix for non-root dim
|
||||||
|
|
||||||
|
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
|
||||||
def test_keep_composite_ops_invalid(self):
|
def test_keep_composite_ops_invalid(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -1446,32 +1454,29 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return torch.ops.aten.chunk.default(x, 3, 0)
|
return torch.ops.aten.chunk.default(x, 3, 0)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
def _(*args, **kwargs):
|
||||||
RuntimeError, "aten.chunk.default is a mutating/aliasing op"
|
return NotImplemented
|
||||||
):
|
|
||||||
_ = torch.export.export(
|
|
||||||
Foo(),
|
|
||||||
(torch.randn(3, 3),),
|
|
||||||
).run_decompositions({}, _preserve_ops=(torch.ops.aten.chunk.default,))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertWarnsRegex(UserWarning, "The op aten.chunk.default"):
|
||||||
RuntimeError, "aten.sym_size.default is a metadata query function"
|
|
||||||
):
|
|
||||||
_ = torch.export.export(
|
_ = torch.export.export(
|
||||||
Foo(),
|
Foo(),
|
||||||
(torch.randn(3, 3),),
|
(torch.randn(3, 3),),
|
||||||
).run_decompositions({}, _preserve_ops=(torch.ops.aten.sym_size.default,))
|
).run_decompositions({torch.ops.aten.chunk.default: _})
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertWarnsRegex(UserWarning, "The op aten.sym_size.default"):
|
||||||
RuntimeError,
|
_ = torch.export.export(
|
||||||
"We can't detect aten.native_batch_norm.default as a functional op statically",
|
Foo(),
|
||||||
|
(torch.randn(3, 3),),
|
||||||
|
).run_decompositions({torch.ops.aten.sym_size.default: _})
|
||||||
|
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning,
|
||||||
|
"The op aten.native_batch_norm.default",
|
||||||
):
|
):
|
||||||
_ = torch.export.export(
|
_ = torch.export.export(
|
||||||
Foo(),
|
Foo(),
|
||||||
(torch.randn(3, 3),),
|
(torch.randn(3, 3),),
|
||||||
).run_decompositions(
|
).run_decompositions({torch.ops.aten.native_batch_norm.default: _})
|
||||||
{}, _preserve_ops=(torch.ops.aten.native_batch_norm.default,)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_keep_composite_ops_linear_convd(self):
|
def test_keep_composite_ops_linear_convd(self):
|
||||||
class MyLinear(torch.nn.Module):
|
class MyLinear(torch.nn.Module):
|
||||||
|
|
@ -1499,10 +1504,14 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||||
ep = torch.export.export(
|
ep = torch.export.export(
|
||||||
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
|
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
|
||||||
)
|
)
|
||||||
ep_has_linear_convd = ep.run_decompositions(
|
if IS_FBCODE:
|
||||||
decomp_table={},
|
ep_has_linear_convd = ep.run_decompositions(
|
||||||
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
|
{},
|
||||||
)
|
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ep_has_linear_convd = ep.run_decompositions({})
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep_has_linear_convd.graph_module.code).strip(),
|
str(ep_has_linear_convd.graph_module.code).strip(),
|
||||||
"""\
|
"""\
|
||||||
|
|
@ -1516,13 +1525,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
ep_has_convd = ep.run_decompositions(
|
if IS_FBCODE:
|
||||||
decomp_table=None,
|
ep_has_convd = ep.run_decompositions(
|
||||||
_preserve_ops=[
|
_preserve_ops=(
|
||||||
torch.ops.aten.conv2d.default,
|
torch.ops.aten.conv2d.default,
|
||||||
torch.ops.aten.conv1d.default,
|
torch.ops.aten.conv1d.default,
|
||||||
],
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
decomp_table = core_aten_decompositions()
|
||||||
|
del decomp_table[torch.ops.aten.conv2d.default]
|
||||||
|
del decomp_table[torch.ops.aten.conv1d.default]
|
||||||
|
|
||||||
|
ep_has_convd = ep.run_decompositions(decomp_table=decomp_table)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep_has_convd.graph_module.code).strip(),
|
str(ep_has_convd.graph_module.code).strip(),
|
||||||
"""\
|
"""\
|
||||||
|
|
@ -1538,10 +1553,15 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||||
add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
if IS_FBCODE:
|
||||||
|
ep_has_convd = ep_has_convd.run_decompositions(
|
||||||
|
_preserve_ops=(torch.ops.aten.conv2d.default,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decomp_table = core_aten_decompositions()
|
||||||
|
del decomp_table[torch.ops.aten.conv2d.default]
|
||||||
|
|
||||||
ep_has_convd = ep_has_convd.run_decompositions(
|
ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table)
|
||||||
decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default]
|
|
||||||
)
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep_has_convd.graph_module.code).strip(),
|
str(ep_has_convd.graph_module.code).strip(),
|
||||||
"""\
|
"""\
|
||||||
|
|
@ -1584,10 +1604,16 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||||
ep = torch.export.export_for_training(
|
ep = torch.export.export_for_training(
|
||||||
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
|
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
|
||||||
)
|
)
|
||||||
ep_has_linear_convd = ep.run_decompositions(
|
|
||||||
decomp_table={},
|
if IS_FBCODE:
|
||||||
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
|
ep_has_linear_convd = ep.run_decompositions(
|
||||||
)
|
{},
|
||||||
|
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ep_has_linear_convd = ep.run_decompositions(
|
||||||
|
decomp_table={},
|
||||||
|
)
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep_has_linear_convd.graph_module.code).strip(),
|
str(ep_has_linear_convd.graph_module.code).strip(),
|
||||||
|
|
@ -1602,13 +1628,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
ep_has_convd = ep.run_decompositions(
|
if IS_FBCODE:
|
||||||
decomp_table=None,
|
ep_has_convd = ep.run_decompositions(
|
||||||
_preserve_ops=[
|
_preserve_ops=(
|
||||||
torch.ops.aten.conv2d.default,
|
torch.ops.aten.conv2d.default,
|
||||||
torch.ops.aten.conv1d.default,
|
torch.ops.aten.conv1d.default,
|
||||||
],
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
decomp_table = core_aten_decompositions()
|
||||||
|
del decomp_table[torch.ops.aten.conv2d.default]
|
||||||
|
del decomp_table[torch.ops.aten.conv1d.default]
|
||||||
|
|
||||||
|
ep_has_convd = ep.run_decompositions(decomp_table=decomp_table)
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep_has_convd.graph_module.code).strip(),
|
str(ep_has_convd.graph_module.code).strip(),
|
||||||
|
|
@ -1626,9 +1658,14 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
ep_has_convd = ep_has_convd.run_decompositions(
|
if IS_FBCODE:
|
||||||
decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default]
|
ep_has_convd = ep_has_convd.run_decompositions(
|
||||||
)
|
_preserve_ops=(torch.ops.aten.conv2d.default,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decomp_table = core_aten_decompositions()
|
||||||
|
del decomp_table[torch.ops.aten.conv2d.default]
|
||||||
|
ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table)
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(ep_has_convd.graph_module.code).strip(),
|
str(ep_has_convd.graph_module.code).strip(),
|
||||||
|
|
@ -1646,6 +1683,57 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skip("See https://github.com/pytorch/pytorch/issues/135759")
|
||||||
|
def test_error_when_passing_mutating_primitive_op(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
ep = export(Foo(), (torch.ones(3, 3),))
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning,
|
||||||
|
"The op aten.index_put_.default",
|
||||||
|
):
|
||||||
|
ep.run_decompositions({torch.ops.aten.index_put_.default: None})
|
||||||
|
|
||||||
|
def test_if_post_autograd_op_preserved(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.sin() + x.sum()
|
||||||
|
|
||||||
|
ep = export(Foo(), (torch.ones(3, 3),))
|
||||||
|
if IS_FBCODE:
|
||||||
|
ep_preserve_sum = ep.run_decompositions(
|
||||||
|
_preserve_ops=(torch.ops.aten.sum.default,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decomp_table = core_aten_decompositions()
|
||||||
|
del decomp_table[torch.ops.aten.sum.default]
|
||||||
|
ep_preserve_sum = ep.run_decompositions(decomp_table)
|
||||||
|
|
||||||
|
# Even though we are decomposing to core aten which should make
|
||||||
|
# sum into sum.dim_IntList, we explicitly marked it to not do that.
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(ep_preserve_sum.graph_module.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, x):
|
||||||
|
sin = torch.ops.aten.sin.default(x)
|
||||||
|
sum_1 = torch.ops.aten.sum.default(x); x = None
|
||||||
|
add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None
|
||||||
|
return (add,)""",
|
||||||
|
)
|
||||||
|
|
||||||
|
ep_no_preserve_sum = ep.run_decompositions()
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(ep_no_preserve_sum.graph_module.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, x):
|
||||||
|
sin = torch.ops.aten.sin.default(x)
|
||||||
|
sum_1 = torch.ops.aten.sum.dim_IntList(x, []); x = None
|
||||||
|
add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None
|
||||||
|
return (add,)""",
|
||||||
|
)
|
||||||
|
|
||||||
def test_set_grad_empty(self):
|
def test_set_grad_empty(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
@ -4674,6 +4762,84 @@ def forward(self, b_a_buffer, x):
|
||||||
self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
|
self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
|
||||||
self.assertEqual(id(state_dict), id(ep.state_dict))
|
self.assertEqual(id(state_dict), id(ep.state_dict))
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
|
||||||
|
def test_export_decomp_torture_case_1(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.lin = torch.nn.Linear(10, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.lin(x)
|
||||||
|
|
||||||
|
inp = (torch.randn(5, 10),)
|
||||||
|
m = M()
|
||||||
|
ep = export(m, inp)
|
||||||
|
|
||||||
|
def custom_decomp_callable(x, weight, bias):
|
||||||
|
return x + bias
|
||||||
|
|
||||||
|
decomp_table = core_aten_decompositions()
|
||||||
|
decomp_table[torch.ops.aten.linear.default] = custom_decomp_callable
|
||||||
|
core_aten_ep = ep.run_decompositions(decomp_table)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(core_aten_ep.graph_module.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, p_lin_weight, p_lin_bias, x):
|
||||||
|
add = torch.ops.aten.add.Tensor(x, p_lin_bias); x = p_lin_bias = None
|
||||||
|
return (add,)""",
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
|
||||||
|
def test_export_decomp_torture_case_2(self):
|
||||||
|
class MyLinear(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.randn(20, 98)
|
||||||
|
self.bias = torch.randn(20)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(16, 33, 3)
|
||||||
|
self.conv1d = torch.nn.Conv1d(16, 33, 3)
|
||||||
|
self.linear = MyLinear()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
x_conv = self.conv(x)
|
||||||
|
y_conv_1d = self.conv1d(y)
|
||||||
|
x_linear = self.linear(x_conv)
|
||||||
|
return x_linear.cos() + y_conv_1d.sum()
|
||||||
|
|
||||||
|
ep = export(Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)))
|
||||||
|
ep_has_linear_convd = ep.run_decompositions(decomp_table={})
|
||||||
|
|
||||||
|
def _decompose_linear_custom(x, weight, bias):
|
||||||
|
return torch.matmul(x, weight.T) + 2 * bias
|
||||||
|
|
||||||
|
ep_decompose_linear = ep_has_linear_convd.run_decompositions(
|
||||||
|
decomp_table={torch.ops.aten.linear.default: _decompose_linear_custom}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(ep_decompose_linear.graph_module.code).strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
|
||||||
|
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
|
||||||
|
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
|
||||||
|
permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None
|
||||||
|
matmul = torch.ops.aten.matmul.default(conv2d, permute); conv2d = permute = None
|
||||||
|
mul = torch.ops.aten.mul.Tensor(c_linear_bias, 2); c_linear_bias = None
|
||||||
|
add = torch.ops.aten.add.Tensor(matmul, mul); matmul = mul = None
|
||||||
|
cos = torch.ops.aten.cos.default(add); add = None
|
||||||
|
sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None
|
||||||
|
add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||||
|
return (add_1,)""",
|
||||||
|
)
|
||||||
|
|
||||||
def test_export_decomps_dynamic(self):
|
def test_export_decomps_dynamic(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -5244,7 +5410,7 @@ graph():
|
||||||
eps,
|
eps,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
ep.run_decompositions(decomp_table=torch._decomp.decomposition_table)
|
ep.run_decompositions()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
ep.module()(
|
ep.module()(
|
||||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||||
|
|
@ -5474,7 +5640,7 @@ graph():
|
||||||
output = model(t, dim, index, src)
|
output = model(t, dim, index, src)
|
||||||
|
|
||||||
ep = torch.export.export(model, args=(t, dim, index, src))
|
ep = torch.export.export(model, args=(t, dim, index, src))
|
||||||
ep.run_decompositions(decomp_table=torch._decomp.decomposition_table)
|
ep = ep.run_decompositions()
|
||||||
self.assertEqual(ep.module()(t, dim, index, src), output)
|
self.assertEqual(ep.module()(t, dim, index, src), output)
|
||||||
|
|
||||||
def test_fqn(self):
|
def test_fqn(self):
|
||||||
|
|
@ -8041,10 +8207,15 @@ class TestExportCustomClass(TorchTestCase):
|
||||||
ep.graph_module.code
|
ep.graph_module.code
|
||||||
)
|
)
|
||||||
|
|
||||||
ep = ep.run_decompositions(
|
if IS_FBCODE:
|
||||||
decomp_table=get_decompositions([torch.ops.aten.elu.default]),
|
ep = ep.run_decompositions(_preserve_ops=(torch.ops.aten.elu.default,))
|
||||||
_preserve_ops=[torch.ops.aten.elu.default],
|
else:
|
||||||
)
|
decomp_table = core_aten_decompositions()
|
||||||
|
del decomp_table[torch.ops.aten.elu.default]
|
||||||
|
|
||||||
|
ep = ep.run_decompositions(
|
||||||
|
decomp_table=decomp_table,
|
||||||
|
)
|
||||||
FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run(
|
FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run(
|
||||||
ep.graph_module.code
|
ep.graph_module.code
|
||||||
)
|
)
|
||||||
|
|
@ -8066,12 +8237,17 @@ class TestExportCustomClass(TorchTestCase):
|
||||||
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
|
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
|
||||||
).run(ep.graph_module.code)
|
).run(ep.graph_module.code)
|
||||||
|
|
||||||
decomp_table = get_decompositions([torch.ops.aten.upsample_bilinear2d.vec])
|
if IS_FBCODE:
|
||||||
ep = ep.run_decompositions(
|
ep = ep.run_decompositions(
|
||||||
decomp_table=decomp_table,
|
_preserve_ops=(torch.ops.aten.upsample_bilinear2d.vec,)
|
||||||
_preserve_ops=[torch.ops.aten.upsample_bilinear2d.vec],
|
)
|
||||||
)
|
else:
|
||||||
assert torch.ops.aten.upsample_bilinear2d.vec in decomp_table
|
decomp_table = core_aten_decompositions()
|
||||||
|
del decomp_table[torch.ops.aten.upsample_bilinear2d.vec]
|
||||||
|
ep = ep.run_decompositions(
|
||||||
|
decomp_table=decomp_table,
|
||||||
|
)
|
||||||
|
|
||||||
FileCheck().check_count(
|
FileCheck().check_count(
|
||||||
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
|
"torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True
|
||||||
).run(ep.graph_module.code)
|
).run(ep.graph_module.code)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
# Owner(s): ["oncall: export"]
|
# Owner(s): ["oncall: export"]
|
||||||
import torch
|
import torch
|
||||||
|
from torch.testing._internal.common_utils import IS_FBCODE
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -15,9 +16,11 @@ test_classes = {}
|
||||||
|
|
||||||
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
|
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
|
||||||
ep = torch.export.export_for_training(*args, **kwargs)
|
ep = torch.export.export_for_training(*args, **kwargs)
|
||||||
return ep.run_decompositions(
|
if IS_FBCODE:
|
||||||
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
|
return ep.run_decompositions(
|
||||||
)
|
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
|
||||||
|
)
|
||||||
|
return ep.run_decompositions({})
|
||||||
|
|
||||||
|
|
||||||
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
|
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
|
||||||
|
|
@ -25,9 +28,12 @@ def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
|
||||||
ep = torch.export.export_for_training(*args, **kwargs)
|
ep = torch.export.export_for_training(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
ep = torch.export.export_for_training(*args, **kwargs, strict=False)
|
ep = torch.export.export_for_training(*args, **kwargs, strict=False)
|
||||||
return ep.run_decompositions(
|
|
||||||
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
|
if IS_FBCODE:
|
||||||
)
|
return ep.run_decompositions(
|
||||||
|
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
|
||||||
|
)
|
||||||
|
return ep.run_decompositions({})
|
||||||
|
|
||||||
|
|
||||||
def make_dynamic_cls(cls, strict):
|
def make_dynamic_cls(cls, strict):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from functools import partial
|
||||||
import torch._inductor.decomposition
|
import torch._inductor.decomposition
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch._decomp import core_aten_decompositions, decomposition_table
|
from torch._decomp import _is_cia_op, core_aten_decompositions, decomposition_table
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
from torch._ops import DispatchKey
|
from torch._ops import DispatchKey
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
|
|
@ -62,7 +62,7 @@ decomposition_names = {
|
||||||
core_decomposition_names = {
|
core_decomposition_names = {
|
||||||
overload_to_aten_name(k)
|
overload_to_aten_name(k)
|
||||||
for k in core_aten_decompositions()
|
for k in core_aten_decompositions()
|
||||||
if isinstance(k, torch._ops.OpOverload)
|
if isinstance(k, torch._ops.OpOverload) and not _is_cia_op(k)
|
||||||
}
|
}
|
||||||
_decomp_test_ops = [
|
_decomp_test_ops = [
|
||||||
op
|
op
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,26 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import inspect
|
import inspect
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Callable, Dict, List, Sequence, TypeVar, Union
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.library
|
import torch.library
|
||||||
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
|
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket
|
||||||
from torch._prims_common import CustomOutParamAnnotation
|
from torch._prims_common import CustomOutParamAnnotation
|
||||||
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,6 +31,8 @@ __all__ = [
|
||||||
"register_decomposition",
|
"register_decomposition",
|
||||||
"get_decompositions",
|
"get_decompositions",
|
||||||
"core_aten_decompositions",
|
"core_aten_decompositions",
|
||||||
|
"_decomp_table_to_post_autograd_aten",
|
||||||
|
"_special_op_to_preserve_cia",
|
||||||
]
|
]
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
@ -250,13 +263,184 @@ import torch._decomp.decompositions
|
||||||
import torch._refs
|
import torch._refs
|
||||||
|
|
||||||
|
|
||||||
|
# Our strategy for deciding if we can preserve a op is following:
|
||||||
|
# 1. The op should be known statically that it is functional
|
||||||
|
# 2. If it is maybe aliasing, we decompose because we must know if an op
|
||||||
|
# is mutating or aliasing.
|
||||||
|
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
|
||||||
|
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
|
||||||
|
def _check_valid_to_preserve(op_overload):
|
||||||
|
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
|
||||||
|
return False
|
||||||
|
if op_overload in FunctionalTensor.metadata_fns:
|
||||||
|
return False
|
||||||
|
|
||||||
|
alias_info = len(
|
||||||
|
[i for i in op_overload._schema.arguments if i.alias_info is not None]
|
||||||
|
)
|
||||||
|
|
||||||
|
is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable
|
||||||
|
|
||||||
|
if is_mutating_or_aliasing:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not torch._C._dispatch_has_kernel(op_overload.name()):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_cia_op(op: "OpOverload") -> bool:
|
||||||
|
return (
|
||||||
|
torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||||
|
op.name(), torch._C.DispatchKey.CompositeImplicitAutograd
|
||||||
|
)
|
||||||
|
or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _collect_all_valid_cia_ops() -> Set["OperatorBase"]:
|
||||||
|
"""
|
||||||
|
This is an util function that gets the all CIA functional ops.
|
||||||
|
|
||||||
|
The algorithm is in 2 steps:
|
||||||
|
1. We first query C++ dispatcher to get the list of CIA ops
|
||||||
|
and then we call getattr on torch.ops.aten to lazily populate
|
||||||
|
them.
|
||||||
|
|
||||||
|
2. Sometimes, handful of ops have CIA registered in python dispatcher
|
||||||
|
but not on the C++ side, these can't be caught at the first step.
|
||||||
|
So we walk again to get the final list.
|
||||||
|
|
||||||
|
Note that the output of this function should never be modified
|
||||||
|
"""
|
||||||
|
# First step to lazily populate torch.ops.aten
|
||||||
|
cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key(
|
||||||
|
"CompositeImplicitAutograd"
|
||||||
|
)
|
||||||
|
# Ignore quantized namespace ops
|
||||||
|
cia_ops = [name[6:] for name in cia_ops if name.startswith("aten::")]
|
||||||
|
# Materialize all CIA ops first
|
||||||
|
for op in cia_ops:
|
||||||
|
split_list = op.split(".")
|
||||||
|
# Sometime overload could be missing
|
||||||
|
assert len(split_list) == 1 or len(split_list) == 2
|
||||||
|
op_name = split_list[0]
|
||||||
|
op_overload_name = "default"
|
||||||
|
if len(split_list) == 2:
|
||||||
|
op_overload_name = split_list[1]
|
||||||
|
|
||||||
|
_ = getattr(getattr(torch.ops.aten, op_name), op_overload_name)
|
||||||
|
|
||||||
|
# Second step to finally compile the list of all valid ops
|
||||||
|
cia_ops = set()
|
||||||
|
for op in torch.ops.aten:
|
||||||
|
op_packet = getattr(torch.ops.aten, op)
|
||||||
|
for overload in op_packet.overloads():
|
||||||
|
op_overload = getattr(op_packet, overload)
|
||||||
|
if _check_valid_to_preserve(op_overload) and _is_cia_op(op_overload):
|
||||||
|
cia_ops.add(op_overload)
|
||||||
|
return cia_ops
|
||||||
|
|
||||||
|
|
||||||
|
def _get_decomp_for_cia(op):
|
||||||
|
# [NOTE] Seperating out func.decompose
|
||||||
|
# Ideally we should be able to just register func.decompose but
|
||||||
|
# we can't as this decomp is gonna be registered to the py_impl.
|
||||||
|
# As a result it will infinitely recurse. So we first check if the op
|
||||||
|
# has py_impl entry for CIA and if it is we use that first. If not,
|
||||||
|
# we register C++ query to py_impl.
|
||||||
|
dk = torch._C.DispatchKey.CompositeImplicitAutograd
|
||||||
|
if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey):
|
||||||
|
return op.py_kernels[dk]
|
||||||
|
|
||||||
|
def _special_op_to_decompose_cia(*args, **kwargs):
|
||||||
|
kernel = kwargs["kernel"]
|
||||||
|
del kwargs["kernel"]
|
||||||
|
# Can't call kernel.decompose due to infinite recursion as
|
||||||
|
# we register this kernel to py_impl directly
|
||||||
|
dk = torch._C.DispatchKey.CompositeImplicitAutograd
|
||||||
|
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||||
|
kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd
|
||||||
|
):
|
||||||
|
return kernel._op_dk(dk, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Expected {kernel} to have CompositeImplicitAutograd kernel"
|
||||||
|
)
|
||||||
|
|
||||||
|
return partial(_special_op_to_decompose_cia, kernel=op)
|
||||||
|
|
||||||
|
|
||||||
# See NOTE [Core ATen Ops]
|
# See NOTE [Core ATen Ops]
|
||||||
#
|
#
|
||||||
# list was copied from torch/_inductor/decomposition.py
|
# list was copied from torch/_inductor/decomposition.py
|
||||||
# excluding decompositions that results in prim ops
|
# excluding decompositions that results in prim ops
|
||||||
# Resulting opset of decomposition is core aten ops
|
# Resulting opset of decomposition is core aten ops
|
||||||
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
||||||
|
decomp_table = _core_aten_decompositions_post_autograd()
|
||||||
|
|
||||||
|
# If it is fbcode change, we return the old decomposition list
|
||||||
|
from torch._inductor import config
|
||||||
|
|
||||||
|
if config.is_fbcode():
|
||||||
|
return decomp_table
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
# We are deleting custom decomp in core_aten_decomp
|
||||||
|
# for CIA ops but it should be fine technically
|
||||||
|
# because this table is only meant to be used in export context
|
||||||
|
# in which we really carefully control the decomp behaviour
|
||||||
|
# In any case, C++ decomps should be preferred
|
||||||
|
cia_ops_that_should_be_removed = [
|
||||||
|
aten.all.dimname,
|
||||||
|
aten.index_add.dimname,
|
||||||
|
aten.index_copy.dimname,
|
||||||
|
aten.index_fill.Dimname_Scalar,
|
||||||
|
aten.index_fill.Dimname_Tensor,
|
||||||
|
aten.norm.names_ScalarOpt_dim_dtype,
|
||||||
|
aten.norm.names_ScalarOpt_dim,
|
||||||
|
aten.silu_backward.default,
|
||||||
|
aten.std.default,
|
||||||
|
aten.std.dim,
|
||||||
|
aten.std.names_dim,
|
||||||
|
aten.std.correction_names,
|
||||||
|
aten.std_mean.default,
|
||||||
|
aten.std_mean.dim,
|
||||||
|
aten.std_mean.names_dim,
|
||||||
|
aten.std_mean.correction_names,
|
||||||
|
aten.upsample_bilinear2d.vec,
|
||||||
|
aten.upsample_trilinear3d.vec,
|
||||||
|
]
|
||||||
|
|
||||||
|
for k in list(decomp_table.keys()):
|
||||||
|
if k in cia_ops_that_should_be_removed:
|
||||||
|
del decomp_table[k]
|
||||||
|
|
||||||
|
for op in _collect_all_valid_cia_ops():
|
||||||
|
decomp_table[op] = _get_decomp_for_cia(op)
|
||||||
|
return decomp_table
|
||||||
|
|
||||||
|
|
||||||
|
# This table is a stop-gap table which replicates
|
||||||
|
# the old behaviour of post-dispatch IR.
|
||||||
|
# This table contains all functional CIA ops mapping
|
||||||
|
# to their default decomp. In old export, this will
|
||||||
|
# be decomposed implicitly.
|
||||||
|
def _decomp_table_to_post_autograd_aten():
|
||||||
|
decomp_table = {}
|
||||||
|
for k in _collect_all_valid_cia_ops():
|
||||||
|
decomp_table[k] = _get_decomp_for_cia(k)
|
||||||
|
return decomp_table
|
||||||
|
|
||||||
|
|
||||||
|
def _core_aten_decompositions_post_autograd() -> (
|
||||||
|
Dict[torch._ops.OperatorBase, Callable]
|
||||||
|
):
|
||||||
|
aten = torch.ops.aten
|
||||||
|
# TODO Delete all mutating or CIA ops from this list
|
||||||
return get_decompositions(
|
return get_decompositions(
|
||||||
[
|
[
|
||||||
aten.addcdiv,
|
aten.addcdiv,
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,8 @@ def _export_forward_backward(
|
||||||
|
|
||||||
ep = _decompose_exported_program(
|
ep = _decompose_exported_program(
|
||||||
ep,
|
ep,
|
||||||
decomp_table=core_aten_decompositions(),
|
cia_to_decomp={},
|
||||||
_preserve_ops=(), # type: ignore[arg-type]
|
python_decomp_table=core_aten_decompositions(),
|
||||||
joint_loss_index=joint_loss_index,
|
joint_loss_index=joint_loss_index,
|
||||||
)
|
)
|
||||||
gm, new_graph_signature = _copy_graph_module_and_signature(ep)
|
gm, new_graph_signature = _copy_graph_module_and_signature(ep)
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,6 @@ from torch._export.utils import (
|
||||||
from torch._export.verifier import Verifier
|
from torch._export.verifier import Verifier
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
|
||||||
from torch.export._tree_utils import is_equivalent, reorder_kwargs
|
from torch.export._tree_utils import is_equivalent, reorder_kwargs
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from torch.fx.passes.infra.pass_base import PassResult
|
from torch.fx.passes.infra.pass_base import PassResult
|
||||||
|
|
@ -174,7 +173,7 @@ _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True):
|
def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True):
|
||||||
# This function overrides CompositeImplicitAutograd decomp for
|
# This function overrides CompositeImplicitAutograd decomp for
|
||||||
# functional composite ops that user specified. Ideally we want to not-decompose
|
# functional composite ops that user specified. Ideally we want to not-decompose
|
||||||
# ALL composite ops but today's C++ functinalization relies on
|
# ALL composite ops but today's C++ functinalization relies on
|
||||||
|
|
@ -192,48 +191,7 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
|
||||||
|
|
||||||
saved_tables = {}
|
saved_tables = {}
|
||||||
patched_ops = set()
|
patched_ops = set()
|
||||||
removed_decomps = {}
|
for op_overload, decomp_callable in cia_ops_to_callable.items():
|
||||||
for op_overload in ops_to_preserve:
|
|
||||||
# Our strategy for deciding if we can preserve CIA is following:
|
|
||||||
# 1. The op should be known statically that it is functional
|
|
||||||
# 2. If it is maybe aliasing, we decompose because we must know if an op
|
|
||||||
# is mutating or aliasing.
|
|
||||||
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
|
|
||||||
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
|
|
||||||
def assert_valid_to_preserve(op_overload):
|
|
||||||
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"We can't detect {op_overload} as a functional op statically, so we can't preserve it"
|
|
||||||
)
|
|
||||||
if op_overload in FunctionalTensor.metadata_fns:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{op_overload} is a metadata query function, "
|
|
||||||
"it will be preserved implicitly in our tracing system. "
|
|
||||||
"Please file an issue on github if you see otherwise"
|
|
||||||
)
|
|
||||||
|
|
||||||
alias_info = len(
|
|
||||||
[i for i in op_overload._schema.arguments if i.alias_info is not None]
|
|
||||||
)
|
|
||||||
|
|
||||||
is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable
|
|
||||||
|
|
||||||
if is_mutating_or_aliasing:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{op_overload} is a mutating/aliasing op, we can't preserve it as is"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not torch._C._dispatch_has_kernel(op_overload.name()):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{op_overload} is a TorchScript op, we can't preserve it as is"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
if safe:
|
|
||||||
# If we didn't error, it means we can go ahead
|
|
||||||
assert_valid_to_preserve(op_overload)
|
|
||||||
|
|
||||||
saved_tables[op_overload] = op_overload.py_kernels.copy()
|
saved_tables[op_overload] = op_overload.py_kernels.copy()
|
||||||
patched_ops.add(op_overload)
|
patched_ops.add(op_overload)
|
||||||
|
|
||||||
|
|
@ -247,11 +205,9 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
|
||||||
del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]
|
del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd]
|
||||||
|
|
||||||
if safe:
|
if safe:
|
||||||
|
op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(
|
||||||
def _(*args, **kwargs):
|
decomp_callable
|
||||||
return NotImplemented
|
)
|
||||||
|
|
||||||
op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_)
|
|
||||||
|
|
||||||
# For fake tensor prop, we do want to register meta kernel directly
|
# For fake tensor prop, we do want to register meta kernel directly
|
||||||
if torch._C.DispatchKey.Meta not in op_overload.py_kernels:
|
if torch._C.DispatchKey.Meta not in op_overload.py_kernels:
|
||||||
|
|
@ -259,10 +215,6 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
|
||||||
functools.partial(_register_cia_to_meta, kernel=op_overload)
|
functools.partial(_register_cia_to_meta, kernel=op_overload)
|
||||||
)
|
)
|
||||||
|
|
||||||
if op_overload in decomp_table:
|
|
||||||
removed_decomps[op_overload] = decomp_table[op_overload]
|
|
||||||
del decomp_table[op_overload]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -271,8 +223,10 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True
|
||||||
op.py_kernels.update(saved_tables[op])
|
op.py_kernels.update(saved_tables[op])
|
||||||
op._dispatch_cache.clear()
|
op._dispatch_cache.clear()
|
||||||
|
|
||||||
for op, decomp in removed_decomps.items():
|
|
||||||
decomp_table[op] = decomp
|
def _special_op_to_preserve_cia(*args, **kwargs):
|
||||||
|
"This is an special marker that tells our infra that we shouldn't decompose this op"
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -281,18 +235,65 @@ def _override_decomp_aten_to_variants():
|
||||||
# and their CompositeImplicitAutograd kernels will not become NotImplemented.
|
# and their CompositeImplicitAutograd kernels will not become NotImplemented.
|
||||||
# We will later replace them with aten._to_copy when functionalizing.
|
# We will later replace them with aten._to_copy when functionalizing.
|
||||||
with _override_composite_implicit_decomp(
|
with _override_composite_implicit_decomp(
|
||||||
(torch.ops.aten.to.dtype_layout, torch.ops.aten.to.dtype),
|
{
|
||||||
{},
|
torch.ops.aten.to.dtype_layout: _special_op_to_preserve_cia,
|
||||||
|
torch.ops.aten.to.dtype: _special_op_to_preserve_cia,
|
||||||
|
},
|
||||||
safe=False,
|
safe=False,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _split_decomp_table_to_cia_and_python_decomp(
|
||||||
|
decomp_table: Dict[torch._ops.OperatorBase, Callable]
|
||||||
|
) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]:
|
||||||
|
from torch._decomp import _collect_all_valid_cia_ops, _get_decomp_for_cia
|
||||||
|
|
||||||
|
all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
|
||||||
|
cia_ops_to_callable = {}
|
||||||
|
|
||||||
|
for op in list(decomp_table.keys()):
|
||||||
|
# TODO we are silently allowing non-safe(non-functional) ops through a crack
|
||||||
|
# due to core aten decomp table having non-functional entries. Once we have
|
||||||
|
# a tigher check around core aten decomp, we should warn users about them.
|
||||||
|
# Tracking issue: (https://github.com/pytorch/pytorch/issues/135759)
|
||||||
|
|
||||||
|
# if it is a valid CIA op we can mess with in export, we check if it is:
|
||||||
|
# 1. Has been marked as to be decomposed. Example:
|
||||||
|
# decomp_table = decomp_table_to_core_aten()
|
||||||
|
# del decomp_table[aten.linear]
|
||||||
|
# In this case, user says decompose everything except for aten.linear
|
||||||
|
# 2. Has been marked with custom decomp behavour. Example:
|
||||||
|
# decomp_table = {aten.linear: some_op}
|
||||||
|
# For (1), we want to remove all the CIA ops that weren't handled by user as
|
||||||
|
# it suggests they are safe to decompose, so we should remove from preservable_list.
|
||||||
|
# for (2), we just plumb the custom decomp to AOTDIspatcher.
|
||||||
|
# In both cases, we want to remove this CIA op from the decomp_table as it is special
|
||||||
|
# handled.
|
||||||
|
if op in all_preservable_cia_ops:
|
||||||
|
# TODO this is annpying case where aten.item has
|
||||||
|
# prim decomposition which later calls into aten.item
|
||||||
|
# and recurses infinitely. (https://github.com/pytorch/pytorch/issues/136050)
|
||||||
|
if op == torch.ops.aten.item.default:
|
||||||
|
cia_ops_to_callable[op] = _get_decomp_for_cia(op)
|
||||||
|
else:
|
||||||
|
cia_ops_to_callable[op] = decomp_table[op]
|
||||||
|
all_preservable_cia_ops.remove(op)
|
||||||
|
del decomp_table[op]
|
||||||
|
|
||||||
|
# If we reached here, it means user intentionally deleted these CIA ops from
|
||||||
|
# decomp table.
|
||||||
|
for k in all_preservable_cia_ops:
|
||||||
|
cia_ops_to_callable[k] = _special_op_to_preserve_cia
|
||||||
|
|
||||||
|
return cia_ops_to_callable, decomp_table
|
||||||
|
|
||||||
|
|
||||||
def _decompose_and_get_gm_with_new_signature_constants(
|
def _decompose_and_get_gm_with_new_signature_constants(
|
||||||
ep,
|
ep,
|
||||||
*,
|
*,
|
||||||
decomp_table: Dict[torch._ops.OperatorBase, Callable],
|
cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
|
||||||
_preserve_ops: Tuple[torch._ops.OpOverload],
|
python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
|
||||||
joint_loss_index: Optional[int],
|
joint_loss_index: Optional[int],
|
||||||
):
|
):
|
||||||
from torch._functorch.aot_autograd import aot_export_module
|
from torch._functorch.aot_autograd import aot_export_module
|
||||||
|
|
@ -356,8 +357,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||||
with _ignore_backend_decomps(), (
|
with _ignore_backend_decomps(), (
|
||||||
fake_mode
|
fake_mode
|
||||||
), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp(
|
), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp(
|
||||||
_preserve_ops,
|
cia_to_decomp,
|
||||||
decomp_table,
|
|
||||||
):
|
):
|
||||||
aten_export_artifact = _export_to_aten_ir(
|
aten_export_artifact = _export_to_aten_ir(
|
||||||
mod,
|
mod,
|
||||||
|
|
@ -369,7 +369,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||||
{},
|
{},
|
||||||
fake_params_buffers,
|
fake_params_buffers,
|
||||||
constant_attrs,
|
constant_attrs,
|
||||||
decomp_table=decomp_table,
|
decomp_table=python_decomp_table,
|
||||||
_check_autograd_state=False,
|
_check_autograd_state=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -404,13 +404,12 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||||
fake_mode = detect_fake_mode(fake_args)
|
fake_mode = detect_fake_mode(fake_args)
|
||||||
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode
|
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode
|
||||||
with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp(
|
with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp(
|
||||||
_preserve_ops,
|
cia_to_decomp
|
||||||
decomp_table,
|
|
||||||
):
|
):
|
||||||
gm, graph_signature = aot_export_module(
|
gm, graph_signature = aot_export_module(
|
||||||
ep.graph_module,
|
ep.graph_module,
|
||||||
fake_args,
|
fake_args,
|
||||||
decompositions=decomp_table,
|
decompositions=python_decomp_table,
|
||||||
trace_joint=True if joint_loss_index is not None else False,
|
trace_joint=True if joint_loss_index is not None else False,
|
||||||
output_loss_index=joint_loss_index
|
output_loss_index=joint_loss_index
|
||||||
if joint_loss_index is not None
|
if joint_loss_index is not None
|
||||||
|
|
@ -610,14 +609,14 @@ def _common_getitem_elimination_pass(
|
||||||
def _decompose_exported_program(
|
def _decompose_exported_program(
|
||||||
ep,
|
ep,
|
||||||
*,
|
*,
|
||||||
decomp_table: Dict[torch._ops.OperatorBase, Callable],
|
cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
|
||||||
_preserve_ops: Tuple[torch._ops.OpOverload],
|
python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
|
||||||
joint_loss_index: Optional[int],
|
joint_loss_index: Optional[int],
|
||||||
):
|
):
|
||||||
gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
|
gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
|
||||||
ep,
|
ep,
|
||||||
decomp_table=decomp_table,
|
cia_to_decomp=cia_to_decomp,
|
||||||
_preserve_ops=_preserve_ops,
|
python_decomp_table=python_decomp_table,
|
||||||
joint_loss_index=joint_loss_index,
|
joint_loss_index=joint_loss_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -994,16 +993,83 @@ class ExportedProgram:
|
||||||
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
|
`Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
|
||||||
|
|
||||||
For now, we do not decompose joint graphs.
|
For now, we do not decompose joint graphs.
|
||||||
"""
|
|
||||||
from torch._decomp import core_aten_decompositions
|
|
||||||
|
|
||||||
if decomp_table is None:
|
Args:
|
||||||
|
decomp_table:
|
||||||
|
An optional argument that specifies decomp behaviour for Aten ops
|
||||||
|
(1) If None, we decompose to core aten decompositions
|
||||||
|
(2) If empty, we don't decompose any operator
|
||||||
|
|
||||||
|
|
||||||
|
Some examples:
|
||||||
|
|
||||||
|
If you don't want to decompose anything
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
ep = torch.export.export(model, ...)
|
||||||
|
ep = ep.run_decompositions(decomp_table={})
|
||||||
|
|
||||||
|
If you want to get a core aten operator set except for certain operator, you can do following:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
ep = torch.export.export(model, ...)
|
||||||
|
from torch._decomp import core_aten_decompositions
|
||||||
decomp_table = core_aten_decompositions()
|
decomp_table = core_aten_decompositions()
|
||||||
|
decomp_table[your_op] = your_custom_decomp
|
||||||
|
ep = ep.run_decompositions(decomp_table=decomp_table)
|
||||||
|
"""
|
||||||
|
from torch._decomp import (
|
||||||
|
_decomp_table_to_post_autograd_aten,
|
||||||
|
core_aten_decompositions,
|
||||||
|
)
|
||||||
|
from torch._inductor import config
|
||||||
|
|
||||||
|
# FIXME delete this option after PTC, Executorch syncing is
|
||||||
|
# bit annoying so can't get rid of it easily
|
||||||
|
if _preserve_ops != ():
|
||||||
|
warnings.warn(
|
||||||
|
"This API is deprecated and soon will be removed. "
|
||||||
|
"Please look at the docstring to see how to preserve "
|
||||||
|
"an operator."
|
||||||
|
)
|
||||||
|
|
||||||
|
_decomp_table = (
|
||||||
|
core_aten_decompositions() if decomp_table is None else dict(decomp_table)
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.is_fbcode():
|
||||||
|
# This means the decomp_table would only be containing post-autograd ops
|
||||||
|
# We should manually add CIA decomps
|
||||||
|
for k, v in _decomp_table_to_post_autograd_aten().items():
|
||||||
|
_decomp_table[k] = v
|
||||||
|
|
||||||
|
for op in _preserve_ops:
|
||||||
|
if op in _decomp_table:
|
||||||
|
del _decomp_table[op]
|
||||||
|
|
||||||
|
# Note [Seperating decomp_table into CIA decomps and non-CIA decomps]
|
||||||
|
# At this point, we have a decomp_table that contains decomp behaviour for
|
||||||
|
# both CIA and post-autograd ops.
|
||||||
|
# We need to separate the op into two categories:
|
||||||
|
# 1. CIA op: These are the ops that we want to override
|
||||||
|
# CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp
|
||||||
|
# context manager to plumb it through AOTDispatcher
|
||||||
|
# 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just
|
||||||
|
# checking if they are statically functional is enough.
|
||||||
|
# For joint IR case tho, we need to use the old path because we can't register
|
||||||
|
# custom decomps this way because we can't use context manager as it installs
|
||||||
|
# autograd_error node.
|
||||||
|
(
|
||||||
|
cia_to_decomp,
|
||||||
|
python_decomp_table,
|
||||||
|
) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table)
|
||||||
|
|
||||||
return _decompose_exported_program(
|
return _decompose_exported_program(
|
||||||
self,
|
self,
|
||||||
decomp_table=decomp_table,
|
cia_to_decomp=cia_to_decomp,
|
||||||
_preserve_ops=_preserve_ops, # type: ignore[arg-type]
|
python_decomp_table=python_decomp_table,
|
||||||
joint_loss_index=None,
|
joint_loss_index=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
"""Build decomp table from PyTorch."""
|
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -83,7 +81,14 @@ def create_onnx_friendly_decomposition_table(
|
||||||
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
|
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
|
||||||
decomposition functions.
|
decomposition functions.
|
||||||
"""
|
"""
|
||||||
decomposition_table: dict[torch._ops.OperatorBase, Callable] = {}
|
# This table contains all CIA decomps, so we should filter out ONNX supported CIAs from it
|
||||||
|
decomposition_table: dict[torch._ops.OperatorBase, Callable] = (
|
||||||
|
torch._decomp._decomp_table_to_post_autograd_aten()
|
||||||
|
)
|
||||||
|
can_preserve = get_preserve_ops().intersection(onnx_registered_ops)
|
||||||
|
for op in list(decomposition_table.keys()):
|
||||||
|
if op in can_preserve:
|
||||||
|
del decomposition_table[op]
|
||||||
|
|
||||||
# NOTE: If we import torch._decomp, we will get RuntimeError: Only a single
|
# NOTE: If we import torch._decomp, we will get RuntimeError: Only a single
|
||||||
# TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your
|
# TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your
|
||||||
|
|
@ -95,6 +100,9 @@ def create_onnx_friendly_decomposition_table(
|
||||||
# not exportable anyways.
|
# not exportable anyways.
|
||||||
if op_overload in onnx_registered_ops:
|
if op_overload in onnx_registered_ops:
|
||||||
continue
|
continue
|
||||||
|
# If it is HOP, we filter those out as well.
|
||||||
|
if not hasattr(op_overload, "_schema"):
|
||||||
|
continue
|
||||||
decomposition_table[op_overload] = decomp_fn
|
decomposition_table[op_overload] = decomp_fn
|
||||||
|
|
||||||
return decomposition_table
|
return decomposition_table
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,7 @@ def decompose_with_registry(
|
||||||
"""
|
"""
|
||||||
onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry))
|
onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry))
|
||||||
decomp_table = _decomp.create_onnx_friendly_decomposition_table(onnx_registered_ops)
|
decomp_table = _decomp.create_onnx_friendly_decomposition_table(onnx_registered_ops)
|
||||||
# Try to preserve some known CompositeImplicitAutograd ops
|
return exported_program.run_decompositions(decomp_table)
|
||||||
to_preserve = _decomp.get_preserve_ops()
|
|
||||||
# We can only preserve implemented ops
|
|
||||||
can_preserve = tuple(to_preserve.intersection(onnx_registered_ops))
|
|
||||||
return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve)
|
|
||||||
|
|
||||||
|
|
||||||
def insert_type_promotion_nodes(
|
def insert_type_promotion_nodes(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user