mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37139 See design doc in https://github.com/pytorch/pytorch/issues/37136 ghstack-source-id: 105926270 Test Plan: TODO: - Make the generated Interface usable. https://github.com/pytorch/pytorch/pull/37139#discussion_r434190978 - - Avoid generating the same template instances for Module that is not scriptable. - Remove "infer_module_interface_cls". - Use Python format instead of a CodeTemplate - Use Python tempfile to track and delete file. Does it work if there is crash. ``` buck test mode/dev-nosan //caffe2/test/distributed/nn/jit:test_instantiator buck build mode/dev-nosan //caffe2/test/distributed/nn/jit:test_instantiator && \ buck-out/gen/caffe2/test/distributed/nn/jit/test_instantiator\#binary.par -r test_instantiate_scripted_remote_module_template buck build mode/dev-nosan //caffe2/test/distributed/nn/jit:test_instantiator && \ buck-out/gen/caffe2/test/distributed/nn/jit/test_instantiator\#binary.par -r test_instantiate_non_scripted_remote_module_template ``` ``` buck test mode/dev-nosan //caffe2/test/distributed/nn/api:remote_module_spawn ``` ``` buck test mode/dev-nosan //caffe2/test/distributed/nn/api:remote_module_fork buck build mode/dev-nosan //caffe2/test/distributed/nn/api:remote_module_fork && \ buck-out/gen/caffe2/test/distributed/nn/api/remote_module_fork\#binary.par -r test_user_provided_global_unique_name buck build mode/dev-nosan //caffe2/test/distributed/nn/api:remote_module_fork && \ buck-out/gen/caffe2/test/distributed/nn/api/remote_module_fork\#binary.par -r test_forward_async_script buck build mode/dev-nosan //caffe2/test/distributed/nn/api:remote_module_fork && \ buck-out/gen/caffe2/test/distributed/nn/api/remote_module_fork\#binary.par -r test_forward_sync_script buck build mode/dev-nosan //caffe2/test/distributed/nn/api:remote_module_fork && \ buck-out/gen/caffe2/test/distributed/nn/api/remote_module_fork\#binary.par -r test_forward_with_kwargs buck build mode/dev-nosan //caffe2/test/distributed/nn/api:remote_module_fork && \ buck-out/gen/caffe2/test/distributed/nn/api/remote_module_fork\#binary.par -r test_user_provided_global_unique_name ``` ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc:rpc_fork ``` buck test mode/opt-asan //caffe2/test:jit -- 'test_script_forward_method_replacement buck build mode/dev-nosan //caffe2/test:jit && \ buck-out/gen/caffe2/test/jit\#binary.par -r 'test_script_forward_method_replacement' buck build mode/dev-nosan //caffe2/test:jit && \ buck-out/gen/caffe2/test/jit\#binary.par -r 'test_imported_classes' Differential Revision: D20499658 fbshipit-source-id: dd9383ae4eb2343366c11127664f845b91ca3b0a
216 lines
8.4 KiB
Python
216 lines
8.4 KiB
Python
#!/usr/bin/python3
|
|
import types
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import torch
|
|
import torch.distributed.rpc as rpc
|
|
from torch import nn
|
|
from torch.distributed.nn.jit import instantiator
|
|
|
|
|
|
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
|
|
instantiator.instantiate_non_scriptable_remote_module_template()
|
|
)
|
|
|
|
|
|
# RPC handler.
|
|
def _instantiate_template(module_interface_cls):
|
|
instantiator.instantiate_scriptable_remote_module_template(module_interface_cls)
|
|
|
|
|
|
def _create_module(module_cls, args, kwargs, module_interface_cls=None):
|
|
module = module_cls(*args, **kwargs)
|
|
if not isinstance(module, nn.Module):
|
|
raise ValueError(
|
|
"Expect `module_cls(*args, **kwargs)` returns an instancee of <class nn.Module>, "
|
|
f"but it returns an instance of {type(module)}."
|
|
)
|
|
if module_interface_cls is not None:
|
|
module = torch.jit.script(module)
|
|
return rpc.RRef(module, module_interface_cls)
|
|
|
|
|
|
class _RemoteModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
on: str,
|
|
module_cls: nn.Module,
|
|
args: Tuple = None,
|
|
kwargs: Dict[str, Any] = None,
|
|
_module_interface_cls: Any = None,
|
|
):
|
|
"""
|
|
A RemoteModule instance can only be created after RPC initialization.
|
|
It creates a user-specified module on a specified remote node.
|
|
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
|
|
executed on the remote node.
|
|
It takes care of autograd recording to ensure the backward pass propogates
|
|
gradients back to the corresponding remote module.
|
|
|
|
The arguments of ``forward_async`` and ``forward`` are the same as
|
|
the ``forward`` method of the module returned by the ``module_cls``.
|
|
|
|
For example, if ``module_cls`` returns an instace of ``nn.Linear``,
|
|
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
|
|
the generated ``RemoteModule`` will have 2 methods in signature of
|
|
``def forward(input: Tensor) -> Tensor:`` and
|
|
``def forward_async(input: Tensor) -> Future[Tensor]:``.
|
|
|
|
Arguments:
|
|
on (str or WorkerInfo): id or name of the destination worker.
|
|
module_cls (nn.Module): For example,
|
|
>>> class MyModule(nn.Module):
|
|
>>> def forward(input):
|
|
>>> return input + 1
|
|
>>>
|
|
>>> module_cls = MyModule
|
|
args (Sequence, optional): args to be passed to ``module_cls``.
|
|
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
|
|
_module_interface_cls (type, optional): The TorchScript interface type for the module
|
|
to be created. The type object should be decorated by @torch.jit.interface.
|
|
If not provided, the generated RemoteModule is not torchscript-able.
|
|
Warning, this is an experimental API and susceptible to frequent changes.
|
|
|
|
Returns:
|
|
A remote module instance which wraps the :class:`~nn.Module` created by the
|
|
user-provided ``module_cls``, it has a blocking ``forward`` method and an
|
|
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
|
|
on the user-provided module on the remote side.
|
|
|
|
Example::
|
|
Run the following code in two different processes:
|
|
|
|
>>> # On worker 0:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> from torch import nn, Tensor
|
|
>>> from torch.distributed.nn.api.remote_module import RemoteModule
|
|
>>>
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
|
>>> remote_linear_module = RemoteModule(
|
|
>>> "worker1", nn.Linear, args=(20, 30),
|
|
>>> )
|
|
>>> input = torch.randn(128, 20)
|
|
>>> ret_fut = remote_linear_module.forward_async(input)
|
|
>>> ret = ret_fut.wait()
|
|
>>> rpc.shutdown()
|
|
|
|
>>> # On worker 1:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>>
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
|
>>> rpc.shutdown()
|
|
"""
|
|
super().__init__()
|
|
|
|
# Sanity checks.
|
|
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
|
|
|
|
# Default arguments preperation.
|
|
args = args if args is not None else ()
|
|
kwargs = kwargs if kwargs is not None else {}
|
|
|
|
if _module_interface_cls is not None:
|
|
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
|
|
self.is_scriptable = True
|
|
|
|
# Instantiate template on remote side.
|
|
fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,))
|
|
|
|
# Instantiate template on local side.
|
|
generated_module = instantiator.instantiate_scriptable_remote_module_template(
|
|
_module_interface_cls
|
|
)
|
|
generated_methods = generated_module._generated_methods
|
|
|
|
# Create the module on the remote side.
|
|
fut.wait() # Ensure remote_module_cls is available on remote side.
|
|
else:
|
|
self.is_scriptable = False
|
|
generated_methods = _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
|
|
|
|
# Create the module on the remote side.
|
|
self.module_rref = rpc.rpc_sync(
|
|
on,
|
|
_create_module,
|
|
(module_cls, args, kwargs, _module_interface_cls),
|
|
)
|
|
|
|
# Install generated methods.
|
|
for method in generated_methods:
|
|
method_name = method.__name__
|
|
method = torch.jit.export(method)
|
|
setattr(self, method_name, types.MethodType(method, self))
|
|
|
|
|
|
class RemoteModule(_RemoteModule):
|
|
"""
|
|
A RemoteModule instance can only be created after RPC initialization.
|
|
It creates a user-specified module on a specified remote node.
|
|
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
|
|
executed on the remote node.
|
|
It takes care of autograd recording to ensure the backward pass propogates
|
|
gradients back to the corresponding remote module.
|
|
|
|
The arguments of ``forward_async`` and ``forward`` are the same as
|
|
the ``forward`` method of the module returned by the ``module_cls``.
|
|
|
|
For example, if ``module_cls`` returns an instace of ``nn.Linear``,
|
|
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
|
|
the generated ``RemoteModule`` will have 2 methods in signature of
|
|
``def forward(input: Tensor) -> Tensor:`` and
|
|
``def forward_async(input: Tensor) -> Future[Tensor]:``.
|
|
|
|
Arguments:
|
|
to (str or WorkerInfo): id or name of the destination worker.
|
|
module_cls (nn.Module): For example,
|
|
>>> class MyModule(nn.Module):
|
|
>>> def forward(input):
|
|
>>> return input + 1
|
|
>>>
|
|
>>> module_cls = MyModule
|
|
args (Sequence, optional): args to be passed to ``module_cls``.
|
|
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
|
|
|
|
Returns:
|
|
A remote module instance which wraps the :class:`~nn.Module` created by the
|
|
user-provided ``module_cls``, it has a blocking ``forward`` method and an
|
|
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
|
|
on the user-provided module on the remote side.
|
|
|
|
Example::
|
|
Run the following code in two different processes:
|
|
|
|
>>> # On worker 0:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> from torch import nn, Tensor
|
|
>>> from torch.distributed.nn.api.remote_module import RemoteModule
|
|
>>>
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
|
>>> remote_linear_module = RemoteModule(
|
|
>>> "worker1", nn.Linear, args=(20, 30),
|
|
>>> )
|
|
>>> input = torch.randn(128, 20)
|
|
>>> ret_fut = remote_linear_module.forward_async(input)
|
|
>>> ret = ret_fut.wait()
|
|
>>> rpc.shutdown()
|
|
|
|
>>> # On worker 1:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>>
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
|
>>> rpc.shutdown()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
on: str,
|
|
module_cls: nn.Module,
|
|
args: Tuple = None,
|
|
kwargs: Dict[str, Any] = None,
|
|
):
|
|
super().__init__(on, module_cls, args, kwargs)
|