[OpenReg][3/N] Migrate cpp_extensions_open_device_registration to OpenReg (#154181)

As the title stated.

**Involved testcases**:
- test_open_device_quantized
- test_open_device_random
- test_open_device_tensor
- test_open_device_packed_sequence
- test_open_device_storage
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154181
Approved by: https://github.com/albanD
ghstack dependencies: #153947, #154018, #154019, #154106
This commit is contained in:
FFFrog 2025-06-13 16:41:10 +08:00 committed by PyTorch MergeBot
parent 7e5f29b2de
commit 1e7989cad5
4 changed files with 89 additions and 130 deletions

View File

@ -16,7 +16,6 @@
#include <ATen/native/Resize.h> #include <ATen/native/Resize.h>
#include <ATen/native/UnaryOps.h> #include <ATen/native/UnaryOps.h>
#include <ATen/native/cpu/Loops.h> #include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <ATen/native/transformers/attention.h> #include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h> #include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/ops/view.h> #include <ATen/ops/view.h>
@ -30,23 +29,6 @@ static c10::DeviceIndex custom_device_index = 0;
static uint64_t storageImpl_counter = 0; static uint64_t storageImpl_counter = 0;
static uint64_t last_storageImpl_saved_value = 0; static uint64_t last_storageImpl_saved_value = 0;
namespace {
void quantize_tensor_per_tensor_affine_privateuse1(
const at::Tensor& rtensor,
at::Tensor& qtensor,
double scale,
int64_t zero_point) {
// do nothing
}
} // namespace
namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
} // namespace at::native
struct CustomBackendMetadata : public c10::BackendMeta { struct CustomBackendMetadata : public c10::BackendMeta {
// for testing this field will mutate when clone() is called by shallow_copy_from. // for testing this field will mutate when clone() is called by shallow_copy_from.
int backend_version_format_{-1}; int backend_version_format_{-1};
@ -184,7 +166,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("add.Tensor", &custom_add_Tensor); m.impl("add.Tensor", &custom_add_Tensor);
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize); m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
m.impl("set_.source_Storage", &custom_set_source_Storage); m.impl("set_.source_Storage", &custom_set_source_Storage);
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
} }
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {

View File

@ -4,10 +4,12 @@
#include <ATen/TensorIterator.h> #include <ATen/TensorIterator.h>
#include <ATen/native/UnaryOps.h> #include <ATen/native/UnaryOps.h>
#include <ATen/ops/as_strided_cpu_dispatch.h> #include <ATen/ops/as_strided_cpu_dispatch.h>
#include <ATen/ops/quantize_per_tensor_native.h>
#include <ATen/ops/set_cpu_dispatch.h> #include <ATen/ops/set_cpu_dispatch.h>
#include <ATen/native/DispatchStub.h> #include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/attention.h> #include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h> #include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <c10/core/Allocator.h> #include <c10/core/Allocator.h>
@ -254,11 +256,20 @@ int64_t _fused_sdp_choice_privateuse1(
return static_cast<int64_t>(backend); return static_cast<int64_t>(backend);
} }
void quantize_tensor_per_tensor_affine_privateuse1(
const at::Tensor& rtensor,
at::Tensor& qtensor,
double scale,
int64_t zero_point) {
// Just test the process, so do nothing
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("empty.memory_format", empty_openreg); m.impl("empty.memory_format", empty_openreg);
m.impl("empty_strided", empty_strided_openreg); m.impl("empty_strided", empty_strided_openreg);
m.impl("as_strided", as_strided_openreg); m.impl("as_strided", as_strided_openreg);
m.impl("set_.source_Storage_storage_offset", set_openreg); m.impl("set_.source_Storage_storage_offset", set_openreg);
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1); m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable); m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward); m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
@ -267,6 +278,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
namespace at::native { namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel); REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel);
REGISTER_PRIVATEUSE1_DISPATCH(
quantize_tensor_per_tensor_affine_stub,
&openreg::quantize_tensor_per_tensor_affine_privateuse1);
REGISTER_PRIVATEUSE1_DISPATCH( REGISTER_PRIVATEUSE1_DISPATCH(
_fused_sdp_choice_stub, _fused_sdp_choice_stub,
&openreg::_fused_sdp_choice_privateuse1); &openreg::_fused_sdp_choice_privateuse1);

