mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57600 Demo script: ```python import torch class MyReLU(torch.autograd.Function): staticmethod def forward(ctx, input, scalar_tuple, scalar, scalar_list): ctx.save_for_backward(input) return input.clamp(min=scalar) staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input < 0] = 0 return grad_input class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear_a = torch.nn.Linear(2, 2) self.linear_b = torch.nn.Linear(2, 2) self.relu = MyReLU.apply def forward(self, x): h = self.linear_a(x) h = self.relu(h, (5, 3), 2, [1, 2, 3]) h = self.linear_b(h) return h """ User define how to export prim::PythonOp into custom op. """ def symbolic_pythonop(g, n, *args, **kwargs): # Print information: print('arguments of ', kwargs['name'], ':') print('original node: ', n) for i, out in enumerate(n.outputs()): print('original output {}: {}, requires grad: {}'.format(i, out, out.requiresGrad())) import torch.onnx.symbolic_helper as sym_helper for i, arg in enumerate(args): print('arg {}: {}, requires grad: {}'.format(i, arg, arg.requiresGrad() if sym_helper._is_value(arg) else False)) for k, v in kwargs.items(): print('key: ', k, ' v: ', v) # TODO: all inputs (tensors and scalars) are in args. # backend can define CustomDomain::PythonOp and how info are stored however it deem fit. return g.op("CustomDomain::PythonOp", args[0], name_s=kwargs['name']) torch.onnx.register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 9) # Define input. x = torch.tensor([[0.3971, 0.7544], [0.5695, 0.4388]], requires_grad=True) model = MyModule() # Forward. y = model(x) torch.onnx.export(model, (x,), 'model.onnx', opset_version=12, verbose=True) ``` Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D28393528 Pulled By: SplitInfinity fbshipit-source-id: e0d55b7c737c5916fda08a3b26b3306037f970df Co-authored-by: BowenBao <bowbao@microsoft.com> |
||
|---|---|---|
| .. | ||
| init.cpp | ||
| init.h | ||
| module_python.h | ||
| pybind_utils.cpp | ||
| pybind_utils.h | ||
| pybind.h | ||
| python_arg_flatten.cpp | ||
| python_arg_flatten.h | ||
| python_custom_class.cpp | ||
| python_custom_class.h | ||
| python_interpreter.cpp | ||
| python_ir.cpp | ||
| python_ir.h | ||
| python_ivalue.h | ||
| python_sugared_value.cpp | ||
| python_sugared_value.h | ||
| python_tracer.cpp | ||
| python_tracer.h | ||
| python_tree_views.cpp | ||
| python_tree_views.h | ||
| script_init.cpp | ||
| script_init.h | ||
| update_graph_executor_opt.cpp | ||
| update_graph_executor_opt.h | ||