pytorch/torch/distributed/nn/api/remote_module.py
Yi Wang 8b17fd2516 Add remote_parameters() into RemoteModule class. (#43906)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43906

This method returns a list of RRefs of remote parameters that can be fed into the DistributedOptimizer.

Original PR issue: RemoteModule enhancements #40550

Test Plan: buck test caffe2/test/distributed/rpc:process_group_agent -- RemoteModule

Reviewed By: rohan-varma

Differential Revision: D23399586

fbshipit-source-id: 4b0f1ccf2e47c8a9e4f79cb2c8668f3cdbdff820
2020-09-04 16:22:40 -07:00

237 lines
9.3 KiB
Python

#!/usr/bin/python3
import types
from typing import Any, Dict, List, Tuple
import torch
import torch.distributed.rpc as rpc
from torch import nn
from torch.distributed.nn.jit import instantiator
from torch.nn.parameter import Parameter
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
instantiator.instantiate_non_scriptable_remote_module_template()
)
# RPC handler.
def _instantiate_template(module_interface_cls):
instantiator.instantiate_scriptable_remote_module_template(module_interface_cls)
def _create_module(module_cls, args, kwargs, module_interface_cls=None):
module = module_cls(*args, **kwargs)
if not isinstance(module, nn.Module):
raise ValueError(
"Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, "
f"but it returns an instance of {type(module)}."
)
if module_interface_cls is not None:
module = torch.jit.script(module)
return rpc.RRef(module, module_interface_cls)
def _param_rrefs(module_rref, recurse):
ret = []
for param in module_rref.local_value().parameters(recurse):
ret.append(rpc.RRef(param))
return ret
class _RemoteModule(nn.Module):
def __init__(
self,
on: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
_module_interface_cls: Any = None,
):
"""
A RemoteModule instance can only be created after RPC initialization.
It creates a user-specified module on a specified remote node.
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
executed on the remote node.
It takes care of autograd recording to ensure the backward pass propogates
gradients back to the corresponding remote module.
The arguments of ``forward_async`` and ``forward`` are the same as
the ``forward`` method of the module returned by the ``module_cls``.
For example, if ``module_cls`` returns an instance of ``nn.Linear``,
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
the generated ``RemoteModule`` will have 2 methods in signature of
``def forward(input: Tensor) -> Tensor:`` and
``def forward_async(input: Tensor) -> Future[Tensor]:``.
Arguments:
on (str or WorkerInfo): id or name of the destination worker.
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
>>> return input + 1
>>>
>>> module_cls = MyModule
args (Sequence, optional): args to be passed to ``module_cls``.
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
_module_interface_cls (type, optional): The TorchScript interface type for the module
to be created. The type object should be decorated by @torch.jit.interface.
If not provided, the generated RemoteModule is not torchscript-able.
Warning, this is an experimental API and susceptible to frequent changes.
Returns:
A remote module instance which wraps the :class:`~nn.Module` created by the
user-provided ``module_cls``, it has a blocking ``forward`` method and an
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
on the user-provided module on the remote side.
Example::
Run the following code in two different processes:
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> from torch import nn, Tensor
>>> from torch.distributed.nn.api.remote_module import RemoteModule
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
>>> ret = ret_fut.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>>
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
super().__init__()
# Sanity checks.
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
# Default arguments preperation.
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}
self.on = on
if _module_interface_cls is not None:
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
self.is_scriptable = True
# Instantiate template on remote side.
fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,))
# Instantiate template on local side.
generated_module = instantiator.instantiate_scriptable_remote_module_template(
_module_interface_cls
)
generated_methods = generated_module._generated_methods
# Create the module on the remote side.
fut.wait() # Ensure remote_module_cls is available on remote side.
else:
self.is_scriptable = False
generated_methods = _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
# Create the module on the remote side.
self.module_rref = rpc.rpc_sync(
on, _create_module, (module_cls, args, kwargs, _module_interface_cls)
)
# Install generated methods.
for method in generated_methods:
method_name = method.__name__
method = torch.jit.export(method)
setattr(self, method_name, types.MethodType(method, self))
def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
r"""Returns a list of RRefs of remote module parameters.
This is typically passed to a distributed optimizer.
Args:
recurse (bool): if True, then returns parameters of the remote module
and all submodules of the remote module.
Otherwise, returns only parameters that are direct members of the remote module.
Returns:
A list of RRefs to remote module parameters.
"""
return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))
class RemoteModule(_RemoteModule):
"""
A RemoteModule instance can only be created after RPC initialization.
It creates a user-specified module on a specified remote node.
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
executed on the remote node.
It takes care of autograd recording to ensure the backward pass propogates
gradients back to the corresponding remote module.
The arguments of ``forward_async`` and ``forward`` are the same as
the ``forward`` method of the module returned by the ``module_cls``.
For example, if ``module_cls`` returns an instance of ``nn.Linear``,
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
the generated ``RemoteModule`` will have 2 methods in signature of
``def forward(input: Tensor) -> Tensor:`` and
``def forward_async(input: Tensor) -> Future[Tensor]:``.
Arguments:
to (str or WorkerInfo): id or name of the destination worker.
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):
>>> def forward(input):
>>> return input + 1
>>>
>>> module_cls = MyModule
args (Sequence, optional): args to be passed to ``module_cls``.
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
Returns:
A remote module instance which wraps the :class:`~nn.Module` created by the
user-provided ``module_cls``, it has a blocking ``forward`` method and an
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
on the user-provided module on the remote side.
Example::
Run the following code in two different processes:
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> from torch import nn, Tensor
>>> from torch.distributed.nn.api.remote_module import RemoteModule
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>> "worker1", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
>>> ret = ret_fut.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>>
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
def __init__(
self,
on: str,
module_cls: nn.Module,
args: Tuple = None,
kwargs: Dict[str, Any] = None,
):
super().__init__(on, module_cls, args, kwargs)