mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43906 This method returns a list of RRefs of remote parameters that can be fed into the DistributedOptimizer. Original PR issue: RemoteModule enhancements #40550 Test Plan: buck test caffe2/test/distributed/rpc:process_group_agent -- RemoteModule Reviewed By: rohan-varma Differential Revision: D23399586 fbshipit-source-id: 4b0f1ccf2e47c8a9e4f79cb2c8668f3cdbdff820
237 lines
9.3 KiB
Python
237 lines
9.3 KiB
Python
#!/usr/bin/python3
|
|
import types
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch.distributed.rpc as rpc
|
|
from torch import nn
|
|
from torch.distributed.nn.jit import instantiator
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
|
|
instantiator.instantiate_non_scriptable_remote_module_template()
|
|
)
|
|
|
|
|
|
# RPC handler.
|
|
def _instantiate_template(module_interface_cls):
|
|
instantiator.instantiate_scriptable_remote_module_template(module_interface_cls)
|
|
|
|
|
|
def _create_module(module_cls, args, kwargs, module_interface_cls=None):
|
|
module = module_cls(*args, **kwargs)
|
|
if not isinstance(module, nn.Module):
|
|
raise ValueError(
|
|
"Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, "
|
|
f"but it returns an instance of {type(module)}."
|
|
)
|
|
if module_interface_cls is not None:
|
|
module = torch.jit.script(module)
|
|
return rpc.RRef(module, module_interface_cls)
|
|
|
|
|
|
def _param_rrefs(module_rref, recurse):
|
|
ret = []
|
|
for param in module_rref.local_value().parameters(recurse):
|
|
ret.append(rpc.RRef(param))
|
|
return ret
|
|
|
|
|
|
class _RemoteModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
on: str,
|
|
module_cls: nn.Module,
|
|
args: Tuple = None,
|
|
kwargs: Dict[str, Any] = None,
|
|
_module_interface_cls: Any = None,
|
|
):
|
|
"""
|
|
A RemoteModule instance can only be created after RPC initialization.
|
|
It creates a user-specified module on a specified remote node.
|
|
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
|
|
executed on the remote node.
|
|
It takes care of autograd recording to ensure the backward pass propogates
|
|
gradients back to the corresponding remote module.
|
|
|
|
The arguments of ``forward_async`` and ``forward`` are the same as
|
|
the ``forward`` method of the module returned by the ``module_cls``.
|
|
|
|
For example, if ``module_cls`` returns an instance of ``nn.Linear``,
|
|
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
|
|
the generated ``RemoteModule`` will have 2 methods in signature of
|
|
``def forward(input: Tensor) -> Tensor:`` and
|
|
``def forward_async(input: Tensor) -> Future[Tensor]:``.
|
|
|
|
Arguments:
|
|
on (str or WorkerInfo): id or name of the destination worker.
|
|
module_cls (nn.Module): For example,
|
|
>>> class MyModule(nn.Module):
|
|
>>> def forward(input):
|
|
>>> return input + 1
|
|
>>>
|
|
>>> module_cls = MyModule
|
|
args (Sequence, optional): args to be passed to ``module_cls``.
|
|
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
|
|
_module_interface_cls (type, optional): The TorchScript interface type for the module
|
|
to be created. The type object should be decorated by @torch.jit.interface.
|
|
If not provided, the generated RemoteModule is not torchscript-able.
|
|
Warning, this is an experimental API and susceptible to frequent changes.
|
|
|
|
Returns:
|
|
A remote module instance which wraps the :class:`~nn.Module` created by the
|
|
user-provided ``module_cls``, it has a blocking ``forward`` method and an
|
|
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
|
|
on the user-provided module on the remote side.
|
|
|
|
Example::
|
|
Run the following code in two different processes:
|
|
|
|
>>> # On worker 0:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> from torch import nn, Tensor
|
|
>>> from torch.distributed.nn.api.remote_module import RemoteModule
|
|
>>>
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
|
>>> remote_linear_module = RemoteModule(
|
|
>>> "worker1", nn.Linear, args=(20, 30),
|
|
>>> )
|
|
>>> input = torch.randn(128, 20)
|
|
>>> ret_fut = remote_linear_module.forward_async(input)
|
|
>>> ret = ret_fut.wait()
|
|
>>> rpc.shutdown()
|
|
|
|
>>> # On worker 1:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>>
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
|
>>> rpc.shutdown()
|
|
"""
|
|
super().__init__()
|
|
|
|
# Sanity checks.
|
|
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
|
|
|
|
# Default arguments preperation.
|
|
args = args if args is not None else ()
|
|
kwargs = kwargs if kwargs is not None else {}
|
|
|
|
self.on = on
|
|
|
|
if _module_interface_cls is not None:
|
|
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
|
|
self.is_scriptable = True
|
|
|
|
# Instantiate template on remote side.
|
|
fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,))
|
|
|
|
# Instantiate template on local side.
|
|
generated_module = instantiator.instantiate_scriptable_remote_module_template(
|
|
_module_interface_cls
|
|
)
|
|
generated_methods = generated_module._generated_methods
|
|
|
|
# Create the module on the remote side.
|
|
fut.wait() # Ensure remote_module_cls is available on remote side.
|
|
else:
|
|
self.is_scriptable = False
|
|
generated_methods = _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
|
|
|
|
# Create the module on the remote side.
|
|
self.module_rref = rpc.rpc_sync(
|
|
on, _create_module, (module_cls, args, kwargs, _module_interface_cls)
|
|
)
|
|
|
|
# Install generated methods.
|
|
for method in generated_methods:
|
|
method_name = method.__name__
|
|
method = torch.jit.export(method)
|
|
setattr(self, method_name, types.MethodType(method, self))
|
|
|
|
def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
|
|
r"""Returns a list of RRefs of remote module parameters.
|
|
This is typically passed to a distributed optimizer.
|
|
Args:
|
|
recurse (bool): if True, then returns parameters of the remote module
|
|
and all submodules of the remote module.
|
|
Otherwise, returns only parameters that are direct members of the remote module.
|
|
|
|
Returns:
|
|
A list of RRefs to remote module parameters.
|
|
"""
|
|
return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))
|
|
|
|
|
|
class RemoteModule(_RemoteModule):
|
|
"""
|
|
A RemoteModule instance can only be created after RPC initialization.
|
|
It creates a user-specified module on a specified remote node.
|
|
It behaves like a regular ``nn.Module`` except that the ``forward`` method is
|
|
executed on the remote node.
|
|
It takes care of autograd recording to ensure the backward pass propogates
|
|
gradients back to the corresponding remote module.
|
|
|
|
The arguments of ``forward_async`` and ``forward`` are the same as
|
|
the ``forward`` method of the module returned by the ``module_cls``.
|
|
|
|
For example, if ``module_cls`` returns an instance of ``nn.Linear``,
|
|
that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
|
|
the generated ``RemoteModule`` will have 2 methods in signature of
|
|
``def forward(input: Tensor) -> Tensor:`` and
|
|
``def forward_async(input: Tensor) -> Future[Tensor]:``.
|
|
|
|
Arguments:
|
|
to (str or WorkerInfo): id or name of the destination worker.
|
|
module_cls (nn.Module): For example,
|
|
>>> class MyModule(nn.Module):
|
|
>>> def forward(input):
|
|
>>> return input + 1
|
|
>>>
|
|
>>> module_cls = MyModule
|
|
args (Sequence, optional): args to be passed to ``module_cls``.
|
|
kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
|
|
|
|
Returns:
|
|
A remote module instance which wraps the :class:`~nn.Module` created by the
|
|
user-provided ``module_cls``, it has a blocking ``forward`` method and an
|
|
asynchronous ``forward_async`` method that returns a future of the ``forward`` call
|
|
on the user-provided module on the remote side.
|
|
|
|
Example::
|
|
Run the following code in two different processes:
|
|
|
|
>>> # On worker 0:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>> from torch import nn, Tensor
|
|
>>> from torch.distributed.nn.api.remote_module import RemoteModule
|
|
>>>
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
|
>>> remote_linear_module = RemoteModule(
|
|
>>> "worker1", nn.Linear, args=(20, 30),
|
|
>>> )
|
|
>>> input = torch.randn(128, 20)
|
|
>>> ret_fut = remote_linear_module.forward_async(input)
|
|
>>> ret = ret_fut.wait()
|
|
>>> rpc.shutdown()
|
|
|
|
>>> # On worker 1:
|
|
>>> import torch
|
|
>>> import torch.distributed.rpc as rpc
|
|
>>>
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
|
>>> rpc.shutdown()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
on: str,
|
|
module_cls: nn.Module,
|
|
args: Tuple = None,
|
|
kwargs: Dict[str, Any] = None,
|
|
):
|
|
super().__init__(on, module_cls, args, kwargs)
|