diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index d99ccd0e656..0ababa80512 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -1505,25 +1505,38 @@ void Reducer::ensure_prior_reduction_finished() { // The variable `require_finalize_` is true until all gradients // have been computed and reduction of all buckets has been kicked off. if (require_finalize_) { - TORCH_CHECK( - false, - "Expected to have finished reduction in the prior iteration before ", - "starting a new one. ", - "", - "This error indicates that your module has parameters that were ", - "not used in producing loss. ", - "", - "You can enable unused parameter detection by (1) passing the keyword " - "argument `find_unused_parameters=True` to ", - "`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ", - "`forward` function outputs participate in calculating loss. " - "", - "If you already have done the above two steps, then the distributed ", - "data parallel module wasn't able to locate the output tensors in the ", - "return value of your module's `forward` function. ", - "Please include the loss function and the structure of the return ", - "value of `forward` of your module when reporting this issue (e.g. ", - "list, dict, iterable)."); + std::string kBaseErrorMsg = "Expected to have finished reduction in the prior iteration before " + "starting a new one. " + "" + "This error indicates that your module has parameters that were " + "not used in producing loss. "; + std::string kOutputsNotUsedInLossErrorMsg = "making sure all " + "`forward` function outputs participate in calculating loss. "; + std::string kDDPBugErrorMsg = "\nIf you already have done the above, then the distributed " + "data parallel module wasn't able to locate the output tensors in the " + "return value of your module's `forward` function. " + "Please include the loss function and the structure of the return " + "value of `forward` of your module when reporting this issue (e.g. " + "list, dict, iterable)."; + + if (!find_unused_parameters_) { + // Parameters may have been unused in forward pass, or not all outputs + // were used in producing loss. + kBaseErrorMsg += "You can enable unused parameter detection by passing the " + "keyword argument `find_unused_parameters=True` to " + "`torch.nn.parallel.DistributedDataParallel`, and by \n"; + kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg; + kBaseErrorMsg += kDDPBugErrorMsg; + } else { + // Note that it does not really matter whether unused_parameters_.empty(), + // since user may have enabled detection but this particular iteration + // could have used or not used all parameters. + kBaseErrorMsg += "Since `find_unused_parameters=True` is enabled, this likely " + " means that not all `forward` outputs participate in computing loss. You can fix this by "; + kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg; + kBaseErrorMsg += kDDPBugErrorMsg; + } + TORCH_CHECK(false, kBaseErrorMsg); } } diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 7f67196ee73..f0bbd7350a1 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -115,6 +115,17 @@ def get_profiling_event(postfix, profiler): event for event in profiler.function_events if event.name.endswith(postfix) ] +# Base error message substring on unfinished reductions. +ddp_prev_reduction_unfinished_str = "Expected to have finished reduction in the prior iteration" +# Error message substring when find_unused_parameters=True has not been passed +ddp_recommend_find_unused_params_str = "passing the keyword argument `find_unused_parameters=True`" +# Error message substring when find_unused_parameters=True is enabled +ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled" +# Error message substring for possibility of not all model outputs being used +# in loss computation +ddp_outputs_not_used_in_loss_str = "`forward` function outputs participate in calculating loss" + + class _FC2(nn.Module): def __init__(self): @@ -4407,11 +4418,23 @@ class DistributedTest: # On 2nd iteration, this will fail during rebuild_buckets, # but we should report an error regarding unused parameters # since that is the underlying root cause. - with self.assertRaisesRegex( - RuntimeError, - "Expected to have finished reduction in the prior iteration", - ): + try: ddp(inp).sum().backward() + except RuntimeError as e: + msg = str(e) + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_recommend_find_unused_params_str, + ddp_outputs_not_used_in_loss_str + ] + for s in expected_strs: + self.assertTrue( + s in msg, + f"Expected {s} to be in {msg}" + ) + self.assertFalse(ddp_find_unused_params_enabled_str in msg) + else: + self.assertFalse(True, "DDP unused parameters error not raised.") else: ddp(inp).sum().backward() @@ -4649,12 +4672,28 @@ class DistributedTest: find_unused_parameters=False, ) for i in range(2): - with self.assertRaisesRegex( - RuntimeError, - "Expected to have finished reduction in the prior iteration before starting a new one", - ) if i == 1 else suppress(): + if i == 0: loss = model(random_input).sum() loss.backward() + else: + try: + loss = model(random_input).sum() + loss.backward() + except RuntimeError as e: + msg = str(e) + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_recommend_find_unused_params_str, + ddp_outputs_not_used_in_loss_str + ] + for s in expected_strs: + self.assertTrue( + s in msg, + f"Expected {s} to be in {msg}" + ) + self.assertFalse(ddp_find_unused_params_enabled_str in msg) + else: + self.assertFalse(True, "DDP error not raised") @require_backend({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"}) @@ -4723,12 +4762,28 @@ class DistributedTest: find_unused_parameters=False, ) for i in range(2): - with self.assertRaisesRegex( - RuntimeError, - "Expected to have finished reduction in the prior iteration before starting a new one", - ) if i == 1 else suppress(): + if i == 0: loss = model(random_input).sum() loss.backward() + else: + try: + loss = model(random_input).sum() + loss.backward() + except RuntimeError as e: + msg = str(e) + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_recommend_find_unused_params_str, + ddp_outputs_not_used_in_loss_str + ] + for s in expected_strs: + self.assertTrue( + s in msg, + f"Expected {s} to be in {msg}" + ) + self.assertFalse(ddp_find_unused_params_enabled_str in msg) + else: + self.assertFalse(True, "DDP error not raised") @require_backend({"gloo"}) @unittest.skipIf(BACKEND == "nccl", "NCCL does not support scatter") @@ -4778,15 +4833,44 @@ class DistributedTest: for ddp in [net, net_with_find_unused]: for i in range(2): - ctx = ( - suppress() - if i == 0 - else self.assertRaisesRegex( - RuntimeError, - "Expected to have finished reduction in the prior iteration", - ) - ) - with ctx: + if i == 0: a, b = ddp(inp) loss = b.sum() loss.backward() + else: + try: + a, b = ddp(inp) + loss = b.sum() + loss.backward() + except RuntimeError as e: + msg = str(e) + if ddp == net: + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_recommend_find_unused_params_str, + ddp_outputs_not_used_in_loss_str, + ] + unexpected_strs = [ + ddp_find_unused_params_enabled_str, + ] + elif ddp == net_with_find_unused: + expected_strs = [ + ddp_prev_reduction_unfinished_str, + ddp_outputs_not_used_in_loss_str, + ddp_find_unused_params_enabled_str, + ] + unexpected_strs = [ + ddp_recommend_find_unused_params_str, + ] + for s in expected_strs: + self.assertTrue( + s in msg, + f"Expected {s} to be in {msg}" + ) + for s in unexpected_strs: + self.assertFalse( + s in msg, + f"Expected {s} not to be in {msg}" + ) + else: + self.assertFalse(True, "DDP error not raised")