CUDA BFloat div, addcdiv, addcmul, mean, var (#44758)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44758

Reviewed By: mruberry

Differential Revision: D23752317

Pulled By: ngimel

fbshipit-source-id: 77992cf991f4e2b4b6839de73ea7e6ce2e1061c6
This commit is contained in:
Xiang Gao 2020-09-18 11:47:31 -07:00 committed by Facebook GitHub Bot
parent f175830558
commit 7bd8a6913d
3 changed files with 19 additions and 26 deletions

View File

@ -10,24 +10,20 @@ namespace at { namespace native {
void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) { void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() { AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "addcmul_cuda", [&] {
auto alpha = value.to<scalar_t>(); auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t { gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * b * c; return a + alpha * b * c;
}); });
}); });
});
} }
void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) { void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() { AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "addcdiv_cuda", [&] {
auto alpha = value.to<scalar_t>(); auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t { gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * (b / c); return a + alpha * (b / c);
}); });
}); });
});
} }
void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm) { void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm) {

View File

@ -30,10 +30,8 @@ void std_var_kernel_impl<at::BFloat16>(TensorIterator& iter, bool unbiased, bool
static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) { static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "std_cuda", [&]() { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "std_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "std_cuda", [&] {
std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt); std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
}); });
});
} }
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
@ -49,14 +47,12 @@ static void mean_kernel_cuda(TensorIterator& iter) {
// type promotion that does cast and reduction in a single kernel // type promotion that does cast and reduction in a single kernel
return mean_kernel_impl<at::Half, float, float>(iter); return mean_kernel_impl<at::Half, float, float>(iter);
} }
#ifdef __HIP_PLATFORM_HCC__
else if(iter.dtype() == kBFloat16) { else if(iter.dtype() == kBFloat16) {
return mean_kernel_impl<at::BFloat16, float>(iter); return mean_kernel_impl<at::BFloat16, float>(iter);
} else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) { } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel // type promotion that does cast and reduction in a single kernel
return mean_kernel_impl<at::BFloat16, float, float>(iter); return mean_kernel_impl<at::BFloat16, float, float>(iter);
} }
#endif
AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cuda", [&]() { AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cuda", [&]() {
mean_kernel_impl<scalar_t>(iter); mean_kernel_impl<scalar_t>(iter);
}); });

View File

@ -19787,20 +19787,20 @@ tensor_op_tests = [
('mul', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2), ('mul', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2),
('mul', 'scalar', _small_0d, lambda t, d: [_small_0d(torch.int32, d)], 1e-2), ('mul', 'scalar', _small_0d, lambda t, d: [_small_0d(torch.int32, d)], 1e-2),
('div', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, ('div', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
1e-1, 1e-5, _float_types2), 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('div', 'tensor', _small_3d, ('div', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
1e-1, 1e-5, _float_types2), 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('true_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, ('true_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
1e-5, 1e-5, _types, _cpu_types, False), 1e-5, 1e-5, _types, _cpu_types, False),
('true_divide', 'with_inplace', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, ('true_divide', 'with_inplace', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1,
1e-1, 1e-5, _float_types2), 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('true_divide', 'tensor', _small_3d, ('true_divide', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
1e-5, 1e-5, _types, _cpu_types, False), 1e-5, 1e-5, _types, _cpu_types, False),
('true_divide', 'tensor_with_inplace', _small_3d, ('true_divide', 'tensor_with_inplace', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1,
1e-1, 1e-5, _float_types2), 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
('floor_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1, 1e-5, 1e-5, _types), ('floor_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1, 1e-5, 1e-5, _types),
('floor_divide', 'tensor', _small_3d, ('floor_divide', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1, 1e-5, 1e-5, _types), lambda t, d: [_small_3d(t, d, has_zeros=False)], 1, 1e-5, 1e-5, _types),
@ -19834,15 +19834,16 @@ tensor_op_tests = [
('addcdiv', '', _small_2d, ('addcdiv', '', _small_2d,
lambda t, d: [_small_2d(t, d), lambda t, d: [_small_2d(t, d),
_small_2d(t, d, has_zeros=False)], 1, 1, 1e-3, _small_2d(t, d, has_zeros=False)], 1, 1, 1e-3,
_float_types2, _cpu_types, True), torch.testing.get_all_fp_dtypes(), _cpu_types, True),
('addcdiv', 'scalar', _small_2d, ('addcdiv', 'scalar', _small_2d,
lambda t, d: [_number(2.8, 1, t), _small_2d(t, d), lambda t, d: [_number(2.8, 1, t), _small_2d(t, d),
_small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3, _small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3,
_float_types, _cpu_types, True), _float_types, _cpu_types, True),
('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-3, _types2), ('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-3,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False)),
('addcmul', 'scalar', _small_3d, ('addcmul', 'scalar', _small_3d,
lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2,
1e-1, 1e-5, _types2, _cpu_types, True, 1e-1, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, True,
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]), [_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)], ('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)], 0, True), 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)], 0, True),
@ -19957,9 +19958,9 @@ tensor_op_tests = [
1e-5, 1e-5, 1e-5, _types, _cpu_types, False), 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('minimum', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], ('minimum', '', _medium_2d, lambda t, d: [_medium_2d(t, d)],
1e-5, 1e-5, 1e-5, _types, _cpu_types, False), 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False), ('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False), ('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, _float_types2, _cpu_types, False), ('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-2, 1e-2, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
# Double here because the CPU result will be wrong otherwise # Double here because the CPU result will be wrong otherwise
('mean', '64bit_indexing', _giant_1d, lambda t, d: [], ('mean', '64bit_indexing', _giant_1d, lambda t, d: [],
1e-3, 1e-5, 1e-5, [torch.double], _cpu_types, False, [slowTest]), 1e-3, 1e-5, 1e-5, [torch.double], _cpu_types, False, [slowTest]),
@ -19983,7 +19984,7 @@ tensor_op_tests = [
('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), ('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
('var', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), ('var', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
('var', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), ('var', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False),
('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, _float_types2, _cpu_types, False), ('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
('ndimension', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('ndimension', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('nelement', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('nelement', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
('numel', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('numel', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),