mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33329 # Use case ``` torch.jit.script def send_rpc_async(dst_worker_name, user_callable_qual_name, tensor): # type: (str, str, Tensor) -> None rpc._rpc_async_torchscript( dst_worker_name, user_callable_qual_name, args=(tensor,) ) ``` # Problem ``` torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported: File "/data/users/shihaoxu/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/rpc/rpc_spawn#binary,link-tree/torch/distributed/rpc/api.py", line 722 args = args if args else () kwargs = kwargs if kwargs else {} fut = _invoke_rpc_torchscript(to, qualified_name, *args, **kwargs) ~~~~~~ <--- HERE return fut ``` # Solution Register `rpc.rpc_async(..)` as a JIT operator to handle variable-length argument list. # Plan This PR is the required changes to make `rpc.rpc_async(..)` a JIT prim operator, which can dynamically handle different number of arguments. - Register "prim::rpc_async" as a `Symbol` in "interned_string.h" - Add a if branch in "python_sugared_value.cpp" `toSugarValue(py::object, ..)` entry utility function to set up how JIT frontend convert `torch.distributed.rpc.rpc_async(..)` Python function (Python object) into a `SpecialFormValue` (IR SugaredValue). - Add a switch case for "prim::rpc_aynsc" Symbol in "ir_emitter.cpp" and `emitApplySpecialForm(..)` to set up how JIT compiler provides inputs to the "prim::rpc_aynsc" Operator. - Register "prim::rpc_async" as a `jit::Operator` and provide implementation in "register_distributed_ops.cpp". Notice, since the distributed module is an optional part when building PyTorch. The code to be added in this PR should be wrapped within preprocessing maco. ``` #ifdef USE_DISTRIBUTED new code here #endif ``` Test Plan: Items that need to be confirmed in the test cases https://fb.quip.com/DCvdA9ZLjeO0 ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \ \ && buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_call_python_function_remotely_from_script_not_supported ``` ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_spawn ``` ``` buck test mode/dev-nosan //caffe2/caffe2/python/operator_test:layer_norm_op_test-2.7 -- test_layer_norm_op_jit ``` Differential Revision: D5738300 fbshipit-source-id: a4604fe762e00be062dc8232ca9790df31fb2074 |
||
|---|---|---|
| .. | ||
| init.cpp | ||
| init.h | ||
| module_python.h | ||
| pybind_utils.h | ||
| pybind.h | ||
| python_arg_flatten.cpp | ||
| python_arg_flatten.h | ||
| python_custom_class.cpp | ||
| python_custom_class.h | ||
| python_interpreter.cpp | ||
| python_ir.cpp | ||
| python_ir.h | ||
| python_ivalue.h | ||
| python_sugared_value.cpp | ||
| python_sugared_value.h | ||
| python_tracer.cpp | ||
| python_tracer.h | ||
| python_tree_views.cpp | ||
| python_tree_views.h | ||
| script_init.cpp | ||
| script_init.h | ||
| update_graph_executor_opt.cpp | ||
| update_graph_executor_opt.h | ||