pytorch/test/custom_operator/op.cpp
rzou 8124a6c40c [TORCH_LIBRARY] Add impl_abstract_pystub (#109529)
We want users to be able to define custom ops in C++ but put the
abstract impl in Python (since it is easier to write them in Python and
the abstract impl better models device semantics and data-dependent
operators).

`m.impl_abstract_pystub(opname, python_module, context)` declares the
abstract_impl of the operator to exist in the given python module.
When the abstract_impl needs to be accessed (either via FakeTensor or
Meta), and it does not exist, the PyTorch Dispatcher will yell
with a descriptive error message.

Some details:
- We construct a new global AbstractImplPyStub mapping in
  Dispatcher.cpp. Read/write to this map is protected by the Dispatcher
  lock.
- We add a new Meta Tensor fallback kernel. The fallback errors out if there is
  no meta kernel, but also offers a nicer error message if we see that there is
  a pystub.
- We create a `torch._utils_internal.throw_abstract_impl_not_imported_error`
  helper function to throw errors. This way, we can throw different error
  messages in OSS PyTorch vs internal PyTorch. To invoke this from C++, we
  added a PyInterpreter::throw_abstract_impl_not_imported_error.

Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753/)

Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109529
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
2023-09-22 04:55:36 +00:00

90 lines
2.6 KiB
C++

#include <c10/util/irange.h>
#include <torch/script.h>
#include "op.h"
#include <cstddef>
#include <string>
torch::List<torch::Tensor> custom_op(
torch::Tensor tensor,
double scalar,
int64_t repeat) {
torch::List<torch::Tensor> output;
output.reserve(repeat);
for (const auto i : c10::irange(repeat)) {
(void)i; // Suppress unused variable warning
output.push_back(tensor * scalar);
}
return output;
}
int64_t custom_op2(std::string s1, std::string s2) {
return s1.compare(s2);
}
struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor var1,
int64_t mul,
torch::Tensor var2,
c10::optional<torch::Tensor> var3) {
ctx->saved_data["mul"] = mul;
ctx->saved_data["var3_has_value"] = var3.has_value();
ctx->save_for_backward({var1, var2});
if (var3) {
return var1 + mul * var2 + var1 * var2 + var3.value();
}
return var1 + mul*var2 + var1*var2;
}
static torch::autograd::variable_list backward(torch::autograd::AutogradContext *ctx, torch::autograd::variable_list grad_output) {
int mul = ctx->saved_data["mul"].toInt();
bool var3_has_value = ctx->saved_data["var3_has_value"].toBool();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
auto var3_grad = var3_has_value ? grad_output[0] : torch::Tensor();
torch::autograd::variable_list output = {
grad_output[0] + grad_output[0] * var2,
torch::Tensor(),
grad_output[0] * mul + grad_output[0] * var1,
var3_grad};
return output;
}
};
torch::Tensor custom_op_with_autograd(
torch::Tensor var1,
int64_t mul,
torch::Tensor var2,
c10::optional<torch::Tensor> var3) {
return CustomOpAutogradFunction::apply(var1, mul, var2, var3);
}
torch::Tensor custom_nonzero(torch::Tensor x) {
return x.nonzero();
}
torch::Tensor custom_sin(torch::Tensor x) {
return x.sin();
}
TORCH_LIBRARY_FRAGMENT(custom, m) {
m.def("op", custom_op);
m.def("op2", custom_op2);
m.def("op_with_defaults(Tensor tensor, float scalar = 1, int repeat = 1) -> Tensor[]", custom_op);
m.def("op_with_autograd(Tensor var1, int mul, Tensor var2, Tensor? var3=None) -> Tensor", custom_op_with_autograd);
m.def("sin(Tensor x) -> Tensor");
m.impl_abstract_pystub("sin", "my_custom_ops2");
m.def("nonzero(Tensor x) -> Tensor");
m.impl_abstract_pystub("nonzero", "my_custom_ops");
}
TORCH_LIBRARY_IMPL(custom, CPU, m) {
m.impl("nonzero", &custom_nonzero);
m.impl("sin", &custom_sin);
}