mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
We want users to be able to define custom ops in C++ but put the abstract impl in Python (since it is easier to write them in Python and the abstract impl better models device semantics and data-dependent operators). `m.impl_abstract_pystub(opname, python_module, context)` declares the abstract_impl of the operator to exist in the given python module. When the abstract_impl needs to be accessed (either via FakeTensor or Meta), and it does not exist, the PyTorch Dispatcher will yell with a descriptive error message. Some details: - We construct a new global AbstractImplPyStub mapping in Dispatcher.cpp. Read/write to this map is protected by the Dispatcher lock. - We add a new Meta Tensor fallback kernel. The fallback errors out if there is no meta kernel, but also offers a nicer error message if we see that there is a pystub. - We create a `torch._utils_internal.throw_abstract_impl_not_imported_error` helper function to throw errors. This way, we can throw different error messages in OSS PyTorch vs internal PyTorch. To invoke this from C++, we added a PyInterpreter::throw_abstract_impl_not_imported_error. Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753/) Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109529 Approved by: https://github.com/ezyang, https://github.com/bdhirsh
15 lines
369 B
Python
15 lines
369 B
Python
import torch
|
|
import torch._custom_ops as library
|
|
from model import get_custom_op_library_path
|
|
|
|
torch.ops.load_library(get_custom_op_library_path())
|
|
|
|
|
|
@library.impl_abstract("custom::nonzero")
|
|
def nonzero_abstract(x):
|
|
n = x.dim()
|
|
ctx = library.get_ctx()
|
|
nnz = ctx.create_unbacked_symint()
|
|
shape = [nnz, n]
|
|
return x.new_empty(shape, dtype=torch.long)
|