diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index abb3c9ee009..b107fdd21f9 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1237,10 +1237,35 @@ class TorchPatcher: DistributedDataParallel._inside_ddp_forward, recursive=False ) - from ..optim import adagrad, adam, adamax, adamw, asgd, nadam, sgd + # Note: this excludes the optimizers that are unsupported in excluded_opts below + from ..optim import ( + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + nadam, + rmsprop, + rprop, + sgd, + ) - for opt_mod in adagrad, adam, adamax, adamw, asgd, nadam, sgd: - multi_tensor_fn_name = f"_multi_tensor_{opt_mod.__name__.split('.')[-1]}" + for opt_mod in ( + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + nadam, + rmsprop, + rprop, + sgd, + ): + opt_name = opt_mod.__name__.split(".")[-1] + multi_tensor_fn_name = f"_multi_tensor_{opt_name}" + fused_fn_name = f"_fused_{opt_name}" if hasattr(opt_mod, multi_tensor_fn_name): setattr( opt_mod, @@ -1248,6 +1273,12 @@ class TorchPatcher: disable(getattr(opt_mod, multi_tensor_fn_name)), ) + if hasattr(opt_mod, fused_fn_name): + setattr( + opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) + ) + + # Note: we don't support sparsity, data-dependent control, or tracing through backwards excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS} for opt in optimizers: if opt in excluded_opts: