mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add device logic handling for functions which allow scalar inputs as tensors (#86149)
Some functions allow scalars as tensor inputs. Add handling for them in device logic. Fix for https://github.com/pytorch/torchdynamo/issues/1445 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86149 Approved by: https://github.com/ezyang, https://github.com/bdhirsh
This commit is contained in:
parent
d6b030856b
commit
9da5646cdb
|
|
@ -396,6 +396,13 @@ class FakeTensorTest(TestCase):
|
|||
self.checkType(b.new(device='cuda'), "cuda", [0])
|
||||
self.checkType(a.new(torch.rand([1])), "cpu", [1])
|
||||
|
||||
def test_scalar_inputs(self):
|
||||
with FakeTensorMode():
|
||||
self.checkType(torch.div(3, 2), "cpu", [])
|
||||
ten = torch.zeros(2, dtype=torch.int32) * 2.0
|
||||
self.assertEqual(ten.dtype, torch.float)
|
||||
self.checkType(ten, "cpu", [2])
|
||||
|
||||
|
||||
class FakeTensorConstHandling(TestCase):
|
||||
def assertConst(self, *args):
|
||||
|
|
|
|||
|
|
@ -846,6 +846,7 @@ def _add_meta_to_tls_dispatch_include() -> None: ...
|
|||
def _meta_in_tls_dispatch_include() -> _bool: ...
|
||||
def _remove_meta_from_tls_dispatch_include() -> None: ...
|
||||
def _has_storage(x: Tensor) -> _bool: ...
|
||||
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
|
||||
# NB: There is no Capsule type in typing, see
|
||||
# https://code.activestate.com/lists/python-dev/139675/
|
||||
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
|
||||
|
|
|
|||
|
|
@ -572,6 +572,17 @@ class FakeTensor(torch.Tensor):
|
|||
tree_map(merge_devices, args)
|
||||
tree_map(merge_devices, kwargs)
|
||||
|
||||
# some functions that allow Python numbers to bind to Tensors
|
||||
# if we have failed to find a device, and we're running one of these operators,
|
||||
# we must have scalar only inputs
|
||||
if (
|
||||
torch._C._should_allow_numbers_as_tensors(
|
||||
func.name().split("::")[-1].split(".")[0]
|
||||
)
|
||||
and common_device is None
|
||||
):
|
||||
common_device = torch.device("cpu")
|
||||
|
||||
assert common_device is not None, f"Could not find common device for {func}"
|
||||
|
||||
return common_device
|
||||
|
|
|
|||
|
|
@ -1408,6 +1408,11 @@ Call this whenever a new thread is created in order to propagate values from
|
|||
std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_should_allow_numbers_as_tensors", [](const std::string& name) {
|
||||
return torch::should_allow_numbers_as_tensors(name);
|
||||
});
|
||||
|
||||
py_module.def("_is_deploy_enabled", []() {
|
||||
#if defined(USE_DEPLOY)
|
||||
return true;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user