mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[functorch] [BC-breaking] Update make_functional* (pytorch/functorch#52)
Updates make_functional to use the new improved variants. The new variants are superior in every way so we're replacing the previous variants with this. If someone wants the older variants, they can be found at: - make_functional_with_buffers_deprecated_v1 - make_functional_deprecated_v1
This commit is contained in:
parent
fdcc680c9d
commit
b29e666ade
|
|
@ -95,8 +95,8 @@ Right now, we support the following transforms:
|
|||
- `vmap`
|
||||
|
||||
Furthermore, we have some utilities for working with PyTorch modules.
|
||||
- `make_functional_v2(model)`
|
||||
- `make_functional_with_buffers_v2(model)`
|
||||
- `make_functional(model)`
|
||||
- `make_functional_with_buffers(model)`
|
||||
|
||||
### vmap
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ We can also try compiling it with NNC (even more experimental)!.
|
|||
|
||||
Check `examples/nnc` for some example benchmarks.
|
||||
|
||||
### Working with NN modules: make_functional_v2 and friends
|
||||
### Working with NN modules: make_functional and friends
|
||||
|
||||
Sometimes you may want to perform a transform with respect to the parameters
|
||||
and/or buffers of an nn.Module. This can happen for example in:
|
||||
|
|
@ -237,9 +237,9 @@ of the loss with respect to the model parameters
|
|||
Our solution to this right now is an API that, given an nn.Module, creates a
|
||||
stateless version of it that can be called like a function.
|
||||
|
||||
- `make_functional_v2(model)` returns a functional version of `model` and the
|
||||
- `make_functional(model)` returns a functional version of `model` and the
|
||||
`model.parameters()`
|
||||
- `make_functional_with_buffers_v2(model)` returns a functional version of
|
||||
- `make_functional_with_buffers(model)` returns a functional version of
|
||||
`model` and the `model.parameters()` and `model.buffers()`.
|
||||
|
||||
Here's an example where we compute per-sample-gradients using an nn.Linear
|
||||
|
|
@ -247,13 +247,13 @@ layer:
|
|||
|
||||
```py
|
||||
import torch
|
||||
from functorch import make_functional_v2, vmap, grad
|
||||
from functorch import make_functional, vmap, grad
|
||||
|
||||
model = torch.nn.Linear(3, 3)
|
||||
data = torch.randn(64, 3)
|
||||
targets = torch.randn(64, 3)
|
||||
|
||||
func_model, params = make_functional_v2(model)
|
||||
func_model, params = make_functional(model)
|
||||
|
||||
def compute_loss(params, data, targets):
|
||||
preds = func_model(params, data)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from tqdm import tqdm
|
|||
from functools import partial
|
||||
import functorch
|
||||
from functorch import vmap, grad_and_value
|
||||
from functorch import make_functional_v2
|
||||
from functorch import make_functional
|
||||
|
||||
# disable warning spam
|
||||
functorch._C._set_vmap_fallback_warning_enabled(False)
|
||||
|
|
@ -89,7 +89,7 @@ def train(args, model, train_loader, optimizer, epoch, device):
|
|||
|
||||
# In order to use functional vmap+grad, we need to be able to
|
||||
# pass the weights to a model.
|
||||
func_model, weights = make_functional_v2(model)
|
||||
func_model, weights = make_functional(model)
|
||||
|
||||
# To use vmap+grad to compute per-sample-grads, the forward pass
|
||||
# must be re-formulated on a single example.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functorch import make_functional_v2, grad_and_value, vmap, combine_state_for_ensemble
|
||||
from functorch import make_functional, grad_and_value, vmap, combine_state_for_ensemble
|
||||
|
||||
# Adapted from http://willwhitney.com/parallel-training-jax.html
|
||||
# The original code comes with the following citation:
|
||||
|
|
@ -58,7 +58,7 @@ loss_fn = nn.NLLLoss()
|
|||
# Step 3: Make the model functional(!!) and define a training function.
|
||||
# NB: this mechanism doesn't exist in PyTorch today, but we want it to:
|
||||
# https://github.com/pytorch/pytorch/issues/49171
|
||||
func_model, weights = make_functional_v2(MLPClassifier().to(DEVICE))
|
||||
func_model, weights = make_functional(MLPClassifier().to(DEVICE))
|
||||
|
||||
def train_step_fn(weights, batch, targets, lr=0.2):
|
||||
def compute_loss(weights, batch, targets):
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from functorch import make_functional_with_buffers_v2, vmap, grad
|
||||
from functorch import make_functional_with_buffers, vmap, grad
|
||||
|
||||
import higher
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ def main():
|
|||
nn.Linear(64, args.n_way)).to(device)
|
||||
|
||||
net.train()
|
||||
fnet, params, buffers = make_functional_with_buffers_v2(net)
|
||||
fnet, params, buffers = make_functional_with_buffers(net)
|
||||
|
||||
# We will use Adam to (meta-)optimize the initial parameters
|
||||
# to be adapted.
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ import torch.nn.functional as F
|
|||
import torch.optim as optim
|
||||
|
||||
import functorch
|
||||
from functorch import make_functional_with_buffers_v2, vmap, grad
|
||||
from functorch import make_functional_with_buffers, vmap, grad
|
||||
|
||||
# Squash the warning spam
|
||||
functorch._C._set_vmap_fallback_warning_enabled(False)
|
||||
|
|
@ -114,7 +114,7 @@ def main():
|
|||
# Given this module we've created, rip out the parameters and buffers
|
||||
# and return a functional version of the module. `fnet` is stateless
|
||||
# and can be called with `fnet(params, buffers, args, kwargs)`
|
||||
fnet, params, buffers = make_functional_with_buffers_v2(net)
|
||||
fnet, params, buffers = make_functional_with_buffers(net)
|
||||
|
||||
# We will use Adam to (meta-)optimize the initial parameters
|
||||
# to be adapted.
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import matplotlib as mpl
|
|||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from functorch import grad, vmap, make_functional_v2
|
||||
from functorch import grad, vmap, make_functional
|
||||
|
||||
class ThreeLayerNet(nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -43,7 +43,7 @@ class ThreeLayerNet(nn.Module):
|
|||
def mse_loss(x, y):
|
||||
return torch.mean((x - y) ** 2)
|
||||
|
||||
net, params = make_functional_v2(ThreeLayerNet())
|
||||
net, params = make_functional(ThreeLayerNet())
|
||||
opt = torch.optim.Adam(params, lr=1e-3)
|
||||
alpha = 0.1
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functorch import grad, vmap, pythonkey_trace, wrap_key, make_fx, nnc_jit, make_functional_v2, grad_and_value
|
||||
from functorch import grad, vmap, pythonkey_trace, wrap_key, make_fx, nnc_jit, make_functional, grad_and_value
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
|
|
@ -41,7 +41,7 @@ mod = Foo(num_layers, features)
|
|||
|
||||
jit_mod = torch.jit.script(mod)
|
||||
|
||||
func_model, weights = make_functional_v2(mod)
|
||||
func_model, weights = make_functional(mod)
|
||||
lr =1.0
|
||||
|
||||
def functional_step(x, weights):
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from . import _C
|
|||
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers, load_state
|
||||
from ._src.make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1
|
||||
from ._src.make_functional import (
|
||||
make_functional_with_buffers_v2,
|
||||
make_functional_v2,
|
||||
make_functional_with_buffers,
|
||||
make_functional,
|
||||
combine_state_for_ensemble,
|
||||
)
|
||||
from ._src.make_functional import functional_init, functional_init_with_buffers
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from collections import namedtuple
|
|||
import gc
|
||||
|
||||
from .vmap import vmap
|
||||
from .make_functional import make_functional, make_functional_with_buffers
|
||||
from .make_functional import make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1
|
||||
|
||||
from functorch._C import (
|
||||
_wrap_for_grad,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from typing import List, Tuple
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
# Utilities to make nn.Module "functional"
|
||||
# In particular the goal is to be able to provide a function that takes as input
|
||||
|
|
@ -94,7 +95,7 @@ def load_state(
|
|||
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
|
||||
|
||||
load_state takes `weights` and `buffers` and assigns them to the model.
|
||||
This is the inverse operation of `make_functional`.
|
||||
This is the inverse operation of `make_functional_deprecated_v1`.
|
||||
"""
|
||||
assert len(weight_names) == len(weights)
|
||||
load_weights(model, weight_names, weights)
|
||||
|
|
@ -104,10 +105,10 @@ def load_state(
|
|||
return model
|
||||
|
||||
|
||||
def make_functional(model: nn.Module):
|
||||
"""make_functional(model) -> weights, func, weight_names
|
||||
def make_functional_deprecated_v1(model: nn.Module):
|
||||
"""make_functional_deprecated_v1(model) -> weights, func, weight_names
|
||||
|
||||
Given an nn.Module, make_functional extracts the state (weights)
|
||||
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
|
||||
and returns a functional version of the model, `func`. This makes
|
||||
it so that it is possible use transforms over the parameters of
|
||||
`model`.
|
||||
|
|
@ -116,7 +117,7 @@ def make_functional(model: nn.Module):
|
|||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, func, _ = make_functional(model)
|
||||
weights, func, _ = make_functional_deprecated_v1(model)
|
||||
func(weights, (x,))
|
||||
```
|
||||
|
||||
|
|
@ -124,7 +125,7 @@ def make_functional(model: nn.Module):
|
|||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, _, func = make_functional(model)
|
||||
weights, _, func = make_functional_deprecated_v1(model)
|
||||
grad_weights = grad(func)(weights, (x,))
|
||||
```
|
||||
|
||||
|
|
@ -132,8 +133,8 @@ def make_functional(model: nn.Module):
|
|||
"""
|
||||
buffers = list(model.buffers())
|
||||
if len(buffers) > 0:
|
||||
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
|
||||
'make_functional_with_buffers(model) instead.')
|
||||
raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use '
|
||||
'make_functional_with_buffers_deprecated_v1(model) instead.')
|
||||
weights, descriptors = extract_weights(model)
|
||||
|
||||
def fun(weights, data):
|
||||
|
|
@ -144,17 +145,17 @@ def make_functional(model: nn.Module):
|
|||
return weights, fun, descriptors
|
||||
|
||||
|
||||
def make_functional_with_buffers(model: nn.Module):
|
||||
"""make_functional_with_buffers(model) -> weights, buffers, func, weight_names, buffer_names
|
||||
def make_functional_with_buffers_deprecated_v1(model: nn.Module):
|
||||
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
|
||||
|
||||
Given an nn.Module, make_functional_with_buffers extracts the state (weights and buffers)
|
||||
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
|
||||
and returns a functional version of the model, `func`.
|
||||
|
||||
`func` can be invoked as follows:
|
||||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, buffers, func, _, _ = make_functional_with_buffers(model)
|
||||
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
|
||||
func(weights, buffers, (x,))
|
||||
```
|
||||
|
||||
|
|
@ -162,7 +163,7 @@ def make_functional_with_buffers(model: nn.Module):
|
|||
```
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
weights, buffers, func, _, _ = make_functional_with_buffers(model)
|
||||
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
|
||||
func(weights, buffers, (x,))
|
||||
grad_weights = grad(func)(weights, buffers, (x,))
|
||||
```
|
||||
|
|
@ -234,10 +235,10 @@ class FunctionalModule(nn.Module):
|
|||
return stateful_model(*args, **kwargs)
|
||||
|
||||
|
||||
def make_functional_v2(model: nn.Module):
|
||||
"""make_functional_v2(model) -> func, weights
|
||||
def make_functional(model: nn.Module):
|
||||
"""make_functional(model) -> func, weights
|
||||
|
||||
Given an nn.Module, make_functional_v2 extracts the state (weights)
|
||||
Given an nn.Module, make_functional extracts the state (weights)
|
||||
and returns a functional version of the model, `func`. This makes
|
||||
it so that it is possible use transforms over the parameters of
|
||||
`model`.
|
||||
|
|
@ -246,11 +247,11 @@ def make_functional_v2(model: nn.Module):
|
|||
```
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functorch import make_functional_v2
|
||||
from functorch import make_functional
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
func, params = make_functional_v2(model)
|
||||
func, params = make_functional(model)
|
||||
func(params, x)
|
||||
```
|
||||
|
||||
|
|
@ -258,12 +259,12 @@ def make_functional_v2(model: nn.Module):
|
|||
```
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functorch import make_functional_v2, grad
|
||||
from functorch import make_functional, grad
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
t = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
func, params = make_functional_v2(model)
|
||||
func, params = make_functional(model)
|
||||
|
||||
def compute_loss(params, x, t):
|
||||
y = func(params, x)
|
||||
|
|
@ -272,17 +273,22 @@ def make_functional_v2(model: nn.Module):
|
|||
grad_weights = grad(compute_loss)(params, x, t)
|
||||
```
|
||||
"""
|
||||
warnings.warn('If this is your first time using make_functional, please '
|
||||
'ignore this warning. Otherwise, we recently made a '
|
||||
'backwards incompatible change to make_functional: '
|
||||
'please try make_functional_deprecated_v1 if you want the '
|
||||
'previous behavior.', stacklevel=2)
|
||||
buffers = list(model.buffers())
|
||||
if len(buffers) > 0:
|
||||
raise RuntimeError('make_functional_v2(model): `model` has buffers. Please use '
|
||||
'make_functional_with_buffers_v2(model) instead.')
|
||||
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
|
||||
'make_functional_with_buffers(model) instead.')
|
||||
return FunctionalModule._create_from(model)
|
||||
|
||||
|
||||
def make_functional_with_buffers_v2(model: nn.Module):
|
||||
"""make_functional_with_buffers_v2(model) -> func, params, buffers
|
||||
def make_functional_with_buffers(model: nn.Module):
|
||||
"""make_functional_with_buffers(model) -> func, params, buffers
|
||||
|
||||
Given an nn.Module, make_functional_with_buffers_v2 extracts the state
|
||||
Given an nn.Module, make_functional_with_buffers extracts the state
|
||||
(params and buffers) and returns a functional version of the model `func`
|
||||
that can be invoked like a function.
|
||||
|
||||
|
|
@ -290,11 +296,11 @@ def make_functional_with_buffers_v2(model: nn.Module):
|
|||
```
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functorch import make_functional_with_buffers_v2
|
||||
from functorch import make_functional_with_buffers
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
func, params, buffers = make_functional_with_buffers_v2(model)
|
||||
func, params, buffers = make_functional_with_buffers(model)
|
||||
func(params, buffers, x)
|
||||
```
|
||||
|
||||
|
|
@ -302,12 +308,12 @@ def make_functional_with_buffers_v2(model: nn.Module):
|
|||
```
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functorch import make_functional_with_buffers_v2, grad
|
||||
from functorch import make_functional_with_buffers, grad
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
t = torch.randn(4, 3)
|
||||
model = nn.Linear(3, 3)
|
||||
func, params, buffers = make_functional_with_buffers_v2(model)
|
||||
func, params, buffers = make_functional_with_buffers(model)
|
||||
|
||||
def compute_loss(params, buffers, x, t):
|
||||
y = func(params, buffers, x)
|
||||
|
|
@ -316,6 +322,11 @@ def make_functional_with_buffers_v2(model: nn.Module):
|
|||
grad_weights = grad(compute_loss)(params, buffers, x, t)
|
||||
```
|
||||
"""
|
||||
warnings.warn('If this is your first time using make_functional_with_buffers, please '
|
||||
'ignore this warning. Otherwise, we recently made a '
|
||||
'backwards incompatible change to make_functional_with_buffers: '
|
||||
'please try make_functional_with_buffers_deprecated_v1 if you want the '
|
||||
'previous behavior.', stacklevel=2)
|
||||
return FunctionalModuleWithBuffers._create_from(model)
|
||||
|
||||
|
||||
|
|
@ -338,7 +349,7 @@ def combine_state_for_ensemble(models):
|
|||
`func(params, buffers, *args, **kwargs)` directly, you probably want to
|
||||
use vmap(func, ...)(params, buffers, *args, **kwargs)
|
||||
"""
|
||||
funcs, params, buffers = zip(*[make_functional_with_buffers_v2(model)
|
||||
funcs, params, buffers = zip(*[make_functional_with_buffers(model)
|
||||
for model in models])
|
||||
params = transpose_stack(params)
|
||||
buffers = transpose_stack(buffers)
|
||||
|
|
@ -368,15 +379,15 @@ def functional_init(model_class, ensemble_shape=(), device='cpu'):
|
|||
raise ValueError('NYI: ensemble_shape with more than 1 element')
|
||||
if len(ensemble_shape) == 0:
|
||||
model = model_class(*args, **kwargs).to(device)
|
||||
return make_functional(model)
|
||||
return make_functional_deprecated_v1(model)
|
||||
num_models = ensemble_shape[0]
|
||||
if num_models <= 0:
|
||||
raise ValueError(f"num_models {num_models} should be > 0")
|
||||
# NB: Not very efficient, more of a POC
|
||||
models = tuple(model_class(*args, **kwargs).to(device)
|
||||
for _ in range(num_models))
|
||||
_, fn, names = make_functional(model_class(*args, **kwargs))
|
||||
weights = tuple(make_functional(model)[0] for model in models)
|
||||
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
|
||||
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
return weights, fn, names
|
||||
|
|
@ -389,7 +400,7 @@ def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
|
|||
raise ValueError('NYI: ensemble_shape with more than 1 element')
|
||||
if len(ensemble_shape) == 0:
|
||||
model = model_class(*args, **kwargs).to(device)
|
||||
return make_functional(model)
|
||||
return make_functional_deprecated_v1(model)
|
||||
num_models = ensemble_shape[0]
|
||||
if num_models <= 0:
|
||||
raise ValueError(f"num_models {num_models} should be > 0")
|
||||
|
|
@ -397,8 +408,8 @@ def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
|
|||
models = tuple(model_class(*args, **kwargs).to(device)
|
||||
for _ in range(num_models))
|
||||
_, _, fn, weight_names, buffer_names = \
|
||||
make_functional_with_buffers(model_class(*args, **kwargs))
|
||||
weights, buffers = zip(*tuple(make_functional_with_buffers(model)[:2]
|
||||
make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
|
||||
weights, buffers = zip(*tuple(make_functional_with_buffers_deprecated_v1(model)[:2]
|
||||
for model in models))
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from functools import partial
|
|||
import functorch
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev, grad_and_value,
|
||||
make_functional_v2, make_functional_with_buffers_v2,
|
||||
make_functional, make_functional_with_buffers,
|
||||
functional_init, functional_init_with_buffers,
|
||||
)
|
||||
|
||||
|
|
@ -553,7 +553,7 @@ class TestVmapOfGrad(TestCase):
|
|||
net = SampleNet(vocab_size).to(device=device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
net_func, weights = make_functional_v2(net)
|
||||
net_func, weights = make_functional(net)
|
||||
|
||||
def compute_loss(weights, data, target):
|
||||
output = net_func(weights, data)
|
||||
|
|
@ -721,7 +721,7 @@ class TestExamplesCorrectness(TestCase):
|
|||
def mse_loss(x, y):
|
||||
return torch.mean((x - y) ** 2)
|
||||
|
||||
net, params = make_functional_v2(ThreeLayerNet().to(device))
|
||||
net, params = make_functional(ThreeLayerNet().to(device))
|
||||
K = 20
|
||||
losses = []
|
||||
num_tasks = 4
|
||||
|
|
@ -811,7 +811,7 @@ class TestExamplesCorrectness(TestCase):
|
|||
Flatten(),
|
||||
nn.Linear(64, n_way)).to(device).to(dtype)
|
||||
|
||||
fnet, params, buffers = make_functional_with_buffers_v2(net)
|
||||
fnet, params, buffers = make_functional_with_buffers(net)
|
||||
net = (params, buffers, fnet)
|
||||
|
||||
def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry):
|
||||
|
|
@ -955,7 +955,7 @@ class TestExamplesCorrectness(TestCase):
|
|||
|
||||
loss_fn = nn.NLLLoss()
|
||||
|
||||
func_model, weights = make_functional_v2(MLPClassifier().to(device))
|
||||
func_model, weights = make_functional(MLPClassifier().to(device))
|
||||
|
||||
def train_step_fn(use_transform, weights, batch, targets, lr=0.2):
|
||||
def compute_loss(weights, batch, targets):
|
||||
|
|
@ -981,7 +981,7 @@ class TestExamplesCorrectness(TestCase):
|
|||
|
||||
def init_fn(num_models):
|
||||
models = tuple(MLPClassifier().to(device) for _ in range(num_models))
|
||||
weights = tuple(make_functional_v2(model)[1] for model in models)
|
||||
weights = tuple(make_functional(model)[1] for model in models)
|
||||
weights = tuple(zip(*weights))
|
||||
weights = tuple(torch.stack(shards).detach() for shards in weights)
|
||||
return weights
|
||||
|
|
@ -1045,7 +1045,7 @@ class TestExamplesCorrectness(TestCase):
|
|||
model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
func_model, weights = make_functional_v2(model)
|
||||
func_model, weights = make_functional(model)
|
||||
|
||||
def compute_loss(weights, image, target):
|
||||
images = image.unsqueeze(0)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from functools import partial
|
|||
import functorch
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev, grad_and_value,
|
||||
make_functional, make_functional_with_buffers, make_fx, nnc_jit
|
||||
make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1, make_fx, nnc_jit
|
||||
)
|
||||
|
||||
# NB: numpy is a testing dependency!
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from common_utils import (
|
|||
)
|
||||
import types
|
||||
|
||||
from functorch import vmap, functional_init_with_buffers, make_functional_with_buffers
|
||||
from functorch import vmap, functional_init_with_buffers, make_functional_with_buffers_deprecated_v1
|
||||
from functorch._C import reshape_dim_into, reshape_dim_outof
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user