mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Avoid prematurely casting GEMM parameters alpha, beta to scalar_t (#67633)
Summary: stas00 uncovered an issue where certain half-precision GEMMs would produce outputs that looked like the result of strange rounding behavior (e.g., `10008.` in place of `10000.`). ptrblck suspected that this was due to the parameters being downcasted to the input types (which would reproduce the problematic output). Indeed, the GEMM and BGEMM cublas wrappers are currently converting the `alpha` and `beta` parameters to `scalar_t` (which potentially is reduced precision) before converting them back to `float`. This PR changes the "ARGTYPE" wrappers to use `acc_t` instead and adds a corresponding test. CC ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/67633 Reviewed By: mruberry Differential Revision: D32076474 Pulled By: ngimel fbshipit-source-id: 2540d9b9d0195c17d07d1161374fb6a5850779d5
This commit is contained in:
parent
3f33ada8d5
commit
a5b57c9433
|
|
@ -14,6 +14,7 @@
|
|||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
|
|
@ -40,9 +41,9 @@ private:
|
|||
|
||||
/* LEVEL 3 BLAS FUNCTIONS */
|
||||
|
||||
#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
|
||||
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
|
||||
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, Dtype beta, \
|
||||
#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
|
||||
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
|
||||
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
|
||||
Dtype *c, int64_t ldc
|
||||
|
||||
template <typename Dtype>
|
||||
|
|
@ -69,11 +70,11 @@ template <>
|
|||
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
#endif
|
||||
|
||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
|
||||
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
|
||||
const Dtype *a, int64_t lda, int64_t stridea, \
|
||||
const Dtype *b, int64_t ldb, int64_t strideb, \
|
||||
Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
|
||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
|
||||
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
|
||||
const Dtype *a, int64_t lda, int64_t stridea, \
|
||||
const Dtype *b, int64_t ldb, int64_t strideb, \
|
||||
at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
|
||||
|
||||
template <typename Dtype>
|
||||
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
|
|
|
|||
|
|
@ -171,8 +171,9 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
|
||||
scalar_t alpha_val = alpha.to<scalar_t>();
|
||||
scalar_t beta_val = beta.to<scalar_t>();
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
scalar_t* mat1_ptr = mat1_->data_ptr<scalar_t>();
|
||||
scalar_t* mat2_ptr = mat2_->data_ptr<scalar_t>();
|
||||
scalar_t* result_ptr = result_->data_ptr<scalar_t>();
|
||||
|
|
@ -240,8 +241,9 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co
|
|||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] {
|
||||
scalar_t alpha_val = alpha.to<scalar_t>();
|
||||
scalar_t beta_val = beta.to<scalar_t>();
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t alpha_val = alpha.to<opmath_t>();
|
||||
opmath_t beta_val = beta.to<opmath_t>();
|
||||
scalar_t* batch1_ptr = batch1_->data_ptr<scalar_t>();
|
||||
scalar_t* batch2_ptr = batch2_->data_ptr<scalar_t>();
|
||||
scalar_t* result_ptr = result_->data_ptr<scalar_t>();
|
||||
|
|
|
|||
|
|
@ -6127,17 +6127,25 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
|||
|
||||
@dtypes(torch.half)
|
||||
@onlyCUDA
|
||||
def test_addmm_addbmm_overflow(self, device, dtype):
|
||||
def test_addmm_baddbmm_overflow(self, device, dtype):
|
||||
inp = torch.zeros(128, 128, dtype=torch.half, device=device)
|
||||
mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100
|
||||
mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100
|
||||
out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.)
|
||||
self.assertFalse(out.isinf().any())
|
||||
# just check for no overflow on ROCM
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertFalse(out.isinf().any())
|
||||
else:
|
||||
self.assertTrue((out == 10000.).all())
|
||||
|
||||
inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device)
|
||||
mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100
|
||||
mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100
|
||||
out = torch.addbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
|
||||
self.assertFalse(out.isinf().any())
|
||||
out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertFalse(out.isinf().any())
|
||||
else:
|
||||
self.assertTrue((out == 10000.).all())
|
||||
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@onlyCUDA
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user