pytorch/torch/nn/utils/_stateless.py
Emilio Castillo fa38e93fe9 Add lightweight reparametrization for _stateless calls (#68969)
Summary:
https://github.com/pytorch/pytorch/issues/61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large.
I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible.

resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters.

Used script:
https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a

| % of parameters passed | CPU Time (us) | GPU Time (us) |
|------------------------|---------------|---------------|
| regular call           | 5539          | 184909        |
| 0                      | 5561          | 184843        |
| 25                     | 11363         | 189236        |
| 50                     | 18716         | 195378        |
| 75                     | 22851         | 198641        |
| 100                    | 27441         | 202281        |

This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor.

The execution times now are as follows:

| % of parameters passed | CPU Time (us) | GPU Time (us) |
|------------------------|---------------|---------------|
| regular call           | 5939          | 187533        |
| 0                      | 5899          | 187570        |
| 25                     | 8541         | 188953        |
| 50                     | 10045         | 189826        |
| 75                     | 11049         | 190344        |
| 100                    | 11911         | 190800        |
| functorch with 100% params | 14014 | 191727

Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap.

cc albanD zou3519

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68969

Reviewed By: george-qi

Differential Revision: D33836360

Pulled By: albanD

fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d
(cherry picked from commit fd4b6bdfbf)
2022-01-28 14:38:45 +00:00

121 lines
3.9 KiB
Python

import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple
import torch
from torch import Tensor
# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
# and using other types causes mypy errors
def _change_class(module) -> None:
cls = module.__class__
func_params : Dict[str, Tensor] = module._functional_parameters
def _getattribute(self, name: str) -> Any:
if name in func_params:
return func_params[name]
return cls.__getattribute__(self, name)
param_cls = type(
f"StatelessReplacer{cls.__name__}",
(cls,),
{
"__getattribute__": _getattribute,
},
)
module.__class__ = param_cls
module._orig_class = cls
def _swap_parameters(module, tensor_name: str, tensor: Tensor) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, "_functional_parameters"):
module._functional_parameters[tensor_name] = tensor
else:
module._functional_parameters = {}
module._functional_parameters[tensor_name] = tensor
_change_class(module)
def _remove_swap(module, name: str) -> None:
if hasattr(module, "_orig_class"):
module.__class__ = module._orig_class
delattr(module, "_orig_class")
delattr(module, "_functional_parameters")
@contextlib.contextmanager
def reparametrize_module(
module: torch.nn.Module,
parameters_and_buffers: Dict[str, Tensor],
) -> Iterator[None]:
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_swap_parameters,
module, name.split("."), (tensor,))
yield
for name in parameters_and_buffers:
_apply_func_submodules(
_remove_swap,
module, name.split("."), ())
def _apply_func_submodules(
func: Callable[..., None],
module: torch.nn.Module,
path: List[str],
args: Tuple,
):
if len(path) == 1:
func(module, path[0], *args)
else:
_apply_func_submodules(func, getattr(module, path[0]), path[1:], args)
def functional_call(
module: torch.nn.Module,
parameters_and_buffers: Dict[str, Tensor],
args: Tuple,
kwargs : Dict[str, Any] = None,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
.. note:: If the module has active parametrizations, passing a value in the
`parameters_and_buffers` argument with the name set to the regular parameter
name will completely disable the parametrization.
If you want to apply the parametrization function to the value passed
please set the key as `{submodule_name}.parametrizations.{parameter_name}.original`.
Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
the module call.
args (tuple): arguments to be passed to the module call
kwargs (dict): keyword arguments to be passed to the module call
Returns:
Any: the result of calling ``module``.
"""
# TODO allow kwargs such as unsafe and others for parametrization
if (
torch.jit.is_tracing()
or torch.jit.is_scripting()
or isinstance(module, (
torch.jit.RecursiveScriptModule,
torch.jit.ScriptModule,
torch.jit.ScriptFunction)
)
):
raise RuntimeError("The stateless API can't be used with Jitted modules")
if kwargs is None:
kwargs = {}
with reparametrize_module(module, parameters_and_buffers):
if isinstance(args, tuple):
out = module(*args, **kwargs)
else:
out = module(args, **kwargs)
return out