pytorch/torch/_ops.py
Peter Goldsborough c101a57a74 Build mechanism for custom operators (#10226)
Summary:
This is the last step in the custom operator implementation: providing a way to build from C++ and Python. For this I:

1. Created a `FindTorch.cmake` taken largely from ebetica with a CMake function to easily create simple custom op libraries
2. Created a ` torch/op.h` header for easy inclusion of necessary headers,
3. Created a test directory `pytorch/test/custom_operator` which includes the basic setup for a custom op.
    1. It defines an op in `op.{h,cpp}`
    2. Registers it with the JIT using `RegisterOperators`
    3. Builds it into a shared library via a `CMakeLists.txt`
    4. Binds it into Python using a `setup.py`. This step makes use of our C++ extension setup that we already have. No work, yey!

The pure C++ and the Python builds are separate and not coupled in any way.

zdevito soumith dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10226

Differential Revision: D9296839

Pulled By: goldsborough

fbshipit-source-id: 32f74cafb6e3d86cada8dfca8136d0dfb1f197a0
2018-08-16 18:56:17 -07:00

95 lines
3.4 KiB
Python

import torch._C
import contextlib
import ctypes
import sys
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
@contextlib.contextmanager
def dl_open_guard():
"""
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
shared library to load custom operators.
"""
if _SET_GLOBAL_FLAGS:
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
yield
if _SET_GLOBAL_FLAGS:
sys.setdlopenflags(old_flags)
class _OpNamespace(object):
"""
An op namespace to dynamically bind Operators into Python.
Say a user has created a custom Operator called "my_namespace::my_op". To
call this op, the user will write torch.ops.my_namespace.my_op(...).
At startup, this operation will not yet be bound into Python. Instead, the
following sequence of magic tricks will occur:
1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
on the `torch.ops` object, which will create a new `_OpNamespace`
object called `my_namespace` and set it as an attribute on the `ops`
object.
2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
the `my_namespace` object, which will retrieve the operation via
`torch.get_operation`, a function bound from C++, and then in a similar
fashion bind this new object onto the `my_namespace` object.
3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
and subsequent accesses will incur no further lookup (the namespace and
operation will already exist).
"""
def __init__(self, name):
self.name = name
def __getattr__(self, op_name):
# Get the op `my_namespace::my_op` if available. This will also check
# for overloads and raise an exception if there are more than one.
op = torch._C._jit_get_operation('{}::{}'.format(self.name, op_name))
setattr(self, op_name, op)
return op
class _Ops(object):
def __init__(self):
self.loaded_libraries = set()
def __getattr__(self, name):
# Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name)
setattr(self, name, namespace)
return namespace
def load_library(self, path):
"""
Loads a shared library from the given path into the current process.
The library being loaded may run global initialization code to register
custom operators with the PyTorch JIT runtime. This allows dynamically
loading custom operators. For this, you should compile your operator
and the static registration code into a shared library object, and then
call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
shared object.
After the library is loaded, it is added to the
``torch.ops.loaded_libraries`` attribute, a set that may be inspected
for the paths of all libraries loaded using this function.
Arguments:
path (str): A path to a shared library to load.
"""
with dl_open_guard():
# Import the shared library into the process, thus running its
# static (global) initialization code in order to register custom
# operators with the JIT.
ctypes.CDLL(path)
self.loaded_libraries.add(path)
# The ops "namespace"
ops = _Ops()