mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[inductor] Fix angle decomposition return type (#115700)
The current decomposition always returns float32 when the input isn't complex. Instead, we should do proper type promotion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115700 Approved by: https://github.com/lezcano ghstack dependencies: #115677, #115699
This commit is contained in:
parent
9cdc80d581
commit
fb80f05ee2
|
|
@ -20,7 +20,11 @@ from torch._decomp.decompositions import (
|
|||
)
|
||||
from torch._decomp.decompositions_for_rng import extra_random_decomps
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._prims_common import type_to_dtype
|
||||
from torch._prims_common import (
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
type_to_dtype,
|
||||
)
|
||||
|
||||
from . import config, inductor_prims
|
||||
|
||||
|
|
@ -260,14 +264,18 @@ def angle(x):
|
|||
return torch.where(
|
||||
torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
|
||||
)
|
||||
else:
|
||||
# when x is real number
|
||||
# if x >= 0, return 0
|
||||
# if x < 0, return pi
|
||||
# if x is nan, return nan
|
||||
ret = torch.where(x < 0, math.pi, 0.0)
|
||||
nan = torch.where(torch.isnan(x), float("nan"), 0.0)
|
||||
return ret + nan
|
||||
|
||||
# when x is real number
|
||||
# if x >= 0, return 0
|
||||
# if x < 0, return pi
|
||||
# if x is nan, return nan
|
||||
_, dtype = elementwise_dtypes(
|
||||
x,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
)
|
||||
pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
|
||||
ret = torch.where(x < 0, pi, 0.0)
|
||||
return torch.where(torch.isnan(x), float("nan"), ret)
|
||||
|
||||
|
||||
@register_decomposition([aten.add])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user