diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 8ac544eef3c..3c1c155301c 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1444,11 +1444,11 @@ class DistributedDataParallel(Module, Joinable): """`TorchDynamo` requires DDP's status and module for cooperative optimization.""" return cls._active_ddp_module + @torch._disable_dynamo(recursive=True) # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in # for the 'module_to_run' underneath # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details @contextmanager - @torch._disable_dynamo(recursive=False) def _inside_ddp_forward(self): DistributedDataParallel._active_ddp_module = self try: