mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add some missing disabled functions (#103662)
Disable Adadelta, rprop, multitensor, and the fused optimizers Fixes https://github.com/pytorch/benchmark/actions/runs/5167132765/jobs/9307817625 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103662 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
ecf4ce7a0e
commit
38f35b4fc3
|
|
@ -1237,10 +1237,35 @@ class TorchPatcher:
|
||||||
DistributedDataParallel._inside_ddp_forward, recursive=False
|
DistributedDataParallel._inside_ddp_forward, recursive=False
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..optim import adagrad, adam, adamax, adamw, asgd, nadam, sgd
|
# Note: this excludes the optimizers that are unsupported in excluded_opts below
|
||||||
|
from ..optim import (
|
||||||
|
adadelta,
|
||||||
|
adagrad,
|
||||||
|
adam,
|
||||||
|
adamax,
|
||||||
|
adamw,
|
||||||
|
asgd,
|
||||||
|
nadam,
|
||||||
|
rmsprop,
|
||||||
|
rprop,
|
||||||
|
sgd,
|
||||||
|
)
|
||||||
|
|
||||||
for opt_mod in adagrad, adam, adamax, adamw, asgd, nadam, sgd:
|
for opt_mod in (
|
||||||
multi_tensor_fn_name = f"_multi_tensor_{opt_mod.__name__.split('.')[-1]}"
|
adadelta,
|
||||||
|
adagrad,
|
||||||
|
adam,
|
||||||
|
adamax,
|
||||||
|
adamw,
|
||||||
|
asgd,
|
||||||
|
nadam,
|
||||||
|
rmsprop,
|
||||||
|
rprop,
|
||||||
|
sgd,
|
||||||
|
):
|
||||||
|
opt_name = opt_mod.__name__.split(".")[-1]
|
||||||
|
multi_tensor_fn_name = f"_multi_tensor_{opt_name}"
|
||||||
|
fused_fn_name = f"_fused_{opt_name}"
|
||||||
if hasattr(opt_mod, multi_tensor_fn_name):
|
if hasattr(opt_mod, multi_tensor_fn_name):
|
||||||
setattr(
|
setattr(
|
||||||
opt_mod,
|
opt_mod,
|
||||||
|
|
@ -1248,6 +1273,12 @@ class TorchPatcher:
|
||||||
disable(getattr(opt_mod, multi_tensor_fn_name)),
|
disable(getattr(opt_mod, multi_tensor_fn_name)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(opt_mod, fused_fn_name):
|
||||||
|
setattr(
|
||||||
|
opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: we don't support sparsity, data-dependent control, or tracing through backwards
|
||||||
excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS}
|
excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS}
|
||||||
for opt in optimizers:
|
for opt in optimizers:
|
||||||
if opt in excluded_opts:
|
if opt in excluded_opts:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user