Enable BFloat support for gemms on arch other than ampere (#50442)

Summary:
Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50442

Reviewed By: bdhirsh

Differential Revision: D26044981

Pulled By: mruberry

fbshipit-source-id: 65c42f2c1de8d24e4852a1b5bd8f4b1735b2230e
This commit is contained in:
Xiang Gao 2021-01-26 11:00:14 -08:00 committed by Facebook GitHub Bot
parent 3562ca2da2
commit b822aba8ec
4 changed files with 74 additions and 64 deletions

View File

@ -327,7 +327,6 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle, TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
@ -343,7 +342,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0, NULL, NULL)); 0, 0, NULL, NULL));
#else #else
TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); TORCH_CHECK(false, "CUDA BFloat16 bgemm requires CUDA 11 or later");
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
} }
#endif // __HIP_PLATFORM_HCC__ #endif // __HIP_PLATFORM_HCC__
@ -550,11 +549,6 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
float fbeta = beta; float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::BFloat16); GEMM_CHECK_ARGVALUES(at::BFloat16);
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 8) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasGemmEx( TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle, handle,
opa, opa,
@ -575,12 +569,6 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP)); CUBLAS_GEMM_DFALT_TENSOR_OP));
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} else {
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
}
} }
#endif #endif

View File

@ -24,7 +24,7 @@ from torch.testing._internal.common_device_type import \
skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA,
onlyCUDA) onlyCUDA)
from torch.testing import floating_and_complex_types, floating_types, all_types from torch.testing import floating_and_complex_types, floating_types, all_types
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32 from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9
from torch.autograd import gradcheck, gradgradcheck from torch.autograd import gradcheck, gradgradcheck
# Protects against includes accidentally setting the default dtype # Protects against includes accidentally setting the default dtype
@ -35,9 +35,6 @@ assert torch.get_default_dtype() is torch.float32
if TEST_SCIPY: if TEST_SCIPY:
import scipy import scipy
# TODO: make this common and import it
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
# See #49409, we should remove these if we end up with a global gradcheck setting # See #49409, we should remove these if we end up with a global gradcheck setting
gradcheck = partial(gradcheck, check_batched_grad=True) gradcheck = partial(gradcheck, check_batched_grad=True)
gradgradcheck = partial(gradgradcheck, check_batched_grad=True) gradgradcheck = partial(gradgradcheck, check_batched_grad=True)
@ -3842,12 +3839,12 @@ class TestLinalg(TestCase):
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
@skipCUDAIfRocm @skipCUDAIfRocm
@dtypesIfCUDA(*(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.cfloat, torch.cdouble,
# This test is disabled on CUDA 9, due to: *torch.testing.get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater)))
# See: https://github.com/pytorch/pytorch/issues/31006
((torch.half,) if torch.version.cuda and not torch.version.cuda.startswith('9.') else ()))
@dtypes(*(set(torch.testing.get_all_dtypes()) - {torch.half, torch.bool})) @dtypes(*(set(torch.testing.get_all_dtypes()) - {torch.half, torch.bool}))
def test_blas_alpha_beta_empty(self, device, dtype): def test_blas_alpha_beta_empty(self, device, dtype):
# This test is disabled on CUDA 9 due to:
# See: https://github.com/pytorch/pytorch/issues/31006
if dtype is torch.bfloat16 and self.device_type == 'xla': if dtype is torch.bfloat16 and self.device_type == 'xla':
# TODO (@zasdfgbnm): this causes the following error on test # TODO (@zasdfgbnm): this causes the following error on test
# TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16: # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16:
@ -4467,8 +4464,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8, @precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8,
torch.cfloat: 1e-4, torch.cdouble: 1e-8}) torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(),
*([torch.float32, torch.float64, torch.bfloat16] *torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)),
if TEST_WITH_ROCM else torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))) include_half=(not TEST_WITH_ROCM)))
@dtypes(torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble) @dtypes(torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_addmv(self, device, dtype): def test_addmv(self, device, dtype):
# have to use torch.randn(...).to(bfloat16) instead of # have to use torch.randn(...).to(bfloat16) instead of
@ -4502,8 +4499,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
for m, v in itertools.product(ms, vs): for m, v in itertools.product(ms, vs):
self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
@dtypesIfCUDA(*([torch.half, torch.float, torch.double] @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
+ ([torch.bfloat16] if TEST_WITH_ROCM else [])))
@dtypes(torch.float, torch.double) @dtypes(torch.float, torch.double)
def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
# tests (o, s)*(s). o is output size, s is summed size. # tests (o, s)*(s). o is output size, s is summed size.
@ -4534,7 +4530,8 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(),
*torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
@dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes()) @dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
@tf32_on_and_off(0.05) @tf32_on_and_off(0.05)
def test_addmm(self, device, dtype): def test_addmm(self, device, dtype):
@ -4709,19 +4706,25 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@tf32_on_and_off(0.05) @tf32_on_and_off(0.05)
def test_bmm(self, device, dtype): def test_bmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
# So on PyTorch, we consider BFloat16 support on SM < 53 as
# undefined bahavior
return
num_batches = 10 num_batches = 10
M, N, O = 23, 8, 12 M, N, O = 23, 8, 12
numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
if self.device_type == 'cpu':
is_supported = True is_supported = True
elif self.device_type == 'cuda': if dtype == torch.bfloat16 and self.device_type == 'cuda':
is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)
if not is_supported: if not is_supported:
b1 = torch.randn(num_batches, M, N, device=device).to(dtype) b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
b2 = torch.randn(num_batches, N, O, device=device).to(dtype) b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2)) self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
lambda: torch.bmm(b1, b2))
return return
def invert_perm(p): def invert_perm(p):
@ -4884,21 +4887,28 @@ else:
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@tf32_on_and_off(0.05) @tf32_on_and_off(0.05)
def test_addbmm(self, device, dtype): def test_addbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
# So on PyTorch, we consider BFloat16 support on SM < 53 as
# undefined bahavior
return
num_batches = 2 num_batches = 2
M, N, O = 2, 3, 4 M, N, O = 2, 3, 4
if self.device_type == 'cpu':
is_supported = True is_supported = True
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
if self.device_type == 'cpu':
self.precision = 1 # 43 vs 43.75 self.precision = 1 # 43 vs 43.75
else: else:
is_supported = (dtype != torch.bfloat16 or AMPERE_OR_ROCM) is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)
if not is_supported: if not is_supported:
b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1)
b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1)
t = make_tensor((M, O), device, dtype, low=-1, high=1) t = make_tensor((M, O), device, dtype, low=-1, high=1)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.addbmm(t, b1, b2)) self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
lambda: torch.addbmm(t, b1, b2))
return return
def invert_perm(p): def invert_perm(p):
@ -4950,19 +4960,25 @@ else:
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@tf32_on_and_off(0.05) @tf32_on_and_off(0.05)
def test_baddbmm(self, device, dtype): def test_baddbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
# So on PyTorch, we consider BFloat16 support on SM < 53 as
# undefined bahavior
return
num_batches = 10 num_batches = 10
M, N, O = 12, 8, 5 M, N, O = 12, 8, 5
if self.device_type == 'cpu':
is_supported = True is_supported = True
elif self.device_type == 'cuda': if dtype == torch.bfloat16 and self.device_type == 'cuda':
is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)
if not is_supported: if not is_supported:
b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1)
b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1)
t = make_tensor((num_batches, M, O), device, dtype, low=-1, high=1) t = make_tensor((num_batches, M, O), device, dtype, low=-1, high=1)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.baddbmm(t, b1, b2)) self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
lambda: torch.baddbmm(t, b1, b2))
return return
def invert_perm(p): def invert_perm(p):

