mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Extend autograd functional benchmarking to run vectorized tasks (#67045)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67045 To run: `python benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py --gpu -1 --model-filter=ppl _robust_reg --num-iter 100` ``` Results for model ppl_robust_reg on task vjp: 0.0012262486852705479s (var: 2.2107682351446556e-10) Results for model ppl_robust_reg on task vhp: 0.002099371049553156s (var: 6.906406557760647e-10) Results for model ppl_robust_reg on task jvp: 0.001860950025729835s (var: 1.1251884146634694e-10) Results for model ppl_robust_reg on task hvp: 0.003481731517240405s (var: 2.2713633751614282e-10) Results for model ppl_robust_reg on task jacobian: 0.0012128615053370595s (var: 1.3687526667638394e-09) Results for model ppl_robust_reg on task hessian: 0.009885427542030811s (var: 9.366265096844018e-09) Results for model ppl_robust_reg on task hessian_fwdrev: 0.005268776323646307s (var: 2.4293791422991262e-09) Results for model ppl_robust_reg on task hessian_revrev: 0.002561321249231696s (var: 7.557877101938004e-10) Results for model ppl_robust_reg on task jacfwd: 0.002619938924908638s (var: 5.109343503839625e-10) Results for model ppl_robust_reg on task jacrev: 0.0013469004770740867s (var: 3.1857563254078514e-09) ``` Notes: - We go through batched fallback for both - ppl_robust_reg takes 3 tensor inputs and returns a single scalar output - this means that jacobian is equivalent to doing vjp and vmap would not help us - we expect jacfwd to be slower than jacrev Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D33265947 Pulled By: soulitzer fbshipit-source-id: 14f537a1376dea7e5afbe0c8e97f94731479b018
This commit is contained in:
parent
82c5f298ed
commit
21c6de9fdc
|
|
@ -12,6 +12,30 @@ import audio_text_models
|
|||
|
||||
from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType
|
||||
|
||||
def get_task_func(task: str) -> Callable:
|
||||
def hessian_fwdrev(model, inp, strict=None):
|
||||
return functional.hessian(model, inp, strict=False, vectorize=True, outer_jacobian_strategy="forward-mode")
|
||||
|
||||
def hessian_revrev(model, inp, strict=None):
|
||||
return functional.hessian(model, inp, strict=False, vectorize=True)
|
||||
|
||||
def jacfwd(model, inp, strict=None):
|
||||
return functional.jacobian(model, inp, strict=False, vectorize=True, strategy="forward-mode")
|
||||
|
||||
def jacrev(model, inp, strict=None):
|
||||
return functional.jacobian(model, inp, strict=False, vectorize=True)
|
||||
|
||||
if task == "hessian_fwdrev":
|
||||
return hessian_fwdrev
|
||||
elif task == "hessian_revrev":
|
||||
return hessian_revrev
|
||||
elif task == "jacfwd":
|
||||
return jacfwd
|
||||
elif task == "jacrev":
|
||||
return jacrev
|
||||
else:
|
||||
return getattr(functional, task)
|
||||
|
||||
# Listing of the different tasks
|
||||
FAST_TASKS_NO_DOUBLE_BACK = [
|
||||
"vjp",
|
||||
|
|
@ -22,7 +46,7 @@ FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
|
|||
"jvp",
|
||||
]
|
||||
|
||||
ALL_TASKS = FAST_TASKS + [
|
||||
ALL_TASKS_NON_VECTORIZED = FAST_TASKS + [
|
||||
"hvp",
|
||||
"jacobian",
|
||||
"hessian"
|
||||
|
|
@ -30,6 +54,10 @@ ALL_TASKS = FAST_TASKS + [
|
|||
|
||||
DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
|
||||
|
||||
VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"]
|
||||
|
||||
ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
|
||||
|
||||
# Model definition which contains:
|
||||
# - name: a string with the model name.
|
||||
# - getter: a function to get the model. It takes as input the device on which the model
|
||||
|
|
@ -72,7 +100,7 @@ def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
|
|||
return v
|
||||
|
||||
def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
|
||||
func = getattr(functional, task)
|
||||
func = get_task_func(task)
|
||||
|
||||
if v is not None:
|
||||
res = func(model, inp, v=v, strict=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user