[NCCL] Patch bfloat16 support (#67843)

Summary:
Patch bfloat16 support in NCCL, PR https://github.com/pytorch/pytorch/issues/63260 adds bfloat16 support but is
still not complete to enable bfloat16 for allreduce in end-to-end training.

This patch does the followings:
* fix minimum NCCL version from 2.9.7 to 2.10, NCCL adds bf16 support in
  v2.10.3-1 (commit 7e51592)
* update bfloat16 datatype flag in `csrc/cuda/nccl.cpp` so that NCCL
  operations like all reduce can use it
* enable unit tests for bfloat16 datatype if possible

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67843

Reviewed By: H-Huang

Differential Revision: D32248132

Pulled By: mrshenli

fbshipit-source-id: 081e96e725af3b933dd65ec157c5ad11c6873525
This commit is contained in:
Yifan Xiong 2021-11-09 13:44:11 -08:00 committed by Facebook GitHub Bot
parent 45ac6f2b65
commit c7eaec86f0
7 changed files with 25 additions and 19 deletions

View File

@ -1626,7 +1626,7 @@ class DistributedDataParallelTest(
not TEST_WITH_ROCM
and BFLOAT16_AVAILABLE
and c10d.is_nccl_available()
and torch.cuda.nccl.version() >= (2, 9, 7)
and torch.cuda.nccl.version() >= (2, 10)
):
hook_options.append(default.bf16_compress_hook)
for hook in hook_options:
@ -1797,7 +1797,7 @@ class DistributedDataParallelTest(
self._test_fp16_compress_wrapper()
@requires_nccl()
@requires_nccl_version((2, 9, 7), "Need NCCL 2.9.7+ for BF16_COMPRESS")
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
@sandcastle_skip_if(
not BFLOAT16_AVAILABLE,
"BFloat16 is only supported by CUDA 11+",
@ -1907,7 +1907,7 @@ class DistributedDataParallelTest(
self._test_fp16_compress_wrapper(gradient_as_bucket_view=True)
@requires_nccl()
@requires_nccl_version((2, 9, 7), "Need NCCL 2.9.7+ for BF16_COMPRESS")
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
@sandcastle_skip_if(
not BFLOAT16_AVAILABLE,
"BFloat16 is only supported by CUDA 11+",

View File

@ -4,12 +4,13 @@ import sys
import torch
import torch.cuda.nccl as nccl
import torch.cuda
import torch.distributed as c10d
from torch.testing._internal.common_utils import (TestCase, run_tests,
IS_WINDOWS, load_tests,
TEST_WITH_ROCM,
sandcastle_skip_if)
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_cuda import CUDA11OrLater, TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
import re
HIP_VERSION = 0.0 if torch.version.hip is None else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
@ -24,7 +25,9 @@ if not TEST_CUDA:
TestCase = object # noqa: F811
datatypes = [torch.float, torch.bfloat16] if TEST_WITH_ROCM else [torch.float]
datatypes = [torch.float]
if (TEST_CUDA and CUDA11OrLater and c10d.is_nccl_available() and nccl.version() >= (2, 10)) or TEST_WITH_ROCM:
datatypes.append(torch.bfloat16)
class TestNCCL(TestCase):

View File

@ -90,7 +90,7 @@ ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
return ncclDataType_t::ncclUint8;
case at::kBool:
return ncclDataType_t::ncclUint8;
#if defined(USE_ROCM) && TORCH_HIP_VERSION >= 301
#if HAS_NCCL_BF16_DATATYPE
case at::kBFloat16:
return ncclDataType_t::ncclBfloat16;
#endif

View File

@ -8,6 +8,16 @@
#include <cstddef>
#include <vector>
// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for HIP 3.1+
#if defined(__CUDA_BF16_TYPES_EXIST__)
#define HAS_NCCL_BF16_DATATYPE \
((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
#elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301)
#define HAS_NCCL_BF16_DATATYPE 1
#else
#define HAS_NCCL_BF16_DATATYPE 0
#endif
namespace torch {
namespace cuda {
namespace nccl {
@ -52,7 +62,9 @@ enum class ncclDataType {
Float16 = 6, Half = 6,
Float32 = 7, Float = 7,
Float64 = 8, Double = 8,
numTypes = 9 };
Bfloat16 = 9,
NumTypes = 10
};

View File

@ -56,15 +56,6 @@ const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::
#define ENABLE_NCCL_P2P_SUPPORT
#endif
// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.9.7+
#if (defined(__CUDA_BF16_TYPES_EXIST__) && \
defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(defined(NCCL_MINOR) && ((NCCL_MINOR > 9) || \
((NCCL_MINOR == 9) && defined(NCCL_PATCH) && (NCCL_PATCH >= 7))))) || \
(defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301))
#define ENABLE_NCCL_BF16_DATATYPE
#endif
// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \

View File

@ -77,7 +77,7 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
{at::kBool, ncclUint8},
#if defined(ENABLE_NCCL_BF16_DATATYPE)
#if HAS_NCCL_BF16_DATATYPE
{at::kBFloat16, ncclBfloat16},
#endif
};

View File

@ -1545,10 +1545,10 @@ class DistributedDataParallel(Module, Joinable):
or int(torch.version.cuda.split('.')[0]) < 11
or not dist.is_available()
or not dist.is_nccl_available()
or torch.cuda.nccl.version() < (2, 9, 7)
or torch.cuda.nccl.version() < (2, 10)
)
):
self._log_and_throw(TypeError, "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.9.7+.")
self._log_and_throw(TypeError, "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.")
@property
def _distributed_rank(self):