Avoid prematurely casting GEMM parameters alpha, beta to scalar_t (#67633)

Summary:
stas00 uncovered an issue where certain half-precision GEMMs would produce outputs that looked like the result of strange rounding behavior (e.g., `10008.` in place of `10000.`). ptrblck suspected that this was due to the parameters being downcasted to the input types (which would reproduce the problematic output). Indeed, the GEMM and BGEMM cublas wrappers are currently converting the `alpha` and `beta` parameters to `scalar_t` (which potentially is reduced precision) before converting them back to `float`. This PR changes the "ARGTYPE" wrappers to use `acc_t` instead and adds a corresponding test.

CC ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67633

Reviewed By: mruberry

Differential Revision: D32076474

Pulled By: ngimel

fbshipit-source-id: 2540d9b9d0195c17d07d1161374fb6a5850779d5
This commit is contained in:
Eddie Yan 2021-11-03 11:55:10 -07:00 committed by Facebook GitHub Bot
parent 3f33ada8d5
commit a5b57c9433
3 changed files with 27 additions and 16 deletions

View File

@ -14,6 +14,7 @@
*/
#include <ATen/cuda/CUDAContext.h>
#include <ATen/OpMathType.h>
namespace at {
namespace cuda {
@ -40,9 +41,9 @@ private:
/* LEVEL 3 BLAS FUNCTIONS */
#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, Dtype beta, \
#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
Dtype *c, int64_t ldc
template <typename Dtype>
@ -69,11 +70,11 @@ template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
const Dtype *a, int64_t lda, int64_t stridea, \
const Dtype *b, int64_t ldb, int64_t strideb, \
Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
const Dtype *a, int64_t lda, int64_t stridea, \
const Dtype *b, int64_t ldb, int64_t strideb, \
at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
template <typename Dtype>
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {

View File

@ -171,8 +171,9 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
scalar_t alpha_val = alpha.to<scalar_t>();
scalar_t beta_val = beta.to<scalar_t>();
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
scalar_t* mat1_ptr = mat1_->data_ptr<scalar_t>();
scalar_t* mat2_ptr = mat2_->data_ptr<scalar_t>();
scalar_t* result_ptr = result_->data_ptr<scalar_t>();
@ -240,8 +241,9 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] {
scalar_t alpha_val = alpha.to<scalar_t>();
scalar_t beta_val = beta.to<scalar_t>();
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
scalar_t* batch1_ptr = batch1_->data_ptr<scalar_t>();
scalar_t* batch2_ptr = batch2_->data_ptr<scalar_t>();
scalar_t* result_ptr = result_->data_ptr<scalar_t>();

View File

@ -6127,17 +6127,25 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
@dtypes(torch.half)
@onlyCUDA
def test_addmm_addbmm_overflow(self, device, dtype):
def test_addmm_baddbmm_overflow(self, device, dtype):
inp = torch.zeros(128, 128, dtype=torch.half, device=device)
mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100
mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100
out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.)
self.assertFalse(out.isinf().any())
# just check for no overflow on ROCM
if TEST_WITH_ROCM:
self.assertFalse(out.isinf().any())
else:
self.assertTrue((out == 10000.).all())
inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device)
mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100
mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100
out = torch.addbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
self.assertFalse(out.isinf().any())
out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
if TEST_WITH_ROCM:
self.assertFalse(out.isinf().any())
else:
self.assertTrue((out == 10000.).all())
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA