mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[OpenReg][3/N] Migrate cpp_extensions_open_device_registration to OpenReg (#154181)
As the title stated. **Involved testcases**: - test_open_device_quantized - test_open_device_random - test_open_device_tensor - test_open_device_packed_sequence - test_open_device_storage Pull Request resolved: https://github.com/pytorch/pytorch/pull/154181 Approved by: https://github.com/albanD ghstack dependencies: #153947, #154018, #154019, #154106
This commit is contained in:
parent
7e5f29b2de
commit
1e7989cad5
|
|
@ -16,7 +16,6 @@
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
#include <ATen/native/UnaryOps.h>
|
#include <ATen/native/UnaryOps.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
|
||||||
#include <ATen/native/transformers/attention.h>
|
#include <ATen/native/transformers/attention.h>
|
||||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||||
#include <ATen/ops/view.h>
|
#include <ATen/ops/view.h>
|
||||||
|
|
@ -30,23 +29,6 @@ static c10::DeviceIndex custom_device_index = 0;
|
||||||
static uint64_t storageImpl_counter = 0;
|
static uint64_t storageImpl_counter = 0;
|
||||||
static uint64_t last_storageImpl_saved_value = 0;
|
static uint64_t last_storageImpl_saved_value = 0;
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
void quantize_tensor_per_tensor_affine_privateuse1(
|
|
||||||
const at::Tensor& rtensor,
|
|
||||||
at::Tensor& qtensor,
|
|
||||||
double scale,
|
|
||||||
int64_t zero_point) {
|
|
||||||
// do nothing
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace at::native {
|
|
||||||
|
|
||||||
REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
|
|
||||||
|
|
||||||
} // namespace at::native
|
|
||||||
struct CustomBackendMetadata : public c10::BackendMeta {
|
struct CustomBackendMetadata : public c10::BackendMeta {
|
||||||
// for testing this field will mutate when clone() is called by shallow_copy_from.
|
// for testing this field will mutate when clone() is called by shallow_copy_from.
|
||||||
int backend_version_format_{-1};
|
int backend_version_format_{-1};
|
||||||
|
|
@ -184,7 +166,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||||
m.impl("add.Tensor", &custom_add_Tensor);
|
m.impl("add.Tensor", &custom_add_Tensor);
|
||||||
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
|
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
|
||||||
m.impl("set_.source_Storage", &custom_set_source_Storage);
|
m.impl("set_.source_Storage", &custom_set_source_Storage);
|
||||||
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,12 @@
|
||||||
#include <ATen/TensorIterator.h>
|
#include <ATen/TensorIterator.h>
|
||||||
#include <ATen/native/UnaryOps.h>
|
#include <ATen/native/UnaryOps.h>
|
||||||
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
||||||
|
#include <ATen/ops/quantize_per_tensor_native.h>
|
||||||
#include <ATen/ops/set_cpu_dispatch.h>
|
#include <ATen/ops/set_cpu_dispatch.h>
|
||||||
#include <ATen/native/DispatchStub.h>
|
#include <ATen/native/DispatchStub.h>
|
||||||
#include <ATen/native/transformers/attention.h>
|
#include <ATen/native/transformers/attention.h>
|
||||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||||
|
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||||
|
|
||||||
#include <c10/core/Allocator.h>
|
#include <c10/core/Allocator.h>
|
||||||
|
|
||||||
|
|
@ -254,11 +256,20 @@ int64_t _fused_sdp_choice_privateuse1(
|
||||||
return static_cast<int64_t>(backend);
|
return static_cast<int64_t>(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void quantize_tensor_per_tensor_affine_privateuse1(
|
||||||
|
const at::Tensor& rtensor,
|
||||||
|
at::Tensor& qtensor,
|
||||||
|
double scale,
|
||||||
|
int64_t zero_point) {
|
||||||
|
// Just test the process, so do nothing
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||||
m.impl("empty.memory_format", empty_openreg);
|
m.impl("empty.memory_format", empty_openreg);
|
||||||
m.impl("empty_strided", empty_strided_openreg);
|
m.impl("empty_strided", empty_strided_openreg);
|
||||||
m.impl("as_strided", as_strided_openreg);
|
m.impl("as_strided", as_strided_openreg);
|
||||||
m.impl("set_.source_Storage_storage_offset", set_openreg);
|
m.impl("set_.source_Storage_storage_offset", set_openreg);
|
||||||
|
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
|
||||||
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
|
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
|
||||||
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
|
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
|
||||||
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
|
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
|
||||||
|
|
@ -267,6 +278,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel);
|
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel);
|
||||||
|
REGISTER_PRIVATEUSE1_DISPATCH(
|
||||||
|
quantize_tensor_per_tensor_affine_stub,
|
||||||
|
&openreg::quantize_tensor_per_tensor_affine_privateuse1);
|
||||||
REGISTER_PRIVATEUSE1_DISPATCH(
|
REGISTER_PRIVATEUSE1_DISPATCH(
|
||||||
_fused_sdp_choice_stub,
|
_fused_sdp_choice_stub,
|
||||||
&openreg::_fused_sdp_choice_privateuse1);
|
&openreg::_fused_sdp_choice_privateuse1);
|
||||||
|
|
|
||||||
|
|
@ -55,117 +55,6 @@ class TestCppExtensionOpenRegistration(common.TestCase):
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_open_device_quantized(self):
|
|
||||||
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to(
|
|
||||||
"openreg"
|
|
||||||
)
|
|
||||||
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
|
|
||||||
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
|
|
||||||
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
|
||||||
|
|
||||||
def test_open_device_random(self):
|
|
||||||
# check if torch.openreg have implemented get_rng_state
|
|
||||||
with torch.random.fork_rng(device_type="openreg"):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_open_device_tensor(self):
|
|
||||||
device = self.module.custom_device()
|
|
||||||
|
|
||||||
# check whether print tensor.type() meets the expectation
|
|
||||||
dtypes = {
|
|
||||||
torch.bool: "torch.openreg.BoolTensor",
|
|
||||||
torch.double: "torch.openreg.DoubleTensor",
|
|
||||||
torch.float32: "torch.openreg.FloatTensor",
|
|
||||||
torch.half: "torch.openreg.HalfTensor",
|
|
||||||
torch.int32: "torch.openreg.IntTensor",
|
|
||||||
torch.int64: "torch.openreg.LongTensor",
|
|
||||||
torch.int8: "torch.openreg.CharTensor",
|
|
||||||
torch.short: "torch.openreg.ShortTensor",
|
|
||||||
torch.uint8: "torch.openreg.ByteTensor",
|
|
||||||
}
|
|
||||||
for tt, dt in dtypes.items():
|
|
||||||
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
|
|
||||||
self.assertTrue(test_tensor.type() == dt)
|
|
||||||
|
|
||||||
# check whether the attributes and methods of the corresponding custom backend are generated correctly
|
|
||||||
x = torch.empty(4, 4)
|
|
||||||
self.assertFalse(x.is_openreg)
|
|
||||||
|
|
||||||
x = x.openreg(torch.device("openreg"))
|
|
||||||
self.assertFalse(self.module.custom_add_called())
|
|
||||||
self.assertTrue(x.is_openreg)
|
|
||||||
|
|
||||||
# test different device type input
|
|
||||||
y = torch.empty(4, 4)
|
|
||||||
self.assertFalse(y.is_openreg)
|
|
||||||
|
|
||||||
y = y.openreg(torch.device("openreg:0"))
|
|
||||||
self.assertFalse(self.module.custom_add_called())
|
|
||||||
self.assertTrue(y.is_openreg)
|
|
||||||
|
|
||||||
# test different device type input
|
|
||||||
z = torch.empty(4, 4)
|
|
||||||
self.assertFalse(z.is_openreg)
|
|
||||||
|
|
||||||
z = z.openreg(0)
|
|
||||||
self.assertFalse(self.module.custom_add_called())
|
|
||||||
self.assertTrue(z.is_openreg)
|
|
||||||
|
|
||||||
def test_open_device_packed_sequence(self):
|
|
||||||
device = self.module.custom_device() # noqa: F841
|
|
||||||
a = torch.rand(5, 3)
|
|
||||||
b = torch.tensor([1, 1, 1, 1, 1])
|
|
||||||
input = torch.nn.utils.rnn.PackedSequence(a, b)
|
|
||||||
self.assertFalse(input.is_openreg)
|
|
||||||
input_openreg = input.openreg()
|
|
||||||
self.assertTrue(input_openreg.is_openreg)
|
|
||||||
|
|
||||||
def test_open_device_storage(self):
|
|
||||||
# check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
|
|
||||||
x = torch.empty(4, 4)
|
|
||||||
z1 = x.storage()
|
|
||||||
self.assertFalse(z1.is_openreg)
|
|
||||||
|
|
||||||
z1 = z1.openreg()
|
|
||||||
self.assertFalse(self.module.custom_add_called())
|
|
||||||
self.assertTrue(z1.is_openreg)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
|
|
||||||
z1.openreg(torch.device("cpu"))
|
|
||||||
|
|
||||||
z1 = z1.cpu()
|
|
||||||
self.assertFalse(self.module.custom_add_called())
|
|
||||||
self.assertFalse(z1.is_openreg)
|
|
||||||
|
|
||||||
z1 = z1.openreg(device="openreg:0", non_blocking=False)
|
|
||||||
self.assertFalse(self.module.custom_add_called())
|
|
||||||
self.assertTrue(z1.is_openreg)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
|
|
||||||
z1.openreg(device="cuda:0", non_blocking=False)
|
|
||||||
|
|
||||||
# check UntypedStorage
|
|
||||||
y = torch.empty(4, 4)
|
|
||||||
z2 = y.untyped_storage()
|
|
||||||
self.assertFalse(z2.is_openreg)
|
|
||||||
|
|
||||||
z2 = z2.openreg()
|
|
||||||
self.assertFalse(self.module.custom_add_called())
|
|
||||||
self.assertTrue(z2.is_openreg)
|
|
||||||
|
|
||||||
# check custom StorageImpl create
|
|
||||||
self.module.custom_storage_registry()
|
|
||||||
|
|
||||||
z3 = y.untyped_storage()
|
|
||||||
self.assertFalse(self.module.custom_storageImpl_called())
|
|
||||||
|
|
||||||
z3 = z3.openreg()
|
|
||||||
self.assertTrue(self.module.custom_storageImpl_called())
|
|
||||||
self.assertFalse(self.module.custom_storageImpl_called())
|
|
||||||
|
|
||||||
z3 = z3[0:3]
|
|
||||||
self.assertTrue(self.module.custom_storageImpl_called())
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
sys.version_info >= (3, 13),
|
sys.version_info >= (3, 13),
|
||||||
"Error: Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.",
|
"Error: Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.",
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,75 @@ class TestPrivateUse1(TestCase):
|
||||||
torch.abs(x_openreg, out=o_openreg[:, :, 0:6:3])
|
torch.abs(x_openreg, out=o_openreg[:, :, 0:6:3])
|
||||||
self.assertEqual(o_cpu, o_openreg.cpu())
|
self.assertEqual(o_cpu, o_openreg.cpu())
|
||||||
|
|
||||||
|
def test_backend_tensor_type(self):
|
||||||
|
dtypes_map = {
|
||||||
|
torch.bool: "torch.openreg.BoolTensor",
|
||||||
|
torch.double: "torch.openreg.DoubleTensor",
|
||||||
|
torch.float32: "torch.openreg.FloatTensor",
|
||||||
|
torch.half: "torch.openreg.HalfTensor",
|
||||||
|
torch.int32: "torch.openreg.IntTensor",
|
||||||
|
torch.int64: "torch.openreg.LongTensor",
|
||||||
|
torch.int8: "torch.openreg.CharTensor",
|
||||||
|
torch.short: "torch.openreg.ShortTensor",
|
||||||
|
torch.uint8: "torch.openreg.ByteTensor",
|
||||||
|
}
|
||||||
|
|
||||||
|
for dtype, str in dtypes_map.items():
|
||||||
|
x = torch.empty(4, 4, dtype=dtype, device="openreg")
|
||||||
|
self.assertTrue(x.type() == str)
|
||||||
|
|
||||||
|
def test_backend_tensor_methods(self):
|
||||||
|
x = torch.empty(4, 4)
|
||||||
|
self.assertFalse(x.is_openreg) # type: ignore[misc]
|
||||||
|
|
||||||
|
y = x.openreg(torch.device("openreg")) # type: ignore[misc]
|
||||||
|
self.assertTrue(y.is_openreg) # type: ignore[misc]
|
||||||
|
z = x.openreg(torch.device("openreg:0")) # type: ignore[misc]
|
||||||
|
self.assertTrue(z.is_openreg) # type: ignore[misc]
|
||||||
|
n = x.openreg(0) # type: ignore[misc]
|
||||||
|
self.assertTrue(n.is_openreg) # type: ignore[misc]
|
||||||
|
|
||||||
|
@unittest.skip("Need to support Parameter in openreg")
|
||||||
|
def test_backend_module_methods(self):
|
||||||
|
class FakeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.x = torch.nn.Parameter(torch.randn(3, 3))
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
module = FakeModule()
|
||||||
|
self.assertEqual(module.x.device.type, "cpu")
|
||||||
|
module.openreg() # type: ignore[misc]
|
||||||
|
self.assertEqual(module.x.device.type, "openreg")
|
||||||
|
|
||||||
|
@unittest.skip("Need to support untyped_storage in openreg")
|
||||||
|
def test_backend_storage_methods(self):
|
||||||
|
x = torch.empty(4, 4)
|
||||||
|
|
||||||
|
x_cpu = x.storage()
|
||||||
|
self.assertFalse(x_cpu.is_openreg) # type: ignore[misc]
|
||||||
|
x_openreg = x_cpu.openreg() # type: ignore[misc]
|
||||||
|
self.assertTrue(x_openreg.is_openreg) # type: ignore[misc]
|
||||||
|
|
||||||
|
y = torch.empty(4, 4)
|
||||||
|
|
||||||
|
y_cpu = y.untyped_storage()
|
||||||
|
self.assertFalse(y_cpu.is_openreg) # type: ignore[misc]
|
||||||
|
y_openreg = y_cpu.openreg() # type: ignore[misc]
|
||||||
|
self.assertTrue(y_openreg.is_openreg) # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_backend_packed_sequence_methods(self):
|
||||||
|
x = torch.rand(5, 3)
|
||||||
|
y = torch.tensor([1, 1, 1, 1, 1])
|
||||||
|
|
||||||
|
z_cpu = torch.nn.utils.rnn.PackedSequence(x, y)
|
||||||
|
self.assertFalse(z_cpu.is_openreg) # type: ignore[misc]
|
||||||
|
|
||||||
|
z_openreg = z_cpu.openreg() # type: ignore[misc]
|
||||||
|
self.assertTrue(z_openreg.is_openreg) # type: ignore[misc]
|
||||||
|
|
||||||
def test_backend_fallback(self):
|
def test_backend_fallback(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -247,6 +316,12 @@ class TestOpenReg(TestCase):
|
||||||
self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]]))
|
self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]]))
|
||||||
self.assertEqual(x.data_ptr(), y.data_ptr())
|
self.assertEqual(x.data_ptr(), y.data_ptr())
|
||||||
|
|
||||||
|
def test_quantize(self):
|
||||||
|
x = torch.randn(3, 4, 5, dtype=torch.float32, device="openreg")
|
||||||
|
quantized_tensor = torch.quantize_per_tensor(x, 0.1, 10, torch.qint8)
|
||||||
|
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
|
||||||
|
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user