Add some missing disabled functions (#103662)

Disable Adadelta, rprop, multitensor, and the fused optimizers

Fixes https://github.com/pytorch/benchmark/actions/runs/5167132765/jobs/9307817625

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103662
Approved by: https://github.com/janeyx99
This commit is contained in:
Michael Lazos 2023-06-16 00:11:09 +00:00 committed by PyTorch MergeBot
parent ecf4ce7a0e
commit 38f35b4fc3

View File

@ -1237,10 +1237,35 @@ class TorchPatcher:
DistributedDataParallel._inside_ddp_forward, recursive=False DistributedDataParallel._inside_ddp_forward, recursive=False
) )
from ..optim import adagrad, adam, adamax, adamw, asgd, nadam, sgd # Note: this excludes the optimizers that are unsupported in excluded_opts below
from ..optim import (
adadelta,
adagrad,
adam,
adamax,
adamw,
asgd,
nadam,
rmsprop,
rprop,
sgd,
)
for opt_mod in adagrad, adam, adamax, adamw, asgd, nadam, sgd: for opt_mod in (
multi_tensor_fn_name = f"_multi_tensor_{opt_mod.__name__.split('.')[-1]}" adadelta,
adagrad,
adam,
adamax,
adamw,
asgd,
nadam,
rmsprop,
rprop,
sgd,
):
opt_name = opt_mod.__name__.split(".")[-1]
multi_tensor_fn_name = f"_multi_tensor_{opt_name}"
fused_fn_name = f"_fused_{opt_name}"
if hasattr(opt_mod, multi_tensor_fn_name): if hasattr(opt_mod, multi_tensor_fn_name):
setattr( setattr(
opt_mod, opt_mod,
@ -1248,6 +1273,12 @@ class TorchPatcher:
disable(getattr(opt_mod, multi_tensor_fn_name)), disable(getattr(opt_mod, multi_tensor_fn_name)),
) )
if hasattr(opt_mod, fused_fn_name):
setattr(
opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
)
# Note: we don't support sparsity, data-dependent control, or tracing through backwards
excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS} excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS}
for opt in optimizers: for opt in optimizers:
if opt in excluded_opts: if opt in excluded_opts: