mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[primTorch] support one tensor and two scalars in _prims.where (#80146)
Fixes an issue of supporting two scalar arguments for `where` and other functions with similar set of arguments: ``` refs.where(a, 1, 0) ``` I had to skip `test_python_ref_executor` because the test causes a `Segmentation fault` when running with two scalars. The issue https://github.com/csarofeen/pytorch/issues/1770 has been fixed https://github.com/csarofeen/pytorch/pull/1774, so we can lift the skip when its merged to the upstream. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80146 Approved by: https://github.com/ngimel
This commit is contained in:
parent
66f66faccf
commit
74fb6ee4c5
|
|
@ -419,6 +419,9 @@ def _elementwise_meta(
|
|||
elif isinstance(arg, Number):
|
||||
scalar_type = type(arg)
|
||||
|
||||
if dtype is None and scalar_type is not None:
|
||||
dtype = utils.type_to_dtype(scalar_type)
|
||||
|
||||
# Acquires the device (if it exists) or number
|
||||
device = None
|
||||
number = None
|
||||
|
|
|
|||
|
|
@ -639,19 +639,23 @@ def dtype_to_type(dtype: torch.dtype) -> type:
|
|||
raise ValueError("Invalid dtype!")
|
||||
|
||||
|
||||
_type_to_dtype_map = {
|
||||
bool: torch.bool,
|
||||
int: torch.int64,
|
||||
float: torch.float64,
|
||||
complex: torch.complex128,
|
||||
}
|
||||
|
||||
|
||||
def type_to_dtype(typ: type) -> torch.dtype:
|
||||
"""
|
||||
Computes the corresponding dtype for a Number type.
|
||||
"""
|
||||
return _type_to_dtype_map[typ]
|
||||
|
||||
assert isinstance(typ, type)
|
||||
|
||||
if typ is bool:
|
||||
return torch.bool
|
||||
if typ is int:
|
||||
return torch.long
|
||||
if typ is float:
|
||||
return torch.get_default_dtype()
|
||||
if typ is complex:
|
||||
return corresponding_complex_dtype(torch.get_default_dtype())
|
||||
|
||||
raise ValueError("Invalid type!")
|
||||
|
||||
|
||||
_ordered_types = (bool, int, float, complex)
|
||||
|
|
|
|||
|
|
@ -8902,6 +8902,13 @@ def reference_inputs_where(op, device, dtype, requires_grad, **kwargs):
|
|||
|
||||
yield SampleInput(a, args=(c, b))
|
||||
|
||||
# two python scalars
|
||||
c = make_cond((10, 3), noncontiguous=True)
|
||||
a = make_arg((1,)).item()
|
||||
b = make_arg((1,)).item()
|
||||
|
||||
yield SampleInput(a, args=(c, b))
|
||||
|
||||
# NaN propagation
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
if dtype.is_floating_point:
|
||||
|
|
@ -21211,6 +21218,7 @@ python_ref_db = [
|
|||
"_refs.where",
|
||||
torch_opinfo_name="where",
|
||||
op=lambda self, condition, other: refs.where(condition, self, other),
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user