[FSDP] Clarify loss dtype check in _test_fsdp_parity (#90251)

A recent PR deprecated `torch.testing.assert_allclose` in favor of `torch.testing.assert_close` and left a `TODO`. This PR follows up to confirm that we do intend to have `check_dtype=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90251
Approved by: https://github.com/rohan-varma
This commit is contained in:
Andrew Gu 2022-12-06 15:31:38 +00:00 committed by PyTorch MergeBot
parent 919e09f26a
commit 7436b19eb2

View File

@ -1003,8 +1003,8 @@ class FSDPTest(MultiProcessTestCase):
self.assertEqual(param.device, cpu_device)
fsdp_loss = fsdp_loss.cuda()
fsdp_unsharded_params = get_full_params(fsdp_model)
# TODO: Are mismatching dtypes actually ok here or did this pass silently before, because `check_dtype=False`
# was the default?
# Do not check dtype since the reference DDP loss may not be the same
# dtype as the FSDP loss in the case of mixed precision
torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
# Do not check for parameter parity if using mixed precision since (1)
# the DDP parameters are in FP16 (from `half()`) while the FSDP