mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
torch.utils._content_store: fix error in hash_storage on XPU (#147785)
See https://github.com/pytorch/pytorch/actions/runs/13508573465/job/37745227468 for an example error. This is triggering after the merge of #147541, which enabled Dynamo compilation on XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147785 Approved by: https://github.com/jansel
This commit is contained in:
parent
915eb012e1
commit
723f3a9eab
|
|
@ -111,12 +111,14 @@ class TestContentStore(TestCase):
|
|||
self.assertIsInstance(x4, FakeTensor)
|
||||
same_meta_as_x(x4)
|
||||
|
||||
# Check fp64 works
|
||||
x5 = torch.ops.debugprims.load_tensor.default(
|
||||
"x", (4,), (1,), dtype=torch.float64, device=device
|
||||
)
|
||||
self.assertEqual(x5.float(), x)
|
||||
self.assertEqual(x5.dtype, torch.float64)
|
||||
# Check fp64 works on non-MPS platforms, since MPS doesn't currently
|
||||
# support fp64.
|
||||
if not device.startswith("mps"):
|
||||
x5 = torch.ops.debugprims.load_tensor.default(
|
||||
"x", (4,), (1,), dtype=torch.float64, device=device
|
||||
)
|
||||
self.assertEqual(x5.float(), x)
|
||||
self.assertEqual(x5.dtype, torch.float64)
|
||||
|
||||
x6 = torch.ops.debugprims.load_tensor.default(
|
||||
"x", (4,), (1,), dtype=torch.float32, device=device
|
||||
|
|
@ -124,7 +126,9 @@ class TestContentStore(TestCase):
|
|||
same_meta_as_x(x6)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestContentStore, globals())
|
||||
instantiate_device_type_tests(
|
||||
TestContentStore, globals(), allow_mps=True, allow_xpu=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ import torch
|
|||
import torch._prims as prims
|
||||
import torch._utils
|
||||
import torch.nn.functional as F
|
||||
from torch._C import default_generator
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
|
||||
|
|
@ -111,11 +110,13 @@ def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) ->
|
|||
|
||||
# TODO: factor this into a random utility
|
||||
if device_type == "cpu":
|
||||
generator = default_generator
|
||||
generator = torch._C.default_generator
|
||||
elif device_type == "cuda":
|
||||
import torch.cuda
|
||||
|
||||
generator = torch.cuda.default_generators[storage.device.index]
|
||||
elif device_type == "mps":
|
||||
generator = torch.mps._get_default_mps_generator()
|
||||
elif device_type == "xpu":
|
||||
generator = torch.xpu.default_generators[storage.device.index]
|
||||
else:
|
||||
raise AssertionError(f"unhandled device type {device_type}")
|
||||
state = generator.get_state()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user