View File

@ -6,6 +6,7 @@ import torch.cuda
from torch.testing._internal.common_utils import TEST_NUMBA from torch.testing._internal.common_utils import TEST_NUMBA
import inspect import inspect
import contextlib import contextlib
from distutils.version import LooseVersion
TEST_CUDA = torch.cuda.is_available() TEST_CUDA = torch.cuda.is_available()
@ -15,6 +16,10 @@ CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)) TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))
TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0 TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0
CUDA11OrLater = torch.version.cuda and LooseVersion(torch.version.cuda) >= "11.0.0"
CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.')
SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)
TEST_MAGMA = TEST_CUDA TEST_MAGMA = TEST_CUDA
if TEST_CUDA: if TEST_CUDA:
torch.ones(1).cuda() # has_magma shows up after cuda is initialized torch.ones(1).cuda() # has_magma shows up after cuda is initialized

View File

@ -19,7 +19,7 @@ from torch.testing import \
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, (skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl,
skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride) skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride)
from torch.testing._internal.common_cuda import tf32_is_not_fp32 from torch.testing._internal.common_cuda import CUDA11OrLater
from torch.testing._internal.common_utils import \ from torch.testing._internal.common_utils import \
(prod_single_zero, random_square_matrix_of_rank, (prod_single_zero, random_square_matrix_of_rank,
random_symmetric_matrix, random_symmetric_psd_matrix, random_symmetric_matrix, random_symmetric_psd_matrix,
@ -1035,8 +1035,9 @@ op_db: List[OpInfo] = [
OpInfo('addmm', OpInfo('addmm',
dtypes=floating_types(), dtypes=floating_types(),
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16), dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
# BFloat16 support on CUDA requires CUDA 11 and SM53
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
*[torch.bfloat16] if tf32_is_not_fp32() else []), *[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_types_and(torch.half), dtypesIfROCM=floating_types_and(torch.half),
assert_autodiffed=True, assert_autodiffed=True,
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],