mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Currently, custom ops are registered for a specific opset version. For example, all torchvision custom ops are registered for opset 11, and cannot be exported into higher opset versions. This PR extends op registration to higher opset versions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32943 Reviewed By: hl475 Differential Revision: D19739406 Pulled By: houseroad fbshipit-source-id: dd8b616de3a69a529d135fdd02608a17a8e421bc
56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
import unittest
|
|
import torch
|
|
import torch.utils.cpp_extension
|
|
|
|
import onnx
|
|
import caffe2.python.onnx.backend as c2
|
|
|
|
import numpy as np
|
|
|
|
from test_pytorch_onnx_caffe2 import do_export
|
|
|
|
class TestCustomOps(unittest.TestCase):
|
|
|
|
def test_custom_add(self):
|
|
op_source = """
|
|
#include <torch/script.h>
|
|
|
|
torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) {
|
|
return self + other;
|
|
}
|
|
|
|
static auto registry =
|
|
torch::RegisterOperators("custom_namespace::custom_add", &custom_add);
|
|
"""
|
|
|
|
torch.utils.cpp_extension.load_inline(
|
|
name="custom_add",
|
|
cpp_sources=op_source,
|
|
is_python_module=False,
|
|
verbose=True,
|
|
)
|
|
|
|
class CustomAddModel(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return torch.ops.custom_namespace.custom_add(a, b)
|
|
|
|
def symbolic_custom_add(g, self, other):
|
|
return g.op('Add', self, other)
|
|
|
|
from torch.onnx import register_custom_op_symbolic
|
|
register_custom_op_symbolic('custom_namespace::custom_add', symbolic_custom_add, 9)
|
|
|
|
x = torch.randn(2, 3, 4, requires_grad=False)
|
|
y = torch.randn(2, 3, 4, requires_grad=False)
|
|
|
|
model = CustomAddModel()
|
|
onnxir, _ = do_export(model, (x, y), opset_version=11)
|
|
onnx_model = onnx.ModelProto.FromString(onnxir)
|
|
prepared = c2.prepare(onnx_model)
|
|
caffe2_out = prepared.run(inputs=[x.cpu().numpy(), y.cpu().numpy()])
|
|
np.testing.assert_array_equal(caffe2_out[0], model(x, y).cpu().numpy())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|