[inductor] Fix angle decomposition return type (#115700)

The current decomposition always returns float32 when the input isn't complex.
Instead, we should do proper type promotion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115700
Approved by: https://github.com/lezcano
ghstack dependencies: #115677, #115699
This commit is contained in:
Peter Bell 2023-12-12 22:48:58 +00:00 committed by PyTorch MergeBot
parent 9cdc80d581
commit fb80f05ee2

View File

@ -20,7 +20,11 @@ from torch._decomp.decompositions import (
)
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch._higher_order_ops.out_dtype import out_dtype
from torch._prims_common import type_to_dtype
from torch._prims_common import (
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
type_to_dtype,
)
from . import config, inductor_prims
@ -260,14 +264,18 @@ def angle(x):
return torch.where(
torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
)
else:
# when x is real number
# if x >= 0, return 0
# if x < 0, return pi
# if x is nan, return nan
ret = torch.where(x < 0, math.pi, 0.0)
nan = torch.where(torch.isnan(x), float("nan"), 0.0)
return ret + nan
# when x is real number
# if x >= 0, return 0
# if x < 0, return pi
# if x is nan, return nan
_, dtype = elementwise_dtypes(
x,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
ret = torch.where(x < 0, pi, 0.0)
return torch.where(torch.isnan(x), float("nan"), ret)
@register_decomposition([aten.add])