mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NCCL] Patch bfloat16 support (#67843)
Summary: Patch bfloat16 support in NCCL, PR https://github.com/pytorch/pytorch/issues/63260 adds bfloat16 support but is still not complete to enable bfloat16 for allreduce in end-to-end training. This patch does the followings: * fix minimum NCCL version from 2.9.7 to 2.10, NCCL adds bf16 support in v2.10.3-1 (commit 7e51592) * update bfloat16 datatype flag in `csrc/cuda/nccl.cpp` so that NCCL operations like all reduce can use it * enable unit tests for bfloat16 datatype if possible cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/67843 Reviewed By: H-Huang Differential Revision: D32248132 Pulled By: mrshenli fbshipit-source-id: 081e96e725af3b933dd65ec157c5ad11c6873525
This commit is contained in:
parent
45ac6f2b65
commit
c7eaec86f0
|
|
@ -1626,7 +1626,7 @@ class DistributedDataParallelTest(
|
|||
not TEST_WITH_ROCM
|
||||
and BFLOAT16_AVAILABLE
|
||||
and c10d.is_nccl_available()
|
||||
and torch.cuda.nccl.version() >= (2, 9, 7)
|
||||
and torch.cuda.nccl.version() >= (2, 10)
|
||||
):
|
||||
hook_options.append(default.bf16_compress_hook)
|
||||
for hook in hook_options:
|
||||
|
|
@ -1797,7 +1797,7 @@ class DistributedDataParallelTest(
|
|||
self._test_fp16_compress_wrapper()
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 9, 7), "Need NCCL 2.9.7+ for BF16_COMPRESS")
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
|
||||
@sandcastle_skip_if(
|
||||
not BFLOAT16_AVAILABLE,
|
||||
"BFloat16 is only supported by CUDA 11+",
|
||||
|
|
@ -1907,7 +1907,7 @@ class DistributedDataParallelTest(
|
|||
self._test_fp16_compress_wrapper(gradient_as_bucket_view=True)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 9, 7), "Need NCCL 2.9.7+ for BF16_COMPRESS")
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
|
||||
@sandcastle_skip_if(
|
||||
not BFLOAT16_AVAILABLE,
|
||||
"BFloat16 is only supported by CUDA 11+",
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ import sys
|
|||
import torch
|
||||
import torch.cuda.nccl as nccl
|
||||
import torch.cuda
|
||||
import torch.distributed as c10d
|
||||
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests,
|
||||
IS_WINDOWS, load_tests,
|
||||
TEST_WITH_ROCM,
|
||||
sandcastle_skip_if)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_cuda import CUDA11OrLater, TEST_CUDA, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
|
||||
import re
|
||||
HIP_VERSION = 0.0 if torch.version.hip is None else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
|
||||
|
|
@ -24,7 +25,9 @@ if not TEST_CUDA:
|
|||
TestCase = object # noqa: F811
|
||||
|
||||
|
||||
datatypes = [torch.float, torch.bfloat16] if TEST_WITH_ROCM else [torch.float]
|
||||
datatypes = [torch.float]
|
||||
if (TEST_CUDA and CUDA11OrLater and c10d.is_nccl_available() and nccl.version() >= (2, 10)) or TEST_WITH_ROCM:
|
||||
datatypes.append(torch.bfloat16)
|
||||
|
||||
class TestNCCL(TestCase):
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
|
|||
return ncclDataType_t::ncclUint8;
|
||||
case at::kBool:
|
||||
return ncclDataType_t::ncclUint8;
|
||||
#if defined(USE_ROCM) && TORCH_HIP_VERSION >= 301
|
||||
#if HAS_NCCL_BF16_DATATYPE
|
||||
case at::kBFloat16:
|
||||
return ncclDataType_t::ncclBfloat16;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -8,6 +8,16 @@
|
|||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for HIP 3.1+
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#define HAS_NCCL_BF16_DATATYPE \
|
||||
((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
|
||||
#elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301)
|
||||
#define HAS_NCCL_BF16_DATATYPE 1
|
||||
#else
|
||||
#define HAS_NCCL_BF16_DATATYPE 0
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace cuda {
|
||||
namespace nccl {
|
||||
|
|
@ -52,7 +62,9 @@ enum class ncclDataType {
|
|||
Float16 = 6, Half = 6,
|
||||
Float32 = 7, Float = 7,
|
||||
Float64 = 8, Double = 8,
|
||||
numTypes = 9 };
|
||||
Bfloat16 = 9,
|
||||
NumTypes = 10
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -56,15 +56,6 @@ const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::
|
|||
#define ENABLE_NCCL_P2P_SUPPORT
|
||||
#endif
|
||||
|
||||
// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.9.7+
|
||||
#if (defined(__CUDA_BF16_TYPES_EXIST__) && \
|
||||
defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
|
||||
(defined(NCCL_MINOR) && ((NCCL_MINOR > 9) || \
|
||||
((NCCL_MINOR == 9) && defined(NCCL_PATCH) && (NCCL_PATCH >= 7))))) || \
|
||||
(defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301))
|
||||
#define ENABLE_NCCL_BF16_DATATYPE
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
|
|||
{at::kLong, ncclInt64},
|
||||
{at::kHalf, ncclHalf},
|
||||
{at::kBool, ncclUint8},
|
||||
#if defined(ENABLE_NCCL_BF16_DATATYPE)
|
||||
#if HAS_NCCL_BF16_DATATYPE
|
||||
{at::kBFloat16, ncclBfloat16},
|
||||
#endif
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1545,10 +1545,10 @@ class DistributedDataParallel(Module, Joinable):
|
|||
or int(torch.version.cuda.split('.')[0]) < 11
|
||||
or not dist.is_available()
|
||||
or not dist.is_nccl_available()
|
||||
or torch.cuda.nccl.version() < (2, 9, 7)
|
||||
or torch.cuda.nccl.version() < (2, 10)
|
||||
)
|
||||
):
|
||||
self._log_and_throw(TypeError, "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.9.7+.")
|
||||
self._log_and_throw(TypeError, "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.")
|
||||
|
||||
@property
|
||||
def _distributed_rank(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user