pytorch/torch/csrc/jit/python
Shihao Xu 7d01888a75 [JIT] Register rpc.rpc_async(..) as a JIT operator (#33329)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33329

# Use case

```
torch.jit.script
def send_rpc_async(dst_worker_name, user_callable_qual_name, tensor):
    # type: (str, str, Tensor) -> None
    rpc._rpc_async_torchscript(
        dst_worker_name, user_callable_qual_name, args=(tensor,)
    )
```

# Problem

```
torch.jit.frontend.NotSupportedError: keyword-arg expansion is not supported:
  File "/data/users/shihaoxu/fbsource/fbcode/buck-out/dev/gen/caffe2/test/distributed/rpc/rpc_spawn#binary,link-tree/torch/distributed/rpc/api.py", line 722
    args = args if args else ()
    kwargs = kwargs if kwargs else {}
    fut = _invoke_rpc_torchscript(to, qualified_name, *args, **kwargs)
                                                               ~~~~~~ <--- HERE
    return fut
```

# Solution

Register `rpc.rpc_async(..)` as a JIT operator to handle variable-length argument list.

# Plan

This PR is the required changes to make `rpc.rpc_async(..)` a JIT prim operator, which can dynamically handle different number of arguments.

- Register "prim::rpc_async" as a `Symbol` in "interned_string.h"
- Add a if branch in "python_sugared_value.cpp" `toSugarValue(py::object, ..)` entry utility function to set up how JIT frontend convert `torch.distributed.rpc.rpc_async(..)` Python function (Python object) into a `SpecialFormValue` (IR SugaredValue).
- Add a switch case for "prim::rpc_aynsc" Symbol in "ir_emitter.cpp" and `emitApplySpecialForm(..)` to set up how JIT compiler provides inputs to the "prim::rpc_aynsc" Operator.
- Register "prim::rpc_async" as a `jit::Operator` and provide implementation in "register_distributed_ops.cpp".

Notice, since the distributed module is an optional part when building PyTorch. The code to be added in this PR should be wrapped within preprocessing maco.
```
#ifdef USE_DISTRIBUTED
new code here
#endif
```

Test Plan:
Items that need to be confirmed in the test cases

https://fb.quip.com/DCvdA9ZLjeO0

```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork

buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork  \
\
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_call_python_function_remotely_from_script_not_supported
```

```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_spawn
```

```
buck test mode/dev-nosan //caffe2/caffe2/python/operator_test:layer_norm_op_test-2.7 -- test_layer_norm_op_jit
```

Differential Revision: D5738300

fbshipit-source-id: a4604fe762e00be062dc8232ca9790df31fb2074
2020-03-03 19:57:42 -08:00
..
init.cpp Freezing Torchscript modules (#32178) 2020-03-02 11:38:36 -08:00
init.h
module_python.h
pybind_utils.h [resubmit] try to infer rref type from python (#33992) 2020-02-29 20:26:40 -08:00
pybind.h
python_arg_flatten.cpp
python_arg_flatten.h
python_custom_class.cpp
python_custom_class.h
python_interpreter.cpp
python_ir.cpp
python_ir.h
python_ivalue.h
python_sugared_value.cpp [JIT] Register rpc.rpc_async(..) as a JIT operator (#33329) 2020-03-03 19:57:42 -08:00
python_sugared_value.h [JIT] Add modulelist indexing for integer literal (#29236) 2020-03-03 14:47:31 -08:00
python_tracer.cpp
python_tracer.h
python_tree_views.cpp
python_tree_views.h
script_init.cpp [jit] Resolve type annotation names to types (#29623) 2020-02-28 18:35:10 -08:00
script_init.h
update_graph_executor_opt.cpp
update_graph_executor_opt.h