pytorch/torch/distributed/__init__.py
Yanli Zhao ab39a55331 python udf over rpc (#23569)
Summary:
This diff is to support python user defined function over rpc for https://github.com/pytorch/pytorch/issues/23110, work flow is like this:
1. pickle python udf
2. pass pickle to C++
3. C++ pass over rpc from client to server
4. server call runPythonUDF() python function to unpickle and run python udf and pickle the udf result using python embedder
6. pass back serialized result from server to client
7. client call loadPythonUDFResult() python function to unpickle result
7. return it to python

right now, put rpc_sync_builtin() and rpc_async_builtin() as temporary interfaces for builtin operator remote calls, they accept qualified name string, this interface can execute builtin operators in C++ land.

rpc_sync() and rpc_async() accept python callables only right now, it could be user define python functions or builtin operator python functions, the python functions will be executed in python land.

once we can resolve builtin operator python callables to qualified name string, we can merge rpc_sync_builtin() into rpc_sync() then
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23569

Test Plan: unit tests

Differential Revision: D16390764

Pulled By: zhaojuanmao

fbshipit-source-id: 2cf2c22a979646830b5581bd75eabf8b3cca564c
2019-08-14 23:13:33 -07:00

23 lines
721 B
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import sys
def is_available():
return hasattr(torch._C, "_c10d_init") and hasattr(torch._C, "_rpc_init")
if is_available() and not (torch._C._c10d_init() and torch._C._rpc_init()):
raise RuntimeError("Failed to initialize PyTorch distributed support")
if is_available():
from .distributed_c10d import * # noqa: F401
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import _backend # noqa: F401
if sys.version_info >= (3, 0):
from .rpc import * # noqa: F401