torch.utils._content_store: fix error in hash_storage on XPU (#147785)

See https://github.com/pytorch/pytorch/actions/runs/13508573465/job/37745227468 for an example error. This is triggering after the merge of #147541, which enabled Dynamo compilation on XPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147785
Approved by: https://github.com/jansel
This commit is contained in:
Benjamin Glass 2025-02-26 23:57:59 +00:00 committed by PyTorch MergeBot
parent 915eb012e1
commit 723f3a9eab
2 changed files with 16 additions and 11 deletions

View File

@ -111,12 +111,14 @@ class TestContentStore(TestCase):
self.assertIsInstance(x4, FakeTensor)
same_meta_as_x(x4)
# Check fp64 works
x5 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float64, device=device
)
self.assertEqual(x5.float(), x)
self.assertEqual(x5.dtype, torch.float64)
# Check fp64 works on non-MPS platforms, since MPS doesn't currently
# support fp64.
if not device.startswith("mps"):
x5 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float64, device=device
)
self.assertEqual(x5.float(), x)
self.assertEqual(x5.dtype, torch.float64)
x6 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
@ -124,7 +126,9 @@ class TestContentStore(TestCase):
same_meta_as_x(x6)
instantiate_device_type_tests(TestContentStore, globals())
instantiate_device_type_tests(
TestContentStore, globals(), allow_mps=True, allow_xpu=True
)
if __name__ == "__main__":

View File

@ -40,7 +40,6 @@ import torch
import torch._prims as prims
import torch._utils
import torch.nn.functional as F
from torch._C import default_generator
from torch.multiprocessing.reductions import StorageWeakRef
@ -111,11 +110,13 @@ def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) ->
# TODO: factor this into a random utility
if device_type == "cpu":
generator = default_generator
generator = torch._C.default_generator
elif device_type == "cuda":
import torch.cuda
generator = torch.cuda.default_generators[storage.device.index]
elif device_type == "mps":
generator = torch.mps._get_default_mps_generator()
elif device_type == "xpu":
generator = torch.xpu.default_generators[storage.device.index]
else:
raise AssertionError(f"unhandled device type {device_type}")
state = generator.get_state()