[c10d][gloo] Enable using c10::Half for gloo (#153862)

Testing with https://github.com/pytorch/gloo/pull/446 and we see that the numerical issues reported in https://github.com/pytorch/pytorch/issues/152300 is indeed resolved and we added a unit test for it. Also update submodule gloo to reflect the change on the gloo side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153862
Approved by: https://github.com/d4l3k, https://github.com/clee2000, https://github.com/malfet
This commit is contained in:
fduwjj 2025-06-03 20:54:37 -07:00 committed by PyTorch MergeBot
parent 9eb7e67727
commit 956716880f
3 changed files with 27 additions and 3 deletions

View File

@ -480,6 +480,30 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
result[0],
)
# Test fp16 numerical correctness for all-reduce SUM.
torch.manual_seed(self.rank)
# TODO: when create larger sizes of tensors, numerical instability will be observed.
# We need to investigate the root cause and ensure it is fixed.
tensor = (
(torch.rand(200, 1, dtype=torch.float32) * 2 - 1) * 65504 / self.world_size
)
opts = c10d.AllreduceOptions()
tensor = tensor.to(torch.float16)
output = [[torch.zeros_like(tensor) for _ in range(self.world_size)]]
# allgather all local tensors first and then sum up.
fut = pg.allgather(output, [tensor]).get_future()
fut.wait()
ag_result = fut.value()
total = torch.stack(ag_result, dim=0).sum(dim=0)
# result from fp16 all-reduce.
fut = pg.allreduce([tensor], opts).get_future()
fut.wait()
result_fp16 = fut.value()
# float16 has only ~11 bits of mantissa, and is sensitive to accumulation
# order and rounding errors so we use a larger tolerance.
self.assertEqual(total, result_fp16[0], rtol=1e-2, atol=1e-3)
@requires_gloo()
def test_allreduce_basics(self):
self._test_allreduce_basics(lambda t: t.clone())

2
third_party/gloo vendored

@ -1 +1 @@
Subproject commit fe67c4bea940a117ff539d23f4110efc19404edb
Subproject commit c7b7b022c124d9643957d9bd55f57ac59fce8fa2

View File

@ -26,7 +26,7 @@
func<double>(__VA_ARGS__); \
break; \
case ::at::ScalarType::Half: \
func<gloo::float16>(__VA_ARGS__); \
func<c10::Half>(__VA_ARGS__); \
break; \
case ::at::ScalarType::BFloat16: \
func<c10::BFloat16>(__VA_ARGS__); \
@ -59,7 +59,7 @@
func<double>(args); \
break; \
case ::at::ScalarType::Half: \
func<gloo::float16>(args); \
func<c10::Half>(args); \
break; \
case ::at::ScalarType::BFloat16: \
func<c10::BFloat16>(args); \