mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This will be the last disruptive functorch internals change. Why are we moving these files? - As a part of rationalizing functorch we are moving the code in functorch/_src to torch/_functorch - This is so that we can offer the functorch APIs as native PyTorch APIs (coming soon) and resolve some internal build issues. Why are we moving all of these files at once? - It's better to break developers all at once rather than many times Test Plan: - wait for tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091 Approved by: https://github.com/anijain2305, https://github.com/ezyang
544 lines
21 KiB
Python
544 lines
21 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from typing import List, Tuple
|
|
from .named_members_polyfill import _named_parameters, _named_buffers
|
|
import copy
|
|
|
|
# Utilities to make nn.Module "functional"
|
|
# In particular the goal is to be able to provide a function that takes as input
|
|
# the parameters and evaluate the nn.Module using fixed inputs.
|
|
|
|
|
|
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
|
"""
|
|
Deletes the attribute specified by the given list of names.
|
|
For example, to delete the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'])
|
|
"""
|
|
if len(names) == 1:
|
|
delattr(obj, names[0])
|
|
else:
|
|
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
|
|
|
|
|
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
|
"""
|
|
Set the attribute specified by the given list of names to value.
|
|
For example, to set the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
|
"""
|
|
if len(names) == 1:
|
|
setattr(obj, names[0], value)
|
|
else:
|
|
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
|
|
|
|
|
def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
|
if len(names) == 1:
|
|
return getattr(obj, names[0])
|
|
else:
|
|
return _get_nested_attr(getattr(obj, names[0]), names[1:])
|
|
|
|
|
|
def raise_parameter_tying_error():
|
|
raise RuntimeError(
|
|
"make_functional(module): we don't yet support models that "
|
|
"do parameter tying (also sometimes known as weight sharing). "
|
|
"Please try to rewrite your model by replacing all instances of the "
|
|
"tied parameter with another and/or comment your support in "
|
|
"https://github.com/pytorch/functorch/issues/446")
|
|
|
|
|
|
def create_names_map(named_params, tied_named_params):
|
|
"""
|
|
named_params is a dictionary of tensors: {'A': A, 'B': B}
|
|
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
|
|
with potentially tied (or 'duplicated') tensors
|
|
|
|
This function creates a mapping from the names in named_params to the
|
|
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
|
|
"""
|
|
named_params = {k: v for k, v in named_params}
|
|
tied_named_params = {k: v for k, v in tied_named_params}
|
|
|
|
tensors_dict_keys = set(named_params.keys())
|
|
tied_tensors_dict_keys = set(tied_named_params.keys())
|
|
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
|
|
|
|
tensor_to_mapping = {}
|
|
for key, tensor in named_params.items():
|
|
tensor_to_mapping[tensor] = (key, [])
|
|
for key, tensor in tied_named_params.items():
|
|
assert tensor in tensor_to_mapping
|
|
tensor_to_mapping[tensor][1].append(key.split('.'))
|
|
result = {key: value for key, value in tensor_to_mapping.values()}
|
|
return result
|
|
|
|
|
|
def _extract_members(mod: nn.Module, _named_members, named_members, subclass):
|
|
all_named_members = tuple(_named_members(mod, remove_duplicate=False))
|
|
named_members = tuple(named_members())
|
|
names_map = create_names_map(named_members, all_named_members)
|
|
|
|
# Remove all the members in the model
|
|
memo = {}
|
|
for name, p in all_named_members:
|
|
if p not in memo:
|
|
memo[p] = subclass(torch.empty_like(p, device='meta'))
|
|
replacement = memo[p]
|
|
_set_nested_attr(mod, name.split("."), replacement)
|
|
|
|
if len(named_members) == 0:
|
|
names, params = (), ()
|
|
else:
|
|
names, params = zip(*named_members)
|
|
return params, names, names_map
|
|
|
|
|
|
def extract_weights(mod: nn.Module):
|
|
"""
|
|
This function removes all the Parameters from the model and
|
|
return them as a tuple as well as their original attribute names.
|
|
The weights must be re-loaded with `load_weights` before the model
|
|
can be used again.
|
|
Note that this function modifies the model in place and after this
|
|
call, mod.parameters() will be empty.
|
|
"""
|
|
return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter)
|
|
|
|
|
|
def extract_buffers(mod: nn.Module):
|
|
return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x)
|
|
|
|
|
|
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
|
|
"""
|
|
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
|
Note that the `params` are regular Tensors (that can have history) and so are left
|
|
as Tensors. This means that mod.parameters() will still be empty after this call.
|
|
"""
|
|
for name, p in zip(names, params):
|
|
if as_params:
|
|
p = nn.Parameter(p)
|
|
_del_nested_attr(mod, name.split("."))
|
|
_set_nested_attr(mod, name.split("."), p)
|
|
|
|
|
|
def _swap_state(mod: nn.Module, names_map: List[str], elems):
|
|
result = []
|
|
for (_, attr_names), elem in zip(names_map.items(), elems):
|
|
for i, attr_name in enumerate(attr_names):
|
|
if i == 0:
|
|
result.append(_get_nested_attr(mod, attr_name))
|
|
_del_nested_attr(mod, attr_name)
|
|
_set_nested_attr(mod, attr_name, elem)
|
|
return result
|
|
|
|
|
|
def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None:
|
|
for name, p in zip(names, buffers):
|
|
_set_nested_attr(mod, name.split("."), p)
|
|
|
|
|
|
def load_state(
|
|
model: nn.Module,
|
|
weights: List[Tensor], weight_names: List[str],
|
|
buffers=(), buffer_names=()):
|
|
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
|
|
|
|
load_state takes `weights` and `buffers` and assigns them to the model.
|
|
This is the inverse operation of `make_functional_deprecated_v1`.
|
|
"""
|
|
assert len(weight_names) == len(weights)
|
|
load_weights(model, weight_names, weights)
|
|
if len(buffers) > 0:
|
|
assert len(buffer_names) == len(buffers)
|
|
load_buffers(model, buffer_names, buffers)
|
|
return model
|
|
|
|
|
|
def make_functional_deprecated_v1(model: nn.Module):
|
|
"""make_functional_deprecated_v1(model) -> weights, func, weight_names
|
|
|
|
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
|
|
and returns a functional version of the model, `func`. This makes
|
|
it so that it is possible use transforms over the parameters of
|
|
`model`.
|
|
|
|
`func` can be invoked as follows:
|
|
```
|
|
x = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
weights, func, _ = make_functional_deprecated_v1(model)
|
|
func(weights, (x,))
|
|
```
|
|
|
|
And here is an example of applying the grad transform:
|
|
```
|
|
x = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
weights, _, func = make_functional_deprecated_v1(model)
|
|
grad_weights = grad(func)(weights, (x,))
|
|
```
|
|
|
|
To put the state back into a model, use `load_state`.
|
|
"""
|
|
buffers = list(model.buffers())
|
|
if len(buffers) > 0:
|
|
raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use '
|
|
'make_functional_with_buffers_deprecated_v1(model) instead.')
|
|
weights, descriptors, _ = extract_weights(model)
|
|
|
|
def fun(weights, data):
|
|
mutable_model = copy.deepcopy(model)
|
|
load_weights(mutable_model, descriptors, weights)
|
|
return mutable_model(*data)
|
|
|
|
return weights, fun, descriptors
|
|
|
|
|
|
def make_functional_with_buffers_deprecated_v1(model: nn.Module):
|
|
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
|
|
|
|
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
|
|
and returns a functional version of the model, `func`.
|
|
|
|
`func` can be invoked as follows:
|
|
```
|
|
x = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
|
|
func(weights, buffers, (x,))
|
|
```
|
|
|
|
And here is an example of applying the grad transform:
|
|
```
|
|
x = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
|
|
func(weights, buffers, (x,))
|
|
grad_weights = grad(func)(weights, buffers, (x,))
|
|
```
|
|
|
|
To put the state back into a model, use `load_state`.
|
|
"""
|
|
weights, weight_descriptors, _ = extract_weights(model)
|
|
buffers, buf_descriptors, _ = extract_buffers(model)
|
|
|
|
def fun(weights, buffers, data):
|
|
mutable_model = copy.deepcopy(model)
|
|
load_weights(mutable_model, weight_descriptors, weights)
|
|
load_buffers(mutable_model, buf_descriptors, buffers)
|
|
return mutable_model(*data)
|
|
|
|
return weights, buffers, fun, weight_descriptors, buf_descriptors
|
|
|
|
|
|
class FunctionalModuleWithBuffers(nn.Module):
|
|
"""
|
|
This is the callable object returned by :func:`make_functional_with_buffers`.
|
|
"""
|
|
|
|
def __init__(self, stateless_model, param_names, buffer_names,
|
|
param_names_map, buffer_names_map):
|
|
super(FunctionalModuleWithBuffers, self).__init__()
|
|
self.stateless_model = stateless_model
|
|
self.param_names = param_names
|
|
self.buffer_names = buffer_names
|
|
|
|
self.all_names_map = dict(param_names_map)
|
|
self.all_names_map.update(buffer_names_map)
|
|
|
|
@staticmethod
|
|
def _create_from(model, disable_autograd_tracking=False):
|
|
# TODO: We don't need to copy the model to create a stateless copy
|
|
model_copy = copy.deepcopy(model)
|
|
params, param_names, param_names_map = extract_weights(model_copy)
|
|
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
|
|
if disable_autograd_tracking:
|
|
for param in params:
|
|
param.requires_grad_(False)
|
|
return (
|
|
FunctionalModuleWithBuffers(model_copy, param_names, buffer_names,
|
|
param_names_map, buffer_names_map),
|
|
params,
|
|
buffers,
|
|
)
|
|
|
|
def forward(self, params, buffers, *args, **kwargs):
|
|
# Temporarily load the state back onto self.stateless_model
|
|
old_state = _swap_state(
|
|
self.stateless_model,
|
|
self.all_names_map,
|
|
list(params) + list(buffers))
|
|
try:
|
|
return self.stateless_model(*args, **kwargs)
|
|
finally:
|
|
# Remove the loaded state on self.stateless_model
|
|
_swap_state(self.stateless_model, self.all_names_map, old_state)
|
|
|
|
|
|
class FunctionalModule(nn.Module):
|
|
"""
|
|
This is the callable object returned by :func:`make_functional`.
|
|
"""
|
|
|
|
def __init__(self, stateless_model, param_names, names_map):
|
|
super(FunctionalModule, self).__init__()
|
|
self.stateless_model = stateless_model
|
|
self.param_names = param_names
|
|
self.names_map = names_map
|
|
|
|
@staticmethod
|
|
def _create_from(model, disable_autograd_tracking=False):
|
|
# TODO: We don't need to copy the model to create a stateless copy
|
|
model_copy = copy.deepcopy(model)
|
|
params, param_names, names_map = extract_weights(model_copy)
|
|
if disable_autograd_tracking:
|
|
for param in params:
|
|
param.requires_grad_(False)
|
|
return FunctionalModule(model_copy, param_names, names_map), params
|
|
|
|
def forward(self, params, *args, **kwargs):
|
|
# Temporarily load the state back onto self.stateless_model
|
|
old_state = _swap_state(self.stateless_model, self.names_map, params)
|
|
try:
|
|
return self.stateless_model(*args, **kwargs)
|
|
finally:
|
|
# Remove the loaded state on self.stateless_model
|
|
_swap_state(self.stateless_model, self.names_map, old_state)
|
|
|
|
|
|
def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
|
|
"""make_functional(model, disable_autograd_tracking=False) -> func, params
|
|
|
|
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
|
|
(params) and returns a functional version of the model, ``func``. This
|
|
makes it so that it is possible use transforms over the parameters of
|
|
``model``.
|
|
|
|
``func`` can be invoked as follows:
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from functorch import make_functional
|
|
|
|
x = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
func, params = make_functional(model)
|
|
func(params, x)
|
|
|
|
And here is an example of applying the grad transform over the parameters
|
|
of a model.
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from functorch import make_functional, grad
|
|
|
|
x = torch.randn(4, 3)
|
|
t = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
func, params = make_functional(model)
|
|
|
|
def compute_loss(params, x, t):
|
|
y = func(params, x)
|
|
return nn.functional.mse_loss(y, t)
|
|
|
|
grad_weights = grad(compute_loss)(params, x, t)
|
|
|
|
If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
|
|
|
|
Args:
|
|
model (torch.nn.Module): Input model.
|
|
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
|
|
The returned params are unrelated to the set of params from the original model. If False (default),
|
|
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
|
|
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
|
|
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
|
|
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
|
|
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
|
|
Otherwise, if you're only planning on using functorch's gradient transforms,
|
|
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
|
|
history with PyTorch autograd.
|
|
|
|
"""
|
|
buffers = list(model.buffers())
|
|
if len(buffers) > 0:
|
|
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
|
|
'make_functional_with_buffers(model) instead.')
|
|
return FunctionalModule._create_from(model, disable_autograd_tracking=disable_autograd_tracking)
|
|
|
|
|
|
def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False):
|
|
"""make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
|
|
|
|
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
|
|
state (params and buffers) and returns a functional version of the model
|
|
``func`` that can be invoked like a function.
|
|
|
|
``func`` can be invoked as follows:
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from functorch import make_functional_with_buffers
|
|
|
|
x = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
func, params, buffers = make_functional_with_buffers(model)
|
|
func(params, buffers, x)
|
|
|
|
And here is an example of applying the grad transform over the parameters
|
|
of a model:
|
|
|
|
.. code-block:: python
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from functorch import make_functional_with_buffers, grad
|
|
|
|
x = torch.randn(4, 3)
|
|
t = torch.randn(4, 3)
|
|
model = nn.Linear(3, 3)
|
|
func, params, buffers = make_functional_with_buffers(model)
|
|
|
|
def compute_loss(params, buffers, x, t):
|
|
y = func(params, buffers, x)
|
|
return nn.functional.mse_loss(y, t)
|
|
|
|
grad_weights = grad(compute_loss)(params, buffers, x, t)
|
|
|
|
Args:
|
|
model (torch.nn.Module): Input model.
|
|
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
|
|
The returned params are unrelated to the set of params from the original model. If False (default),
|
|
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
|
|
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
|
|
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
|
|
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
|
|
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
|
|
Otherwise, if you're only planning on using functorch's gradient transforms,
|
|
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
|
|
history with PyTorch autograd.
|
|
|
|
"""
|
|
return FunctionalModuleWithBuffers._create_from(model, disable_autograd_tracking=disable_autograd_tracking)
|
|
|
|
|
|
def transpose_stack(tuple_of_tuple_of_tensors):
|
|
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
|
|
results = tuple(torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors)
|
|
return results
|
|
|
|
|
|
def combine_state_for_ensemble(models):
|
|
"""combine_state_for_ensemble(models) -> func, params, buffers
|
|
|
|
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
|
|
|
|
Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their
|
|
parameters and buffers together to make ``params`` and ``buffers``.
|
|
Each parameter and buffer in the result will have an additional dimension
|
|
of size ``M``.
|
|
|
|
:func:`combine_state_for_ensemble` also returns ``func``, a functional
|
|
version of one of the models in :attr:`models`. One cannot directly run
|
|
``func(params, buffers, *args, **kwargs)`` directly, you probably want to
|
|
use ``vmap(func, ...)(params, buffers, *args, **kwargs)``
|
|
|
|
Here's an example of how to ensemble over a very simple model:
|
|
|
|
.. code-block:: python
|
|
|
|
num_models = 5
|
|
batch_size = 64
|
|
in_features, out_features = 3, 3
|
|
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
|
data = torch.randn(batch_size, 3)
|
|
|
|
fmodel, params, buffers = combine_state_for_ensemble(models)
|
|
output = vmap(fmodel, (0, 0, None))(params, buffers, data)
|
|
|
|
assert output.shape == (num_models, batch_size, out_features)
|
|
|
|
.. warning::
|
|
All of the modules being stacked together must be the same (except for
|
|
the values of their parameters/buffers). For example, they should be in the
|
|
same mode (training vs eval).
|
|
|
|
This API is subject to change -- we're investigating better ways to
|
|
create ensembles and would love your feedback how to improve this.
|
|
"""
|
|
if len(models) == 0:
|
|
raise RuntimeError('combine_state_for_ensemble: Expected at least one model, got 0.')
|
|
if not (all(m.training for m in models) or all(not m.training for m in models)):
|
|
raise RuntimeError('combine_state_for_ensemble: Expected all models to '
|
|
'have the same training/eval mode.')
|
|
model0_typ = type(models[0])
|
|
if not all(type(m) == model0_typ for m in models):
|
|
raise RuntimeError('combine_state_for_ensemble: Expected all models to '
|
|
'be of the same class.')
|
|
funcs, params, buffers = zip(*[make_functional_with_buffers(model)
|
|
for model in models])
|
|
params = transpose_stack(params)
|
|
buffers = transpose_stack(buffers)
|
|
return funcs[0], params, buffers
|
|
|
|
|
|
def functional_init(model_class, ensemble_shape=(), device='cpu'):
|
|
def wrapped(*args, **kwargs):
|
|
if len(ensemble_shape) >= 2:
|
|
raise ValueError('NYI: ensemble_shape with more than 1 element')
|
|
if len(ensemble_shape) == 0:
|
|
model = model_class(*args, **kwargs).to(device)
|
|
return make_functional_deprecated_v1(model)
|
|
num_models = ensemble_shape[0]
|
|
if num_models <= 0:
|
|
raise ValueError(f"num_models {num_models} should be > 0")
|
|
# NB: Not very efficient, more of a POC
|
|
models = tuple(model_class(*args, **kwargs).to(device)
|
|
for _ in range(num_models))
|
|
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
|
|
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
|
|
weights = tuple(zip(*weights))
|
|
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
|
return weights, fn, names
|
|
return wrapped
|
|
|
|
|
|
def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
|
|
def wrapped(*args, **kwargs):
|
|
if len(ensemble_shape) >= 2:
|
|
raise ValueError('NYI: ensemble_shape with more than 1 element')
|
|
if len(ensemble_shape) == 0:
|
|
model = model_class(*args, **kwargs).to(device)
|
|
return make_functional_deprecated_v1(model)
|
|
num_models = ensemble_shape[0]
|
|
if num_models <= 0:
|
|
raise ValueError(f"num_models {num_models} should be > 0")
|
|
# NB: Not very efficient, more of a POC
|
|
models = tuple(model_class(*args, **kwargs).to(device)
|
|
for _ in range(num_models))
|
|
_, _, fn, weight_names, buffer_names = \
|
|
make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
|
|
weights, buffers = zip(*tuple(make_functional_with_buffers_deprecated_v1(model)[:2]
|
|
for model in models))
|
|
weights = tuple(zip(*weights))
|
|
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
|
buffers = tuple(zip(*buffers))
|
|
buffers = tuple(torch.stack(shards).detach() for shards in buffers)
|
|
return weights, buffers, fn, weight_names, buffer_names
|
|
return wrapped
|