mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[c10d][gloo] Enable using c10::Half for gloo (#153862)
Testing with https://github.com/pytorch/gloo/pull/446 and we see that the numerical issues reported in https://github.com/pytorch/pytorch/issues/152300 is indeed resolved and we added a unit test for it. Also update submodule gloo to reflect the change on the gloo side. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153862 Approved by: https://github.com/d4l3k, https://github.com/clee2000, https://github.com/malfet
This commit is contained in:
parent
9eb7e67727
commit
956716880f
|
|
@ -480,6 +480,30 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
result[0],
|
||||
)
|
||||
|
||||
# Test fp16 numerical correctness for all-reduce SUM.
|
||||
torch.manual_seed(self.rank)
|
||||
# TODO: when create larger sizes of tensors, numerical instability will be observed.
|
||||
# We need to investigate the root cause and ensure it is fixed.
|
||||
tensor = (
|
||||
(torch.rand(200, 1, dtype=torch.float32) * 2 - 1) * 65504 / self.world_size
|
||||
)
|
||||
opts = c10d.AllreduceOptions()
|
||||
tensor = tensor.to(torch.float16)
|
||||
output = [[torch.zeros_like(tensor) for _ in range(self.world_size)]]
|
||||
# allgather all local tensors first and then sum up.
|
||||
fut = pg.allgather(output, [tensor]).get_future()
|
||||
fut.wait()
|
||||
ag_result = fut.value()
|
||||
total = torch.stack(ag_result, dim=0).sum(dim=0)
|
||||
|
||||
# result from fp16 all-reduce.
|
||||
fut = pg.allreduce([tensor], opts).get_future()
|
||||
fut.wait()
|
||||
result_fp16 = fut.value()
|
||||
# float16 has only ~11 bits of mantissa, and is sensitive to accumulation
|
||||
# order and rounding errors so we use a larger tolerance.
|
||||
self.assertEqual(total, result_fp16[0], rtol=1e-2, atol=1e-3)
|
||||
|
||||
@requires_gloo()
|
||||
def test_allreduce_basics(self):
|
||||
self._test_allreduce_basics(lambda t: t.clone())
|
||||
|
|
|
|||
2
third_party/gloo
vendored
2
third_party/gloo
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit fe67c4bea940a117ff539d23f4110efc19404edb
|
||||
Subproject commit c7b7b022c124d9643957d9bd55f57ac59fce8fa2
|
||||
|
|
@ -26,7 +26,7 @@
|
|||
func<double>(__VA_ARGS__); \
|
||||
break; \
|
||||
case ::at::ScalarType::Half: \
|
||||
func<gloo::float16>(__VA_ARGS__); \
|
||||
func<c10::Half>(__VA_ARGS__); \
|
||||
break; \
|
||||
case ::at::ScalarType::BFloat16: \
|
||||
func<c10::BFloat16>(__VA_ARGS__); \
|
||||
|
|
@ -59,7 +59,7 @@
|
|||
func<double>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::Half: \
|
||||
func<gloo::float16>(args); \
|
||||
func<c10::Half>(args); \
|
||||
break; \
|
||||
case ::at::ScalarType::BFloat16: \
|
||||
func<c10::BFloat16>(args); \
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user