From 7de669f2f9918fabb502ede109b51ca9354ed4dd Mon Sep 17 00:00:00 2001 From: SS-JIA Date: Fri, 22 Sep 2023 21:15:49 -0700 Subject: [PATCH] [core IR] Remove trunc decomp and add trunc to core (#109902) Following up from [this comment](https://github.com/pytorch/pytorch/pull/109319#discussion_r1330803226). Remove the decomposition for `trunc`, and add it as a core operator. Going forward, provide similar treatment for operators that map cleanly to hardware instructions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109902 Approved by: https://github.com/peterbell10 --- aten/src/ATen/native/native_functions.yaml | 2 +- .../HasDecompTest.test_aten_core_operators.expect | 2 ++ test/test_ops.py | 1 - torch/_decomp/__init__.py | 1 - torch/_decomp/decompositions.py | 5 ----- torch/_inductor/decomposition.py | 1 - torch/_refs/__init__.py | 11 +++-------- 7 files changed, 6 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 314305f14df..9937859a556 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6043,7 +6043,7 @@ dispatch: SparseCPU, SparseCUDA: trunc_sparse SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr - tags: pointwise + tags: [core, pointwise] - func: trunc_(Tensor(a!) self) -> Tensor(a!) structured_delegate: trunc.out diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 3af2496c1e3..237f0df4a86 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -496,6 +496,8 @@ aten::tril_indices aten::tril_indices.out aten::triu_indices aten::triu_indices.out +aten::trunc +aten::trunc.out aten::trunc_ aten::unbind.int aten::unfold diff --git a/test/test_ops.py b/test/test_ops.py index abcd6a35095..001c6d4d7b5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1873,7 +1873,6 @@ class TestRefsOpsInfo(TestCase): '_refs.tensor_split', '_refs.to', '_refs.true_divide', - '_refs.trunc', '_refs.trunc_divide', '_refs.vsplit', '_refs.vstack', diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index c45a0e45216..dafbcf963e1 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -375,7 +375,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.tril_, aten.triu, aten.triu_, - aten.trunc, aten.unfold_backward, aten.unfold_copy, aten._unsafe_index, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 59c057237cb..b51b4488470 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4131,11 +4131,6 @@ def scaled_dot_product_flash_attention( ) -@register_decomposition([aten.trunc]) -def trunc(self: Tensor, **kwargs) -> Tensor: - return torch.where(self > 0, torch.floor(self), torch.ceil(self)) - - def register_inplace(aten_op, outplace_op): @register_decomposition(aten_op) def inplace_op(*args, **kwargs): diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index da407f65ebb..a3ae11e2c3d 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -69,7 +69,6 @@ decomps_to_exclude = [ aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py aten.clamp_max, aten.clamp_min, - aten.trunc, ] remove_decompositions(decompositions, decomps_to_exclude) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 13902fbca44..df6a1b8a386 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -945,14 +945,9 @@ def tanh(a): return prims.tanh(a) -@out_wrapper() -@elementwise_unary_scalar_wrapper -@elementwise_type_promotion_wrapper( - type_promoting_args=("a",), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) -def trunc(a: TensorLikeType) -> TensorLikeType: - return handle_noncontiguous_outputs([a], prims.trunc(a)) +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def trunc(a): + return prims.trunc(a) # TODO: register this as a real ref/decomposition once TorchInductor supports complex!