mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added functorch to functional_autograd_benchmark
Description: - Following https://github.com/pytorch/functorch/issues/497 adding an option to run benchmarks with functorch and compare to original functional autograd results. Running the benchmark we get below table: <details> <summary> Table </summary> ``` | model | task | mean | var | | -- | -- | -- | -- | | resnet18 | vjp | 0.03826599195599556 | 4.3332115637895186e-06 | | resnet18 | functorch vjp | 0.037201929837465286 | 6.139693198292662e-09 | | resnet18 | vhp | 0.2202976644039154 | 2.8687209052691287e-08 | | resnet18 | functorch vhp | 0.22117868065834045 | 4.108771278765744e-08 | | resnet18 | jvp | 0.18679651618003845 | 1.832455254202614e-08 | | resnet18 | functorch jvp | 0.05305683612823486 | 1.6690266946284282e-08 | | fcn_resnet | vjp | 0.6071907877922058 | 7.436695454998699e-07 | | fcn_resnet | functorch vjp | 0.6115708947181702 | 1.121692207561864e-06 | | fcn_resnet | vhp | 3.419469118118286 | 0.020633839070796967 | | fcn_resnet | jvp | 2.5421929359436035 | 3.1765587209520163e-06 | | fcn_resnet | functorch jvp | 0.7628333568572998 | 1.4555752159139956e-07 | | detr | vjp | 0.19494840502738953 | 1.9122715457342565e-05 | | detr | vhp | 1.1664292812347412 | 0.000948643428273499 | | detr | jvp | 0.9990308880805969 | 1.0214127541985363e-05 | | ppl_simple_reg | vjp | 0.0007535457843914628 | 6.024204690646684e-09 | | ppl_simple_reg | functorch vjp | 0.0016954183811321855 | 1.160151974488599e-08 | | ppl_simple_reg | vhp | 0.0011888503795489669 | 5.93119386937957e-10 | | ppl_simple_reg | functorch vhp | 0.0026826143730431795 | 1.6787025103326414e-08 | | ppl_simple_reg | jvp | 0.001067900680936873 | 7.409912128331086e-10 | | ppl_simple_reg | functorch jvp | 0.002065300941467285 | 9.710328185974504e-08 | | ppl_simple_reg | hvp | 0.001212477684020996 | 1.974137298077494e-09 | | ppl_simple_reg | functorch hvp | 0.00482442369684577 | 2.327668653379078e-07 | | ppl_simple_reg | jacobian | 0.0009108781814575195 | 3.489469158068914e-09 | | ppl_simple_reg | functorch jacobian | 0.0019866942893713713 | 1.938326299466553e-08 | | ppl_simple_reg | hessian | 0.005053090862929821 | 3.370298600202659e-07 | | ppl_simple_reg | functorch hessian | 0.006374978926032782 | 7.556796077778927e-08 | | ppl_simple_reg | hessian_fwdrev | 0.0036706924438476562 | 1.996075527088692e-09 | | ppl_simple_reg | functorch hessian_fwdrev | 0.0058908225037157536 | 7.548283775804521e-08 | | ppl_simple_reg | hessian_revrev | 0.0015769004821777344 | 1.5754418214442012e-08 | | ppl_simple_reg | functorch hessian_revrev | 0.0041002752259373665 | 6.713568723171193e-08 | | ppl_simple_reg | jacfwd | 0.0018048763740807772 | 2.7375660849315864e-08 | | ppl_simple_reg | functorch jacfwd | 0.002047991845756769 | 2.432247070416338e-09 | | ppl_simple_reg | jacrev | 0.0009733677143231034 | 1.0078769818733235e-08 | | ppl_simple_reg | functorch jacrev | 0.0021971464157104492 | 1.2729884701911942e-08 | | ppl_robust_reg | vjp | 0.005820560269057751 | 8.582588151284654e-08 | | ppl_robust_reg | functorch vjp | 0.00796132069081068 | 9.663100541956737e-09 | | ppl_robust_reg | vhp | 0.009825301356613636 | 2.0081762386325863e-07 | | ppl_robust_reg | functorch vhp | 0.014890861697494984 | 4.558066279969353e-07 | | ppl_robust_reg | jvp | 0.008297419175505638 | 2.9454400873873965e-07 | | ppl_robust_reg | functorch jvp | 0.008052706718444824 | 7.120377176761394e-08 | | ppl_robust_reg | hvp | 0.015414690598845482 | 7.42123745567369e-07 | | ppl_robust_reg | functorch hvp | 0.02699306048452854 | 1.4650488537881756e-06 | | ppl_robust_reg | jacobian | 0.006207776255905628 | 1.7068457225377642e-07 | | ppl_robust_reg | functorch jacobian | 0.009173822589218616 | 1.2214455580306094e-07 | | ppl_robust_reg | hessian | 0.04670915752649307 | 1.4299343092716299e-05 | | ppl_robust_reg | functorch hessian | 0.02337808534502983 | 3.0397418413485866e-06 | | ppl_robust_reg | hessian_fwdrev | 0.024229884147644043 | 2.0425247839739313e-06 | | ppl_robust_reg | functorch hessian_fwdrev | 0.022021746262907982 | 3.512146236062108e-07 | | ppl_robust_reg | hessian_revrev | 0.012355780228972435 | 7.090877147675201e-07 | | ppl_robust_reg | functorch hessian_revrev | 0.013960313983261585 | 6.326549737423193e-07 | | ppl_robust_reg | jacfwd | 0.008112502284348011 | 2.88503088086145e-08 | | ppl_robust_reg | functorch jacfwd | 0.008947920985519886 | 4.2070990247111695e-08 | | ppl_robust_reg | jacrev | 0.00635871896520257 | 1.3403841592207755e-07 | | ppl_robust_reg | functorch jacrev | 0.009123563766479492 | 2.677554675756255e-07 | | wav2letter | vjp | 0.02078995667397976 | 2.1110793113621185e-06 | | wav2letter | functorch vjp | 0.019202351570129395 | 9.210506135559626e-09 | | wav2letter | vhp | 0.05997290462255478 | 8.558587616391833e-09 | | wav2letter | functorch vhp | 0.06035261228680611 | 1.6448565842708263e-09 | | wav2letter | jvp | 0.04507789760828018 | 1.5771547401399744e-09 | | wav2letter | functorch jvp | 0.013057494536042213 | 3.804750292601966e-09 | | deepspeech | vjp | 0.3648746609687805 | 1.5359055396402255e-05 | | transformer | vjp | 0.05496881157159805 | 1.242562319703211e-08 | | transformer | functorch vjp | 0.057835936546325684 | 2.6113376350167528e-08 | | transformer | vhp | 0.18313491344451904 | 7.226336151688884e-08 | | transformer | jvp | 0.13924935460090637 | 1.6989159234981344e-07 | | multiheadattn | vjp | 0.0014708995586261153 | 3.710916729460223e-08 | | multiheadattn | functorch vjp | 0.002404856728389859 | 2.1910574687922235e-08 | | multiheadattn | vhp | 0.003382015274837613 | 5.3098595742540056e-08 | | multiheadattn | functorch vhp | 0.005340623669326305 | 5.897558708056749e-08 | | multiheadattn | jvp | 0.0027526854537427425 | 3.508620949332908e-08 | | multiheadattn | functorch jvp | 0.0022981404326856136 | 1.327894807445773e-07 | ``` </details> <details> <summary> Stdout </summary> ``` Found functorch: 0.2.0a0+386a541 Results for model resnet18 on task vjp: 0.03826599195599556s (var: 4.3332115637895186e-06) Results for model resnet18 on task vjp using Functorch: 0.037201929837465286s (var: 6.139693198292662e-09) Results for model resnet18 on task vhp: 0.2202976644039154s (var: 2.8687209052691287e-08) Results for model resnet18 on task vhp using Functorch: 0.22117868065834045s (var: 4.108771278765744e-08) Results for model resnet18 on task jvp: 0.18679651618003845s (var: 1.832455254202614e-08) Results for model resnet18 on task jvp using Functorch: 0.05305683612823486s (var: 1.6690266946284282e-08) Results for model fcn_resnet on task vjp: 0.6071907877922058s (var: 7.436695454998699e-07) Results for model fcn_resnet on task vjp using Functorch: 0.6115708947181702s (var: 1.121692207561864e-06) Results for model fcn_resnet on task vhp: 3.419469118118286s (var: 0.020633839070796967) Failed model using Functorch: fcn_resnet, task: vhp, Error message: CUDA out of memory. Tried to allocate 114.00 MiB (GPU 0; 47.46 GiB total capacity; 45.62 GiB already allocated; 5.31 MiB free; 46.02 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF Results for model fcn_resnet on task jvp: 2.5421929359436035s (var: 3.1765587209520163e-06) Results for model fcn_resnet on task jvp using Functorch: 0.7628333568572998s (var: 1.4555752159139956e-07) Results for model detr on task vjp: 0.19494840502738953s (var: 1.9122715457342565e-05) Failed model using Functorch: detr, task: vjp, Error message: Cannot access data pointer of Tensor that doesn't have storage Results for model detr on task vhp: 1.1664292812347412s (var: 0.000948643428273499) Failed model using Functorch: detr, task: vhp, Error message: Cannot access data pointer of Tensor that doesn't have storage Results for model detr on task jvp: 0.9990308880805969s (var: 1.0214127541985363e-05) Failed model using Functorch: detr, task: jvp, Error message: Trying to use forward AD with _cdist_forward that does not support it because it has not been implemented yet. Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation. Results for model ppl_simple_reg on task vjp: 0.0007535457843914628s (var: 6.024204690646684e-09) Results for model ppl_simple_reg on task vjp using Functorch: 0.0016954183811321855s (var: 1.160151974488599e-08) Results for model ppl_simple_reg on task vhp: 0.0011888503795489669s (var: 5.93119386937957e-10) Results for model ppl_simple_reg on task vhp using Functorch: 0.0026826143730431795s (var: 1.6787025103326414e-08) Results for model ppl_simple_reg on task jvp: 0.001067900680936873s (var: 7.409912128331086e-10) Results for model ppl_simple_reg on task jvp using Functorch: 0.002065300941467285s (var: 9.710328185974504e-08) Results for model ppl_simple_reg on task hvp: 0.001212477684020996s (var: 1.974137298077494e-09) Results for model ppl_simple_reg on task hvp using Functorch: 0.00482442369684577s (var: 2.327668653379078e-07) Results for model ppl_simple_reg on task jacobian: 0.0009108781814575195s (var: 3.489469158068914e-09) Results for model ppl_simple_reg on task jacobian using Functorch: 0.0019866942893713713s (var: 1.938326299466553e-08) Results for model ppl_simple_reg on task hessian: 0.005053090862929821s (var: 3.370298600202659e-07) Results for model ppl_simple_reg on task hessian using Functorch: 0.006374978926032782s (var: 7.556796077778927e-08) Results for model ppl_simple_reg on task hessian_fwdrev: 0.0036706924438476562s (var: 1.996075527088692e-09) Results for model ppl_simple_reg on task hessian_fwdrev using Functorch: 0.0058908225037157536s (var: 7.548283775804521e-08) Results for model ppl_simple_reg on task hessian_revrev: 0.0015769004821777344s (var: 1.5754418214442012e-08) Results for model ppl_simple_reg on task hessian_revrev using Functorch: 0.0041002752259373665s (var: 6.713568723171193e-08) Results for model ppl_simple_reg on task jacfwd: 0.0018048763740807772s (var: 2.7375660849315864e-08) Results for model ppl_simple_reg on task jacfwd using Functorch: 0.002047991845756769s (var: 2.432247070416338e-09) Results for model ppl_simple_reg on task jacrev: 0.0009733677143231034s (var: 1.0078769818733235e-08) Results for model ppl_simple_reg on task jacrev using Functorch: 0.0021971464157104492s (var: 1.2729884701911942e-08) Results for model ppl_robust_reg on task vjp: 0.005820560269057751s (var: 8.582588151284654e-08) Results for model ppl_robust_reg on task vjp using Functorch: 0.00796132069081068s (var: 9.663100541956737e-09) Results for model ppl_robust_reg on task vhp: 0.009825301356613636s (var: 2.0081762386325863e-07) Results for model ppl_robust_reg on task vhp using Functorch: 0.014890861697494984s (var: 4.558066279969353e-07) Results for model ppl_robust_reg on task jvp: 0.008297419175505638s (var: 2.9454400873873965e-07) Results for model ppl_robust_reg on task jvp using Functorch: 0.008052706718444824s (var: 7.120377176761394e-08) Results for model ppl_robust_reg on task hvp: 0.015414690598845482s (var: 7.42123745567369e-07) Results for model ppl_robust_reg on task hvp using Functorch: 0.02699306048452854s (var: 1.4650488537881756e-06) Results for model ppl_robust_reg on task jacobian: 0.006207776255905628s (var: 1.7068457225377642e-07) Results for model ppl_robust_reg on task jacobian using Functorch: 0.009173822589218616s (var: 1.2214455580306094e-07) Results for model ppl_robust_reg on task hessian: 0.04670915752649307s (var: 1.4299343092716299e-05) Results for model ppl_robust_reg on task hessian using Functorch: 0.02337808534502983s (var: 3.0397418413485866e-06) Results for model ppl_robust_reg on task hessian_fwdrev: 0.024229884147644043s (var: 2.0425247839739313e-06) Results for model ppl_robust_reg on task hessian_fwdrev using Functorch: 0.022021746262907982s (var: 3.512146236062108e-07) Results for model ppl_robust_reg on task hessian_revrev: 0.012355780228972435s (var: 7.090877147675201e-07) Results for model ppl_robust_reg on task hessian_revrev using Functorch: 0.013960313983261585s (var: 6.326549737423193e-07) Results for model ppl_robust_reg on task jacfwd: 0.008112502284348011s (var: 2.88503088086145e-08) Results for model ppl_robust_reg on task jacfwd using Functorch: 0.008947920985519886s (var: 4.2070990247111695e-08) Results for model ppl_robust_reg on task jacrev: 0.00635871896520257s (var: 1.3403841592207755e-07) Results for model ppl_robust_reg on task jacrev using Functorch: 0.009123563766479492s (var: 2.677554675756255e-07) Results for model wav2letter on task vjp: 0.02078995667397976s (var: 2.1110793113621185e-06) Results for model wav2letter on task vjp using Functorch: 0.019202351570129395s (var: 9.210506135559626e-09) Results for model wav2letter on task vhp: 0.05997290462255478s (var: 8.558587616391833e-09) Results for model wav2letter on task vhp using Functorch: 0.06035261228680611s (var: 1.6448565842708263e-09) Results for model wav2letter on task jvp: 0.04507789760828018s (var: 1.5771547401399744e-09) Results for model wav2letter on task jvp using Functorch: 0.013057494536042213s (var: 3.804750292601966e-09) Results for model deepspeech on task vjp: 0.3648746609687805s (var: 1.5359055396402255e-05) Failed model using Functorch: deepspeech, task: vjp, Error message: Cannot access storage of TensorWrapper Results for model transformer on task vjp: 0.05496881157159805s (var: 1.242562319703211e-08) Results for model transformer on task vjp using Functorch: 0.057835936546325684s (var: 2.6113376350167528e-08) Results for model transformer on task vhp: 0.18313491344451904s (var: 7.226336151688884e-08) Failed model using Functorch: transformer, task: vhp, Error message: bad optional access Results for model transformer on task jvp: 0.13924935460090637s (var: 1.6989159234981344e-07) Failed model using Functorch: transformer, task: jvp, Error message: Trying to use forward AD with embedding that does not support it because it has not been implemented yet. Please file an issue to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml so that we can prioritize its implementation. Results for model multiheadattn on task vjp: 0.0014708995586261153s (var: 3.710916729460223e-08) Results for model multiheadattn on task vjp using Functorch: 0.002404856728389859s (var: 2.1910574687922235e-08) Results for model multiheadattn on task vhp: 0.003382015274837613s (var: 5.3098595742540056e-08) Results for model multiheadattn on task vhp using Functorch: 0.005340623669326305s (var: 5.897558708056749e-08) Results for model multiheadattn on task jvp: 0.0027526854537427425s (var: 3.508620949332908e-08) Results for model multiheadattn on task jvp using Functorch: 0.0022981404326856136s (var: 1.327894807445773e-07) ``` </details> All functorch errors are reported in its repository. cc @zou3519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75689 Approved by: https://github.com/zou3519
This commit is contained in:
parent
b3aa2de5be
commit
6593d293f7
|
|
@ -20,6 +20,10 @@ export OMP_NUM_THREADS=10
|
||||||
git checkout master
|
git checkout master
|
||||||
python setup.py develop
|
python setup.py develop
|
||||||
|
|
||||||
|
# Install dependencies:
|
||||||
|
# Scipy is required by detr
|
||||||
|
pip install scipy
|
||||||
|
|
||||||
# Run the benchmark for the base
|
# Run the benchmark for the base
|
||||||
# This will use the GPU if available.
|
# This will use the GPU if available.
|
||||||
pushd benchmarks/functional_autograd_benchmark
|
pushd benchmarks/functional_autograd_benchmark
|
||||||
|
|
@ -46,3 +50,18 @@ popd
|
||||||
- `compare.py` is the entry point to run the comparison script that generates a markdown table.
|
- `compare.py` is the entry point to run the comparison script that generates a markdown table.
|
||||||
- `torchaudio_models.py` and `torchvision_models.py` contains code extracted from torchaudio and torchvision to be able to run the models without having a specific version of these libraries installed.
|
- `torchaudio_models.py` and `torchvision_models.py` contains code extracted from torchaudio and torchvision to be able to run the models without having a specific version of these libraries installed.
|
||||||
- `ppl_models.py`, `vision_models.py` and `audio_text_models.py` contain all the getter functions used for the benchmark.
|
- `ppl_models.py`, `vision_models.py` and `audio_text_models.py` contain all the getter functions used for the benchmark.
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmarking against `functorch`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install stable functorch:
|
||||||
|
pip install functorch
|
||||||
|
# or install from source:
|
||||||
|
pip install git+https://github.com/pytorch/functorch
|
||||||
|
|
||||||
|
# Run the benchmark for the base
|
||||||
|
# This will use the GPU if available.
|
||||||
|
pushd benchmarks/functional_autograd_benchmark
|
||||||
|
python functional_autograd_benchmark.py --output bench-with-functorch.txt
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,11 @@ from torch import nn, Tensor
|
||||||
|
|
||||||
import torchaudio_models as models
|
import torchaudio_models as models
|
||||||
|
|
||||||
from utils import extract_weights, load_weights, GetterReturnType
|
from utils import check_for_functorch, extract_weights, load_weights, GetterReturnType
|
||||||
|
|
||||||
|
|
||||||
|
has_functorch = check_for_functorch()
|
||||||
|
|
||||||
|
|
||||||
def get_wav2letter(device: torch.device) -> GetterReturnType:
|
def get_wav2letter(device: torch.device) -> GetterReturnType:
|
||||||
N = 10
|
N = 10
|
||||||
|
|
@ -50,6 +54,12 @@ def get_deepspeech(device: torch.device) -> GetterReturnType:
|
||||||
|
|
||||||
model = models.DeepSpeech(rnn_type=nn.LSTM, labels=labels, rnn_hidden_size=1024, nb_layers=5,
|
model = models.DeepSpeech(rnn_type=nn.LSTM, labels=labels, rnn_hidden_size=1024, nb_layers=5,
|
||||||
audio_conf=audio_conf, bidirectional=True)
|
audio_conf=audio_conf, bidirectional=True)
|
||||||
|
|
||||||
|
if has_functorch:
|
||||||
|
from functorch.experimental import replace_all_batch_norm_modules_
|
||||||
|
|
||||||
|
replace_all_batch_norm_modules_(model)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
criterion = nn.CTCLoss()
|
criterion = nn.CTCLoss()
|
||||||
params, names = extract_weights(model)
|
params, names = extract_weights(model)
|
||||||
|
|
@ -71,6 +81,11 @@ def get_transformer(device: torch.device) -> GetterReturnType:
|
||||||
ntoken = 50
|
ntoken = 50
|
||||||
model = models.TransformerModel(ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2)
|
model = models.TransformerModel(ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
if has_functorch:
|
||||||
|
# disable dropout for consistency checking
|
||||||
|
model.eval()
|
||||||
|
|
||||||
criterion = nn.NLLLoss()
|
criterion = nn.NLLLoss()
|
||||||
params, names = extract_weights(model)
|
params, names = extract_weights(model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,13 @@ from argparse import ArgumentParser
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import NamedTuple, Callable, List, Any
|
from typing import NamedTuple, Callable, List, Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
import functorch as ft
|
||||||
|
has_functorch = True
|
||||||
|
print(f"Found functorch: {ft.__version__}")
|
||||||
|
except ImportError:
|
||||||
|
has_functorch = False
|
||||||
|
|
||||||
import ppl_models
|
import ppl_models
|
||||||
import vision_models
|
import vision_models
|
||||||
import audio_text_models
|
import audio_text_models
|
||||||
|
|
@ -36,6 +43,65 @@ def get_task_func(task: str) -> Callable:
|
||||||
else:
|
else:
|
||||||
return getattr(functional, task)
|
return getattr(functional, task)
|
||||||
|
|
||||||
|
def get_task_functorch(task: str) -> Callable:
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def vjp(model, inp, v=None, strict=None):
|
||||||
|
assert v is not None
|
||||||
|
out, vjpfunc = ft.vjp(model, *inp)
|
||||||
|
return out, vjpfunc(v)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def jvp(model, inp, v=None, strict=None):
|
||||||
|
assert v is not None
|
||||||
|
return ft.jvp(model, inp, v)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def vhp(model, inp, v=None, strict=None):
|
||||||
|
assert v is not None
|
||||||
|
argnums = tuple(range(len(inp)))
|
||||||
|
_, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums), *inp, has_aux=True)
|
||||||
|
return aux, vjpfunc(v)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def hvp(model, inp, v=None, strict=None):
|
||||||
|
assert v is not None
|
||||||
|
argnums = tuple(range(len(inp)))
|
||||||
|
_, hvp_out, aux = ft.jvp(ft.grad_and_value(model, argnums), inp, v, has_aux=True)
|
||||||
|
return aux, hvp_out
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def jacfwd(model, inp, v=None, strict=None):
|
||||||
|
argnums = tuple(range(len(inp)))
|
||||||
|
return ft.jacfwd(model, argnums)(*inp)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def jacrev(model, inp, v=None, strict=None):
|
||||||
|
argnums = tuple(range(len(inp)))
|
||||||
|
return ft.jacrev(model, argnums)(*inp)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def hessian(model, inp, v=None, strict=None):
|
||||||
|
argnums = tuple(range(len(inp)))
|
||||||
|
return ft.hessian(model, argnums=argnums)(*inp)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def hessian_fwdrev(model, inp, v=None, strict=None):
|
||||||
|
argnums = tuple(range(len(inp)))
|
||||||
|
return ft.jacfwd(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def hessian_revrev(model, inp, v=None, strict=None):
|
||||||
|
argnums = tuple(range(len(inp)))
|
||||||
|
return ft.jacrev(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)
|
||||||
|
|
||||||
|
if task in locals():
|
||||||
|
return locals()[task]
|
||||||
|
elif task == "jacobian":
|
||||||
|
raise RuntimeError("functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet")
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported task: {task}")
|
||||||
|
|
||||||
# Listing of the different tasks
|
# Listing of the different tasks
|
||||||
FAST_TASKS_NO_DOUBLE_BACK = [
|
FAST_TASKS_NO_DOUBLE_BACK = [
|
||||||
"vjp",
|
"vjp",
|
||||||
|
|
@ -99,7 +165,7 @@ def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
|
def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) -> None:
|
||||||
func = get_task_func(task)
|
func = get_task_func(task)
|
||||||
|
|
||||||
if v is not None:
|
if v is not None:
|
||||||
|
|
@ -107,7 +173,24 @@ def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
|
||||||
else:
|
else:
|
||||||
res = func(model, inp, strict=True)
|
res = func(model, inp, strict=True)
|
||||||
|
|
||||||
def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]:
|
def run_once_functorch(model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False) -> None:
|
||||||
|
func = get_task_functorch(task)
|
||||||
|
|
||||||
|
if v is not None:
|
||||||
|
res = func(model, inp, v=v, strict=True)
|
||||||
|
else:
|
||||||
|
res = func(model, inp, strict=True)
|
||||||
|
|
||||||
|
if maybe_check_consistency:
|
||||||
|
af_func = get_task_func(task)
|
||||||
|
if v is not None:
|
||||||
|
expected = af_func(model, inp, v=v, strict=True)
|
||||||
|
else:
|
||||||
|
expected = af_func(model, inp, strict=True)
|
||||||
|
atol = 1e-2 if task == "vhp" else 5e-3
|
||||||
|
torch.testing.assert_close(res, expected, rtol=1e-5, atol=atol, msg=f"Consistency fail for task '{task}'")
|
||||||
|
|
||||||
|
def run_model(model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once) -> List[float]:
|
||||||
if args.gpu == -1:
|
if args.gpu == -1:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|
@ -121,14 +204,17 @@ def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]:
|
||||||
model, inp = model_getter(device)
|
model, inp = model_getter(device)
|
||||||
|
|
||||||
v = get_v_for(model, inp, task)
|
v = get_v_for(model, inp, task)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
run_once(model, inp, task, v)
|
# maybe_check_consistency=True checks for consistency between
|
||||||
|
# functorch vs autograd.functional and is done in run_once_functorch only
|
||||||
|
run_once_fn(model, inp, task, v, maybe_check_consistency=True)
|
||||||
|
|
||||||
elapsed = []
|
elapsed = []
|
||||||
for it in range(args.num_iters):
|
for it in range(args.num_iters):
|
||||||
do_sync()
|
do_sync()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
run_once(model, inp, task, v)
|
run_once_fn(model, inp, task, v)
|
||||||
do_sync()
|
do_sync()
|
||||||
elapsed.append(time.time() - start)
|
elapsed.append(time.time() - start)
|
||||||
|
|
||||||
|
|
@ -173,6 +259,18 @@ def main():
|
||||||
results[name][task] = (mean.item(), var.item())
|
results[name][task] = (mean.item(), var.item())
|
||||||
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))
|
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))
|
||||||
|
|
||||||
|
if has_functorch:
|
||||||
|
try:
|
||||||
|
runtimes = run_model(model_getter, args, task, run_once_fn=run_once_functorch)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t", e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
runtimes = torch.tensor(runtimes)
|
||||||
|
mean, var = runtimes.mean(), runtimes.var()
|
||||||
|
results[name][f"functorch {task}"] = (mean.item(), var.item())
|
||||||
|
print("Results for model {} on task {} using Functorch: {}s (var: {})".format(name, task, mean, var))
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
with open(args.output, "w") as f:
|
with open(args.output, "w") as f:
|
||||||
f.write(to_markdown_table(results))
|
f.write(to_markdown_table(results))
|
||||||
|
|
|
||||||
|
|
@ -101,3 +101,10 @@ def from_markdown_table(data: str) -> TimingResultType:
|
||||||
res[model][task] = (float(mean), float(var))
|
res[model][task] = (float(mean), float(var))
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def check_for_functorch():
|
||||||
|
try:
|
||||||
|
import functorch # noqa: F401
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,22 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torchvision_models as models
|
import torchvision_models as models
|
||||||
|
|
||||||
from utils import extract_weights, load_weights, GetterReturnType
|
from utils import check_for_functorch, extract_weights, load_weights, GetterReturnType
|
||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
has_functorch = check_for_functorch()
|
||||||
|
|
||||||
|
|
||||||
def get_resnet18(device: torch.device) -> GetterReturnType:
|
def get_resnet18(device: torch.device) -> GetterReturnType:
|
||||||
N = 32
|
N = 32
|
||||||
model = models.resnet18(pretrained=False)
|
model = models.resnet18(pretrained=False)
|
||||||
|
|
||||||
|
if has_functorch:
|
||||||
|
from functorch.experimental import replace_all_batch_norm_modules_
|
||||||
|
|
||||||
|
replace_all_batch_norm_modules_(model)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
params, names = extract_weights(model)
|
params, names = extract_weights(model)
|
||||||
|
|
@ -29,6 +38,14 @@ def get_fcn_resnet(device: torch.device) -> GetterReturnType:
|
||||||
N = 8
|
N = 8
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)
|
model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)
|
||||||
|
|
||||||
|
if has_functorch:
|
||||||
|
from functorch.experimental import replace_all_batch_norm_modules_
|
||||||
|
|
||||||
|
replace_all_batch_norm_modules_(model)
|
||||||
|
# disable dropout for consistency checking
|
||||||
|
model.eval()
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
params, names = extract_weights(model)
|
params, names = extract_weights(model)
|
||||||
|
|
||||||
|
|
@ -56,6 +73,12 @@ def get_detr(device: torch.device) -> GetterReturnType:
|
||||||
|
|
||||||
model = models.DETR(num_classes=num_classes, hidden_dim=hidden_dim, nheads=nheads,
|
model = models.DETR(num_classes=num_classes, hidden_dim=hidden_dim, nheads=nheads,
|
||||||
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
|
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
|
||||||
|
|
||||||
|
if has_functorch:
|
||||||
|
from functorch.experimental import replace_all_batch_norm_modules_
|
||||||
|
|
||||||
|
replace_all_batch_norm_modules_(model)
|
||||||
|
|
||||||
losses = ['labels', 'boxes', 'cardinality']
|
losses = ['labels', 'boxes', 'cardinality']
|
||||||
eos_coef = 0.1
|
eos_coef = 0.1
|
||||||
bbox_loss_coef = 5
|
bbox_loss_coef = 5
|
||||||
|
|
@ -74,9 +97,9 @@ def get_detr(device: torch.device) -> GetterReturnType:
|
||||||
for idx in range(N):
|
for idx in range(N):
|
||||||
targets = {}
|
targets = {}
|
||||||
n_targets: int = int(torch.randint(5, 10, size=tuple()).item())
|
n_targets: int = int(torch.randint(5, 10, size=tuple()).item())
|
||||||
label = torch.randint(5, 10, size=(n_targets,))
|
label = torch.randint(5, 10, size=(n_targets,), device=device)
|
||||||
targets["labels"] = label
|
targets["labels"] = label
|
||||||
boxes = torch.randint(100, 800, size=(n_targets, 4))
|
boxes = torch.randint(100, 800, size=(n_targets, 4), device=device)
|
||||||
for t in range(n_targets):
|
for t in range(n_targets):
|
||||||
if boxes[t, 0] > boxes[t, 2]:
|
if boxes[t, 0] > boxes[t, 2]:
|
||||||
boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
|
boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user