View File

@ -55,117 +55,6 @@ class TestCppExtensionOpenRegistration(common.TestCase):
verbose=True, verbose=True,
) )
def test_open_device_quantized(self):
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to(
"openreg"
)
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
self.assertEqual(quantized_tensor.dtype, torch.qint8)
def test_open_device_random(self):
# check if torch.openreg have implemented get_rng_state
with torch.random.fork_rng(device_type="openreg"):
pass
def test_open_device_tensor(self):
device = self.module.custom_device()
# check whether print tensor.type() meets the expectation
dtypes = {
torch.bool: "torch.openreg.BoolTensor",
torch.double: "torch.openreg.DoubleTensor",
torch.float32: "torch.openreg.FloatTensor",
torch.half: "torch.openreg.HalfTensor",
torch.int32: "torch.openreg.IntTensor",
torch.int64: "torch.openreg.LongTensor",
torch.int8: "torch.openreg.CharTensor",
torch.short: "torch.openreg.ShortTensor",
torch.uint8: "torch.openreg.ByteTensor",
}
for tt, dt in dtypes.items():
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
self.assertTrue(test_tensor.type() == dt)
# check whether the attributes and methods of the corresponding custom backend are generated correctly
x = torch.empty(4, 4)
self.assertFalse(x.is_openreg)
x = x.openreg(torch.device("openreg"))
self.assertFalse(self.module.custom_add_called())
self.assertTrue(x.is_openreg)
# test different device type input
y = torch.empty(4, 4)
self.assertFalse(y.is_openreg)
y = y.openreg(torch.device("openreg:0"))
self.assertFalse(self.module.custom_add_called())
self.assertTrue(y.is_openreg)
# test different device type input
z = torch.empty(4, 4)
self.assertFalse(z.is_openreg)
z = z.openreg(0)
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z.is_openreg)
def test_open_device_packed_sequence(self):
device = self.module.custom_device() # noqa: F841
a = torch.rand(5, 3)
b = torch.tensor([1, 1, 1, 1, 1])
input = torch.nn.utils.rnn.PackedSequence(a, b)
self.assertFalse(input.is_openreg)
input_openreg = input.openreg()
self.assertTrue(input_openreg.is_openreg)
def test_open_device_storage(self):
# check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
x = torch.empty(4, 4)
z1 = x.storage()
self.assertFalse(z1.is_openreg)
z1 = z1.openreg()
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z1.is_openreg)
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
z1.openreg(torch.device("cpu"))
z1 = z1.cpu()
self.assertFalse(self.module.custom_add_called())
self.assertFalse(z1.is_openreg)
z1 = z1.openreg(device="openreg:0", non_blocking=False)
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z1.is_openreg)
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
z1.openreg(device="cuda:0", non_blocking=False)
# check UntypedStorage
y = torch.empty(4, 4)
z2 = y.untyped_storage()
self.assertFalse(z2.is_openreg)
z2 = z2.openreg()
self.assertFalse(self.module.custom_add_called())
self.assertTrue(z2.is_openreg)
# check custom StorageImpl create
self.module.custom_storage_registry()
z3 = y.untyped_storage()
self.assertFalse(self.module.custom_storageImpl_called())
z3 = z3.openreg()
self.assertTrue(self.module.custom_storageImpl_called())
self.assertFalse(self.module.custom_storageImpl_called())
z3 = z3[0:3]
self.assertTrue(self.module.custom_storageImpl_called())
@unittest.skipIf( @unittest.skipIf(
sys.version_info >= (3, 13), sys.version_info >= (3, 13),
"Error: Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.", "Error: Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.",

View File

@ -88,6 +88,75 @@ class TestPrivateUse1(TestCase):
torch.abs(x_openreg, out=o_openreg[:, :, 0:6:3]) torch.abs(x_openreg, out=o_openreg[:, :, 0:6:3])
self.assertEqual(o_cpu, o_openreg.cpu()) self.assertEqual(o_cpu, o_openreg.cpu())
def test_backend_tensor_type(self):
dtypes_map = {
torch.bool: "torch.openreg.BoolTensor",
torch.double: "torch.openreg.DoubleTensor",
torch.float32: "torch.openreg.FloatTensor",
torch.half: "torch.openreg.HalfTensor",
torch.int32: "torch.openreg.IntTensor",
torch.int64: "torch.openreg.LongTensor",
torch.int8: "torch.openreg.CharTensor",
torch.short: "torch.openreg.ShortTensor",
torch.uint8: "torch.openreg.ByteTensor",
}
for dtype, str in dtypes_map.items():
x = torch.empty(4, 4, dtype=dtype, device="openreg")
self.assertTrue(x.type() == str)
def test_backend_tensor_methods(self):
x = torch.empty(4, 4)
self.assertFalse(x.is_openreg) # type: ignore[misc]
y = x.openreg(torch.device("openreg")) # type: ignore[misc]
self.assertTrue(y.is_openreg) # type: ignore[misc]
z = x.openreg(torch.device("openreg:0")) # type: ignore[misc]
self.assertTrue(z.is_openreg) # type: ignore[misc]
n = x.openreg(0) # type: ignore[misc]
self.assertTrue(n.is_openreg) # type: ignore[misc]
@unittest.skip("Need to support Parameter in openreg")
def test_backend_module_methods(self):
class FakeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.x = torch.nn.Parameter(torch.randn(3, 3))
def forward(self):
pass
module = FakeModule()
self.assertEqual(module.x.device.type, "cpu")
module.openreg() # type: ignore[misc]
self.assertEqual(module.x.device.type, "openreg")
@unittest.skip("Need to support untyped_storage in openreg")
def test_backend_storage_methods(self):
x = torch.empty(4, 4)
x_cpu = x.storage()
self.assertFalse(x_cpu.is_openreg) # type: ignore[misc]
x_openreg = x_cpu.openreg() # type: ignore[misc]
self.assertTrue(x_openreg.is_openreg) # type: ignore[misc]
y = torch.empty(4, 4)
y_cpu = y.untyped_storage()
self.assertFalse(y_cpu.is_openreg) # type: ignore[misc]
y_openreg = y_cpu.openreg() # type: ignore[misc]
self.assertTrue(y_openreg.is_openreg) # type: ignore[misc]
def test_backend_packed_sequence_methods(self):
x = torch.rand(5, 3)
y = torch.tensor([1, 1, 1, 1, 1])
z_cpu = torch.nn.utils.rnn.PackedSequence(x, y)
self.assertFalse(z_cpu.is_openreg) # type: ignore[misc]
z_openreg = z_cpu.openreg() # type: ignore[misc]
self.assertTrue(z_openreg.is_openreg) # type: ignore[misc]
def test_backend_fallback(self): def test_backend_fallback(self):
pass pass
@ -247,6 +316,12 @@ class TestOpenReg(TestCase):
self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]])) self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]]))
self.assertEqual(x.data_ptr(), y.data_ptr()) self.assertEqual(x.data_ptr(), y.data_ptr())
def test_quantize(self):
x = torch.randn(3, 4, 5, dtype=torch.float32, device="openreg")
quantized_tensor = torch.quantize_per_tensor(x, 0.1, 10, torch.qint8)
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
self.assertEqual(quantized_tensor.dtype, torch.qint8)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()