From 74fb6ee4c5303da9980bb6658d4989fbba0c9432 Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Wed, 29 Jun 2022 19:58:29 +0000 Subject: [PATCH] [primTorch] support one tensor and two scalars in _prims.where (#80146) Fixes an issue of supporting two scalar arguments for `where` and other functions with similar set of arguments: ``` refs.where(a, 1, 0) ``` I had to skip `test_python_ref_executor` because the test causes a `Segmentation fault` when running with two scalars. The issue https://github.com/csarofeen/pytorch/issues/1770 has been fixed https://github.com/csarofeen/pytorch/pull/1774, so we can lift the skip when its merged to the upstream. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80146 Approved by: https://github.com/ngimel --- torch/_prims/__init__.py | 3 +++ torch/_prims/utils.py | 22 +++++++++++-------- .../_internal/common_methods_invocations.py | 8 +++++++ 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 07f2e1e61e7..6df66e04d6a 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -419,6 +419,9 @@ def _elementwise_meta( elif isinstance(arg, Number): scalar_type = type(arg) + if dtype is None and scalar_type is not None: + dtype = utils.type_to_dtype(scalar_type) + # Acquires the device (if it exists) or number device = None number = None diff --git a/torch/_prims/utils.py b/torch/_prims/utils.py index 14301ee2bb5..0e775fc68d2 100644 --- a/torch/_prims/utils.py +++ b/torch/_prims/utils.py @@ -639,19 +639,23 @@ def dtype_to_type(dtype: torch.dtype) -> type: raise ValueError("Invalid dtype!") -_type_to_dtype_map = { - bool: torch.bool, - int: torch.int64, - float: torch.float64, - complex: torch.complex128, -} - - def type_to_dtype(typ: type) -> torch.dtype: """ Computes the corresponding dtype for a Number type. """ - return _type_to_dtype_map[typ] + + assert isinstance(typ, type) + + if typ is bool: + return torch.bool + if typ is int: + return torch.long + if typ is float: + return torch.get_default_dtype() + if typ is complex: + return corresponding_complex_dtype(torch.get_default_dtype()) + + raise ValueError("Invalid type!") _ordered_types = (bool, int, float, complex) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fcb950b9101..074bc8e3995 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8902,6 +8902,13 @@ def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): yield SampleInput(a, args=(c, b)) + # two python scalars + c = make_cond((10, 3), noncontiguous=True) + a = make_arg((1,)).item() + b = make_arg((1,)).item() + + yield SampleInput(a, args=(c, b)) + # NaN propagation if dtype.is_floating_point or dtype.is_complex: if dtype.is_floating_point: @@ -21211,6 +21218,7 @@ python_ref_db = [ "_refs.where", torch_opinfo_name="where", op=lambda self, condition, other: refs.where(condition, self, other), + supports_nvfuser=False, ), ]