[functorch] [BC-breaking] Update make_functional* (pytorch/functorch#52)

Updates make_functional to use the new improved variants. The new
variants are superior in every way so we're replacing the previous
variants with this.

If someone wants the older variants, they can be found at:
- make_functional_with_buffers_deprecated_v1
- make_functional_deprecated_v1
This commit is contained in:
Richard Zou 2021-06-07 17:55:13 -04:00 committed by Jon Janzen
parent fdcc680c9d
commit b29e666ade
13 changed files with 79 additions and 68 deletions

View File

@ -95,8 +95,8 @@ Right now, we support the following transforms:
- `vmap`
Furthermore, we have some utilities for working with PyTorch modules.
- `make_functional_v2(model)`
- `make_functional_with_buffers_v2(model)`
- `make_functional(model)`
- `make_functional_with_buffers(model)`
### vmap
@ -225,7 +225,7 @@ We can also try compiling it with NNC (even more experimental)!.
Check `examples/nnc` for some example benchmarks.
### Working with NN modules: make_functional_v2 and friends
### Working with NN modules: make_functional and friends
Sometimes you may want to perform a transform with respect to the parameters
and/or buffers of an nn.Module. This can happen for example in:
@ -237,9 +237,9 @@ of the loss with respect to the model parameters
Our solution to this right now is an API that, given an nn.Module, creates a
stateless version of it that can be called like a function.
- `make_functional_v2(model)` returns a functional version of `model` and the
- `make_functional(model)` returns a functional version of `model` and the
`model.parameters()`
- `make_functional_with_buffers_v2(model)` returns a functional version of
- `make_functional_with_buffers(model)` returns a functional version of
`model` and the `model.parameters()` and `model.buffers()`.
Here's an example where we compute per-sample-gradients using an nn.Linear
@ -247,13 +247,13 @@ layer:
```py
import torch
from functorch import make_functional_v2, vmap, grad
from functorch import make_functional, vmap, grad
model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)
func_model, params = make_functional_v2(model)
func_model, params = make_functional(model)
def compute_loss(params, data, targets):
preds = func_model(params, data)

View File

@ -27,7 +27,7 @@ from tqdm import tqdm
from functools import partial
import functorch
from functorch import vmap, grad_and_value
from functorch import make_functional_v2
from functorch import make_functional
# disable warning spam
functorch._C._set_vmap_fallback_warning_enabled(False)
@ -89,7 +89,7 @@ def train(args, model, train_loader, optimizer, epoch, device):
# In order to use functional vmap+grad, we need to be able to
# pass the weights to a model.
func_model, weights = make_functional_v2(model)
func_model, weights = make_functional(model)
# To use vmap+grad to compute per-sample-grads, the forward pass
# must be re-formulated on a single example.

View File

@ -2,7 +2,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional_v2, grad_and_value, vmap, combine_state_for_ensemble
from functorch import make_functional, grad_and_value, vmap, combine_state_for_ensemble
# Adapted from http://willwhitney.com/parallel-training-jax.html
# The original code comes with the following citation:
@ -58,7 +58,7 @@ loss_fn = nn.NLLLoss()
# Step 3: Make the model functional(!!) and define a training function.
# NB: this mechanism doesn't exist in PyTorch today, but we want it to:
# https://github.com/pytorch/pytorch/issues/49171
func_model, weights = make_functional_v2(MLPClassifier().to(DEVICE))
func_model, weights = make_functional(MLPClassifier().to(DEVICE))
def train_step_fn(weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):

View File

@ -42,7 +42,7 @@ import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from functorch import make_functional_with_buffers_v2, vmap, grad
from functorch import make_functional_with_buffers, vmap, grad
import higher
@ -105,7 +105,7 @@ def main():
nn.Linear(64, args.n_way)).to(device)
net.train()
fnet, params, buffers = make_functional_with_buffers_v2(net)
fnet, params, buffers = make_functional_with_buffers(net)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.

View File

@ -45,7 +45,7 @@ import torch.nn.functional as F
import torch.optim as optim
import functorch
from functorch import make_functional_with_buffers_v2, vmap, grad
from functorch import make_functional_with_buffers, vmap, grad
# Squash the warning spam
functorch._C._set_vmap_fallback_warning_enabled(False)
@ -114,7 +114,7 @@ def main():
# Given this module we've created, rip out the parameters and buffers
# and return a functional version of the module. `fnet` is stateless
# and can be called with `fnet(params, buffers, args, kwargs)`
fnet, params, buffers = make_functional_with_buffers_v2(net)
fnet, params, buffers = make_functional_with_buffers(net)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.

View File

@ -20,7 +20,7 @@ import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from functorch import grad, vmap, make_functional_v2
from functorch import grad, vmap, make_functional
class ThreeLayerNet(nn.Module):
def __init__(self):
@ -43,7 +43,7 @@ class ThreeLayerNet(nn.Module):
def mse_loss(x, y):
return torch.mean((x - y) ** 2)
net, params = make_functional_v2(ThreeLayerNet())
net, params = make_functional(ThreeLayerNet())
opt = torch.optim.Adam(params, lr=1e-3)
alpha = 0.1

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch import grad, vmap, pythonkey_trace, wrap_key, make_fx, nnc_jit, make_functional_v2, grad_and_value
from functorch import grad, vmap, pythonkey_trace, wrap_key, make_fx, nnc_jit, make_functional, grad_and_value
import torch
import torch.fx as fx
import torch.nn as nn
@ -41,7 +41,7 @@ mod = Foo(num_layers, features)
jit_mod = torch.jit.script(mod)
func_model, weights = make_functional_v2(mod)
func_model, weights = make_functional(mod)
lr =1.0
def functional_step(x, weights):

View File

@ -10,10 +10,10 @@ from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers, load_state
from ._src.make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1
from ._src.make_functional import (
make_functional_with_buffers_v2,
make_functional_v2,
make_functional_with_buffers,
make_functional,
combine_state_for_ensemble,
)
from ._src.make_functional import functional_init, functional_init_with_buffers

View File

@ -15,7 +15,7 @@ from collections import namedtuple
import gc
from .vmap import vmap
from .make_functional import make_functional, make_functional_with_buffers
from .make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1
from functorch._C import (
_wrap_for_grad,

View File

@ -9,6 +9,7 @@ import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
import copy
import warnings
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
@ -94,7 +95,7 @@ def load_state(
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
load_state takes `weights` and `buffers` and assigns them to the model.
This is the inverse operation of `make_functional`.
This is the inverse operation of `make_functional_deprecated_v1`.
"""
assert len(weight_names) == len(weights)
load_weights(model, weight_names, weights)
@ -104,10 +105,10 @@ def load_state(
return model
def make_functional(model: nn.Module):
"""make_functional(model) -> weights, func, weight_names
def make_functional_deprecated_v1(model: nn.Module):
"""make_functional_deprecated_v1(model) -> weights, func, weight_names
Given an nn.Module, make_functional extracts the state (weights)
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
and returns a functional version of the model, `func`. This makes
it so that it is possible use transforms over the parameters of
`model`.
@ -116,7 +117,7 @@ def make_functional(model: nn.Module):
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, func, _ = make_functional(model)
weights, func, _ = make_functional_deprecated_v1(model)
func(weights, (x,))
```
@ -124,7 +125,7 @@ def make_functional(model: nn.Module):
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, _, func = make_functional(model)
weights, _, func = make_functional_deprecated_v1(model)
grad_weights = grad(func)(weights, (x,))
```
@ -132,8 +133,8 @@ def make_functional(model: nn.Module):
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
'make_functional_with_buffers(model) instead.')
raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use '
'make_functional_with_buffers_deprecated_v1(model) instead.')
weights, descriptors = extract_weights(model)
def fun(weights, data):
@ -144,17 +145,17 @@ def make_functional(model: nn.Module):
return weights, fun, descriptors
def make_functional_with_buffers(model: nn.Module):
"""make_functional_with_buffers(model) -> weights, buffers, func, weight_names, buffer_names
def make_functional_with_buffers_deprecated_v1(model: nn.Module):
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
Given an nn.Module, make_functional_with_buffers extracts the state (weights and buffers)
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
and returns a functional version of the model, `func`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers(model)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
```
@ -162,7 +163,7 @@ def make_functional_with_buffers(model: nn.Module):
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers(model)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
grad_weights = grad(func)(weights, buffers, (x,))
```
@ -234,10 +235,10 @@ class FunctionalModule(nn.Module):
return stateful_model(*args, **kwargs)
def make_functional_v2(model: nn.Module):
"""make_functional_v2(model) -> func, weights
def make_functional(model: nn.Module):
"""make_functional(model) -> func, weights
Given an nn.Module, make_functional_v2 extracts the state (weights)
Given an nn.Module, make_functional extracts the state (weights)
and returns a functional version of the model, `func`. This makes
it so that it is possible use transforms over the parameters of
`model`.
@ -246,11 +247,11 @@ def make_functional_v2(model: nn.Module):
```
import torch
import torch.nn as nn
from functorch import make_functional_v2
from functorch import make_functional
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional_v2(model)
func, params = make_functional(model)
func(params, x)
```
@ -258,12 +259,12 @@ def make_functional_v2(model: nn.Module):
```
import torch
import torch.nn as nn
from functorch import make_functional_v2, grad
from functorch import make_functional, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional_v2(model)
func, params = make_functional(model)
def compute_loss(params, x, t):
y = func(params, x)
@ -272,17 +273,22 @@ def make_functional_v2(model: nn.Module):
grad_weights = grad(compute_loss)(params, x, t)
```
"""
warnings.warn('If this is your first time using make_functional, please '
'ignore this warning. Otherwise, we recently made a '
'backwards incompatible change to make_functional: '
'please try make_functional_deprecated_v1 if you want the '
'previous behavior.', stacklevel=2)
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError('make_functional_v2(model): `model` has buffers. Please use '
'make_functional_with_buffers_v2(model) instead.')
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
'make_functional_with_buffers(model) instead.')
return FunctionalModule._create_from(model)
def make_functional_with_buffers_v2(model: nn.Module):
"""make_functional_with_buffers_v2(model) -> func, params, buffers
def make_functional_with_buffers(model: nn.Module):
"""make_functional_with_buffers(model) -> func, params, buffers
Given an nn.Module, make_functional_with_buffers_v2 extracts the state
Given an nn.Module, make_functional_with_buffers extracts the state
(params and buffers) and returns a functional version of the model `func`
that can be invoked like a function.
@ -290,11 +296,11 @@ def make_functional_with_buffers_v2(model: nn.Module):
```
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers_v2
from functorch import make_functional_with_buffers
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers_v2(model)
func, params, buffers = make_functional_with_buffers(model)
func(params, buffers, x)
```
@ -302,12 +308,12 @@ def make_functional_with_buffers_v2(model: nn.Module):
```
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers_v2, grad
from functorch import make_functional_with_buffers, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers_v2(model)
func, params, buffers = make_functional_with_buffers(model)
def compute_loss(params, buffers, x, t):
y = func(params, buffers, x)
@ -316,6 +322,11 @@ def make_functional_with_buffers_v2(model: nn.Module):
grad_weights = grad(compute_loss)(params, buffers, x, t)
```
"""
warnings.warn('If this is your first time using make_functional_with_buffers, please '
'ignore this warning. Otherwise, we recently made a '
'backwards incompatible change to make_functional_with_buffers: '
'please try make_functional_with_buffers_deprecated_v1 if you want the '
'previous behavior.', stacklevel=2)
return FunctionalModuleWithBuffers._create_from(model)
@ -338,7 +349,7 @@ def combine_state_for_ensemble(models):
`func(params, buffers, *args, **kwargs)` directly, you probably want to
use vmap(func, ...)(params, buffers, *args, **kwargs)
"""
funcs, params, buffers = zip(*[make_functional_with_buffers_v2(model)
funcs, params, buffers = zip(*[make_functional_with_buffers(model)
for model in models])
params = transpose_stack(params)
buffers = transpose_stack(buffers)
@ -368,15 +379,15 @@ def functional_init(model_class, ensemble_shape=(), device='cpu'):
raise ValueError('NYI: ensemble_shape with more than 1 element')
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional(model)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(model_class(*args, **kwargs).to(device)
for _ in range(num_models))
_, fn, names = make_functional(model_class(*args, **kwargs))
weights = tuple(make_functional(model)[0] for model in models)
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights, fn, names
@ -389,7 +400,7 @@ def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
raise ValueError('NYI: ensemble_shape with more than 1 element')
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional(model)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
@ -397,8 +408,8 @@ def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
models = tuple(model_class(*args, **kwargs).to(device)
for _ in range(num_models))
_, _, fn, weight_names, buffer_names = \
make_functional_with_buffers(model_class(*args, **kwargs))
weights, buffers = zip(*tuple(make_functional_with_buffers(model)[:2]
make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
weights, buffers = zip(*tuple(make_functional_with_buffers_deprecated_v1(model)[:2]
for model in models))
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)

View File

@ -22,7 +22,7 @@ from functools import partial
import functorch
from functorch import (
grad, vjp, vmap, jacrev, grad_and_value,
make_functional_v2, make_functional_with_buffers_v2,
make_functional, make_functional_with_buffers,
functional_init, functional_init_with_buffers,
)
@ -553,7 +553,7 @@ class TestVmapOfGrad(TestCase):
net = SampleNet(vocab_size).to(device=device)
criterion = nn.CrossEntropyLoss()
net_func, weights = make_functional_v2(net)
net_func, weights = make_functional(net)
def compute_loss(weights, data, target):
output = net_func(weights, data)
@ -721,7 +721,7 @@ class TestExamplesCorrectness(TestCase):
def mse_loss(x, y):
return torch.mean((x - y) ** 2)
net, params = make_functional_v2(ThreeLayerNet().to(device))
net, params = make_functional(ThreeLayerNet().to(device))
K = 20
losses = []
num_tasks = 4
@ -811,7 +811,7 @@ class TestExamplesCorrectness(TestCase):
Flatten(),
nn.Linear(64, n_way)).to(device).to(dtype)
fnet, params, buffers = make_functional_with_buffers_v2(net)
fnet, params, buffers = make_functional_with_buffers(net)
net = (params, buffers, fnet)
def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry):
@ -955,7 +955,7 @@ class TestExamplesCorrectness(TestCase):
loss_fn = nn.NLLLoss()
func_model, weights = make_functional_v2(MLPClassifier().to(device))
func_model, weights = make_functional(MLPClassifier().to(device))
def train_step_fn(use_transform, weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):
@ -981,7 +981,7 @@ class TestExamplesCorrectness(TestCase):
def init_fn(num_models):
models = tuple(MLPClassifier().to(device) for _ in range(num_models))
weights = tuple(make_functional_v2(model)[1] for model in models)
weights = tuple(make_functional(model)[1] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights
@ -1045,7 +1045,7 @@ class TestExamplesCorrectness(TestCase):
model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device)
criterion = nn.CrossEntropyLoss()
func_model, weights = make_functional_v2(model)
func_model, weights = make_functional(model)
def compute_loss(weights, image, target):
images = image.unsqueeze(0)

View File

@ -22,7 +22,7 @@ from functools import partial
import functorch
from functorch import (
grad, vjp, vmap, jacrev, grad_and_value,
make_functional, make_functional_with_buffers, make_fx, nnc_jit
make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1, make_fx, nnc_jit
)
# NB: numpy is a testing dependency!

View File

@ -23,7 +23,7 @@ from common_utils import (
)
import types
from functorch import vmap, functional_init_with_buffers, make_functional_with_buffers
from functorch import vmap, functional_init_with_buffers, make_functional_with_buffers_deprecated_v1
from functorch._C import reshape_dim_into, reshape_dim_outof