mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[OpenReg][3/N] Migrate cpp_extensions_open_device_registration to OpenReg (#154181)
As the title stated. **Involved testcases**: - test_open_device_quantized - test_open_device_random - test_open_device_tensor - test_open_device_packed_sequence - test_open_device_storage Pull Request resolved: https://github.com/pytorch/pytorch/pull/154181 Approved by: https://github.com/albanD ghstack dependencies: #153947, #154018, #154019, #154106
This commit is contained in:
parent
7e5f29b2de
commit
1e7989cad5
|
|
@ -16,7 +16,6 @@
|
|||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <ATen/ops/view.h>
|
||||
|
|
@ -30,23 +29,6 @@ static c10::DeviceIndex custom_device_index = 0;
|
|||
static uint64_t storageImpl_counter = 0;
|
||||
static uint64_t last_storageImpl_saved_value = 0;
|
||||
|
||||
namespace {
|
||||
|
||||
void quantize_tensor_per_tensor_affine_privateuse1(
|
||||
const at::Tensor& rtensor,
|
||||
at::Tensor& qtensor,
|
||||
double scale,
|
||||
int64_t zero_point) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace at::native {
|
||||
|
||||
REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
|
||||
|
||||
} // namespace at::native
|
||||
struct CustomBackendMetadata : public c10::BackendMeta {
|
||||
// for testing this field will mutate when clone() is called by shallow_copy_from.
|
||||
int backend_version_format_{-1};
|
||||
|
|
@ -184,7 +166,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
|||
m.impl("add.Tensor", &custom_add_Tensor);
|
||||
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
|
||||
m.impl("set_.source_Storage", &custom_set_source_Storage);
|
||||
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
|
||||
}
|
||||
|
||||
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
||||
#include <ATen/ops/quantize_per_tensor_native.h>
|
||||
#include <ATen/ops/set_cpu_dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
|
||||
|
|
@ -254,11 +256,20 @@ int64_t _fused_sdp_choice_privateuse1(
|
|||
return static_cast<int64_t>(backend);
|
||||
}
|
||||
|
||||
void quantize_tensor_per_tensor_affine_privateuse1(
|
||||
const at::Tensor& rtensor,
|
||||
at::Tensor& qtensor,
|
||||
double scale,
|
||||
int64_t zero_point) {
|
||||
// Just test the process, so do nothing
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
m.impl("empty.memory_format", empty_openreg);
|
||||
m.impl("empty_strided", empty_strided_openreg);
|
||||
m.impl("as_strided", as_strided_openreg);
|
||||
m.impl("set_.source_Storage_storage_offset", set_openreg);
|
||||
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
|
||||
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
|
||||
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
|
||||
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
|
||||
|
|
@ -267,6 +278,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
|||
|
||||
namespace at::native {
|
||||
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel);
|
||||
REGISTER_PRIVATEUSE1_DISPATCH(
|
||||
quantize_tensor_per_tensor_affine_stub,
|
||||
&openreg::quantize_tensor_per_tensor_affine_privateuse1);
|
||||
REGISTER_PRIVATEUSE1_DISPATCH(
|
||||
_fused_sdp_choice_stub,
|
||||
&openreg::_fused_sdp_choice_privateuse1);
|
||||
|
|
|
|||
|
|
@ -55,117 +55,6 @@ class TestCppExtensionOpenRegistration(common.TestCase):
|
|||
verbose=True,
|
||||
)
|
||||
|
||||
def test_open_device_quantized(self):
|
||||
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to(
|
||||
"openreg"
|
||||
)
|
||||
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
|
||||
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
|
||||
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
||||
|
||||
def test_open_device_random(self):
|
||||
# check if torch.openreg have implemented get_rng_state
|
||||
with torch.random.fork_rng(device_type="openreg"):
|
||||
pass
|
||||
|
||||
def test_open_device_tensor(self):
|
||||
device = self.module.custom_device()
|
||||
|
||||
# check whether print tensor.type() meets the expectation
|
||||
dtypes = {
|
||||
torch.bool: "torch.openreg.BoolTensor",
|
||||
torch.double: "torch.openreg.DoubleTensor",
|
||||
torch.float32: "torch.openreg.FloatTensor",
|
||||
torch.half: "torch.openreg.HalfTensor",
|
||||
torch.int32: "torch.openreg.IntTensor",
|
||||
torch.int64: "torch.openreg.LongTensor",
|
||||
torch.int8: "torch.openreg.CharTensor",
|
||||
torch.short: "torch.openreg.ShortTensor",
|
||||
torch.uint8: "torch.openreg.ByteTensor",
|
||||
}
|
||||
for tt, dt in dtypes.items():
|
||||
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
|
||||
self.assertTrue(test_tensor.type() == dt)
|
||||
|
||||
# check whether the attributes and methods of the corresponding custom backend are generated correctly
|
||||
x = torch.empty(4, 4)
|
||||
self.assertFalse(x.is_openreg)
|
||||
|
||||
x = x.openreg(torch.device("openreg"))
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
self.assertTrue(x.is_openreg)
|
||||
|
||||
# test different device type input
|
||||
y = torch.empty(4, 4)
|
||||
self.assertFalse(y.is_openreg)
|
||||
|
||||
y = y.openreg(torch.device("openreg:0"))
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
self.assertTrue(y.is_openreg)
|
||||
|
||||
# test different device type input
|
||||
z = torch.empty(4, 4)
|
||||
self.assertFalse(z.is_openreg)
|
||||
|
||||
z = z.openreg(0)
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
self.assertTrue(z.is_openreg)
|
||||
|
||||
def test_open_device_packed_sequence(self):
|
||||
device = self.module.custom_device() # noqa: F841
|
||||
a = torch.rand(5, 3)
|
||||
b = torch.tensor([1, 1, 1, 1, 1])
|
||||
input = torch.nn.utils.rnn.PackedSequence(a, b)
|
||||
self.assertFalse(input.is_openreg)
|
||||
input_openreg = input.openreg()
|
||||
self.assertTrue(input_openreg.is_openreg)
|
||||
|
||||
def test_open_device_storage(self):
|
||||
# check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
|
||||
x = torch.empty(4, 4)
|
||||
z1 = x.storage()
|
||||
self.assertFalse(z1.is_openreg)
|
||||
|
||||
z1 = z1.openreg()
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
self.assertTrue(z1.is_openreg)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
|
||||
z1.openreg(torch.device("cpu"))
|
||||
|
||||
z1 = z1.cpu()
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
self.assertFalse(z1.is_openreg)
|
||||
|
||||
z1 = z1.openreg(device="openreg:0", non_blocking=False)
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
self.assertTrue(z1.is_openreg)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
|
||||
z1.openreg(device="cuda:0", non_blocking=False)
|
||||
|
||||
# check UntypedStorage
|
||||
y = torch.empty(4, 4)
|
||||
z2 = y.untyped_storage()
|
||||
self.assertFalse(z2.is_openreg)
|
||||
|
||||
z2 = z2.openreg()
|
||||
self.assertFalse(self.module.custom_add_called())
|
||||
self.assertTrue(z2.is_openreg)
|
||||
|
||||
# check custom StorageImpl create
|
||||
self.module.custom_storage_registry()
|
||||
|
||||
z3 = y.untyped_storage()
|
||||
self.assertFalse(self.module.custom_storageImpl_called())
|
||||
|
||||
z3 = z3.openreg()
|
||||
self.assertTrue(self.module.custom_storageImpl_called())
|
||||
self.assertFalse(self.module.custom_storageImpl_called())
|
||||
|
||||
z3 = z3[0:3]
|
||||
self.assertTrue(self.module.custom_storageImpl_called())
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info >= (3, 13),
|
||||
"Error: Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first.",
|
||||
|
|
|
|||
|
|
@ -88,6 +88,75 @@ class TestPrivateUse1(TestCase):
|
|||
torch.abs(x_openreg, out=o_openreg[:, :, 0:6:3])
|
||||
self.assertEqual(o_cpu, o_openreg.cpu())
|
||||
|
||||
def test_backend_tensor_type(self):
|
||||
dtypes_map = {
|
||||
torch.bool: "torch.openreg.BoolTensor",
|
||||
torch.double: "torch.openreg.DoubleTensor",
|
||||
torch.float32: "torch.openreg.FloatTensor",
|
||||
torch.half: "torch.openreg.HalfTensor",
|
||||
torch.int32: "torch.openreg.IntTensor",
|
||||
torch.int64: "torch.openreg.LongTensor",
|
||||
torch.int8: "torch.openreg.CharTensor",
|
||||
torch.short: "torch.openreg.ShortTensor",
|
||||
torch.uint8: "torch.openreg.ByteTensor",
|
||||
}
|
||||
|
||||
for dtype, str in dtypes_map.items():
|
||||
x = torch.empty(4, 4, dtype=dtype, device="openreg")
|
||||
self.assertTrue(x.type() == str)
|
||||
|
||||
def test_backend_tensor_methods(self):
|
||||
x = torch.empty(4, 4)
|
||||
self.assertFalse(x.is_openreg) # type: ignore[misc]
|
||||
|
||||
y = x.openreg(torch.device("openreg")) # type: ignore[misc]
|
||||
self.assertTrue(y.is_openreg) # type: ignore[misc]
|
||||
z = x.openreg(torch.device("openreg:0")) # type: ignore[misc]
|
||||
self.assertTrue(z.is_openreg) # type: ignore[misc]
|
||||
n = x.openreg(0) # type: ignore[misc]
|
||||
self.assertTrue(n.is_openreg) # type: ignore[misc]
|
||||
|
||||
@unittest.skip("Need to support Parameter in openreg")
|
||||
def test_backend_module_methods(self):
|
||||
class FakeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x = torch.nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
module = FakeModule()
|
||||
self.assertEqual(module.x.device.type, "cpu")
|
||||
module.openreg() # type: ignore[misc]
|
||||
self.assertEqual(module.x.device.type, "openreg")
|
||||
|
||||
@unittest.skip("Need to support untyped_storage in openreg")
|
||||
def test_backend_storage_methods(self):
|
||||
x = torch.empty(4, 4)
|
||||
|
||||
x_cpu = x.storage()
|
||||
self.assertFalse(x_cpu.is_openreg) # type: ignore[misc]
|
||||
x_openreg = x_cpu.openreg() # type: ignore[misc]
|
||||
self.assertTrue(x_openreg.is_openreg) # type: ignore[misc]
|
||||
|
||||
y = torch.empty(4, 4)
|
||||
|
||||
y_cpu = y.untyped_storage()
|
||||
self.assertFalse(y_cpu.is_openreg) # type: ignore[misc]
|
||||
y_openreg = y_cpu.openreg() # type: ignore[misc]
|
||||
self.assertTrue(y_openreg.is_openreg) # type: ignore[misc]
|
||||
|
||||
def test_backend_packed_sequence_methods(self):
|
||||
x = torch.rand(5, 3)
|
||||
y = torch.tensor([1, 1, 1, 1, 1])
|
||||
|
||||
z_cpu = torch.nn.utils.rnn.PackedSequence(x, y)
|
||||
self.assertFalse(z_cpu.is_openreg) # type: ignore[misc]
|
||||
|
||||
z_openreg = z_cpu.openreg() # type: ignore[misc]
|
||||
self.assertTrue(z_openreg.is_openreg) # type: ignore[misc]
|
||||
|
||||
def test_backend_fallback(self):
|
||||
pass
|
||||
|
||||
|
|
@ -247,6 +316,12 @@ class TestOpenReg(TestCase):
|
|||
self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]]))
|
||||
self.assertEqual(x.data_ptr(), y.data_ptr())
|
||||
|
||||
def test_quantize(self):
|
||||
x = torch.randn(3, 4, 5, dtype=torch.float32, device="openreg")
|
||||
quantized_tensor = torch.quantize_per_tensor(x, 0.1, 10, torch.qint8)
|
||||
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
|
||||
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user