[core IR] Remove trunc decomp and add trunc to core (#109902)

Following up from [this comment](https://github.com/pytorch/pytorch/pull/109319#discussion_r1330803226). Remove the decomposition for `trunc`, and add it as a core operator.

Going forward, provide similar treatment for operators that map cleanly to hardware instructions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109902
Approved by: https://github.com/peterbell10
This commit is contained in:
SS-JIA 2023-09-22 21:15:49 -07:00 committed by PyTorch MergeBot
parent fe5e63f5db
commit 7de669f2f9
7 changed files with 6 additions and 17 deletions

View File

@ -6043,7 +6043,7 @@
dispatch:
SparseCPU, SparseCUDA: trunc_sparse
SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr
tags: pointwise
tags: [core, pointwise]
- func: trunc_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: trunc.out

View File

@ -496,6 +496,8 @@ aten::tril_indices
aten::tril_indices.out
aten::triu_indices
aten::triu_indices.out
aten::trunc
aten::trunc.out
aten::trunc_
aten::unbind.int
aten::unfold

View File

@ -1873,7 +1873,6 @@ class TestRefsOpsInfo(TestCase):
'_refs.tensor_split',
'_refs.to',
'_refs.true_divide',
'_refs.trunc',
'_refs.trunc_divide',
'_refs.vsplit',
'_refs.vstack',

View File

@ -375,7 +375,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.tril_,
aten.triu,
aten.triu_,
aten.trunc,
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,

View File

@ -4131,11 +4131,6 @@ def scaled_dot_product_flash_attention(
)
@register_decomposition([aten.trunc])
def trunc(self: Tensor, **kwargs) -> Tensor:
return torch.where(self > 0, torch.floor(self), torch.ceil(self))
def register_inplace(aten_op, outplace_op):
@register_decomposition(aten_op)
def inplace_op(*args, **kwargs):

View File

@ -69,7 +69,6 @@ decomps_to_exclude = [
aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py
aten.clamp_max,
aten.clamp_min,
aten.trunc,
]
remove_decompositions(decompositions, decomps_to_exclude)

View File

@ -945,14 +945,9 @@ def tanh(a):
return prims.tanh(a)
@out_wrapper()
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def trunc(a: TensorLikeType) -> TensorLikeType:
return handle_noncontiguous_outputs([a], prims.trunc(a))
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def trunc(a):
return prims.trunc(a)
# TODO: register this as a real ref/decomposition once TorchInductor supports complex!