mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDA][cuBLASLt] Respect allow[FP16/BF16]ReductionCuBLAS in cuBLASLt (#153095)
cuBLASLt matmuls have been silently allowing all reduction types, which meant that e.g., `allow_fp16_reduced_precision_reduction = False` had no effect. In practice split-K with reduced precision reductions were unlikely to happen as the default `CUBLASLT_WORKSPACE_SIZE` of 1MiB tends to prevent this. However this isn't guaranteed and we are on the path to increasing the default workspace size following #151163 This setting is effectively already tested in e.g., `test_cublas_addmm_size_100_cuda_float16` and `test_cublas_addmm_size_100_cuda_bfloat16` but the backend selection is not deterministic. Running the full `test_matmul_cuda.py` seems to exercise the Lt interface, but running a standalone test does not (apparently due to spurious alignment differences). Pull Request resolved: https://github.com/pytorch/pytorch/pull/153095 Approved by: https://github.com/cyyever, https://github.com/Skylion007
This commit is contained in:
parent
e581e1c0f4
commit
6ae0c42278
|
|
@ -391,6 +391,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||
cudaDataType_t cType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
CuBlasLtMatmulPreference preference;
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
|
|
@ -429,9 +430,21 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||
#endif
|
||||
abType = CUDA_R_16F;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16F;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
}
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abType = CUDA_R_16BF;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16BF;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented");
|
||||
}
|
||||
|
|
@ -467,8 +480,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||
Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec);
|
||||
}
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(a));
|
||||
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(b));
|
||||
|
|
@ -1562,6 +1573,7 @@ bool gemm_and_bias(
|
|||
cudaDataType_t cType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
CuBlasLtMatmulPreference preference;
|
||||
void * alpha_ptr = &alpha_val;
|
||||
void * beta_ptr = &beta_val;
|
||||
#ifndef USE_ROCM
|
||||
|
|
@ -1591,9 +1603,21 @@ bool gemm_and_bias(
|
|||
#endif
|
||||
abType = CUDA_R_16F;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16F;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
}
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abType = CUDA_R_16BF;
|
||||
cType = (std::is_same_v<C_Dtype, float>) ? CUDA_R_32F : CUDA_R_16BF;
|
||||
#ifndef USE_ROCM
|
||||
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | CUBLASLT_REDUCTION_SCHEME_NONE);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||
|
|
@ -1627,7 +1651,6 @@ bool gemm_and_bias(
|
|||
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
|
||||
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
auto ltworkspace = CublasLtWorkspace();
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
|
||||
|
||||
|
|
|
|||
|
|
@ -154,6 +154,29 @@ class TestMatmulCuda(TestCase):
|
|||
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
|
||||
self.cublas_addmm(size, dtype, True)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
@dtypes(torch.float16)
|
||||
# m == 4 chooses OUTPUT_TYPE reduction on H200
|
||||
# m == 8 chooses OUTOUT_TYPE reduction on A100
|
||||
@parametrize("small_size", [4, 8])
|
||||
@parametrize("size", [32768])
|
||||
@parametrize("backend", ["cublaslt", "cublas"])
|
||||
def test_cublas_addmm_no_reduced_precision(self, small_size: int, size: int, dtype: torch.dtype, backend):
|
||||
# TODO(eqy): replace with contextlib once that is merged
|
||||
orig = torch.backends.cuda.preferred_blas_library()
|
||||
torch.backends.cuda.preferred_blas_library(backend)
|
||||
orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda')
|
||||
m2 = torch.ones((size, small_size), dtype=dtype, device='cuda')
|
||||
m2[size // 2:, :] = -1.0
|
||||
b = torch.zeros((small_size,), dtype=dtype, device='cuda')
|
||||
out = torch.addmm(b, m1, m2, beta=1.0)
|
||||
self.assertEqual(out.sum().item(), 0.0)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision
|
||||
torch.backends.cuda.preferred_blas_library(orig)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user