mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[core IR] Remove trunc decomp and add trunc to core (#109902)
Following up from [this comment](https://github.com/pytorch/pytorch/pull/109319#discussion_r1330803226). Remove the decomposition for `trunc`, and add it as a core operator. Going forward, provide similar treatment for operators that map cleanly to hardware instructions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109902 Approved by: https://github.com/peterbell10
This commit is contained in:
parent
fe5e63f5db
commit
7de669f2f9
|
|
@ -6043,7 +6043,7 @@
|
|||
dispatch:
|
||||
SparseCPU, SparseCUDA: trunc_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr
|
||||
tags: pointwise
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: trunc_(Tensor(a!) self) -> Tensor(a!)
|
||||
structured_delegate: trunc.out
|
||||
|
|
|
|||
|
|
@ -496,6 +496,8 @@ aten::tril_indices
|
|||
aten::tril_indices.out
|
||||
aten::triu_indices
|
||||
aten::triu_indices.out
|
||||
aten::trunc
|
||||
aten::trunc.out
|
||||
aten::trunc_
|
||||
aten::unbind.int
|
||||
aten::unfold
|
||||
|
|
|
|||
|
|
@ -1873,7 +1873,6 @@ class TestRefsOpsInfo(TestCase):
|
|||
'_refs.tensor_split',
|
||||
'_refs.to',
|
||||
'_refs.true_divide',
|
||||
'_refs.trunc',
|
||||
'_refs.trunc_divide',
|
||||
'_refs.vsplit',
|
||||
'_refs.vstack',
|
||||
|
|
|
|||
|
|
@ -375,7 +375,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
|||
aten.tril_,
|
||||
aten.triu,
|
||||
aten.triu_,
|
||||
aten.trunc,
|
||||
aten.unfold_backward,
|
||||
aten.unfold_copy,
|
||||
aten._unsafe_index,
|
||||
|
|
|
|||
|
|
@ -4131,11 +4131,6 @@ def scaled_dot_product_flash_attention(
|
|||
)
|
||||
|
||||
|
||||
@register_decomposition([aten.trunc])
|
||||
def trunc(self: Tensor, **kwargs) -> Tensor:
|
||||
return torch.where(self > 0, torch.floor(self), torch.ceil(self))
|
||||
|
||||
|
||||
def register_inplace(aten_op, outplace_op):
|
||||
@register_decomposition(aten_op)
|
||||
def inplace_op(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -69,7 +69,6 @@ decomps_to_exclude = [
|
|||
aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py
|
||||
aten.clamp_max,
|
||||
aten.clamp_min,
|
||||
aten.trunc,
|
||||
]
|
||||
|
||||
remove_decompositions(decompositions, decomps_to_exclude)
|
||||
|
|
|
|||
|
|
@ -945,14 +945,9 @@ def tanh(a):
|
|||
return prims.tanh(a)
|
||||
|
||||
|
||||
@out_wrapper()
|
||||
@elementwise_unary_scalar_wrapper
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
def trunc(a: TensorLikeType) -> TensorLikeType:
|
||||
return handle_noncontiguous_outputs([a], prims.trunc(a))
|
||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
|
||||
def trunc(a):
|
||||
return prims.trunc(a)
|
||||
|
||||
|
||||
# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user