From fe94ece375f46d09a223d9796cac02b0796f13f1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 11 Feb 2025 03:17:59 +0000 Subject: [PATCH] Revert "Exclude upsample_bilinear2d.vec from default core ATen decomposition table (#141791)" This reverts commit 3d604b17d91b928c850ded83b2ec25ea066bb3f6. Reverted https://github.com/pytorch/pytorch/pull/141791 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/141791#issuecomment-2649717140)) --- test/export/test_export.py | 56 +++++------------------------------- torch/export/decomp_utils.py | 14 +-------- 2 files changed, 8 insertions(+), 62 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 8e3b911bbaf..d724d13b0c6 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -11881,48 +11881,6 @@ class GraphModule(torch.nn.Module): ] self.assertEqual(len(shift_op), 1) - def test_default_decomposition_core_cia_ops(self): - """ - Verify that core ATen ops with Composite Implicit Autograd dispatch are not - decomposed by default. - """ - - # TODO Add avg_pool1d, and adaptive_avg_pool1d when ready. - # See issue #116684. - core_cia_ops = { - "torch.ops.aten.upsample_bilinear2d.vec": ( - torch.ops.aten.upsample_bilinear2d.vec, - { - "align_corners": False, - "scale_factors": [2, 2], - "output_size": None, - }, - ), - "torch.ops.aten.upsample_nearest2d.vec": ( - torch.ops.aten.upsample_nearest2d.vec, - { - "scale_factors": [2, 2], - "output_size": None, - }, - ), - } - - for op_name, (op, kwargs) in core_cia_ops.items(): - - class M(torch.nn.Module): - def forward(self, x): - return op(x, **kwargs) - - ep = export(M(), (torch.randn(2, 3, 4, 5),)) - FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code) - - decomp_table = default_decompositions() - - ep = ep.run_decompositions( - decomp_table=decomp_table, - ) - FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code) - @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): @@ -12528,30 +12486,30 @@ class TestExportCustomClass(TorchTestCase): torch.distributed.destroy_process_group() def test_preserve_cia_op(self): - class StaticResizeTrilinear2dModule(torch.nn.Module): + class StaticResizeBilinear2dModule(torch.nn.Module): def forward(self, x): a = torch.nn.functional.interpolate( x, - size=(x.shape[2] * 2, x.shape[3] * 3, x.shape[4] * 4), - mode="trilinear", + size=(x.shape[2] * 2, x.shape[3] * 3), + mode="bilinear", align_corners=False, antialias=False, ) return a - ep = export(StaticResizeTrilinear2dModule(), (torch.randn(2, 3, 4, 5, 6),)) + ep = export(StaticResizeBilinear2dModule(), (torch.randn(2, 3, 4, 5),)) FileCheck().check_count( - "torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True + "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True ).run(ep.graph_module.code) decomp_table = default_decompositions() - del decomp_table[torch.ops.aten.upsample_trilinear3d.vec] + del decomp_table[torch.ops.aten.upsample_bilinear2d.vec] ep = ep.run_decompositions( decomp_table=decomp_table, ) FileCheck().check_count( - "torch.ops.aten.upsample_trilinear3d.vec", 1, exactly=True + "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True ).run(ep.graph_module.code) diff --git a/torch/export/decomp_utils.py b/torch/export/decomp_utils.py index 2f4c86617cb..ac6d107149d 100644 --- a/torch/export/decomp_utils.py +++ b/torch/export/decomp_utils.py @@ -13,17 +13,6 @@ from torch._export.utils import ( __all__ = ["CustomDecompTable"] -""" -Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition -by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all -backends are ready, this list allows opt-in one at a time. -""" -PRESERVED_ATEN_CIA_OPS = { - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.upsample_nearest2d.vec, -} - - class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]): """ This is a custom dictionary that is specifically used for handling decomp_table in export. @@ -49,8 +38,7 @@ class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]): self.decomp_table = _core_aten_decompositions_post_autograd() for op in _collect_all_valid_cia_ops_for_aten_namespace(): - if op not in PRESERVED_ATEN_CIA_OPS: - self.decomp_table[op] = _get_decomp_for_cia(op) + self.decomp_table[op] = _get_decomp_for_cia(op) # This is to track the *pending* deleted custom ops that haven't been materialized yet self.deleted_custom_ops = set()