mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[FSDP] Clarify loss dtype check in _test_fsdp_parity (#90251)
A recent PR deprecated `torch.testing.assert_allclose` in favor of `torch.testing.assert_close` and left a `TODO`. This PR follows up to confirm that we do intend to have `check_dtype=False`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90251 Approved by: https://github.com/rohan-varma
This commit is contained in:
parent
919e09f26a
commit
7436b19eb2
|
|
@ -1003,8 +1003,8 @@ class FSDPTest(MultiProcessTestCase):
|
|||
self.assertEqual(param.device, cpu_device)
|
||||
fsdp_loss = fsdp_loss.cuda()
|
||||
fsdp_unsharded_params = get_full_params(fsdp_model)
|
||||
# TODO: Are mismatching dtypes actually ok here or did this pass silently before, because `check_dtype=False`
|
||||
# was the default?
|
||||
# Do not check dtype since the reference DDP loss may not be the same
|
||||
# dtype as the FSDP loss in the case of mixed precision
|
||||
torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
|
||||
# Do not check for parameter parity if using mixed precision since (1)
|
||||
# the DDP parameters are in FP16 (from `half()`) while the FSDP
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user