[primTorch] support one tensor and two scalars in _prims.where (#80146)

Fixes an issue of supporting two scalar arguments for `where` and other functions with similar set of arguments:

```
refs.where(a, 1, 0)
```

I had to skip `test_python_ref_executor` because the test causes a `Segmentation fault` when running with two scalars.
The issue https://github.com/csarofeen/pytorch/issues/1770 has been fixed https://github.com/csarofeen/pytorch/pull/1774, so we can lift the skip when its merged to the upstream.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80146
Approved by: https://github.com/ngimel
This commit is contained in:
Aidyn-A 2022-06-29 19:58:29 +00:00 committed by PyTorch MergeBot
parent 66f66faccf
commit 74fb6ee4c5
3 changed files with 24 additions and 9 deletions

View File

@ -419,6 +419,9 @@ def _elementwise_meta(
elif isinstance(arg, Number):
scalar_type = type(arg)
if dtype is None and scalar_type is not None:
dtype = utils.type_to_dtype(scalar_type)
# Acquires the device (if it exists) or number
device = None
number = None

View File

@ -639,19 +639,23 @@ def dtype_to_type(dtype: torch.dtype) -> type:
raise ValueError("Invalid dtype!")
_type_to_dtype_map = {
bool: torch.bool,
int: torch.int64,
float: torch.float64,
complex: torch.complex128,
}
def type_to_dtype(typ: type) -> torch.dtype:
"""
Computes the corresponding dtype for a Number type.
"""
return _type_to_dtype_map[typ]
assert isinstance(typ, type)
if typ is bool:
return torch.bool
if typ is int:
return torch.long
if typ is float:
return torch.get_default_dtype()
if typ is complex:
return corresponding_complex_dtype(torch.get_default_dtype())
raise ValueError("Invalid type!")
_ordered_types = (bool, int, float, complex)

View File

@ -8902,6 +8902,13 @@ def reference_inputs_where(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(a, args=(c, b))
# two python scalars
c = make_cond((10, 3), noncontiguous=True)
a = make_arg((1,)).item()
b = make_arg((1,)).item()
yield SampleInput(a, args=(c, b))
# NaN propagation
if dtype.is_floating_point or dtype.is_complex:
if dtype.is_floating_point:
@ -21211,6 +21218,7 @@ python_ref_db = [
"_refs.where",
torch_opinfo_name="where",
op=lambda self, condition, other: refs.where(condition, self, other),
supports_nvfuser=False,
),
]