mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[OpenReg][2/N] Migrate cpp_extensions_open_device_registration to OpenReg (#156401)
As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156401 Approved by: https://github.com/albanD ghstack dependencies: #156400
This commit is contained in:
parent
1d522325b4
commit
a28e6ae38f
|
|
@ -53,38 +53,6 @@ class TestCppExtensionOpenRegistration(common.TestCase):
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_open_device_storage_type(self):
|
|
||||||
# test cpu float storage
|
|
||||||
cpu_tensor = torch.randn([8]).float()
|
|
||||||
cpu_storage = cpu_tensor.storage()
|
|
||||||
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
|
|
||||||
|
|
||||||
# test custom float storage before defining FloatStorage
|
|
||||||
openreg_tensor = cpu_tensor.openreg()
|
|
||||||
openreg_storage = openreg_tensor.storage()
|
|
||||||
self.assertEqual(openreg_storage.type(), "torch.storage.TypedStorage")
|
|
||||||
|
|
||||||
class CustomFloatStorage:
|
|
||||||
@property
|
|
||||||
def __module__(self):
|
|
||||||
return "torch." + torch._C._get_privateuse1_backend_name()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def __name__(self):
|
|
||||||
return "FloatStorage"
|
|
||||||
|
|
||||||
# test custom float storage after defining FloatStorage
|
|
||||||
try:
|
|
||||||
torch.openreg.FloatStorage = CustomFloatStorage()
|
|
||||||
self.assertEqual(openreg_storage.type(), "torch.openreg.FloatStorage")
|
|
||||||
|
|
||||||
# test custom int storage after defining FloatStorage
|
|
||||||
openreg_tensor2 = torch.randn([8]).int().openreg()
|
|
||||||
openreg_storage2 = openreg_tensor2.storage()
|
|
||||||
self.assertEqual(openreg_storage2.type(), "torch.storage.TypedStorage")
|
|
||||||
finally:
|
|
||||||
torch.openreg.FloatStorage = None
|
|
||||||
|
|
||||||
def test_open_device_faketensor(self):
|
def test_open_device_faketensor(self):
|
||||||
with torch._subclasses.fake_tensor.FakeTensorMode.push():
|
with torch._subclasses.fake_tensor.FakeTensorMode.push():
|
||||||
a = torch.empty(1, device="openreg")
|
a = torch.empty(1, device="openreg")
|
||||||
|
|
|
||||||
|
|
@ -107,6 +107,44 @@ class TestPrivateUse1(TestCase):
|
||||||
x = torch.empty(4, 4, dtype=dtype, device="openreg")
|
x = torch.empty(4, 4, dtype=dtype, device="openreg")
|
||||||
self.assertTrue(x.type() == str)
|
self.assertTrue(x.type() == str)
|
||||||
|
|
||||||
|
# Note that all dtype-d Tensor objects here are only for legacy reasons
|
||||||
|
# and should NOT be used.
|
||||||
|
def test_backend_type_methods(self):
|
||||||
|
# Tensor
|
||||||
|
tensor_cpu = torch.randn([8]).float()
|
||||||
|
self.assertEqual(tensor_cpu.type(), "torch.FloatTensor")
|
||||||
|
|
||||||
|
tensor_openreg = tensor_cpu.openreg()
|
||||||
|
self.assertEqual(tensor_openreg.type(), "torch.openreg.FloatTensor")
|
||||||
|
|
||||||
|
# Storage
|
||||||
|
storage_cpu = tensor_cpu.storage()
|
||||||
|
self.assertEqual(storage_cpu.type(), "torch.FloatStorage")
|
||||||
|
|
||||||
|
tensor_openreg = tensor_cpu.openreg()
|
||||||
|
storage_openreg = tensor_openreg.storage()
|
||||||
|
self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage")
|
||||||
|
|
||||||
|
class CustomFloatStorage:
|
||||||
|
@property
|
||||||
|
def __module__(self):
|
||||||
|
return "torch." + torch._C._get_privateuse1_backend_name()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __name__(self):
|
||||||
|
return "FloatStorage"
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.openreg.FloatStorage = CustomFloatStorage()
|
||||||
|
self.assertEqual(storage_openreg.type(), "torch.openreg.FloatStorage")
|
||||||
|
|
||||||
|
# test custom int storage after defining FloatStorage
|
||||||
|
tensor_openreg = tensor_cpu.int().openreg()
|
||||||
|
storage_openreg = tensor_openreg.storage()
|
||||||
|
self.assertEqual(storage_openreg.type(), "torch.storage.TypedStorage")
|
||||||
|
finally:
|
||||||
|
torch.openreg.FloatStorage = None
|
||||||
|
|
||||||
def test_backend_tensor_methods(self):
|
def test_backend_tensor_methods(self):
|
||||||
x = torch.empty(4, 4)
|
x = torch.empty(4, 4)
|
||||||
self.assertFalse(x.is_openreg) # type: ignore[misc]
|
self.assertFalse(x.is_openreg) # type: ignore[misc]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user