Refactor and Improve the OpenReg Module (#158090)

----
# Refactor and Improve the OpenReg Module

## Background

Since PrivateUse1 has become the main path for integrating new devices with PyTorch, there have been some feature requests related to PrivateUse1 regarding interfaces, documentation, reference examples, etc., such as the following:

- https://github.com/pytorch/pytorch/issues/155864
- https://github.com/pytorch/pytorch/issues/144955
- https://github.com/pytorch/pytorch/issues/144845

Taking these requests into consideration and combining them with the position of OpenReg, which is currently used as the test backend for PrivateUse1, I'm planning to make the following optimizations:

- Optimize the implementation of OpenReg to make it align with the standard specifications for real backend (C++) access, serving as a reference for new device integration code.
- Add comprehensive documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html) to guide new accelerator integration, functioning as a reference manual.

## Design Principles:

- Minimization Principle: Keep the code small and clear; only implement the minimum set of code required for verification and as an integration reference.
- Authenticity Principle: Integrate OpenReg in the same way that real accelerators access PyTorch.

## More Infos:

Pleaes refer to [this](6b8020f1ab/test/cpp_extensions/open_registration_extension/torch_openreg/README.md) for more information about `OpenReg`.

## Current Progress:
- Refer to the implementation of [torch_xla](https://github.com/pytorch/xla) to refactor all of OpenReg's code, making it easier to understand.
- Ensure all tests in [test/test_openreg.py](https://github.com/FFFrog/pytorch/blob/openreg/test/test_openreg.py) pass after refactoring.

## Next Steps:
- Add more features to cover all integration points.
- Gradually add user guides and documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158090
Approved by: https://github.com/seemethere, https://github.com/albanD
This commit is contained in:
FFFrog 2025-07-15 12:26:58 +08:00 committed by PyTorch MergeBot
parent 6c5227ba00
commit 1b389025ba
51 changed files with 2568 additions and 1792 deletions

View File

@ -1,37 +0,0 @@
# PyTorch OpenReg
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core.
## How to use
Install as standalone with `python -m pip install -e .` (or `python -m pip install .`)
from this folder. You can run test via `python {PYTORCH_ROOT_PATH}/test/test_openreg.py`.
## Design principles
For simplicity anything that can be implemented from python is done so.
A real implementation will most likely want to call these different APIs from c++ directly.
The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing.
The codebase is split as follows:
- `pytorch_openreg/__init__.py`
- imports torch to get core state initialized.
- imports `._aten_impl` to register our aten op implementations to torch.
- imports `.C` to load our c++ extension that registers more ops, allocator and hooks.
- renames the PrivateUse1 backend and register our python-side module.
- `pytorch_openreg/_aten_impl.py`
- Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation.
- `pytorch_openreg/_device_daemon.py`
- contains the Allocator (responsible for allocating memory on the device side and host side, as int8 buffers).
- contains `Driver`, which as user-process driver to deal with some information needed to be done in driver.
- contains `Executor`, which as device-process exector to do something related device logic.
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process.
- The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
## Next steps
The main next step would be to:
- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this.

View File

@ -1,122 +0,0 @@
import types
import torch
# Create our python implementation dict so that the C++ module
# can access it during its initialization and also register aten impls.
from ._aten_impl import impl_factory as impl_factory # noqa: F401
from ._device_daemon import driver
# Load the C++ Module
import pytorch_openreg._C # isort:skip # type: ignore[import] # noqa: F401
def _create_module():
module = types.ModuleType("_OpenRegMod")
class device:
r"""Context-manager that changes the selected device.
Args:
device (torch.device or int): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device):
self.idx = torch.accelerator._get_device_index(device, optional=True)
self.prev_idx = -1
def __enter__(self):
self.prev_idx = driver.exec("exchangeDevice", self.idx)
def __exit__(self, type, value, traceback):
self.idx = driver.exec("uncheckedSetDevice", self.prev_idx)
return False
def device_count() -> int:
return driver.exec("deviceCount")
def is_available():
return True
def current_device():
return torch.accelerator.current_device_index()
def get_rng_state(device="openreg"):
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("openreg", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = pytorch_openreg._C._get_default_generator(idx)
return default_generator.get_state()
def set_rng_state(new_state, device="openreg"):
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("openreg", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = pytorch_openreg._C._get_default_generator(idx)
default_generator.set_state(new_state)
def initial_seed() -> int:
_lazy_init()
idx = current_device()
default_generator = pytorch_openreg._C._get_default_generator(idx)
return default_generator.initial_seed()
def manual_seed(seed: int) -> None:
seed = int(seed)
idx = current_device()
default_generator = pytorch_openreg._C._get_default_generator(idx)
default_generator.manual_seed(seed)
def manual_seed_all(seed: int) -> None:
seed = int(seed)
for idx in range(device_count()):
default_generator = pytorch_openreg._C._get_default_generator(idx)
default_generator.manual_seed(seed)
def is_initialized():
return module._initialized
def _is_in_bad_fork():
return False
def _lazy_init():
if is_initialized():
return
pytorch_openreg._C._init()
module._initialized = True
module.is_available = is_available # type: ignore[assignment]
module._initialized = False # type: ignore[assignment]
module._lazy_init = _lazy_init # type: ignore[assignment]
module.is_initialized = is_initialized # type: ignore[assignment]
module.device = device # type: ignore[assignment]
module.device_count = device_count # type: ignore[assignment]
module.current_device = current_device # type: ignore[assignment]
module.get_rng_state = get_rng_state # type: ignore[assignment]
module.set_rng_state = set_rng_state # type: ignore[assignment]
module._is_in_bad_fork = _is_in_bad_fork # type: ignore[assignment]
module.initial_seed = initial_seed # type: ignore[assignment]
module.manual_seed = manual_seed # type: ignore[assignment]
module.manual_seed_all = manual_seed_all # type: ignore[assignment]
return module
# Set all the appropriate state on PyTorch
torch.utils.rename_privateuse1_backend("openreg")
torch._register_device_module("openreg", _create_module())
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)

View File

@ -1,186 +0,0 @@
import logging
import torch
from torch.utils._pytree import tree_any
log = logging.getLogger(__name__)
from ._device_daemon import driver
from ._meta_parser import prepare_for_sending, to_device_no_copy
_IMPL_REGISTRY = {}
def impl_factory(name):
if name in _IMPL_REGISTRY:
return _IMPL_REGISTRY[name]
def _(*args, **kwargs):
log.info("Calling hook %s", name)
return driver.exec(name, *args, **kwargs)
_IMPL_REGISTRY[name] = _
return _
def _openreg_kernel_fallback(op, *args, **kwargs):
def get_tensor_device(*args):
for arg in args:
if isinstance(arg, torch.Tensor) and arg.device.type == "openreg":
return arg.device
device = get_tensor_device(*args)
if device is None:
return _kernel_fallback(op, *args, **kwargs)
# Mimicks the DeviceGuard system we have in aten
with torch.openreg.device(device): # type: ignore[misc]
return _kernel_fallback(op, *args, **kwargs)
def _kernel_fallback(op, *args, **kwargs):
log.info("Calling kernel %s", op)
op_name = None
post_process = None
if "out" in op._overloadname:
# Note that all structured native op will call here
if isinstance(kwargs["out"], tuple):
raise RuntimeError(f"out= variant {op} with tuple out= not supported")
if kwargs["out"].nelement() == 0:
# Out variant that needs a resize, convert to an out of place
# and handle generically below
orig_out = kwargs["out"]
del kwargs["out"]
if op._overloadname != "out":
raise RuntimeError(
"Cannot retranslate non-default out= variant form 0 size"
)
op = op.overloadpacket.default
def _post_process():
nonlocal real_res
orig_out.set_(real_res)
real_res = orig_out
post_process = _post_process
else:
# No metadata update to do, just run the op on the device
op_name = op.overloadpacket._qualified_op_name
real_res = kwargs["out"]
elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)):
# No Tensor argument means factory function
# They should decompose and be handled in our c++ side directly
raise RuntimeError(f"{op} not handled yet.")
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
# Only handle inplace ops returning their first arg
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
assert len(op._schema.returns) == 1, (
f"NYI Inplace {op} with more than one return"
)
op_name = op.overloadpacket._qualified_op_name
real_res = args[0]
elif any(r.alias_info is not None for r in op._schema.returns):
# View ops
if op is torch.ops.aten.view.default:
return torch.ops.aten._unsafe_view(*args, **kwargs)
raise RuntimeError(f"{op} view op is not handled yet")
if op_name is None:
# 1. Compute updated metadata
if torch.Tag.dynamic_output_shape not in op.tags:
# Usual case: run the meta op to see the output metadata
meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs)
meta_res = op(*meta_args, **meta_kwargs)
# 2. Allocate the output
real_res, _ = to_device_no_copy("openreg", meta_res, {})
else:
# Slow version for data-dependent functions:
# Run the op on the device just to get the output shape
args_, kwargs_ = prepare_for_sending(args, kwargs)
shape = driver.exec(
"get_op_output_shape",
op.overloadpacket._qualified_op_name,
args_,
kwargs_,
)
# 2. Allocate the output
real_res = args[0].new(shape)
# 3. Move to out variant
kwargs["out"] = real_res
# Let overload resolution find the out= overload
op_name = op.overloadpacket._qualified_op_name
# 4. Run the compute and populate the output on the device
args, kwargs = prepare_for_sending(args, kwargs)
driver.exec("run_op", op_name, args, kwargs)
if post_process is not None:
post_process()
return real_res
def copy_from_device(from_):
with torch.openreg.device(from_.device): # type: ignore[misc]
args, _ = prepare_for_sending((from_,), {})
return driver.exec("send_data", *args)
def copy_from_host_to_device(from_, to_):
with torch.openreg.device(to_.device): # type: ignore[misc]
args, _ = prepare_for_sending((to_,), {})
driver.exec("recv_data", from_, *args)
return to_
def _copy_from(from_, to_):
if from_.device.type == to_.device.type:
assert from_.device.type == "openreg"
if from_.device.index == to_.device.index:
op = torch.ops.aten.copy_.default
return _openreg_kernel_fallback(op, to_, from_)
else:
host_mem = copy_from_device(from_)
return copy_from_host_to_device(host_mem, to_)
elif from_.device.type == "openreg":
host_mem = copy_from_device(from_)
return to_.copy_(host_mem)
elif to_.device.type == "openreg":
return copy_from_host_to_device(from_, to_)
else:
raise RuntimeError("Should not happen")
def _set_source_tensor(ten1, ten2):
return torch.ops.aten.set_.source_Storage_storage_offset(
ten1,
ten2.untyped_storage(),
ten2.storage_offset(),
ten2.size(),
ten2.stride(),
)
def _local_scalar_dense(ten):
host_mem = copy_from_device(ten)
return host_mem.item()
_openreg_lib = torch.library.Library("_", "IMPL")
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")
_openreg_lib_aten = torch.library.Library("aten", "IMPL")
_openreg_lib_aten.impl("_copy_from", _copy_from, dispatch_key="PrivateUse1")
_openreg_lib_aten.impl(
"set_.source_Tensor", _set_source_tensor, dispatch_key="PrivateUse1"
)
_openreg_lib_aten.impl(
"_local_scalar_dense", _local_scalar_dense, dispatch_key="PrivateUse1"
)

View File

@ -1,391 +0,0 @@
import ctypes
import logging
import threading
import time
import torch
from ._meta_parser import (
OpenRegTensorData,
receive_after_sending,
safe_str,
validate_send_queue_args,
)
log = logging.getLogger(__name__)
mp_context = torch.multiprocessing.get_context("spawn")
# Constant properties of our device
NUM_DEVICES = 2
# Our allocator
class Allocator:
def __init__(self):
self.allocated = {}
def malloc(self, size):
mem = ctypes.create_string_buffer(size)
ptr = ctypes.addressof(mem)
self.allocated[ptr] = (size, mem)
return ptr
def free(self, ptr):
if ptr not in self.allocated:
return False
else:
del self.allocated[ptr]
return True
class HostAllocator(Allocator):
def is_pinned_ptr(self, ptr):
return ptr in self.allocated or any(
ptr_ <= ptr and ptr < ptr_ + size
for ptr_, (size, _) in self.allocated.items()
)
class DeviceAllocator(Allocator):
def tensor_from_meta(self, meta):
def create_tensor_from_data_ptr(ptr, size):
storage = torch._C._construct_storage_from_data_pointer(
ptr, torch.device("cpu"), size
)
return torch.Tensor(storage)
found_base = None
# Usual case, we're receiving a known Tensor
if meta.data_ptr in self.allocated:
found_base = create_tensor_from_data_ptr(
meta.data_ptr, self.allocated[meta.data_ptr][0]
)
# Might be a rewrap of another storage at a different offset
# Slow path to try and find the corresponding storage
if found_base is None:
for tag, (size, _) in self.allocated.items():
# t is always a 1D uint8 storage!
if meta.data_ptr > tag and meta.data_ptr < tag + size:
# Blame @ngimel for this
slice_size = size - (meta.data_ptr - tag)
found_base = create_tensor_from_data_ptr(meta.data_ptr, slice_size)
# Might be an empty tensor
if found_base is None and meta.nelem_in_bytes == 0:
found_base = torch.tensor((), dtype=torch.uint8)
# This pointer is not allocated here, segfault !
if found_base is None:
log.info("Currently allocated blocks:\n %s", safe_str(self.allocated))
log.info("Trying to access %s", meta)
raise RuntimeError("SEGFAULT!")
# Raw 1d uint8 data
raw = found_base
# Reinterpret cast in the right dtype
as_dtype = raw.view(dtype=meta.dtype)
# View to the right shape/stride/offset
view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset)
return view
def register(registry):
def func(fn):
registry[fn.__name__] = fn
return fn
return func
class Driver:
def __init__(self, num_devices):
super().__init__()
self.num_devices = num_devices
self.is_initialized = False
# State of our driver
self.curr_device_idx = 0
self.curr_streams = {}
# Allocated memory belongs to which device
self.memory_belong = {}
self.host_allocator = HostAllocator()
self.event_belong = {}
self.rlock = threading.RLock()
def _lazy_init(self):
if self.is_initialized:
return
self.devices = []
for i in range(self.num_devices):
req_queue = mp_context.Queue()
ans_queue = mp_context.Queue()
runner = mp_context.Process(
target=_Executor(i).run_forever,
args=(req_queue, ans_queue),
daemon=True,
)
runner.start()
self.devices.append((req_queue, ans_queue, runner))
self.is_initialized = True
def exec(self, cmd, *args):
with self.rlock:
log.info("Main process launched: %s(*%s)", cmd, safe_str(args))
if cmd in Driver.registry:
res = Driver.registry[cmd](self, *args)
else:
res = self.run_on_executor(self.curr_device_idx, cmd, *args)
log.info("Main process result for %s received: %s", cmd, safe_str(res))
if res == "ERROR":
raise RuntimeError(f"Error in daemon while executing {cmd}, see logs")
else:
return res
def run_on_executor(self, device_idx, cmd, *args):
self._lazy_init()
req_queue, ans_queue, _ = self.devices[device_idx]
stream = self.getStream(device_idx)
validate_send_queue_args(cmd, args)
req_queue.put((stream, cmd) + args)
return ans_queue.get()
registry = {}
@register(registry)
def hasPrimaryContext(self, device_idx):
return device_idx >= 0 and device_idx < self.num_devices
@register(registry)
def deviceCount(self, *args):
assert len(args) == 0
return self.num_devices
@register(registry)
def getDevice(self):
return self.curr_device_idx
@register(registry)
def setDevice(self, device_idx):
assert device_idx >= 0 and device_idx < self.num_devices
self.curr_device_idx = device_idx
@register(registry)
def uncheckedSetDevice(self, *args):
assert len(args) == 1
self.curr_device_idx = int(args[0])
@register(registry)
def exchangeDevice(self, *args):
assert len(args) == 1
res = self.curr_device_idx
self.curr_device_idx = int(args[0])
return res
@register(registry)
def malloc(self, size):
ptr = self.run_on_executor(self.curr_device_idx, "malloc", size)
self.memory_belong[ptr] = self.curr_device_idx
return ptr
@register(registry)
def free(self, ptr):
device_idx = self.memory_belong.pop(ptr, None)
if device_idx is None:
return False
return self.run_on_executor(device_idx, "free", ptr)
@register(registry)
def isPinnedPtr(self, ptr):
return self.host_allocator.is_pinned_ptr(ptr)
@register(registry)
def hostMalloc(self, size):
return self.host_allocator.malloc(size)
@register(registry)
def hostFree(self, ptr):
return self.host_allocator.free(ptr)
@register(registry)
def getNewStream(self, device_idx, priority):
return self.run_on_executor(device_idx, "getNewStream", priority)
@register(registry)
def queryStream(self, stream):
return self.run_on_executor(
stream.device_index, "queryStream", stream.stream_id
)
@register(registry)
def getStream(self, device_idx):
return self.curr_streams.get(device_idx, 0)
@register(registry)
def exchangeStream(self, stream):
stream_id = self.curr_streams.get(stream.device_index, 0)
self.curr_streams[stream.device_index] = stream.stream_id
return stream_id
@register(registry)
def synchronizeStream(self, stream):
self.run_on_executor(stream.device_index, "synchronizeStream", stream.stream_id)
@register(registry)
def record(self, event, stream, device_index, flags):
event_ptr = ctypes.cast(event, ctypes.POINTER(ctypes.c_int64))
# Create event if needed
if event_ptr.contents.value == 0:
event_ptr.contents.value = self.run_on_executor(
stream.device_index, "eventCreateWithFlags", flags
)
self.event_belong[event_ptr.contents.value] = stream.device_index
# Record event
self.run_on_executor(
stream.device_index,
"eventRecord",
event_ptr.contents.value,
stream.stream_id,
)
@register(registry)
def destroyEvent(self, event, device_index):
self.run_on_executor(device_index, "eventDestroy", event)
self.event_belong.pop(event)
@register(registry)
def synchronizeEvent(self, event):
self.run_on_executor(self.event_belong[event], "eventSynchronize", event)
@register(registry)
def queryEvent(self, event):
return self.run_on_executor(self.event_belong[event], "eventQuery", event)
@register(registry)
def elapsedTime(self, e1, e2, device_index):
return self.run_on_executor(device_index, "eventElapsedTime", e1, e2)
@register(registry)
def block(self, event, stream):
self.run_on_executor(stream.device_index, "block", event, stream.stream_id)
class _Executor:
def __init__(self, id):
self.id = id
self.allocator = DeviceAllocator()
self.stream = 0
self.event_incr_id = 0
self.events = {}
def run_forever(self, req_queue, ans_queue):
# Serve all requests
while True:
# Ignore stream since cpu backend doesn't support asynchronous execution
_, cmd, *args = req_queue.get()
log.info("Worker executing: %s", cmd)
if cmd in _Executor.registry:
res = _Executor.registry[cmd](self, *args)
else:
log.warning("Bad command in worker")
res = "ERROR"
log.info("Worker answering to: %s", cmd)
ans_queue.put(res)
registry = {}
@register(registry)
def malloc(self, size):
return self.allocator.malloc(size)
@register(registry)
def free(self, ptr):
return self.allocator.free(ptr)
def _run_op(self, op_name, args, kwargs):
op, _ = torch._C._jit_get_operation(op_name)
args, kwargs = receive_after_sending(self.allocator, args, kwargs)
return op(*args, **kwargs)
@register(registry)
def run_op(self, op_name, args, kwargs):
self._run_op(op_name, args, kwargs)
@register(registry)
def get_op_output_shape(self, op_name, args, kwargs):
return self._run_op(op_name, args, kwargs).size()
@register(registry)
def send_data(self, *args):
assert len(args) == 1
return OpenRegTensorData.from_meta(self.allocator, args[0])
@register(registry)
def recv_data(self, host_tensor, dev_mem):
dev_tensor = OpenRegTensorData.from_meta(self.allocator, dev_mem)
dev_tensor.copy_(host_tensor)
@register(registry)
def getNewStream(self, priority):
self.stream += 1
return self.stream
@register(registry)
def queryStream(self, stream):
return True
@register(registry)
def synchronizeStream(self, stream):
# no-op
pass
@register(registry)
def eventCreateWithFlags(self, flags):
self.event_incr_id += 1
self.events[self.event_incr_id] = [flags, None]
return self.event_incr_id
@register(registry)
def eventRecord(self, event, stream):
# Only flags == 1 enables timing
if self.events[event][0] == 1:
self.events[event][1] = time.time() * 1000
return 0
@register(registry)
def eventDestroy(self, event):
self.events.pop(event)
@register(registry)
def eventSynchronize(self, event):
assert self.events.get(event) is not None
return 0
@register(registry)
def eventQuery(self, event):
assert self.events.get(event) is not None
return True
@register(registry)
def eventElapsedTime(self, e1, e2):
time_1 = self.events[e1][1]
time_2 = self.events[e2][1]
assert time_1 is not None and time_2 is not None
return time_2 - time_1
@register(registry)
def block(self, event, stream):
# no-op
pass
driver = Driver(NUM_DEVICES)

View File

@ -1,103 +0,0 @@
import pprint
import torch
from torch.utils._pytree import tree_map, tree_map_only
class OpenRegTensorMeta:
def __init__(self, tensor, checked=True):
if checked and not tensor.device.type == "openreg":
raise RuntimeError(
"Creating OpenRegTensorMeta is only for Tensors on openreg device"
)
self.data_ptr = tensor.untyped_storage().data_ptr()
self.size = tensor.size()
self.stride = tensor.stride()
self.storage_offset = tensor.storage_offset()
self.dtype = tensor.dtype
self.nelem_in_bytes = tensor.nelement() * tensor.element_size()
def __repr__(self):
return (
f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, "
f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})"
)
class OpenRegTensorData(torch.Tensor):
@staticmethod
def from_meta(allocator, tensor_meta):
return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta))
VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float}
VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str}
def safe_str(args):
def convert(obj):
if isinstance(obj, torch.Tensor):
return str(OpenRegTensorMeta(obj, checked=False))
else:
return obj
new_args = tree_map(convert, args)
return pprint.pformat(new_args)
def validate_send_queue_args(cmd, args):
def check(obj):
if type(obj) not in VALID_QUEUE_TYPES_OUT:
if (
cmd == "recv_data"
and type(obj) in [torch.Tensor, OpenRegTensorData]
and obj.device.type == "cpu"
):
# Only HtoD copy command can send cpu Tensors over
return
raise RuntimeError(
f"Trying to send invalid object through queue: {type(obj)}"
)
tree_map(check, args)
def prepare_for_sending(args, kwargs):
def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_IN:
raise RuntimeError(
f"Cannot send object of type {type(obj)} over openreg device pipe."
)
if isinstance(obj, torch.Tensor):
return OpenRegTensorMeta(obj)
else:
return obj
return tree_map(convert, (args, kwargs))
def receive_after_sending(allocator, args, kwargs):
def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_OUT:
raise RuntimeError(
f"Received invalid object of type {type(obj)} over openreg device pipe."
)
if isinstance(obj, OpenRegTensorMeta):
return allocator.tensor_from_meta(obj)
else:
return obj
return tree_map(convert, (args, kwargs))
def to_device_no_copy(device, args, kwargs):
def safe_to(t):
if device == "meta":
return t.to(device=device)
else:
return torch.empty_like(t, device=device)
return tree_map_only(torch.Tensor, safe_to, (args, kwargs))

View File

@ -1,51 +0,0 @@
#include "OpenReg.h"
#include <ATen/Context.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_numbers.h>
static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
THPUtils_checkLong(arg),
"_get_default_generator expects an int, but got ",
THPUtils_typename(arg));
auto idx = static_cast<int>(THPUtils_unpackLong(arg));
return THPGenerator_initDefaultGenerator(
at::globalContext().defaultGenerator(
c10::Device(c10::DeviceType::PrivateUse1, idx)));
END_HANDLE_TH_ERRORS
}
static PyMethodDef methods[] = {
{"_init", _initExtension, METH_NOARGS, nullptr},
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
{nullptr, nullptr, 0, nullptr}
};
static struct PyModuleDef openreg_C_module =
{PyModuleDef_HEAD_INIT, "pytorch_openreg._C", nullptr, -1, methods};
PyMODINIT_FUNC PyInit__C(void) {
PyObject* mod = PyModule_Create(&openreg_C_module);
py::object openreg_mod = py::module_::import("pytorch_openreg");
// Only borrowed from the python side!
openreg::set_impl_factory(openreg_mod.attr("impl_factory").ptr());
return mod;
}

View File

@ -1,50 +0,0 @@
#pragma once
#include <torch/csrc/utils/pybind.h>
namespace openreg {
using openreg_ptr_t = uint64_t;
void set_impl_factory(PyObject* factory);
py::function get_method(const char* name);
static constexpr char kFreeMethod[] = "free";
static constexpr char kHostFreeMethod[] = "hostFree";
template <const char* name>
static void ReportAndDelete(void* ptr) {
if (!ptr || !Py_IsInitialized()) {
return;
}
py::gil_scoped_acquire acquire;
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
// Always stash, this will be a no-op if there is no error
PyErr_Fetch(&type, &value, &traceback);
TORCH_CHECK(
get_method(name)(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
"Failed to free memory pointer at ",
ptr);
// If that user code raised an error, just print it without raising it
if (PyErr_Occurred()) {
PyErr_Print();
}
// Restore the original error
PyErr_Restore(type, value, traceback);
}
#define REGISTER_PRIVATEUSE1_SERIALIZATION( \
FOR_SERIALIZATION, FOR_DESERIALIZATION) \
static int register_serialization() { \
torch::jit::TensorBackendMetaRegistry( \
c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \
return 0; \
} \
static const int _temp = register_serialization();
} // namespace openreg

View File

@ -1,350 +0,0 @@
#include "OpenReg.h"
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
namespace openreg {
namespace {
// Python factory function where real implementations can be found
PyObject* py_factory;
struct HostAllocator final : at::Allocator {
HostAllocator() = default;
at::DataPtr allocate(size_t nbytes) override {
py::gil_scoped_acquire acquire;
void* data = nullptr;
if (nbytes > 0) {
data = reinterpret_cast<void*>(
get_method("hostMalloc")(nbytes).cast<openreg_ptr_t>());
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host.");
}
return {data, data, &ReportAndDelete<kHostFreeMethod>, at::Device(at::kCPU)};
}
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete<kHostFreeMethod>;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
py::gil_scoped_acquire acquire;
get_method("hostCopyData")(
reinterpret_cast<openreg_ptr_t>(dest),
reinterpret_cast<openreg_ptr_t>(src),
count);
}
};
static HostAllocator global_host_alloc;
static c10::DeviceIndex device_count() {
py::gil_scoped_acquire acquire;
return get_method("deviceCount")().cast<c10::DeviceIndex>();
}
static c10::DeviceIndex current_device_idx() {
py::gil_scoped_acquire acquire;
return get_method("getDevice")().cast<c10::DeviceIndex>();
}
class OpenRegGeneratorImpl : public at::CPUGeneratorImpl {
public:
OpenRegGeneratorImpl(c10::DeviceIndex device_index) {
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
}
~OpenRegGeneratorImpl() override = default;
};
static at::Generator make_openreg_generator(c10::DeviceIndex device_index) {
return at::make_generator<OpenRegGeneratorImpl>(device_index);
}
// Default, global generators, one per device.
static std::vector<at::Generator> default_generators;
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
OpenRegHooksInterface() {};
~OpenRegHooksInterface() override = default;
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
py::gil_scoped_acquire acquire;
return get_method("hasPrimaryContext")(device_index).cast<bool>();
}
at::Allocator* getPinnedMemoryAllocator() const override {
return &global_host_alloc;
}
bool isPinnedPtr(const void* data) const override {
py::gil_scoped_acquire acquire;
return get_method("isPinnedPtr")(reinterpret_cast<openreg_ptr_t>(data))
.cast<bool>();
}
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
static bool flag [[maybe_unused]] = []() {
auto deivce_nums = device_count();
default_generators.resize(deivce_nums);
for (auto i = 0; i < deivce_nums; i++) {
default_generators[i] = make_openreg_generator(i);
default_generators[i].seed();
}
return true;
}();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = current_device_idx();
} else {
TORCH_CHECK(idx >= 0 && idx < device_count());
}
return default_generators[idx];
}
at::Generator getNewGenerator(c10::DeviceIndex device_index) const override {
return make_openreg_generator(device_index);
}
};
static bool register_hook_flag [[maybe_unused]] = []() {
at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface());
return true;
}();
// Device guard registration
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1;
OpenRegGuardImpl() = default;
explicit OpenRegGuardImpl(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == static_type);
}
/**
* Return the type of device managed by this guard implementation.
*/
c10::DeviceType type() const override {
return static_type;
}
/**
* Set the current device to Device, and return the previous c10::Device.
*/
c10::Device exchangeDevice(c10::Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_privateuseone());
py::gil_scoped_acquire acquire;
auto old_device_index =
get_method("exchangeDevice")(d.index()).cast<c10::DeviceIndex>();
return c10::Device(static_type, old_device_index);
}
/**
* Get the current device.
*/
c10::Device getDevice() const override {
return c10::Device(static_type, current_device_idx());
}
/**
* Set the current device to c10::Device.
*/
void setDevice(c10::Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_privateuseone());
py::gil_scoped_acquire acquire;
auto device = get_method("setDevice")(d.index());
}
/**
* Set the current device to c10::Device, without checking for errors
* (so, e.g., this can be called from a destructor).
*/
void uncheckedSetDevice(c10::Device d) const noexcept override {
py::gil_scoped_acquire acquire;
auto device = get_method("uncheckedSetDevice")(d.index());
}
/**
* Get the current stream for a given device.
*/
c10::Stream getStream(c10::Device d) const noexcept override {
py::gil_scoped_acquire acquire;
auto stream_id = get_method("getStream")(d.index()).cast<c10::StreamId>();
return c10::Stream(c10::Stream::UNSAFE, d, stream_id);
}
/**
* Get the default stream for a given device.
*/
c10::Stream getDefaultStream(c10::Device d) const override {
py::gil_scoped_acquire acquire;
return get_method("getDefaultStream")(d.index()).cast<c10::Stream>();
}
/**
* Get a stream from the global pool for a given device.
*/
c10::Stream getStreamFromGlobalPool(
c10::Device d,
bool isHighPriority = false) const override {
py::gil_scoped_acquire acquire;
return get_method("getStreamFromGlobalPool")(d.index(), isHighPriority)
.cast<c10::Stream>();
}
/**
* Return a new stream for a given device and priority. The stream will be
* copied and shared around, device backend should be able to correctly handle
* the lifetime of the stream.
*/
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
py::gil_scoped_acquire acquire;
auto stream_id =
get_method("getNewStream")(d.index(), priority).cast<c10::StreamId>();
return c10::Stream(c10::Stream::UNSAFE, d, stream_id);
}
/**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
*/
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
py::gil_scoped_acquire acquire;
auto stream_id = get_method("exchangeStream")(s).cast<c10::StreamId>();
return c10::Stream(c10::Stream::UNSAFE, s.device(), stream_id);
}
/**
* Destroys the given event.
*/
void destroyEvent(void* event, const c10::DeviceIndex device_index)
const noexcept override {
py::gil_scoped_acquire acquire;
get_method("destroyEvent")((int64_t)event, device_index);
}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
void record(
void** event,
const c10::Stream& stream,
const c10::DeviceIndex device_index,
const c10::EventFlag flag) const override {
py::gil_scoped_acquire acquire;
get_method("record")((int64_t)event, stream, device_index, (int64_t)flag);
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
void block(void* event, const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
get_method("block")((int64_t)event, stream);
}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
bool queryEvent(void* event) const override {
py::gil_scoped_acquire acquire;
return get_method("queryEvent")((int64_t)event).cast<bool>();
}
/**
* Get the number of devices. WARNING: This is REQUIRED to not raise
* an exception. If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
c10::DeviceIndex deviceCount() const noexcept override {
return device_count();
}
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
bool queryStream(const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
return get_method("queryStream")(stream).cast<bool>();
}
/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
virtual void synchronizeStream(const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
get_method("synchronizeStream")(stream);
}
/**
* Wait (by blocking the calling thread) until all the work previously
* recorded on the event has completed running on the device.
*/
void synchronizeEvent(void* event) const override {
py::gil_scoped_acquire acquire;
get_method("synchronizeEvent")((int64_t)event);
}
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
*/
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
get_method("recordDataPtrOnStream")(data_ptr, stream);
}
/**
* Fetch the elapsed time between two recorded events.
*/
double elapsedTime(
void* event1,
void* event2,
const c10::DeviceIndex device_index) const override {
py::gil_scoped_acquire acquire;
return get_method("elapsedTime")(
(int64_t)event1, (int64_t)event2, device_index)
.cast<double>();
}
};
// Register our device guard
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
} // namespace
// Setter for the python dictionary with implementations
void set_impl_factory(PyObject* factory) {
py_factory = factory;
}
py::function get_method(const char* name) {
auto factory = py::cast<py::function>(py_factory);
return factory(name);
}
} // namespace openreg

View File

@ -1,418 +0,0 @@
#include "OpenReg.h"
#include <ATen/EmptyTensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/ops/as_strided_cpu_dispatch.h>
#include <ATen/ops/quantize_per_tensor_native.h>
#include <ATen/ops/resize_native.h>
#include <ATen/ops/set_cpu_dispatch.h>
#include <ATen/ops/set_native.h>
#include <c10/core/Allocator.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/library.h>
namespace openreg {
namespace {
struct OpenRegAllocator final : at::Allocator {
OpenRegAllocator() = default;
at::DataPtr allocate(size_t nbytes) override {
py::gil_scoped_acquire acquire;
auto curr_device_idx = get_method("getDevice")().cast<c10::DeviceIndex>();
auto curr_device =
c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx);
void* data = nullptr;
if (nbytes > 0) {
data = reinterpret_cast<void*>(
get_method("malloc")(nbytes).cast<openreg_ptr_t>());
TORCH_CHECK(
data, "Failed to allocator ", nbytes, " bytes on openreg device.");
}
return {data, data, &ReportAndDelete<kFreeMethod>, curr_device};
}
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete<kFreeMethod>;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
py::gil_scoped_acquire acquire;
get_method("copy_data")(
reinterpret_cast<openreg_ptr_t>(dest),
reinterpret_cast<openreg_ptr_t>(src),
count);
}
};
static OpenRegAllocator global_openreg_alloc;
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc);
// Empty op needs C++ code and cannot be handled by python side fallback
at::Tensor empty_openreg(
c10::IntArrayRef size,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
const auto device = c10::device_or_default(device_opt);
const auto dtype = c10::dtype_or_default(dtype_opt);
TORCH_CHECK(device.is_privateuseone());
TORCH_CHECK(
c10::layout_or_default(layout_opt) == c10::Layout::Strided,
"Non strided layout not supported");
TORCH_CHECK(
!c10::pinned_memory_or_default(pin_memory_opt),
"Pin memory can only be on CPU");
const c10::DeviceGuard device_guard(device);
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(
size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt);
}
at::Tensor empty_strided_openreg(
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt) {
const auto device = c10::device_or_default(device_opt);
const auto dtype = c10::dtype_or_default(dtype_opt);
TORCH_CHECK(device.is_privateuseone());
TORCH_CHECK(
c10::layout_or_default(layout_opt) == c10::Layout::Strided,
"Non strided layout not supported");
TORCH_CHECK(
!c10::pinned_memory_or_default(pin_memory_opt),
"Pin memory can only be on CPU");
const c10::DeviceGuard device_guard(device);
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_strided_generic(
size, stride, &global_openreg_alloc, pu1_dks, dtype);
}
at::Tensor as_strided_openreg(
const at::Tensor& self,
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<int64_t> storage_offset_) {
// Metadata-only change so we re-use the cpu impl
return at::cpu::as_strided(self, size, stride, storage_offset_);
}
const at::Tensor& resize__openreg(
const at::Tensor& self,
c10::SymIntArrayRef size,
::std::optional<at::MemoryFormat> memory_format) {
return at::native::resize_(
self, C10_AS_INTARRAYREF_SLOW(size), memory_format);
}
at::Tensor& set_source_Storage_storage_offsetset_openreg(
at::Tensor& result,
at::Storage storage,
int64_t storage_offset,
c10::IntArrayRef size,
c10::IntArrayRef stride) {
return at::cpu::set_(result, storage, storage_offset, size, stride);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable(
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const std::optional<at::Tensor> & attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_q = query.size(2);
const int64_t max_seqlen_kv = key.size(2);
auto opts = query.options();
auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
opts.dtype(at::kFloat));
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor & grad_out,
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const at::Tensor & attn_bias,
std::array<bool,4> grad_input_mask,
const at::Tensor & out,
const at::Tensor & logsumexp,
const at::Tensor & cum_seq_q,
const at::Tensor & cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor & philox_seed,
const at::Tensor & philox_offset,
std::optional<double> scale) {
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
at::empty_like(query),
at::empty_like(key),
at::empty_like(value),
at::empty_like(attn_bias));
}
}
// Using the simplest way to obtain continuous Tensor data and process it.
// This is a demo for using operand API, and you can add more complex logic
// for input and output tensor based on your custom device kernel.
void abs_kernel(at::TensorIteratorBase& iter) {
// Abs only have a input tensor and a output tensor.
auto& output_operand = iter.operand(0);
auto& input_operand = iter.operand(1);
auto& output_tensor_base = output_operand.tensor_base();
auto& input_tensor_base = input_operand.tensor_base();
TORCH_CHECK(!input_operand.original_tensor_base().defined(),
"input original tensor is defined.");
TORCH_CHECK(!output_operand.original_tensor_base().defined(),
"output original tensor is defined.");
// For easy test, only accept contiguous input tensor for calculate.
auto memory_format = input_tensor_base.suggest_memory_format();
TORCH_CHECK(input_tensor_base.is_contiguous(memory_format),
"Input tensor need be contiguous.");
// Add necessary restrictions to ensure the security of the demo.
TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(),
"Intput and output tensor size are not equal.");
// Common dtype is calculate in TensorIteratorBase.
TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float,
"Only support float type.")
// Using for loop for abs calculate.
auto abs_function = [](float* output_ptr, const float* input_ptr,
const int64_t NUM) {
for (int64_t i = 0; i < NUM; ++i) {
*(output_ptr + i) = std::abs(*(input_ptr + i));
}
};
// To simplify the logic of the test demo code,
// we only use contiguous tensor to calculate on device side.
// And using input tensor memory format.
if (iter.is_contiguous()) {
// Add for will_resize flag check. You can convert to differernt
// tensor memory format when will_resize is True.
// If TensorIteratorConfig resize_outputs_ flag is true, and there are two
// situations:
// 1) Out tensor is undefined, and TensorIterator set will_resize to true;
// 2) Out tensor is defined and tensor size is not equal to input tensor size;
// TensorIterator set will_resize to true, and call set_output_raw_strided
// to resize output tensor.
// When output operand will_resize flag is ture, dummy
// device can convert tensor to dummy device preferred memory format.
// Here we don't convert tensor memory format, because it will become complex
// when dummy device want keep same memory format for training network.
TORCH_CHECK(output_operand.will_resize,
"output operand will_resize flag need be True.");
abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel());
} else {
// Stride copy is not support for foo device, using cpu device instead.
// For abs op, the last situation is: output tensor is not contiguous with
// operand will_resize is False.
TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True.");
// Get a contiguous tensor with input memory format.
at::Tensor output = at::empty(output_tensor_base.sizes(),
input_tensor_base.options()
.memory_format(memory_format));
// For structured op which inheried from TensorIteratorBase, maybe you need to
// call set_output_raw_strided function to update output stored in op sturctured.
// abs op is no need to do this.
output_operand.exchange_tensor(c10::MaybeOwned<at::TensorBase>::owned(std::in_place, output));
abs_function((float*)output_operand.tensor_base().mutable_data_ptr(),
(float*)iter.data_ptr(1), iter.numel());
// Copy tensor base to original tensor base, and keep same scalar type and
// stride with cpu and gpu.
if (output_operand.original_tensor_base().defined() &&
!output_operand.original_tensor_base().is_same(output_operand.tensor_base())) {
output_operand.original_tensor().copy_(output_operand.tensor());
output_operand.restore_original_tensor();
}
}
}
int64_t _fused_sdp_choice_privateuse1(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}
void quantize_tensor_per_tensor_affine_privateuse1(
const at::Tensor& rtensor,
at::Tensor& qtensor,
double scale,
int64_t zero_point) {
// Just test the process, so do nothing
}
struct CustomAutogradFnReturnsSelf
: public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
static at::Tensor forward(
torch::autograd::AutogradContext* ctx,
at::Tensor self) {
return self;
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
struct CustomAutogradFnAliasing
: public torch::autograd::Function<CustomAutogradFnAliasing> {
static at::Tensor forward(
torch::autograd::AutogradContext* ctx,
at::Tensor self) {
return self.view_symint(self.sym_sizes());
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
return CustomAutogradFnReturnsSelf::apply(x);
}
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
return CustomAutogradFnAliasing::apply(x);
}
/* Notes:
*
* OpenReg is currently designed to simulate device memory through multiple
* subprocesses on purpose to ensure we don't mistakenly poke at the "device's
* memory" from the main process. And be able to simulate the same thing that
* happens with other accelerators: any metadata-only change is cpu-only
* (main process), any data change must go through to the device (other process)
* and any data transfer between the two is expensive (serializing the whole
* Tensor).
*
* Currently, for the efficiency of IPC, most operations are to pass the Tensor
* metadata, and only a small number of operations involving copy will serialize
* and pass the Tensor body by custom pickler provided by torch.multiprocess.
*
* Therefore, in principle, only operations related to Metadata modification can
* be directly implemented at the C++ level and registered in PrivateUse1; but
* if memory access is involved, the relevant operations must be implemented at
* the Python level, otherwise invalid memory access will result.
*/
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("empty.memory_format", empty_openreg);
m.impl("empty_strided", empty_strided_openreg);
m.impl("as_strided", as_strided_openreg);
m.impl("resize_", resize__openreg);
m.impl("set_.source_Storage", at::native::set_);
m.impl("set_.source_Storage_storage_offset", set_source_Storage_storage_offsetset_openreg);
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
}
struct OpenRegBackendMeta : public c10::BackendMeta {
OpenRegBackendMeta(int version_number, int format_number)
: version_number_(version_number), format_number_(format_number) {}
int version_number_{-1};
int format_number_{-1};
};
void for_serialization(
const at::Tensor& t,
std::unordered_map<std::string, bool>& m) {
auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta();
if (meta_ptr != nullptr) {
auto o_meta_ptr = dynamic_cast<OpenRegBackendMeta*>(meta_ptr);
if (o_meta_ptr->version_number_ == 1) {
m["version_number"] = true;
}
if (o_meta_ptr->format_number_ == 29) {
m["format_number"] = true;
}
}
}
void for_deserialization(
const at::Tensor& t,
std::unordered_map<std::string, bool>& m) {
int version_number{-1};
int format_number{-1};
if (m.find("version_number") != m.end()) {
version_number = 1;
}
if (m.find("format_number") != m.end()) {
format_number = 29;
}
c10::intrusive_ptr<c10::BackendMeta> meta{std::unique_ptr<c10::BackendMeta>(
new OpenRegBackendMeta(version_number, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(meta);
}
REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization)
} // namespace openreg
namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel);
REGISTER_PRIVATEUSE1_DISPATCH(
quantize_tensor_per_tensor_affine_stub,
&openreg::quantize_tensor_per_tensor_affine_privateuse1);
REGISTER_PRIVATEUSE1_DISPATCH(
_fused_sdp_choice_stub,
&openreg::_fused_sdp_choice_privateuse1);
} // namespace at::native
TORCH_LIBRARY(openreg, m) {
m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor");
m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
}
TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) {
m.impl("custom_autograd_fn_aliasing", &openreg::custom_autograd_fn_aliasing);
m.impl(
"custom_autograd_fn_returns_self",
&openreg::custom_autograd_fn_returns_self);
}

View File

@ -1,78 +0,0 @@
import distutils.command.clean
import os
import platform
import shutil
import sys
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
PACKAGE_NAME = "pytorch_openreg"
version = 1.0
ROOT_DIR = Path(__file__).absolute().parent
CSRS_DIR = ROOT_DIR / "pytorch_openreg/csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove pytorch_openreg extension
for path in (ROOT_DIR / "pytorch_openreg").glob("**/*.so"):
path.unlink()
# Remove build directory
build_dirs = [
ROOT_DIR / "build",
]
for path in build_dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
if __name__ == "__main__":
if sys.platform == "win32":
vc_version = os.getenv("VCToolsVersion", "")
if vc_version.startswith("14.16."):
CXX_FLAGS = ["/sdl"]
else:
CXX_FLAGS = ["/sdl", "/permissive-"]
elif platform.machine() == "s390x":
# no -Werror on s390x due to newer compiler
CXX_FLAGS = {"cxx": ["-g", "-Wall"]}
else:
CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]}
sources = list(CSRS_DIR.glob("*.cpp"))
# Note that we always compile with debug info
ext_modules = [
CppExtension(
name="pytorch_openreg._C",
sources=sorted(str(s) for s in sources),
include_dirs=[CSRS_DIR],
extra_compile_args=CXX_FLAGS,
)
]
setup(
name=PACKAGE_NAME,
version=version,
author="PyTorch Core Team",
description="Example for PyTorch out of tree registration",
packages=find_packages(exclude=("test",)),
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=ext_modules,
python_requires=">=3.8",
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
)

View File

@ -0,0 +1,38 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(TORCH_OPENREG CXX C)
include(GNUInstallDirs)
include(CheckCXXCompilerFlag)
include(CMakeDependentOption)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/")
set(LINUX TRUE)
set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_INSTALL_LIBDIR lib)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
find_package(Torch REQUIRED)
include_directories(${PYTORCH_INSTALL_DIR}/include)
if(DEFINED PYTHON_INCLUDE_DIR)
include_directories(${PYTHON_INCLUDE_DIR})
else()
message(FATAL_ERROR "Cannot find Python directory")
endif()
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)
add_subdirectory(${PROJECT_SOURCE_DIR}/csrc)
add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc)

View File

@ -0,0 +1,177 @@
# PyTorch OpenReg
## Background
The third-party device integration mechanism based on PrivateUse1 has become the official mainstream method for new backends to integrate with PyTorch. Ensuring the availability of this mechanism is crucial for enriching PyTorch's hardware ecosystem.
**Note:**
The goal of `torch_openreg` is **not to implement a fully functional, high-performance PyTorch backend**, but to serve as a **minimalist reference implementation for mechanism verification**.
### Purpose
- **Test Backend**: To serve as an in-tree test backend for PrivateUse1, ensuring quality stability through CI/CD.
- **Integration Example**: To serve as a reference example for new backend integration.
- **Integration Documentation**: To provide module-level integration documentation that corresponds with the code.
### Design Principles
- **Minimality Principle**: The fundamental goal is to enable/verify all integration paths/mechanisms for a new backend to integrate to PyTorch. All functions follow a "just right" strategy to ensure the correctness of relevant integration capabilities.
- **Authenticity Principle**: To complete the OpenReg integration in the same way a real accelerator backend would integrate with PyTorch.
## Directory Structure
```shell
torch_openreg/
├── CMakeLists.txt
├── csrc
│ ├── aten
│ │ ├── native
│ │ │ ├── Extra.cpp
│ │ │ ├── Minimal.cpp
│ │ │ └── ...
│ │ ├── OpenRegExtra.cpp
│ │ └── OpenRegMinimal.cpp
│ ├── CMakeLists.txt
│ └── runtime
│ ├── OpenRegDeviceAllocator.cpp
│ ├── OpenRegDeviceAllocator.h
│ ├── OpenRegFunctions.cpp
│ ├── OpenRegFunctions.h
│ ├── OpenRegGenerator.cpp
│ ├── OpenRegGenerator.h
│ ├── OpenRegGuard.cpp
│ ├── OpenRegGuard.h
│ ├── OpenRegHooks.cpp
│ ├── OpenRegHooks.h
│ ├── OpenRegHostAllocator.cpp
│ ├── OpenRegHostAllocator.h
│ └── ...
├── README.md
├── setup.py
├── third_party
│ └── openreg
└── torch_openreg
├── csrc
│ ├── CMakeLists.txt
│ ├── Module.cpp
│ └── stub.c
├── __init__.py
└── openreg
├── __init__.py
└── random.py
```
**Dependencies**:
```mermaid
graph LR
A[Python]
B[_C.so]
C[libtorch_bindings.so]
D[libtorch_openreg.so]
E[libopenreg.so]
A --> B --> C --> D --> E
```
- `_C.so`: torch\_openreg/csrc/stub.c
- `libtorch_bindings.so`: torch\_openreg/csrc/\*.cpp
- `libtorch_openreg.so`: csrc
- `libopenreg.so`: third\_party/openreg
**Key Directories**:
- `csrc/`: Core device implementation, including operator registration, runtime, etc.
- `csrc/aten/`: Operator registration
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
- `csrc/aten/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
- `csrc/aten/OpenRegExtra.cpp`: Implementations for other types of operators.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
- `torch_openreg/csrc/`: Python C++ binding code.
- `torch_openreg/openreg/`: Python API.
## Currently Implemented Features
### Operator Registration
- Operator Implementation
- `TORCH_LIBRARY` form
- Registering a specific operator for an existing schema: See `empty.memory_format`
- Registering an operator with a custom schema
- Extending an existing namespace: (TODO)
- Custom namespace: See `custom_autograd_fn_returns_self`
- Autograd: See `custom_autograd_fn_returns_self`
- STUB form: See `abs_stub`
- Fallback
- Global Fallback: See `wrapper_cpu_fallback`
- Per-operator Fallback: (TODO)
- AMP (TODO)
### Memory Management
- Device Memory Management (TODO)
- Host Memory Management (TODO)
### Custom Storage
- Adding custom device descriptions (TODO)
- Serialization support (TODO)
### Autoload
- (TODO)
...
## Installation and Usage
### Installation
```python
pip3 install -r requirements.txt
python setup.py develop/install
```
### Usage Example
After installation, you can use the `openreg` device in Python just like any other regular device.
```python
import torch
import torch_openreg
if not torch.openreg.is_available():
print("OpenReg backend is not available in this build.")
exit()
print("OpenReg backend is available!")
device = torch.device("openreg")
try:
x = torch.tensor([[1., 2.], [3., 4.]], device=device)
y = x + 2
print("Result y:\n", y)
print(f"Device of y: {y.device}")
z = y.cpu()
print("Result z:\n", z)
print(f"Device of z: {z.device}")
except Exception as e:
print(f"\nAn error occurred: {e}")
```
## Future Plans
- **Enhance Features**: AMP, memory management, generators, distributed computing, etc. (to reiterate, the fundamental goal is to verify the integration mechanism).
- **Improve Tests**: Add more test cases related to the integration mechanism.
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.

View File

@ -0,0 +1,12 @@
set(LIBRARY_NAME torch_openreg)
file(GLOB_RECURSE SOURCE_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
)
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu)
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

View File

@ -0,0 +1,138 @@
#include "native/Extra.h"
#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <torch/library.h>
namespace at::openreg {
at::Tensor wrapper_quantize_per_tensor(
const at::Tensor& self,
double scale,
int64_t zero_point,
at::ScalarType dtype) {
return at::native::quantize_per_tensor_openreg(
self, scale, zero_point, dtype);
}
int64_t wrapper__fused_sdp_choice(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
return at::native::_fused_sdp_choice_openreg(
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa);
}
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
wrapper__scaled_dot_product_fused_attention_overrideable(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
return at::native::_scaled_dot_product_fused_attention_overrideable_openreg(
query,
key,
value,
attn_bias,
dropout_p,
is_causal,
return_debug_mask,
scale);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
wrapper_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& attn_bias,
std::array<bool, 4> grad_input_mask,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
std::optional<double> scale) {
return at::native::
_scaled_dot_product_fused_attention_overrideable_backward_openreg(
grad_out,
query,
key,
value,
attn_bias,
grad_input_mask,
out,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale);
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor);
m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice);
m.impl(
"_scaled_dot_product_fused_attention_overrideable",
&wrapper__scaled_dot_product_fused_attention_overrideable);
m.impl(
"_scaled_dot_product_fused_attention_overrideable_backward",
&wrapper_scaled_dot_product_fused_attention_overrideable_backward);
}
} // namespace at::openreg
namespace at::openreg {
TORCH_LIBRARY(openreg, m) {
m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor");
m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
}
TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) {
m.impl(
"custom_autograd_fn_returns_self",
&at::native::custom_autograd_fn_returns_self);
m.impl(
"custom_autograd_fn_aliasing", &at::native::custom_autograd_fn_aliasing);
}
} // namespace at::openreg
namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel_openreg);
REGISTER_PRIVATEUSE1_DISPATCH(
quantize_tensor_per_tensor_affine_stub,
&quantize_tensor_per_tensor_affine_stub_openreg);
REGISTER_PRIVATEUSE1_DISPATCH(
_fused_sdp_choice_stub,
&_fused_sdp_choice_openreg);
} // namespace at::native

View File

@ -0,0 +1,128 @@
#include "native/Minimal.h"
#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <torch/library.h>
namespace at::openreg {
at::Tensor wrapper_empty_memory_format(
c10::IntArrayRef size,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
return at::native::empty_memory_format_openreg(
size,
dtype_opt,
layout_opt,
device_opt,
pin_memory_opt,
memory_format_opt);
}
at::Tensor wrapper_empty_strided(
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt) {
return at::native::empty_strided_openreg(
size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
}
at::Tensor wrapper_as_strided(
const at::Tensor& self,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride,
std::optional<c10::SymInt> storage_offset) {
return at::native::as_strided_openreg(self, size, stride, storage_offset);
}
const at::Tensor& wrapper_resize_(
const at::Tensor& self,
c10::SymIntArrayRef size,
::std::optional<at::MemoryFormat> memory_format) {
return at::native::resize_openreg_(self, size, memory_format);
}
at::Tensor wrapper__reshape_alias(
const at::Tensor& self,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride) {
return at::native::_reshape_alias_openreg(self, size, stride);
}
at::Tensor wrapper__copy_from(
const at::Tensor& self,
const at::Tensor& dst,
bool non_blocking) {
return at::native::_copy_from_openreg(self, dst, non_blocking);
}
at::Tensor wrapper__copy_from_and_resize(
const at::Tensor& self,
const at::Tensor& dst) {
return at::native::_copy_from_and_resize_openreg(self, dst);
}
at::Scalar wrapper__local_scalar_densor(const at::Tensor& self) {
return at::native::_local_scalar_dense_openreg(self);
}
at::Tensor& wrapper_set_source_Tensor_(
at::Tensor& self,
const at::Tensor& source) {
return at::native::set_source_Tensor_openreg_(self, source);
}
at::Tensor& wrapper_set_source_Storage_(at::Tensor& self, at::Storage source) {
return at::native::set_source_Storage_openreg_(self, source);
}
at::Tensor& wrapper_set_source_Storage_storage_offsetset_(
at::Tensor& result,
at::Storage storage,
int64_t storage_offset,
c10::IntArrayRef size,
c10::IntArrayRef stride) {
return at::native::set_source_Storage_storage_offset_openreg_(
result, storage, storage_offset, size, stride);
}
at::Tensor wrapper_view(const at::Tensor& self, c10::SymIntArrayRef size) {
return at::native::view_openreg(self, size);
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("empty.memory_format", wrapper_empty_memory_format);
m.impl("empty_strided", wrapper_empty_strided);
m.impl("as_strided", wrapper_as_strided);
m.impl("resize_", wrapper_resize_);
m.impl("_reshape_alias", wrapper__reshape_alias);
m.impl("_copy_from", wrapper__copy_from);
m.impl("_copy_from_and_resize", wrapper__copy_from_and_resize);
m.impl("_local_scalar_dense", wrapper__local_scalar_densor);
m.impl("set_.source_Tensor", wrapper_set_source_Tensor_);
m.impl("set_.source_Storage", wrapper_set_source_Storage_);
m.impl(
"set_.source_Storage_storage_offset",
wrapper_set_source_Storage_storage_offsetset_);
m.impl("view", wrapper_view);
}
void wrapper_cpu_fallback(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
at::native::cpu_fallback_openreg(op, stack);
}
TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
m.fallback(
torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>());
}
} // namespace at::openreg

View File

@ -0,0 +1,106 @@
#include <ATen/EmptyTensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorOperators.h>
#include <ATen/core/blob.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/ops/_local_scalar_dense_native.h>
#include <ATen/ops/_reshape_alias_native.h>
#include <ATen/ops/as_strided_cpu_dispatch.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/quantize_per_tensor_native.h>
#include <ATen/ops/resize_as_native.h>
#include <ATen/ops/resize_native.h>
#include <ATen/ops/set_cpu_dispatch.h>
#include <ATen/ops/set_native.h>
#include <ATen/ops/view_native.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/function_hook.h>
#include <c10/core/Allocator.h>
#include <set>
#include <include/openreg.h>
namespace at::native {
class MemoryGuard {
public:
explicit MemoryGuard(const torch::jit::Stack& stack) {
for (const c10::IValue& ivalue : stack) {
find_and_unprotect_tensors(ivalue);
}
}
template <typename... Args>
explicit MemoryGuard(const Args&... args) {
(handler(args), ...);
}
~MemoryGuard() {
for (void* ptr : unprotected_pointers_) {
orMemoryProtect(ptr);
}
}
MemoryGuard(const MemoryGuard&) = delete;
MemoryGuard& operator=(const MemoryGuard&) = delete;
MemoryGuard(MemoryGuard&&) = delete;
MemoryGuard& operator=(MemoryGuard&&) = delete;
private:
void find_and_unprotect_tensors(const c10::IValue& ivalue) {
if (ivalue.isTensor()) {
unprotect_if_needed(ivalue.toTensor());
} else if (ivalue.isTensorList()) {
for (const at::Tensor& tensor : ivalue.toTensorList()) {
unprotect_if_needed(tensor);
}
} else if (ivalue.isList()) {
for (const c10::IValue& element : ivalue.toListRef()) {
find_and_unprotect_tensors(element);
}
} else if (ivalue.isGenericDict()) {
for (const auto& pair : ivalue.toGenericDict()) {
find_and_unprotect_tensors(pair.key());
find_and_unprotect_tensors(pair.value());
}
}
}
void unprotect_if_needed(const at::Tensor& tensor) {
if (!tensor.defined() || !tensor.has_storage()) {
return;
}
void* ptr = tensor.data_ptr();
orPointerAttributes attr;
if (orPointerGetAttributes(&attr, ptr) == orSuccess) {
if (attr.type == orMemoryTypeDevice) {
if (unprotected_pointers_.find(attr.pointer) ==
unprotected_pointers_.end()) {
orMemoryUnprotect(attr.pointer);
unprotected_pointers_.insert(attr.pointer);
}
}
}
}
template <typename T>
void handler(const T& x) {
if constexpr (std::is_same_v<std::decay_t<T>, at::Tensor>) {
unprotect_if_needed(x);
}
}
std::set<void*> unprotected_pointers_;
};
} // namespace at::native

View File

@ -0,0 +1,238 @@
#include "Extra.h"
namespace at::native {
at::Tensor quantize_per_tensor_openreg(
const at::Tensor& self,
double scale,
int64_t zero_point,
at::ScalarType dtype) {
return at::native::quantize_per_tensor(self, scale, zero_point, dtype);
}
int64_t _fused_sdp_choice_openreg(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
_scaled_dot_product_fused_attention_overrideable_openreg(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_q = query.size(2);
const int64_t max_seqlen_kv = key.size(2);
auto opts = query.options();
auto output =
at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
auto logsumexp =
at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto debug_attn_mask = at::empty(
{batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
opts.dtype(at::kFloat));
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple(
output,
logsumexp,
at::Tensor(),
at::Tensor(),
max_seqlen_q,
max_seqlen_kv,
philox_seed,
philox_offset,
debug_attn_mask);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward_openreg(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& attn_bias,
std::array<bool, 4> grad_input_mask,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
std::optional<double> scale) {
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
at::empty_like(query),
at::empty_like(key),
at::empty_like(value),
at::empty_like(attn_bias));
}
} // namespace at::native
namespace at::native {
void abs_kernel_openreg(at::TensorIteratorBase& iter) {
// Abs only have a input tensor and a output tensor.
auto& output_operand = iter.operand(0);
auto& input_operand = iter.operand(1);
auto& output_tensor_base = output_operand.tensor_base();
auto& input_tensor_base = input_operand.tensor_base();
TORCH_CHECK(
!input_operand.original_tensor_base().defined(),
"input original tensor is defined.");
TORCH_CHECK(
!output_operand.original_tensor_base().defined(),
"output original tensor is defined.");
// For easy test, only accept contiguous input tensor for calculate.
auto memory_format = input_tensor_base.suggest_memory_format();
TORCH_CHECK(
input_tensor_base.is_contiguous(memory_format),
"Input tensor need be contiguous.");
// Add necessary restrictions to ensure the security of the demo.
TORCH_CHECK(
input_tensor_base.sizes() == output_tensor_base.sizes(),
"Intput and output tensor size are not equal.");
// Common dtype is calculate in TensorIteratorBase.
TORCH_CHECK(
iter.common_dtype() == at::ScalarType::Float, "Only support float type.")
// Using for loop for abs calculate.
auto abs_function =
[](float* output_ptr, const float* input_ptr, const int64_t NUM) {
for (int64_t i = 0; i < NUM; ++i) {
*(output_ptr + i) = std::abs(*(input_ptr + i));
}
};
// To simplify the logic of the test demo code,
// we only use contiguous tensor to calculate on device side.
// And using input tensor memory format.
if (iter.is_contiguous()) {
// Add for will_resize flag check. You can convert to differernt
// tensor memory format when will_resize is True.
// If TensorIteratorConfig resize_outputs_ flag is true, and there are two
// situations:
// 1) Out tensor is undefined, and TensorIterator set will_resize to true;
// 2) Out tensor is defined and tensor size is not equal to input tensor
// size;
// TensorIterator set will_resize to true, and call
// set_output_raw_strided to resize output tensor.
// When output operand will_resize flag is ture, dummy
// device can convert tensor to dummy device preferred memory format.
// Here we don't convert tensor memory format, because it will become
// complex when dummy device want keep same memory format for training
// network.
TORCH_CHECK(
output_operand.will_resize,
"output operand will_resize flag need be True.");
abs_function(
(float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel());
} else {
// Stride copy is not support for foo device, using cpu device instead.
// For abs op, the last situation is: output tensor is not contiguous with
// operand will_resize is False.
TORCH_CHECK(
!output_operand.will_resize, "output operand will_resize is True.");
// Get a contiguous tensor with input memory format.
at::Tensor output = at::empty(
output_tensor_base.sizes(),
input_tensor_base.options().memory_format(memory_format));
// For structured op which inheried from TensorIteratorBase, maybe you need
// to call set_output_raw_strided function to update output stored in op
// sturctured. abs op is no need to do this.
output_operand.exchange_tensor(
c10::MaybeOwned<at::TensorBase>::owned(std::in_place, output));
abs_function(
(float*)output_operand.tensor_base().mutable_data_ptr(),
(float*)iter.data_ptr(1),
iter.numel());
// Copy tensor base to original tensor base, and keep same scalar type and
// stride with cpu and gpu.
if (output_operand.original_tensor_base().defined() &&
!output_operand.original_tensor_base().is_same(
output_operand.tensor_base())) {
output_operand.original_tensor().copy_(output_operand.tensor());
output_operand.restore_original_tensor();
}
}
}
void quantize_tensor_per_tensor_affine_stub_openreg(
const at::Tensor& rtensor,
at::Tensor& qtensor,
double scale,
int64_t zero_point) {}
} // namespace at::native
namespace at::native {
namespace {
struct CustomAutogradFnReturnsSelf
: public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
static at::Tensor forward(
torch::autograd::AutogradContext* ctx,
at::Tensor self) {
return self;
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
struct CustomAutogradFnAliasing
: public torch::autograd::Function<CustomAutogradFnAliasing> {
static at::Tensor forward(
torch::autograd::AutogradContext* ctx,
at::Tensor self) {
return self.view_symint(self.sym_sizes());
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
} // namespace
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
return CustomAutogradFnReturnsSelf::apply(x);
}
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
return CustomAutogradFnAliasing::apply(x);
}
} // namespace at::native

View File

@ -0,0 +1,70 @@
#include "Common.h"
namespace at::native {
at::Tensor quantize_per_tensor_openreg(
const at::Tensor& self,
double scale,
int64_t zero_point,
at::ScalarType dtype);
int64_t _fused_sdp_choice_openreg(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_mask,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa);
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
_scaled_dot_product_fused_attention_overrideable_openreg(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const std::optional<at::Tensor>& attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward_openreg(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& attn_bias,
std::array<bool, 4> grad_input_mask,
const at::Tensor& out,
const at::Tensor& logsumexp,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
std::optional<double> scale);
} // namespace at::native
namespace at::native {
void abs_kernel_openreg(at::TensorIteratorBase& iter);
void quantize_tensor_per_tensor_affine_stub_openreg(
const at::Tensor& rtensor,
at::Tensor& qtensor,
double scale,
int64_t zero_point);
} // namespace at::native
namespace at::native {
at::Tensor custom_autograd_fn_returns_self(at::Tensor x);
at::Tensor custom_autograd_fn_aliasing(at::Tensor x);
} // namespace at::native

View File

@ -0,0 +1,173 @@
#include "Minimal.h"
namespace at::native {
at::Tensor empty_memory_format_openreg(
c10::IntArrayRef size,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt) {
const auto device = c10::device_or_default(device_opt);
const auto dtype = c10::dtype_or_default(dtype_opt);
TORCH_CHECK(device.is_privateuseone());
TORCH_CHECK(
c10::layout_or_default(layout_opt) == c10::Layout::Strided,
"Non strided layout not supported");
TORCH_CHECK(
!c10::pinned_memory_or_default(pin_memory_opt),
"Pin memory can only be on CPU");
const c10::DeviceGuard device_guard(device);
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
auto allocator = at::GetAllocator(at::kPrivateUse1);
return at::detail::empty_generic(
size, allocator, pu1_dks, dtype, memory_format_opt);
}
at::Tensor empty_strided_openreg(
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt) {
const auto device = c10::device_or_default(device_opt);
const auto dtype = c10::dtype_or_default(dtype_opt);
TORCH_CHECK(device.is_privateuseone());
TORCH_CHECK(
c10::layout_or_default(layout_opt) == c10::Layout::Strided,
"Non strided layout not supported");
TORCH_CHECK(
!c10::pinned_memory_or_default(pin_memory_opt),
"Pin memory can only be on CPU");
const c10::DeviceGuard device_guard(device);
constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1);
auto allocator = at::GetAllocator(at::kPrivateUse1);
return at::detail::empty_strided_generic(
size, stride, allocator, pu1_dks, dtype);
}
at::Tensor as_strided_openreg(
const at::Tensor& self,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride,
std::optional<c10::SymInt> storage_offset) {
MemoryGuard guard(self);
return at::cpu::as_strided_symint(self, size, stride, storage_offset);
}
const at::Tensor& resize_openreg_(
const at::Tensor& self,
c10::SymIntArrayRef size,
::std::optional<at::MemoryFormat> memory_format) {
return at::native::resize_(
self, C10_AS_INTARRAYREF_SLOW(size), memory_format);
}
at::Tensor _reshape_alias_openreg(
const at::Tensor& self,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride) {
return at::native::_reshape_alias(
self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride));
}
at::Tensor _copy_from_openreg(
const at::Tensor& self,
const at::Tensor& dst,
bool non_blocking) {
TORCH_CHECK(self.defined(), "Source tensor (self) is not defined.");
TORCH_CHECK(dst.defined(), "Destination tensor (dst) is not defined.");
MemoryGuard guard(self, dst);
if (self.device() == dst.device()) {
at::Tensor dst_as_cpu = at::from_blob(
dst.data_ptr(),
dst.sizes(),
dst.strides(),
dst.options().device(at::kCPU));
const at::Tensor self_as_cpu = at::from_blob(
self.data_ptr(),
self.sizes(),
self.strides(),
self.options().device(at::kCPU));
at::native::copy_(
const_cast<at::Tensor&>(dst_as_cpu), self_as_cpu, non_blocking);
} else {
if (self.is_cpu()) {
at::Tensor dst_as_cpu = at::from_blob(
dst.data_ptr(),
dst.sizes(),
dst.strides(),
dst.options().device(at::kCPU));
at::native::copy_(
const_cast<at::Tensor&>(dst_as_cpu), self, non_blocking);
} else {
at::Tensor self_as_cpu = at::from_blob(
self.data_ptr(),
self.sizes(),
self.strides(),
self.options().device(at::kCPU));
at::native::copy_(
const_cast<at::Tensor&>(dst), self_as_cpu, non_blocking);
}
}
return dst;
}
at::Tensor _copy_from_and_resize_openreg(
const at::Tensor& self,
const at::Tensor& dst) {
at::native::resize_(dst, self.sizes(), std::nullopt);
MemoryGuard guard(self, dst);
return at::native::copy_(const_cast<at::Tensor&>(dst), self, false);
}
at::Scalar _local_scalar_dense_openreg(const at::Tensor& self) {
MemoryGuard guard(self);
return at::native::_local_scalar_dense_cpu(self);
}
at::Tensor& set_source_Tensor_openreg_(
at::Tensor& self,
const at::Tensor& source) {
return at::native::set_tensor_(self, source);
}
at::Tensor& set_source_Storage_openreg_(at::Tensor& self, at::Storage source) {
return at::native::set_(self, source);
}
at::Tensor& set_source_Storage_storage_offset_openreg_(
at::Tensor& result,
at::Storage storage,
int64_t storage_offset,
c10::IntArrayRef size,
c10::IntArrayRef stride) {
// call native::
return at::cpu::set_(result, storage, storage_offset, size, stride);
}
at::Tensor view_openreg(const at::Tensor& self, c10::SymIntArrayRef size) {
MemoryGuard guard(self);
return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size));
}
void cpu_fallback_openreg(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
at::native::cpu_fallback(op, stack);
}
} // namespace at::native

View File

@ -0,0 +1,67 @@
#include "Common.h"
namespace at::native {
at::Tensor empty_memory_format_openreg(
c10::IntArrayRef size,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
at::Tensor empty_strided_openreg(
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt);
at::Tensor as_strided_openreg(
const at::Tensor& self,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride,
std::optional<c10::SymInt> storage_offset);
const at::Tensor& resize_openreg_(
const at::Tensor& self,
c10::SymIntArrayRef size,
::std::optional<at::MemoryFormat> memory_format);
at::Tensor _reshape_alias_openreg(
const at::Tensor& self,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride);
at::Tensor _copy_from_openreg(
const at::Tensor& self,
const at::Tensor& dst,
bool non_blocking);
at::Tensor _copy_from_and_resize_openreg(
const at::Tensor& self,
const at::Tensor& dst);
at::Scalar _local_scalar_dense_openreg(const at::Tensor& self);
at::Tensor& set_source_Tensor_openreg_(
at::Tensor& self,
const at::Tensor& source);
at::Tensor& set_source_Storage_openreg_(at::Tensor& self, at::Storage source);
at::Tensor& set_source_Storage_storage_offset_openreg_(
at::Tensor& result,
at::Storage storage,
int64_t storage_offset,
c10::IntArrayRef size,
c10::IntArrayRef stride);
at::Tensor view_openreg(const at::Tensor& self, c10::SymIntArrayRef size);
void cpu_fallback_openreg(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
} // namespace at::native

View File

@ -0,0 +1,8 @@
#include "OpenRegDeviceAllocator.h"
namespace c10::openreg {
static OpenRegDeviceAllocator global_openreg_alloc;
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc);
} // namespace c10::openreg

View File

@ -0,0 +1,43 @@
#include <ATen/core/CachingHostAllocator.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <include/openreg.h>
namespace c10::openreg {
struct OpenRegDeviceAllocator final : at::Allocator {
OpenRegDeviceAllocator() = default;
static void ReportAndDelete(void* ptr) {
if (!ptr) {
return;
}
orFreeHost(ptr);
}
at::DataPtr allocate(size_t nbytes) override {
int current_device_index = -1;
orGetDevice(&current_device_index);
auto curr_device =
c10::Device(c10::DeviceType::PrivateUse1, current_device_index);
void* data = nullptr;
if (nbytes > 0) {
orMalloc(&data, nbytes);
TORCH_CHECK(
data, "Failed to allocator ", nbytes, " bytes on openreg device.");
}
return {data, data, &ReportAndDelete, curr_device};
}
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
orMemcpy(dest, src, count, orMemcpyDeviceToDevice);
}
};
} // namespace c10::openreg

View File

@ -0,0 +1,73 @@
#include <include/openreg.h>
#include "OpenRegFunctions.h"
namespace c10::openreg {
orError_t GetDeviceCount(int* dev_count) {
return orGetDeviceCount(dev_count);
}
orError_t GetDevice(c10::DeviceIndex* device) {
int tmp_device = -1;
auto err = orGetDevice(&tmp_device);
*device = static_cast<c10::DeviceIndex>(tmp_device);
return err;
}
orError_t SetDevice(c10::DeviceIndex device) {
int cur_device = -1;
orGetDevice(&cur_device);
if (device == cur_device) {
return orSuccess;
}
return orSetDevice(device);
}
int device_count_impl() {
int count = 0;
GetDeviceCount(&count);
return count;
}
c10::DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl();
TORCH_INTERNAL_ASSERT(
result <= std::numeric_limits<c10::DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed");
return result;
} catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
TORCH_WARN("Device initialization: ", ex.msg());
return 0;
}
}();
return static_cast<c10::DeviceIndex>(count);
}
c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device);
return cur_device;
}
void set_device(c10::DeviceIndex device) {
SetDevice(device);
}
DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1;
orGetDevice(&current_device);
if (device != current_device) {
orSetDevice(device);
}
return current_device;
}
} // namespace c10::openreg

View File

@ -0,0 +1,16 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
#include <limits>
namespace c10::openreg {
c10::DeviceIndex device_count() noexcept;
DeviceIndex current_device();
void set_device(c10::DeviceIndex device);
DeviceIndex ExchangeDevice(DeviceIndex device);
} // namespace c10::openreg

View File

@ -0,0 +1,28 @@
#include "OpenRegGenerator.h"
// Default, global generators, one per device.
static std::vector<at::Generator> default_generators;
namespace c10::openreg {
const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) {
static bool flag [[maybe_unused]] = []() {
auto deivce_nums = device_count();
default_generators.resize(deivce_nums);
for (auto i = 0; i < deivce_nums; i++) {
default_generators[i] = at::make_generator<OpenRegGeneratorImpl>(i);
default_generators[i].seed();
}
return true;
}();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = current_device();
} else {
TORCH_CHECK(idx >= 0 && idx < device_count());
}
return default_generators[idx];
}
} // namespace c10::openreg

View File

@ -0,0 +1,21 @@
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <c10/core/Device.h>
#include "OpenRegFunctions.h"
namespace c10::openreg {
class OpenRegGeneratorImpl : public at::CPUGeneratorImpl {
public:
OpenRegGeneratorImpl(c10::DeviceIndex device_index) {
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
}
~OpenRegGeneratorImpl() override = default;
};
const at::Generator& getDefaultOpenRegGenerator(
c10::DeviceIndex device_index = -1);
} // namespace c10::openreg

View File

@ -0,0 +1,7 @@
#include "OpenRegGuard.h"
namespace c10::openreg {
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
} // namespace c10::openreg

View File

@ -0,0 +1,197 @@
#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <include/openreg.h>
#include "OpenRegFunctions.h"
namespace c10::openreg {
// Device guard registration
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1;
OpenRegGuardImpl() = default;
explicit OpenRegGuardImpl(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == static_type);
}
/**
* Return the type of device managed by this guard implementation.
*/
c10::DeviceType type() const override {
return static_type;
}
/**
* Set the current device to Device, and return the previous c10::Device.
*/
c10::Device exchangeDevice(c10::Device d) const override {
TORCH_CHECK(d.is_privateuseone());
auto old_device_index = ExchangeDevice(d.index());
return c10::Device(static_type, old_device_index);
}
/**
* Get the current device.
*/
c10::Device getDevice() const override {
int device_index = current_device();
return c10::Device(static_type, device_index);
}
/**
* Set the current device to c10::Device.
*/
void setDevice(c10::Device d) const override {
TORCH_CHECK(d.is_privateuseone());
set_device(d.index());
}
/**
* Set the current device to c10::Device, without checking for errors
* (so, e.g., this can be called from a destructor).
*/
void uncheckedSetDevice(c10::Device d) const noexcept override {
TORCH_CHECK(d.is_privateuseone());
set_device(d.index());
}
/**
* Get the current stream for a given device.
*/
c10::Stream getStream(c10::Device d) const noexcept override {
return c10::Stream(c10::Stream::DEFAULT, d);
}
/**
* Get the default stream for a given device.
*/
c10::Stream getDefaultStream(c10::Device d) const override {
return c10::Stream(c10::Stream::DEFAULT, d);
}
/**
* Get a stream from the global pool for a given device.
*/
c10::Stream getStreamFromGlobalPool(
c10::Device d,
bool isHighPriority = false) const override {
return c10::Stream(c10::Stream::DEFAULT, d);
}
/**
* Return a new stream for a given device and priority. The stream will be
* copied and shared around, device backend should be able to correctly handle
* the lifetime of the stream.
*/
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
return c10::Stream(c10::Stream::DEFAULT, d);
}
/**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
*/
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
return s;
}
/**
* Destroys the given event.
*/
void destroyEvent(void* event, const c10::DeviceIndex device_index)
const noexcept override {}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
void record(
void** event,
const c10::Stream& stream,
const c10::DeviceIndex device_index,
const c10::EventFlag flag) const override {
static int event_id = 1;
if (!*event)
*event = reinterpret_cast<void*>(event_id++);
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
void block(void* event, const c10::Stream& stream) const override {}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
bool queryEvent(void* event) const override {
return true;
}
/**
* Get the number of devices. WARNING: This is REQUIRED to not raise
* an exception. If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
c10::DeviceIndex deviceCount() const noexcept override {
int device_index = -1;
orGetDeviceCount(&device_index);
return device_index;
}
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
bool queryStream(const c10::Stream& stream) const override {
return true;
}
/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
void synchronizeStream(const c10::Stream& stream) const override {}
/**
* Wait (by blocking the calling thread) until all the work previously
* recorded on the event has completed running on the device.
*/
void synchronizeEvent(void* event) const override {}
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
*/
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const c10::Stream& stream) const override {}
/**
* Fetch the elapsed time between two recorded events.
*/
double elapsedTime(
void* event1,
void* event2,
const c10::DeviceIndex device_index) const override {
return 1;
}
};
} // namespace c10::openreg

View File

@ -0,0 +1,11 @@
#include "OpenRegHooks.h"
namespace c10::openreg {
static bool register_hook_flag [[maybe_unused]] = []() {
at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface());
return true;
}();
} // namespace c10::openreg

View File

@ -0,0 +1,41 @@
#include <ATen/core/CachingHostAllocator.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <include/openreg.h>
#include "OpenRegGenerator.h"
namespace c10::openreg {
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
OpenRegHooksInterface() {};
~OpenRegHooksInterface() override = default;
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
return true;
}
at::Allocator* getPinnedMemoryAllocator() const override {
return at::getHostAllocator(at::kPrivateUse1);
}
bool isPinnedPtr(const void* data) const override {
orPointerAttributes attr{};
orPointerGetAttributes(&attr, data);
return attr.type == orMemoryTypeHost;
}
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
return getDefaultOpenRegGenerator(device_index);
}
at::Generator getNewGenerator(c10::DeviceIndex device_index) const override {
return at::make_generator<OpenRegGeneratorImpl>(device_index);
}
};
} // namespace c10::openreg

View File

@ -0,0 +1,8 @@
#include "OpenRegHostAllocator.h"
namespace c10::openreg {
OpenRegHostAllocator caching_host_allocator;
REGISTER_HOST_ALLOCATOR(at::kPrivateUse1, &caching_host_allocator);
} // namespace c10::openreg

View File

@ -0,0 +1,48 @@
#include <ATen/core/CachingHostAllocator.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <include/openreg.h>
namespace c10::openreg {
struct OpenRegHostAllocator final : at::HostAllocator {
OpenRegHostAllocator() = default;
static void ReportAndDelete(void* ptr) {
if (!ptr) {
return;
}
orFreeHost(ptr);
}
at::DataPtr allocate(size_t nbytes) override {
void* data = nullptr;
if (nbytes > 0) {
orMallocHost(&data, nbytes);
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host.");
}
return {data, data, &ReportAndDelete, at::Device(at::kCPU)};
}
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
orMemcpy(dest, src, count, orMemcpyHostToHost);
}
// ignore
bool record_event(void* ptr, void* ctx, c10::Stream stream) override {
return true;
}
void empty_cache() override {}
at::HostStats get_stats() override {
return at::HostStats();
}
void reset_accumulated_stats() override {}
void reset_peak_stats() override {}
};
} // namespace c10::openreg

View File

@ -0,0 +1,48 @@
#include "OpenRegSerialization.h"
namespace c10::openreg {
struct OpenRegBackendMeta : public c10::BackendMeta {
OpenRegBackendMeta(int version_number, int format_number)
: version_number_(version_number), format_number_(format_number) {}
int version_number_{-1};
int format_number_{-1};
};
void for_serialization(
const at::Tensor& t,
std::unordered_map<std::string, bool>& m) {
auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta();
if (meta_ptr != nullptr) {
auto o_meta_ptr = dynamic_cast<OpenRegBackendMeta*>(meta_ptr);
if (o_meta_ptr->version_number_ == 1) {
m["version_number"] = true;
}
if (o_meta_ptr->format_number_ == 29) {
m["format_number"] = true;
}
}
}
void for_deserialization(
const at::Tensor& t,
std::unordered_map<std::string, bool>& m) {
int version_number{-1};
int format_number{-1};
if (m.find("version_number") != m.end()) {
version_number = 1;
}
if (m.find("format_number") != m.end()) {
format_number = 29;
}
c10::intrusive_ptr<c10::BackendMeta> meta{std::unique_ptr<c10::BackendMeta>(
new OpenRegBackendMeta(version_number, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(meta);
}
REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization)
} // namespace c10::openreg

View File

@ -0,0 +1,10 @@
#include <torch/csrc/jit/serialization/pickler.h>
#define REGISTER_PRIVATEUSE1_SERIALIZATION( \
FOR_SERIALIZATION, FOR_DESERIALIZATION) \
static int register_serialization() { \
torch::jit::TensorBackendMetaRegistry( \
c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \
return 0; \
} \
static const int _temp = register_serialization();

View File

@ -0,0 +1,2 @@
torch
pybind11

View File

@ -0,0 +1,102 @@
import multiprocessing
import os
import shutil
import subprocess
import sys
import sysconfig
from distutils.command.clean import clean
from setuptools import Extension, find_packages, setup
PACKAGE_NAME = "torch_openreg"
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
def get_pytorch_dir():
import torch
return os.path.dirname(os.path.realpath(torch.__file__))
def build_deps():
build_dir = os.path.join(BASE_DIR, "build")
os.makedirs(build_dir, exist_ok=True)
cmake_args = [
"-DCMAKE_INSTALL_PREFIX="
+ os.path.realpath(os.path.join(BASE_DIR, "torch_openreg")),
"-DPYTHON_INCLUDE_DIR=" + sysconfig.get_paths().get("include"),
"-DPYTORCH_INSTALL_DIR=" + get_pytorch_dir(),
]
subprocess.check_call(
["cmake", BASE_DIR] + cmake_args, cwd=build_dir, env=os.environ
)
build_args = [
"--build",
".",
"--target",
"install",
"--",
]
build_args += ["-j", str(multiprocessing.cpu_count())]
command = ["cmake"] + build_args
subprocess.check_call(command, cwd=build_dir, env=os.environ)
class BuildClean(clean):
def run(self):
for i in ["build", "install", "torch_openreg.egg-info", "torch_openreg/lib"]:
dirs = os.path.join(BASE_DIR, i)
if os.path.exists(dirs) and os.path.isdir(dirs):
shutil.rmtree(dirs)
for dirpath, _, filenames in os.walk(os.path.join(BASE_DIR, "torch_openreg")):
for filename in filenames:
if filename.endswith(".so"):
os.remove(os.path.join(dirpath, filename))
RUN_BUILD_DEPS = any(arg == "clean" for arg in sys.argv)
def main():
if not RUN_BUILD_DEPS:
build_deps()
ext_modules = [
Extension(
name="torch_openreg._C",
sources=["torch_openreg/csrc/stub.c"],
extra_compile_args=["-g", "-Wall", "-Werror"],
libraries=["torch_bindings"],
library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")],
extra_link_args=["-Wl,-rpath,$ORIGIN/lib"],
)
]
package_data = {PACKAGE_NAME: ["lib/*.so*"]}
setup(
name=PACKAGE_NAME,
version="0.0.1",
author="PyTorch Core Team",
description="Example for PyTorch out of tree registration",
packages=find_packages(exclude=("test",)),
package_data=package_data,
install_requires=[
"torch",
],
ext_modules=ext_modules,
python_requires=">=3.8",
cmdclass={
"clean": BuildClean, # type: ignore[misc]
},
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,11 @@
set(LIBRARY_NAME openreg)
file(GLOB_RECURSE SOURCE_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
)
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

View File

@ -0,0 +1,137 @@
# OpenReg: An Accelerator Backend that Simulates CUDA Behavior on a CPU
## Introduction
OpenReg is a C++ backend library that simulates the behavior of a CUDA-like device on a CPU. Its core objective is **not to accelerate computation or improve performance**, but rather to **simulate modern CUDA programming, enabling developers to prototype and test in an environment without actual GPU hardware**. The current design principles are as follows:
* **API Consistency**: Provide an interface consistent with the CUDA Runtime API, allowing upper-level applications (like PyTorch's PrivateUse1 backend) to switch and test seamlessly.
* **Functional Consistency**: Provide behavior consistent with the CUDA Runtime, such as memory isolation, device context management, etc.
* **Completeness**: Aim to support PrivateUse1 device integration and safeguard the third-party device integration mechanism, without striving to cover all capabilities of the CUDA Runtime.
## Directory Structure
The project's code is organized with a clear structure and separation of responsibilities:
```text
openreg/
├── CMakeLists.txt # Top-level CMake build script, used to compile and generate libopenreg.so
├── include/
│ └── openreg.h # Public API header file, external users only need to include this file
└── csrc/
├── device.cpp # Implementation of device management-related APIs
└── memory.cpp # Implementation of APIs for memory management, copying, and protection
```
* `include/openreg.h`: Defines all externally exposed C-style APIs, data structures, and enums. It is the "public face" of this library.
* `csrc/`: Contains the C++ implementation source code for all core functionalities.
* `device.cpp`: Implements device discovery (`orGetDeviceCount`) and thread context management (`orSetDevice`/`orGetDevice`).
* `memory.cpp`: Implements the core functions of memory allocation (`orMalloc`/`orMallocHost`), deallocation, copying, and memory protection (`orMemoryProtect`, `orMemoryUnprotect`).
* `CMakeLists.txt`: Responsible for compiling and linking all source files under the `csrc/` directory to generate the final `libopenreg.so` shared library.
## Implemented APIs
OpenReg currently provides a set of APIs covering basic memory and device management.
### Device Management APIs
| OpenReg | CUDA | Feature Description |
| :------------------- | :------------------- | :------------------------------------------------ |
| `orGetDeviceCount` | `cudaGetDeviceCount` | Get the number of devices |
| `orSetDevice` | `cudaSetDevice` | Set the current device for the current thread |
| `orGetDevice` | `cudaGetDevice` | Get the current device for the current thread |
### Memory Management APIs
| OpenReg | CUDA | Feature Description |
| :----------------------- | :--------------------------- | :----------------------------------------- |
| `orMalloc` | `cudaMalloc` | Allocate device memory |
| `orFree` | `cudaFree` | Free device memory |
| `orMallocHost` | `cudaMallocHost` | Allocate page-locked (Pinned) host memory |
| `orFreeHost` | `cudaFreeHost` | Free page-locked host memory |
| `orMemcpy` | `cudaMemcpy` | Synchronous memory copy |
| `orMemcpyAsync` | `cudaMemcpyAsync` | Asynchronous memory copy |
| `orPointerGetAttributes` | `cudaPointerGetAttributes` | Get pointer attributes |
| `orMemoryUnprotect` | - | (Internal use) Unprotect memory |
| `orMemoryProtect` | - | (Internal use) Restore memory protection |
## Implementation Principles
### Device Management Principles
Simulating multiple devices and thread-safe device context switching:
1. **Device Count**: The total number of simulated devices is defined by the compile-time constant `constexpr int kDeviceCount`.
2. **Device Switching**: Device switching in multi-threaded scenarios is simulated using a **TLS (Thread-Local Storage) global variable**.
### Memory Management Principles
Simulating device memory, host memory, and memory copies:
1. **Allocation**: A page-aligned memory block is allocated using `mmap` + `mprotect` with the permission flag `PROT_NONE`. Read, write, and execute operations on this memory region are all prohibited.
2. **Deallocation**: Memory is freed using `munmap`.
3. **Authorization**: When a legitimate memory access is required, an RAII guard restores the memory permissions to `PROT_READ | PROT_WRITE`. The permissions are automatically reverted to `PROT_NONE` when the scope is exited.
## Usage Example
The following is a simple code snippet demonstrating how to use the core features of the OpenReg library.
```cpp
#include "openreg.h"
#include <iostream>
#include <vector>
#include <cstdio>
#define OR_CHECK(call) do { \
orError_t err = call; \
if (err != orSuccess) { \
fprintf(stderr, "OR Error code %d in %s at line %d\n", err, __FILE__, __LINE__); \
exit(EXIT_FAILURE); \
} \
} while (0)
int main() {
int device_count = 0;
OR_CHECK(orGetDeviceCount(&device_count));
std::cout << "Found " << device_count << " simulated devices." << std::endl;
int current_device = -1;
OR_CHECK(orSetDevice(1));
OR_CHECK(orGetDevice(&current_device));
std::cout << "Set current device to " << current_device << "." << std::endl;
const int n = 1024;
const size_t size = n * sizeof(int);
int *h_a, *d_a;
OR_CHECK(orMallocHost((void**)&h_a, size));
OR_CHECK(orMalloc((void**)&d_a, size));
orPointerAttributes attr;
OR_CHECK(orPointerGetAttributes(&attr, d_a));
std::cout << "Pointer " << (void*)d_a << " is of type " << attr.type
<< " on device " << attr.device << std::endl;
for (int i = 0; i < n; ++i) {
h_a[i] = i;
}
OR_CHECK(orMemcpy(d_a, h_a, size, orMemcpyHostToDevice));
std::cout << "Data copied from Host to Device." << std::endl;
// std::cout << "Trying to access device memory directly from CPU..." << std::endl;
// int val = d_a[0]; // CRASH!
// Clean up resources
OR_CHECK(orFree(d_a));
OR_CHECK(orFreeHost(h_a));
std::cout << "Resources freed." << std::endl;
return 0;
}
```
## Next Steps
To better support PrivateUse1 device integration, the following capabilities are planned for the future:
* **Stream Support**: Provide the ability to simulate CUDA Streams.
* **Event Support**: Provide the ability to simulate CUDA Events.
* **Cross-Platform Support**: Add support for Windows and macOS (low priority).

View File

@ -0,0 +1,35 @@
#include <include/openreg.h>
namespace {
// Total device numbers
constexpr int DEVICE_COUNT = 2;
// Current device index
thread_local int gCurrentDevice = 0;
} // namespace
orError_t orGetDeviceCount(int* count) {
if (!count) {
return orErrorUnknown;
}
*count = DEVICE_COUNT;
return orSuccess;
}
orError_t orGetDevice(int* device) {
if (!device) {
return orErrorUnknown;
}
*device = gCurrentDevice;
return orSuccess;
}
orError_t orSetDevice(int device) {
if (device < 0 || device >= DEVICE_COUNT) {
return orErrorUnknown;
}
gCurrentDevice = device;
return orSuccess;
}

View File

@ -0,0 +1,249 @@
#include <include/openreg.h>
#include <sys/mman.h>
#include <unistd.h>
#include <cstdlib>
#include <cstring>
#include <map>
#include <mutex>
namespace openreg {
namespace internal {
class ScopedMemoryProtector {
public:
ScopedMemoryProtector(const orPointerAttributes& info)
: m_info(info), m_protected(false) {
if (m_info.type == orMemoryType::orMemoryTypeDevice) {
if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) ==
0) {
m_protected = true;
}
}
}
~ScopedMemoryProtector() {
if (m_protected) {
mprotect(m_info.pointer, m_info.size, PROT_NONE);
}
}
ScopedMemoryProtector(const ScopedMemoryProtector&) = delete;
ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete;
private:
orPointerAttributes m_info;
bool m_protected;
};
class MemoryManager {
public:
static MemoryManager& getInstance() {
static MemoryManager instance;
return instance;
}
orError_t allocate(void** ptr, size_t size, orMemoryType type) {
if (!ptr || size == 0)
return orErrorUnknown;
std::lock_guard<std::mutex> lock(m_mutex);
long page_size = sysconf(_SC_PAGESIZE);
size_t aligned_size = ((size - 1) / page_size + 1) * page_size;
void* mem = nullptr;
int current_device = -1;
if (type == orMemoryType::orMemoryTypeDevice) {
orGetDevice(&current_device);
mem = mmap(
nullptr,
aligned_size,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS,
-1,
0);
if (mem == MAP_FAILED)
return orErrorUnknown;
if (mprotect(mem, aligned_size, PROT_NONE) != 0) {
munmap(mem, aligned_size);
return orErrorUnknown;
}
} else {
if (posix_memalign(&mem, page_size, aligned_size) != 0) {
return orErrorUnknown;
}
}
m_registry[mem] = {type, current_device, mem, aligned_size};
*ptr = mem;
return orSuccess;
}
orError_t free(void* ptr) {
if (!ptr)
return orSuccess;
std::lock_guard<std::mutex> lock(m_mutex);
auto it = m_registry.find(ptr);
if (it == m_registry.end())
return orErrorUnknown;
const auto& info = it->second;
if (info.type == orMemoryType::orMemoryTypeDevice) {
mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE);
munmap(info.pointer, info.size);
} else {
::free(info.pointer);
}
m_registry.erase(it);
return orSuccess;
}
orError_t memcpy(
void* dst,
const void* src,
size_t count,
orMemcpyKind kind) {
if (!dst || !src || count == 0)
return orErrorUnknown;
std::lock_guard<std::mutex> lock(m_mutex);
orPointerAttributes dst_info = getPointerInfo(dst);
orPointerAttributes src_info = getPointerInfo(src);
switch (kind) {
case orMemcpyHostToDevice:
if (dst_info.type != orMemoryType::orMemoryTypeDevice ||
src_info.type == orMemoryType::orMemoryTypeDevice)
return orErrorUnknown;
break;
case orMemcpyDeviceToHost:
if (dst_info.type == orMemoryType::orMemoryTypeDevice ||
src_info.type != orMemoryType::orMemoryTypeDevice)
return orErrorUnknown;
break;
case orMemcpyDeviceToDevice:
if (dst_info.type != orMemoryType::orMemoryTypeDevice ||
src_info.type != orMemoryType::orMemoryTypeDevice)
return orErrorUnknown;
break;
case orMemcpyHostToHost:
if (dst_info.type == orMemoryType::orMemoryTypeDevice ||
src_info.type == orMemoryType::orMemoryTypeDevice)
return orErrorUnknown;
break;
}
{
ScopedMemoryProtector dst_protector(dst_info);
ScopedMemoryProtector src_protector(src_info);
::memcpy(dst, src, count);
}
return orSuccess;
}
orError_t getPointerAttributes(
orPointerAttributes* attributes,
const void* ptr) {
if (!attributes || !ptr)
return orErrorUnknown;
std ::lock_guard<std::mutex> lock(m_mutex);
orPointerAttributes info = getPointerInfo(ptr);
attributes->type = info.type;
if (info.type == orMemoryType::orMemoryTypeUnmanaged) {
attributes->device = -1;
attributes->pointer = const_cast<void*>(ptr);
attributes->size = 0;
} else {
attributes->device = info.device;
attributes->pointer = info.pointer;
attributes->size = info.size;
}
return orSuccess;
}
orError_t unprotect(void* ptr) {
std::lock_guard<std::mutex> lock(m_mutex);
orPointerAttributes info = getPointerInfo(ptr);
if (info.type != orMemoryType::orMemoryTypeDevice) {
return orErrorUnknown;
}
if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) {
return orErrorUnknown;
}
return orSuccess;
}
orError_t protect(void* ptr) {
std::lock_guard<std::mutex> lock(m_mutex);
orPointerAttributes info = getPointerInfo(ptr);
if (info.type != orMemoryType::orMemoryTypeDevice) {
return orErrorUnknown;
}
if (mprotect(info.pointer, info.size, PROT_NONE) != 0) {
return orErrorUnknown;
}
return orSuccess;
}
private:
MemoryManager() = default;
orPointerAttributes getPointerInfo(const void* ptr) {
auto it = m_registry.upper_bound(const_cast<void*>(ptr));
if (it == m_registry.begin())
return {};
--it;
const char* p_char = static_cast<const char*>(ptr);
const char* base_char = static_cast<const char*>(it->first);
if (p_char >= base_char && p_char < (base_char + it->second.size)) {
return it->second;
}
return {};
}
std::map<void*, orPointerAttributes> m_registry;
std::mutex m_mutex;
};
} // namespace internal
} // namespace openreg
orError_t orMalloc(void** devPtr, size_t size) {
return openreg::internal::MemoryManager::getInstance().allocate(
devPtr, size, orMemoryType::orMemoryTypeDevice);
}
orError_t orFree(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().free(devPtr);
}
orError_t orMallocHost(void** hostPtr, size_t size) {
return openreg::internal::MemoryManager::getInstance().allocate(
hostPtr, size, orMemoryType::orMemoryTypeHost);
}
orError_t orFreeHost(void* hostPtr) {
return openreg::internal::MemoryManager::getInstance().free(hostPtr);
}
orError_t orMemcpy(
void* dst,
const void* src,
size_t count,
orMemcpyKind kind) {
return openreg::internal::MemoryManager::getInstance().memcpy(
dst, src, count, kind);
}
orError_t orPointerGetAttributes(
orPointerAttributes* attributes,
const void* ptr) {
return openreg::internal::MemoryManager::getInstance().getPointerAttributes(
attributes, ptr);
}
orError_t orMemoryUnprotect(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().unprotect(devPtr);
}
orError_t orMemoryProtect(void* devPtr) {
return openreg::internal::MemoryManager::getInstance().protect(devPtr);
}

View File

@ -0,0 +1,49 @@
#pragma once
#include <cstddef>
#ifdef __cplusplus
extern "C" {
#endif
typedef enum orError_t { orSuccess = 0, orErrorUnknown = 1 } orError_t;
typedef enum orMemcpyKind {
orMemcpyHostToHost = 0,
orMemcpyHostToDevice = 1,
orMemcpyDeviceToHost = 2,
orMemcpyDeviceToDevice = 3
} orMemcpyKind;
typedef enum orMemoryType {
orMemoryTypeUnmanaged = 0,
orMemoryTypeHost = 1,
orMemoryTypeDevice = 2
} orMemoryType;
struct orPointerAttributes {
orMemoryType type = orMemoryType::orMemoryTypeUnmanaged;
int device;
void* pointer;
size_t size;
};
orError_t orMalloc(void** devPtr, size_t size);
orError_t orFree(void* devPtr);
orError_t orMallocHost(void** hostPtr, size_t size);
orError_t orFreeHost(void* hostPtr);
orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
orError_t orMemoryUnprotect(void* devPtr);
orError_t orMemoryProtect(void* devPtr);
orError_t orGetDeviceCount(int* count);
orError_t orSetDevice(int device);
orError_t orGetDevice(int* device);
orError_t orPointerGetAttributes(
orPointerAttributes* attributes,
const void* ptr);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -0,0 +1,8 @@
import torch
import torch_openreg._C # type: ignore[misc]
import torch_openreg.openreg
torch.utils.rename_privateuse1_backend("openreg")
torch._register_device_module("openreg", torch_openreg.openreg)
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)

View File

@ -0,0 +1,12 @@
set(LIBRARY_NAME torch_bindings)
file(GLOB_RECURSE SOURCE_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
)
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg)
target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib)
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

View File

@ -0,0 +1,99 @@
#include <ATen/Context.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_numbers.h>
#include <runtime/OpenRegFunctions.h>
static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
THPUtils_checkLong(arg),
"_get_default_generator expects an int, but got ",
THPUtils_typename(arg));
auto idx = static_cast<int>(THPUtils_unpackLong(arg));
return THPGenerator_initDefaultGenerator(
at::globalContext().defaultGenerator(
c10::Device(c10::DeviceType::PrivateUse1, idx)));
END_HANDLE_TH_ERRORS
}
PyObject* _setDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
auto device = THPUtils_unpackLong(arg);
torch::utils::device_lazy_init(at::kPrivateUse1);
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
auto device_index = THPUtils_unpackDeviceIndex(arg);
if (device_index < 0) {
return THPUtils_packInt32(-1);
}
torch::utils::device_lazy_init(at::kPrivateUse1);
auto current_device = c10::openreg::ExchangeDevice(device_index);
return THPUtils_packDeviceIndex(current_device);
END_HANDLE_TH_ERRORS
}
PyObject* _getDevice(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
torch::utils::device_lazy_init(at::kPrivateUse1);
auto device = static_cast<int32_t>(c10::openreg::current_device());
return THPUtils_packInt32(device);
END_HANDLE_TH_ERRORS
}
PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return THPUtils_packUInt64(c10::openreg::device_count());
END_HANDLE_TH_ERRORS
}
static PyMethodDef methods[] = {
{"_init", _initExtension, METH_NOARGS, nullptr},
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
{"_get_device", _getDevice, METH_NOARGS, nullptr},
{"_set_device", _setDevice, METH_O, nullptr},
{"_exchangeDevice", _exchangeDevice, METH_O, nullptr},
{"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
/*
* When ASAN is enabled, PyTorch modifies the dlopen flag during import,
* causing all global and weak symbols in _C.so and its dependent libraries
* to be exposed to the global symbol scope, which in turn causes
* subsequent symbols with the same name in other libraries to be intercepted.
* Therefore, it cannot be named initModule here, otherwise initModule
* in torch/csrc/Module.cpp will be called, resulting in failure.
*/
extern "C" PyObject* initOpenRegModule(void) {
static struct PyModuleDef openreg_C_module = {
PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods};
PyObject* mod = PyModule_Create(&openreg_C_module);
return mod;
}

View File

@ -0,0 +1,15 @@
#include <Python.h>
extern PyObject* initOpenRegModule(void);
#ifndef _WIN32
#ifdef __cplusplus
extern "C"
#endif
__attribute__((visibility("default"))) PyObject* PyInit__C(void);
#endif
PyMODINIT_FUNC PyInit__C(void)
{
return initOpenRegModule();
}

View File

@ -0,0 +1,72 @@
import torch
import torch_openreg._C # type: ignore[misc]
_initialized = False
class device:
r"""Context-manager that changes the selected device.
Args:
device (torch.device or int): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device):
self.idx = torch.accelerator._get_device_index(device, optional=True)
self.prev_idx = -1
def __enter__(self):
self.prev_idx = torch_openreg._C._exchangeDevice(self.idx)
def __exit__(self, type, value, traceback):
self.idx = torch_openreg._C._set_device(self.prev_idx)
return False
def is_available():
return True
def device_count() -> int:
return torch_openreg._C._get_device_count()
def current_device():
return torch_openreg._C._get_device()
def set_device(device) -> None:
return torch_openreg._C._set_device(device)
def is_initialized():
return _initialized
def _lazy_init():
global _initialized
if is_initialized():
return
torch_openreg._C._init()
_initialized = True
from .random import * # noqa: F403
__all__ = [
"device",
"device_count",
"current_device",
"set_device",
"initial_seed",
"is_available",
"is_initialized",
"random",
"manual_seed",
"manual_seed_all",
"get_rng_state",
"set_rng_state",
]

View File

@ -0,0 +1,60 @@
import torch
import torch_openreg._C # type: ignore[misc]
from . import _lazy_init, current_device, device_count
__all__ = [
"get_rng_state",
"set_rng_state",
"manual_seed",
"manual_seed_all",
"initial_seed",
]
def get_rng_state(device="openreg"):
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("openreg", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch_openreg._C._get_default_generator(idx)
return default_generator.get_state()
def set_rng_state(new_state, device="openreg"):
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("openreg", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch_openreg._C._get_default_generator(idx)
default_generator.set_state(new_state)
def initial_seed() -> int:
_lazy_init()
idx = current_device()
default_generator = torch_openreg._C._get_default_generator(idx)
return default_generator.initial_seed()
def manual_seed(seed: int) -> None:
seed = int(seed)
idx = current_device()
default_generator = torch_openreg._C._get_default_generator(idx)
default_generator.manual_seed(seed)
def manual_seed_all(seed: int) -> None:
seed = int(seed)
for idx in range(device_count()):
default_generator = torch_openreg._C._get_default_generator(idx)
default_generator.manual_seed(seed)

View File

@ -28,6 +28,7 @@ from torch.multiprocessing import current_process, get_context
from torch.testing._internal.common_utils import (
get_report_path,
IS_CI,
IS_LINUX,
IS_MACOS,
retry_shell,
set_cwd,
@ -913,8 +914,12 @@ def _test_autoload(test_directory, options, enable=True):
def run_test_with_openreg(test_module, test_directory, options):
# TODO(FFFrog): Will remove this later when windows/macos are supported.
if not IS_LINUX:
return 0
openreg_dir = os.path.join(
test_directory, "cpp_extensions", "open_registration_extension"
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
)
install_dir, return_code = install_cpp_extensions(openreg_dir)
if return_code != 0:

View File

@ -3,7 +3,7 @@
import os
import unittest
import pytorch_openreg # noqa: F401
import torch_openreg # noqa: F401
import torch
import torch.testing._internal.common_utils as common

View File

@ -10,7 +10,7 @@ from unittest.mock import patch
import numpy as np
import psutil
import pytorch_openreg # noqa: F401
import torch_openreg # noqa: F401
import torch
from torch.serialization import safe_globals
@ -285,7 +285,6 @@ class TestOpenReg(TestCase):
self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc]
# Autograd
@unittest.skipIf(not IS_LINUX, "Only works on linux")
def test_autograd_init(self):
# Make sure autograd is initialized
torch.ones(2, requires_grad=True, device="openreg").sum().backward()
@ -584,4 +583,5 @@ class TestOpenReg(TestCase):
if __name__ == "__main__":
run_tests()
if IS_LINUX:
run_tests()

View File

@ -4,7 +4,7 @@ import unittest
from collections import namedtuple
from functools import partial
import pytorch_openreg # noqa: F401
import torch_openreg # noqa: F401
import torch
from torch.nn.attention import SDPBackend