pytorch/torch/csrc/jit/python
BowenBao 346dc88bfa [ONNX] Support registering custom export for prim::PythonOp from torch.autograd.Function (#55630) (#57600)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57600

Demo script:

```python
import torch

class MyReLU(torch.autograd.Function):
    staticmethod
    def forward(ctx, input, scalar_tuple, scalar, scalar_list):
        ctx.save_for_backward(input)
        return input.clamp(min=scalar)
    staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_a = torch.nn.Linear(2, 2)
        self.linear_b = torch.nn.Linear(2, 2)
        self.relu = MyReLU.apply
    def forward(self, x):
        h = self.linear_a(x)
        h = self.relu(h, (5, 3), 2, [1, 2, 3])
        h = self.linear_b(h)
        return h

"""
User define how to export prim::PythonOp into custom op.
"""
def symbolic_pythonop(g, n, *args, **kwargs):
    # Print information:
    print('arguments of ', kwargs['name'], ':')
    print('original node: ', n)
    for i, out in enumerate(n.outputs()):
        print('original output {}: {}, requires grad: {}'.format(i, out, out.requiresGrad()))
    import torch.onnx.symbolic_helper as sym_helper
    for i, arg in enumerate(args):
        print('arg {}: {}, requires grad: {}'.format(i, arg, arg.requiresGrad() if sym_helper._is_value(arg) else False))
    for k, v in kwargs.items():
        print('key: ', k, ' v: ', v)

    # TODO: all inputs (tensors and scalars) are in args.
    #       backend can define CustomDomain::PythonOp and how info are stored however it deem fit.
    return g.op("CustomDomain::PythonOp", args[0], name_s=kwargs['name'])

torch.onnx.register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 9)

# Define input.
x = torch.tensor([[0.3971, 0.7544],
                  [0.5695, 0.4388]], requires_grad=True)

model = MyModule()
# Forward.
y = model(x)

torch.onnx.export(model, (x,), 'model.onnx', opset_version=12, verbose=True)
```

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D28393528

Pulled By: SplitInfinity

fbshipit-source-id: e0d55b7c737c5916fda08a3b26b3306037f970df

Co-authored-by: BowenBao <bowbao@microsoft.com>
2021-05-13 13:42:49 -07:00
..
init.cpp Add pybind type caster for c10::Device (#57292) 2021-05-01 16:11:10 -07:00
init.h
module_python.h
pybind_utils.cpp Revert D27448156: irange for size_t 2021-04-03 19:14:00 -07:00
pybind_utils.h Pass reference to parent future in callbacks (#57635) 2021-05-07 03:59:18 -07:00
pybind.h Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_arg_flatten.cpp Replace all direct cdata access with THPVariable_Unpack (#55799) 2021-04-15 08:57:04 -07:00
python_arg_flatten.h
python_custom_class.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_custom_class.h
python_interpreter.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_ir.cpp [ONNX] Support registering custom export for prim::PythonOp from torch.autograd.Function (#55630) (#57600) 2021-05-13 13:42:49 -07:00
python_ir.h Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_ivalue.h Make DataPtr extraction in CUDAFuture faster for Python values (#56918) 2021-05-06 01:12:53 -07:00
python_sugared_value.cpp Add cuda device synchronization support in JIT (#55469) 2021-04-14 09:13:07 -07:00
python_sugared_value.h Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_tracer.cpp [Usability] Capture argument names for traced functions and modules (#51775) 2021-02-10 18:28:08 -08:00
python_tracer.h [Usability] Capture argument names for traced functions and modules (#51775) 2021-02-10 18:28:08 -08:00
python_tree_views.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
python_tree_views.h
script_init.cpp [ONNX] Handle PackedParams inputs for _propagate_and_assign_input_shapes (#56449) (#57079) 2021-05-12 15:20:26 -07:00
script_init.h
update_graph_executor_opt.cpp Make PyTorch code-base clang-tidy compliant (#56892) 2021-04-28 14:10:25 -07:00
update_graph_executor_opt.h