mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add some missing disabled functions (#103662)
Disable Adadelta, rprop, multitensor, and the fused optimizers Fixes https://github.com/pytorch/benchmark/actions/runs/5167132765/jobs/9307817625 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103662 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
ecf4ce7a0e
commit
38f35b4fc3
|
|
@ -1237,10 +1237,35 @@ class TorchPatcher:
|
|||
DistributedDataParallel._inside_ddp_forward, recursive=False
|
||||
)
|
||||
|
||||
from ..optim import adagrad, adam, adamax, adamw, asgd, nadam, sgd
|
||||
# Note: this excludes the optimizers that are unsupported in excluded_opts below
|
||||
from ..optim import (
|
||||
adadelta,
|
||||
adagrad,
|
||||
adam,
|
||||
adamax,
|
||||
adamw,
|
||||
asgd,
|
||||
nadam,
|
||||
rmsprop,
|
||||
rprop,
|
||||
sgd,
|
||||
)
|
||||
|
||||
for opt_mod in adagrad, adam, adamax, adamw, asgd, nadam, sgd:
|
||||
multi_tensor_fn_name = f"_multi_tensor_{opt_mod.__name__.split('.')[-1]}"
|
||||
for opt_mod in (
|
||||
adadelta,
|
||||
adagrad,
|
||||
adam,
|
||||
adamax,
|
||||
adamw,
|
||||
asgd,
|
||||
nadam,
|
||||
rmsprop,
|
||||
rprop,
|
||||
sgd,
|
||||
):
|
||||
opt_name = opt_mod.__name__.split(".")[-1]
|
||||
multi_tensor_fn_name = f"_multi_tensor_{opt_name}"
|
||||
fused_fn_name = f"_fused_{opt_name}"
|
||||
if hasattr(opt_mod, multi_tensor_fn_name):
|
||||
setattr(
|
||||
opt_mod,
|
||||
|
|
@ -1248,6 +1273,12 @@ class TorchPatcher:
|
|||
disable(getattr(opt_mod, multi_tensor_fn_name)),
|
||||
)
|
||||
|
||||
if hasattr(opt_mod, fused_fn_name):
|
||||
setattr(
|
||||
opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
|
||||
)
|
||||
|
||||
# Note: we don't support sparsity, data-dependent control, or tracing through backwards
|
||||
excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS}
|
||||
for opt in optimizers:
|
||||
if opt in excluded_opts:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user