mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[DDP] Separate error messages for unused params in forward and not all outputs (#52391)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52391 There are 2 ways DDP can throw the exception refactored here - 1) Unused params in the forward pass. We provide `find_unused_parameters=True` for this. 2) All params used in fwd pass, but not all outputs used in loss computation. There are a few workarounds for this but we do not provide native support. Previously, these 2 issues were combined into 1 error message but that has historically resulted in confusion, with users reporting getting this error even when they enable `find_unused_parameters=True` (which they expect to fix this error). As a result there is additional churn to debug these issues because the true cause (1) vs (2) is not known. This commit helps to fix the issue by separating out the 2 error messages depending on if we ran with unused parameter detection or not. Hopefully this should make the error message much more clear and actionable. error msg with `find_unused_params=True`: ``` RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. Since `find_unused_parameters=True` is enabled, this likely means that not all `forward` outputs participate in computing loss. You can fix this by making sure all `forward` function outputs participate in calculating loss. If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable). ``` error msg without `find_unused_params` specified: ``` RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by making sure all `forward` function outputs participate in calculating loss. If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable). ``` ghstack-source-id: 122097900 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D26496688 fbshipit-source-id: 4a9eeeda10293da13d94a692d10cb954e4506d7c
This commit is contained in:
parent
a3e693789f
commit
ef8d17e112
|
|
@ -1505,25 +1505,38 @@ void Reducer::ensure_prior_reduction_finished() {
|
|||
// The variable `require_finalize_` is true until all gradients
|
||||
// have been computed and reduction of all buckets has been kicked off.
|
||||
if (require_finalize_) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Expected to have finished reduction in the prior iteration before ",
|
||||
"starting a new one. ",
|
||||
"",
|
||||
"This error indicates that your module has parameters that were ",
|
||||
"not used in producing loss. ",
|
||||
"",
|
||||
"You can enable unused parameter detection by (1) passing the keyword "
|
||||
"argument `find_unused_parameters=True` to ",
|
||||
"`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ",
|
||||
"`forward` function outputs participate in calculating loss. "
|
||||
"",
|
||||
"If you already have done the above two steps, then the distributed ",
|
||||
"data parallel module wasn't able to locate the output tensors in the ",
|
||||
"return value of your module's `forward` function. ",
|
||||
"Please include the loss function and the structure of the return ",
|
||||
"value of `forward` of your module when reporting this issue (e.g. ",
|
||||
"list, dict, iterable).");
|
||||
std::string kBaseErrorMsg = "Expected to have finished reduction in the prior iteration before "
|
||||
"starting a new one. "
|
||||
""
|
||||
"This error indicates that your module has parameters that were "
|
||||
"not used in producing loss. ";
|
||||
std::string kOutputsNotUsedInLossErrorMsg = "making sure all "
|
||||
"`forward` function outputs participate in calculating loss. ";
|
||||
std::string kDDPBugErrorMsg = "\nIf you already have done the above, then the distributed "
|
||||
"data parallel module wasn't able to locate the output tensors in the "
|
||||
"return value of your module's `forward` function. "
|
||||
"Please include the loss function and the structure of the return "
|
||||
"value of `forward` of your module when reporting this issue (e.g. "
|
||||
"list, dict, iterable).";
|
||||
|
||||
if (!find_unused_parameters_) {
|
||||
// Parameters may have been unused in forward pass, or not all outputs
|
||||
// were used in producing loss.
|
||||
kBaseErrorMsg += "You can enable unused parameter detection by passing the "
|
||||
"keyword argument `find_unused_parameters=True` to "
|
||||
"`torch.nn.parallel.DistributedDataParallel`, and by \n";
|
||||
kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
|
||||
kBaseErrorMsg += kDDPBugErrorMsg;
|
||||
} else {
|
||||
// Note that it does not really matter whether unused_parameters_.empty(),
|
||||
// since user may have enabled detection but this particular iteration
|
||||
// could have used or not used all parameters.
|
||||
kBaseErrorMsg += "Since `find_unused_parameters=True` is enabled, this likely "
|
||||
" means that not all `forward` outputs participate in computing loss. You can fix this by ";
|
||||
kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
|
||||
kBaseErrorMsg += kDDPBugErrorMsg;
|
||||
}
|
||||
TORCH_CHECK(false, kBaseErrorMsg);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -115,6 +115,17 @@ def get_profiling_event(postfix, profiler):
|
|||
event for event in profiler.function_events if event.name.endswith(postfix)
|
||||
]
|
||||
|
||||
# Base error message substring on unfinished reductions.
|
||||
ddp_prev_reduction_unfinished_str = "Expected to have finished reduction in the prior iteration"
|
||||
# Error message substring when find_unused_parameters=True has not been passed
|
||||
ddp_recommend_find_unused_params_str = "passing the keyword argument `find_unused_parameters=True`"
|
||||
# Error message substring when find_unused_parameters=True is enabled
|
||||
ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled"
|
||||
# Error message substring for possibility of not all model outputs being used
|
||||
# in loss computation
|
||||
ddp_outputs_not_used_in_loss_str = "`forward` function outputs participate in calculating loss"
|
||||
|
||||
|
||||
|
||||
class _FC2(nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -4407,11 +4418,23 @@ class DistributedTest:
|
|||
# On 2nd iteration, this will fail during rebuild_buckets,
|
||||
# but we should report an error regarding unused parameters
|
||||
# since that is the underlying root cause.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to have finished reduction in the prior iteration",
|
||||
):
|
||||
try:
|
||||
ddp(inp).sum().backward()
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
expected_strs = [
|
||||
ddp_prev_reduction_unfinished_str,
|
||||
ddp_recommend_find_unused_params_str,
|
||||
ddp_outputs_not_used_in_loss_str
|
||||
]
|
||||
for s in expected_strs:
|
||||
self.assertTrue(
|
||||
s in msg,
|
||||
f"Expected {s} to be in {msg}"
|
||||
)
|
||||
self.assertFalse(ddp_find_unused_params_enabled_str in msg)
|
||||
else:
|
||||
self.assertFalse(True, "DDP unused parameters error not raised.")
|
||||
else:
|
||||
ddp(inp).sum().backward()
|
||||
|
||||
|
|
@ -4649,12 +4672,28 @@ class DistributedTest:
|
|||
find_unused_parameters=False,
|
||||
)
|
||||
for i in range(2):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to have finished reduction in the prior iteration before starting a new one",
|
||||
) if i == 1 else suppress():
|
||||
if i == 0:
|
||||
loss = model(random_input).sum()
|
||||
loss.backward()
|
||||
else:
|
||||
try:
|
||||
loss = model(random_input).sum()
|
||||
loss.backward()
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
expected_strs = [
|
||||
ddp_prev_reduction_unfinished_str,
|
||||
ddp_recommend_find_unused_params_str,
|
||||
ddp_outputs_not_used_in_loss_str
|
||||
]
|
||||
for s in expected_strs:
|
||||
self.assertTrue(
|
||||
s in msg,
|
||||
f"Expected {s} to be in {msg}"
|
||||
)
|
||||
self.assertFalse(ddp_find_unused_params_enabled_str in msg)
|
||||
else:
|
||||
self.assertFalse(True, "DDP error not raised")
|
||||
|
||||
@require_backend({"gloo", "nccl"})
|
||||
@require_backends_available({"gloo", "nccl"})
|
||||
|
|
@ -4723,12 +4762,28 @@ class DistributedTest:
|
|||
find_unused_parameters=False,
|
||||
)
|
||||
for i in range(2):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to have finished reduction in the prior iteration before starting a new one",
|
||||
) if i == 1 else suppress():
|
||||
if i == 0:
|
||||
loss = model(random_input).sum()
|
||||
loss.backward()
|
||||
else:
|
||||
try:
|
||||
loss = model(random_input).sum()
|
||||
loss.backward()
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
expected_strs = [
|
||||
ddp_prev_reduction_unfinished_str,
|
||||
ddp_recommend_find_unused_params_str,
|
||||
ddp_outputs_not_used_in_loss_str
|
||||
]
|
||||
for s in expected_strs:
|
||||
self.assertTrue(
|
||||
s in msg,
|
||||
f"Expected {s} to be in {msg}"
|
||||
)
|
||||
self.assertFalse(ddp_find_unused_params_enabled_str in msg)
|
||||
else:
|
||||
self.assertFalse(True, "DDP error not raised")
|
||||
|
||||
@require_backend({"gloo"})
|
||||
@unittest.skipIf(BACKEND == "nccl", "NCCL does not support scatter")
|
||||
|
|
@ -4778,15 +4833,44 @@ class DistributedTest:
|
|||
|
||||
for ddp in [net, net_with_find_unused]:
|
||||
for i in range(2):
|
||||
ctx = (
|
||||
suppress()
|
||||
if i == 0
|
||||
else self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to have finished reduction in the prior iteration",
|
||||
)
|
||||
)
|
||||
with ctx:
|
||||
if i == 0:
|
||||
a, b = ddp(inp)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
else:
|
||||
try:
|
||||
a, b = ddp(inp)
|
||||
loss = b.sum()
|
||||
loss.backward()
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
if ddp == net:
|
||||
expected_strs = [
|
||||
ddp_prev_reduction_unfinished_str,
|
||||
ddp_recommend_find_unused_params_str,
|
||||
ddp_outputs_not_used_in_loss_str,
|
||||
]
|
||||
unexpected_strs = [
|
||||
ddp_find_unused_params_enabled_str,
|
||||
]
|
||||
elif ddp == net_with_find_unused:
|
||||
expected_strs = [
|
||||
ddp_prev_reduction_unfinished_str,
|
||||
ddp_outputs_not_used_in_loss_str,
|
||||
ddp_find_unused_params_enabled_str,
|
||||
]
|
||||
unexpected_strs = [
|
||||
ddp_recommend_find_unused_params_str,
|
||||
]
|
||||
for s in expected_strs:
|
||||
self.assertTrue(
|
||||
s in msg,
|
||||
f"Expected {s} to be in {msg}"
|
||||
)
|
||||
for s in unexpected_strs:
|
||||
self.assertFalse(
|
||||
s in msg,
|
||||
f"Expected {s} not to be in {msg}"
|
||||
)
|
||||
else:
|
||||
self.assertFalse(True, "DDP error not raised")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user