[OpenReg][2/N] Migrate cpp_extensions_open_device_registration to OpenReg (#156401)

As the title stated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156401
Approved by: https://github.com/albanD
ghstack dependencies: #156400
This commit is contained in:
FFFrog 2025-06-22 21:53:12 +08:00 committed by PyTorch MergeBot
parent 1d522325b4
commit a28e6ae38f
2 changed files with 38 additions and 32 deletions

View File

@ -53,38 +53,6 @@ class TestCppExtensionOpenRegistration(common.TestCase):
verbose=True, verbose=True,
) )
def test_open_device_storage_type(self):
# test cpu float storage
cpu_tensor = torch.randn([8]).float()
cpu_storage = cpu_tensor.storage()
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
# test custom float storage before defining FloatStorage
openreg_tensor = cpu_tensor.openreg()
openreg_storage = openreg_tensor.storage()
self.assertEqual(openreg_storage.type(), "torch.storage.TypedStorage")
class CustomFloatStorage:
@property
def __module__(self):
return "torch." + torch._C._get_privateuse1_backend_name()
@property
def __name__(self):
return "FloatStorage"
# test custom float storage after defining FloatStorage
try:
torch.openreg.FloatStorage = CustomFloatStorage()
self.assertEqual(openreg_storage.type(), "torch.openreg.FloatStorage")
# test custom int storage after defining FloatStorage
openreg_tensor2 = torch.randn([8]).int().openreg()
openreg_storage2 = openreg_tensor2.storage()
self.assertEqual(openreg_storage2.type(), "torch.storage.TypedStorage")
finally:
torch.openreg.FloatStorage = None
def test_open_device_faketensor(self): def test_open_device_faketensor(self):
with torch._subclasses.fake_tensor.FakeTensorMode.push(): with torch._subclasses.fake_tensor.FakeTensorMode.push():
a = torch.empty(1, device="openreg") a = torch.empty(1, device="openreg")

View File

@ -107,6 +107,44 @@ class TestPrivateUse1(TestCase):
x = torch.empty(4, 4, dtype=dtype, device="openreg") x = torch.empty(4, 4, dtype=dtype, device="openreg")
self.assertTrue(x.type() == str) self.assertTrue(x.type() == str)
# Note that all dtype-d Tensor objects here are only for legacy reasons
# and should NOT be used.
def test_backend_type_methods(self):
# Tensor
tensor_cpu = torch.randn([8]).float()
self.assertEqual(tensor_cpu.type(), "torch.FloatTensor")
tensor_openreg = tensor_cpu.openreg()
self.assertEqual(tensor_openreg.type(), "torch.openreg.FloatTensor")
# Storage
storage_cpu = tensor_cpu.storage()
self.assertEqual(storage_cpu.type(), "torch.FloatStorage")
tensor_openreg = tensor_cpu.openreg()
storage_openreg = tensor_openreg.storage()
self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage")
class CustomFloatStorage:
@property
def __module__(self):
return "torch." + torch._C._get_privateuse1_backend_name()
@property
def __name__(self):
return "FloatStorage"
try:
torch.openreg.FloatStorage = CustomFloatStorage()
self.assertEqual(storage_openreg.type(), "torch.openreg.FloatStorage")
# test custom int storage after defining FloatStorage
tensor_openreg = tensor_cpu.int().openreg()
storage_openreg = tensor_openreg.storage()
self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage")
finally:
torch.openreg.FloatStorage = None
def test_backend_tensor_methods(self): def test_backend_tensor_methods(self):
x = torch.empty(4, 4) x = torch.empty(4, 4)
self.assertFalse(x.is_openreg) # type: ignore[misc] self.assertFalse(x.is_openreg) # type: ignore[misc]