Add device logic handling for functions which allow scalar inputs as tensors (#86149)

Some functions allow scalars as tensor inputs. Add handling for them in device logic.

Fix for https://github.com/pytorch/torchdynamo/issues/1445
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86149
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
This commit is contained in:
Elias Ellison 2022-10-04 16:15:56 +00:00 committed by PyTorch MergeBot
parent d6b030856b
commit 9da5646cdb
4 changed files with 24 additions and 0 deletions

View File

@ -396,6 +396,13 @@ class FakeTensorTest(TestCase):
self.checkType(b.new(device='cuda'), "cuda", [0])
self.checkType(a.new(torch.rand([1])), "cpu", [1])
def test_scalar_inputs(self):
with FakeTensorMode():
self.checkType(torch.div(3, 2), "cpu", [])
ten = torch.zeros(2, dtype=torch.int32) * 2.0
self.assertEqual(ten.dtype, torch.float)
self.checkType(ten, "cpu", [2])
class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):

View File

@ -846,6 +846,7 @@ def _add_meta_to_tls_dispatch_include() -> None: ...
def _meta_in_tls_dispatch_include() -> _bool: ...
def _remove_meta_from_tls_dispatch_include() -> None: ...
def _has_storage(x: Tensor) -> _bool: ...
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
# NB: There is no Capsule type in typing, see
# https://code.activestate.com/lists/python-dev/139675/
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack

View File

@ -572,6 +572,17 @@ class FakeTensor(torch.Tensor):
tree_map(merge_devices, args)
tree_map(merge_devices, kwargs)
# some functions that allow Python numbers to bind to Tensors
# if we have failed to find a device, and we're running one of these operators,
# we must have scalar only inputs
if (
torch._C._should_allow_numbers_as_tensors(
func.name().split("::")[-1].split(".")[0]
)
and common_device is None
):
common_device = torch.device("cpu")
assert common_device is not None, f"Could not find common device for {func}"
return common_device

View File

@ -1408,6 +1408,11 @@ Call this whenever a new thread is created in order to propagate values from
std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
});
py_module.def(
"_should_allow_numbers_as_tensors", [](const std::string& name) {
return torch::should_allow_numbers_as_tensors(name);
});
py_module.def("_is_deploy_enabled", []() {
#if defined(USE_DEPLOY)
return true;