add _amp_foreach_non_finite_check_and_unscale_cpu_ and _amp_update_scale_cpu_ kernels on CPU (#109281)

Step1 of https://github.com/pytorch/pytorch/issues/111559.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109281
Approved by: https://github.com/jgong5, https://github.com/ezyang
This commit is contained in:
CaoE 2024-01-08 23:38:03 -08:00 committed by PyTorch MergeBot
parent 0fa6ee44d9
commit 29516bd2a0
17 changed files with 457 additions and 22 deletions

View File

@ -100,6 +100,10 @@ public:
Vectorized<double> isnan() const {
return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
}
bool has_inf_nan() const {
__m256d self_sub = _mm256_sub_pd(values, values);
return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != 0;
}
Vectorized<double> map(double (*const f)(double)) const {
__at_align__ double tmp[size()];
store(tmp);

View File

@ -106,6 +106,12 @@ public:
Vectorized<float> isnan() const {
return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
}
bool has_inf_nan() const {
__m256 self_sub = _mm256_sub_ps(values, values);
return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != 0;
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);

View File

@ -307,6 +307,16 @@ public:
}
return loadu(res);
};
bool has_inf_nan() const {
__at_align__ float tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
if(_isnan(tmp[i]) || _isinf(tmp[i])) {
return true;
}
}
return false;
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);

View File

@ -383,6 +383,19 @@ class Vectorized<double> {
auto ret = (x == x);
return ret._nor();
}
bool has_inf_nan() const {
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
return true;
}
}
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
return true;
}
}
return false;
}
DEFINE_MEMBER_OP(operator==, double, vec_cmpeq)
DEFINE_MEMBER_OP(operator!=, double, vec_cmpne)

View File

@ -239,6 +239,20 @@ class Vectorized<float> {
return (x == v_inf) | (x == v_minus_inf);
}
bool has_inf_nan() const {
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
return true;
}
}
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
return true;
}
}
return false;
}
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit
// and others are translated to 0-bit

View File

@ -875,6 +875,20 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
return ret._not();
}
bool has_inf_nan() const {
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
return true;
}
}
for (const auto i : c10::irange(size()/2)) {
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
return true;
}
}
return false;
}
template <
typename U = T,
std::enable_if_t<std::is_floating_point<U>::value, int> = 0>

View File

@ -106,6 +106,10 @@ public:
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
0xFFFFFFFFFFFFFFFF));
}
bool has_inf_nan() const {
__m512d self_sub = _mm512_sub_pd(values, values);
return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & 0x7777777777777777) != 0;
}
Vectorized<double> map(double (*const f)(double)) const {
__at_align__ double tmp[size()];
store(tmp);

View File

@ -125,6 +125,10 @@ public:
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
0xFFFFFFFF));
}
bool has_inf_nan() const {
__m512 self_sub = _mm512_sub_ps(values, values);
return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & 0x7777777777777777) != 0;
}
Vectorized<float> map(float (*const f)(float)) const {
__at_align__ float tmp[size()];
store(tmp);

View File

@ -255,6 +255,14 @@ public:
}
return vector;
}
bool has_inf_nan() const {
for (int64_t i = 0; i != size(); i++) {
if(_isnan(values[i]) || _isinf(values[i])) {
return true;
}
}
return false;
}
Vectorized<T> map(T (*const f)(T)) const {
Vectorized<T> ret;
for (int64_t i = 0; i != size(); i++) {

View File

@ -0,0 +1,41 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/AmpKernels.h>
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h>
#include <ATen/ops/_amp_update_scale.h>
#include <ATen/ops/_amp_update_scale_native.h>
#endif
namespace at::native {
void _amp_foreach_non_finite_check_and_unscale_cpu_(
TensorList scaled_grads,
at::Tensor& found_inf,
const at::Tensor& inv_scale) {
_amp_foreach_non_finite_check_and_unscale_cpu_stub(
found_inf.device().type(), scaled_grads, found_inf, inv_scale);
}
at::Tensor& _amp_update_scale_cpu_ (
at::Tensor& current_scale,
at::Tensor& growth_tracker,
const at::Tensor& found_inf,
double growth_factor,
double backoff_factor,
int64_t growth_interval) {
return _amp_update_scale_cpu_stub(
growth_tracker.device().type(), current_scale, growth_tracker,
found_inf, growth_factor, backoff_factor, growth_interval);
}
DEFINE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub);
DEFINE_DISPATCH(_amp_update_scale_cpu_stub);
} // namespace at::native

View File

@ -0,0 +1,28 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <ATen/core/ATen_fwd.h>
namespace at {
class Tensor;
namespace native {
using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
TensorList,
Tensor&,
const Tensor&);
using _amp_update_scale_cpu__fn = Tensor& (*)(
Tensor&,
Tensor&,
const Tensor&,
double,
double,
int64_t);
DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,199 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/AmpKernels.h>
#include <math.h>
#include <ATen/DeviceGuard.h>
#include <ATen/Dispatch.h>
#include <ATen/OpMathType.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
namespace at::native {
namespace {
// Follow the implementations of CUDA.
// Multiplies each tensor in scaled_grads by inv_scale in-place.
// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf
// to 1.0.
//
// Args:
// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or
// NaNs. found_inf: A single-element float tensor to which 1.0 will be written
// if any gradient contain infs/nans.
// Pre-zeroing found_inf, if appropriate, is the responsibility of
// the caller.
// inv_scale: The inverse of the scale factor by which scaled_grads are
// currently multiplied.
void _amp_foreach_non_finite_check_and_unscale_cpu_kernel(
TensorList scaled_grads,
at::Tensor& found_inf,
const at::Tensor& inv_scale) {
if (scaled_grads.size() == 0) {
return;
}
TORCH_CHECK(inv_scale.is_cpu(), "inv_scale must be a CPU tensor.");
TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
TORCH_CHECK(
inv_scale.scalar_type() == at::ScalarType::Float,
"inv_scale must be a float tensor.");
TORCH_CHECK(
found_inf.scalar_type() == at::ScalarType::Float,
"found_inf must be a float tensor.");
// Ensures client code (GradScaler) filtered scaled_grads by dtype.
at::native::check_foreach_api_restrictions(scaled_grads);
for (const at::Tensor& t : scaled_grads) {
TORCH_CHECK(t.is_cpu(), "one of scaled_grads was not a CPU tensor.");
TORCH_CHECK(
t.layout() == at::kStrided,
"one of scaled_grads was not a strided tensor.");
auto iter = at::TensorIterator::unary_op(
const_cast<at::Tensor&>(t), const_cast<at::Tensor&>(t));
if (at::isReducedFloatingType(iter.dtype())) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
iter.dtype(),
"_amp_foreach_non_finite_check_and_unscale_cpu",
[&iter, &found_inf, &inv_scale] {
auto* found_inf_ptr = found_inf.data_ptr<float>();
auto* inv_scale_ptr = inv_scale.data_ptr<float>();
using opmath_t = at::opmath_type<scalar_t>;
at::native::cpu_kernel_vec(
iter,
[found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
auto val = static_cast<opmath_t>(val_in);
if (!std::isfinite(val)) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<scalar_t>(
inv_scale_val == 1.f ? val : val * inv_scale_val);
},
[found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
Vectorized<opmath_t> val_vec0, val_vec1;
std::tie(val_vec0, val_vec1) = convert_to_float<scalar_t>(val_vec);
if (val_vec0.has_inf_nan() || val_vec1.has_inf_nan()) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
val_vec0 = inv_scale_val == 1.f ? val_vec0 : val_vec0 * Vectorized<opmath_t>(inv_scale_val);
val_vec1 = inv_scale_val == 1.f ? val_vec1 : val_vec1 * Vectorized<opmath_t>(inv_scale_val);
return convert_from_float<scalar_t>(val_vec0, val_vec1);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES(
iter.dtype(),
"_amp_foreach_non_finite_check_and_unscale_cpu",
[&iter, &found_inf, &inv_scale] {
auto* found_inf_ptr = found_inf.data_ptr<float>();
auto* inv_scale_ptr = inv_scale.data_ptr<float>();
at::native::cpu_kernel_vec(
iter,
[found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
if (!std::isfinite(val_in)) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<scalar_t>(
inv_scale_val == 1.f ? val_in : val_in * inv_scale_val);
},
[found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
if (val_vec.has_inf_nan()) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return inv_scale_val == 1.f ? val_vec : val_vec * Vectorized<scalar_t>(inv_scale_val);
});
});
}
}
}
// _amp_update_scale_cpu updates the scale tensor in place.
//
// Args:
// current_scale: A one-element float tensor containing the scale value.
// growth_tracker: A one-element IntTensor containing the number of recent
// consecutive unskipped steps. found_inf: A one-element float tensor. If > 0,
// indicates that infs/nans were found by the relevant
// prior _amp_non_finite_check_and_unscale_cpu call, and 0 if no
// infs/nans were found.
// growth_factor: Multiplier if no infs/NaNs were found (typically slightly >
// 1). backoff_factor: Multiplier if infs/NaNs were found (typically 0.5).
// growth_interval: Number of consecutive unskipped steps that must occur for
// current_scale to be multiplied by
// growth_factor.
//
// Returns:
// current_scale
at::Tensor& _amp_update_scale_cpu_kernel(
at::Tensor& current_scale,
at::Tensor& growth_tracker,
const at::Tensor& found_inf,
double growth_factor,
double backoff_factor,
int64_t growth_interval) {
TORCH_CHECK(growth_tracker.is_cpu(), "growth_tracker must be a CPU tensor.");
TORCH_CHECK(current_scale.is_cpu(), "current_scale must be a CPU tensor.");
TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
TORCH_CHECK(
growth_tracker.numel() == 1,
"growth_tracker must be a 1-element tensor.");
TORCH_CHECK(
current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
TORCH_CHECK(
growth_tracker.scalar_type() == at::ScalarType::Int,
"growth_tracker must be an int tensor.");
TORCH_CHECK(
current_scale.scalar_type() == at::ScalarType::Float,
"current_scale must be a float tensor.");
TORCH_CHECK(
found_inf.scalar_type() == at::ScalarType::Float,
"found_inf must be a float tensor.");
float* current_scale_ptr = current_scale.data_ptr<float>();
int* growth_tracker_ptr = growth_tracker.data_ptr<int>();
float* found_inf_ptr = found_inf.data_ptr<float>();
if (*found_inf_ptr) {
*current_scale_ptr = (*current_scale_ptr) * backoff_factor;
*growth_tracker_ptr = 0;
} else {
// Entering this branch means we just carried out a successful step,
// so growth_tracker is incremented before comparing to growth_interval.
auto successful = (*growth_tracker_ptr) + 1;
if (successful == growth_interval) {
auto new_scale = static_cast<float>((*current_scale_ptr) * growth_factor);
// Do not grow the scale past fp32 bounds to inf.
if (std::isfinite(new_scale)) {
*current_scale_ptr = new_scale;
}
*growth_tracker_ptr = 0;
} else {
*growth_tracker_ptr = successful;
}
}
return current_scale;
}
} // namepace
REGISTER_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub, &_amp_foreach_non_finite_check_and_unscale_cpu_kernel);
REGISTER_DISPATCH(_amp_update_scale_cpu_stub, &_amp_update_scale_cpu_kernel);
} // namespace at::native

View File

@ -10138,12 +10138,14 @@
variants: function
dispatch:
CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_
CPU: _amp_foreach_non_finite_check_and_unscale_cpu_
autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out
- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
variants: function
dispatch:
CUDA: _amp_update_scale_cuda_
CPU: _amp_update_scale_cpu_
autogen: _amp_update_scale, _amp_update_scale.out
#- func: _cat(Tensor[] tensors, int dim=0) -> Tensor

View File

@ -66,6 +66,8 @@ namespace {
class FunctionalTests : public ::testing::Test {};
template <typename T>
class FunctionalTestsReducedFloat : public ::testing::Test {};
template <typename T>
class InfiniteTests : public ::testing::Test {};
using RealFloatTestedTypes = ::testing::Types<vfloat, vdouble>;
using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
using ALLTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vlong, vint, vshort, vqint8, vquint8, vqint>;
@ -106,6 +108,7 @@ namespace {
TYPED_TEST_SUITE(BitwiseFloatsAdditional, RealFloatTestedTypes);
TYPED_TEST_SUITE(BitwiseFloatsAdditional2, FloatTestedTypes);
TYPED_TEST_SUITE(QuantizationTests, QuantTestedTypes);
TYPED_TEST_SUITE(InfiniteTests, RealFloatTestedTypes);
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
TYPED_TEST_SUITE(
Quantization8BitWithTailTests,
@ -1587,6 +1590,36 @@ namespace {
<< "Test failed for uint16 to float " << u16 << "\n";
}
}
TYPED_TEST(InfiniteTests, HasInfNan) {
using vec = TypeParam;
using VT = UholdType<TypeParam>;
auto vec_size = vec::size();
VT values[20];
for (const auto i : c10::irange(20)) {
values[i] = i + 0.3;
}
auto vec_val = vec::loadu(values);
auto seed = TestSeed();
ValueGen<int> generator(int(0), int(vec_size - 1), seed);
int index = generator.get();
int nanBits = 0x7FC00000;
VT v_nan = static_cast<VT>(*(float *)&nanBits);
values[index] = v_nan;
auto vec_nan = vec::loadu(values);
int infBits = 0x7F800000;
VT v_pinf = static_cast<VT>(*(float *)&infBits);
values[index] = v_pinf;
auto vec_pinf = vec::loadu(values);
int negInfBits = 0xFF800000;
VT v_ninf = static_cast<VT>(*(float *)&negInfBits);
values[index] = v_ninf;
auto vec_ninf = vec::loadu(values);
ASSERT_TRUE(!(vec_val.has_inf_nan())) << "Test failed for normal value\n";
ASSERT_TRUE(vec_nan.has_inf_nan()) << "Test failed for NAN\n";
ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n";
ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n";
}
#else
#error GTEST does not have TYPED_TEST

View File

@ -1142,6 +1142,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/batch_norm_kernel.cpp",
"aten/src/ATen/native/cpu/group_norm_kernel.cpp",
"aten/src/ATen/native/cpu/layer_norm_kernel.cpp",
"aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp",
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
@ -1342,6 +1343,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/WeightNorm.cpp",
"aten/src/ATen/native/group_norm.cpp",
"aten/src/ATen/native/layer_norm.cpp",
"aten/src/ATen/native/AmpKernels.cpp",
"aten/src/ATen/native/mkl/LinearAlgebra.cpp",
"aten/src/ATen/native/mkl/SparseBlasImpl.cpp",
"aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp",

View File

@ -1098,28 +1098,6 @@ torch.cuda.synchronize()
self.assertTrue(r != 0)
def test_grad_scaling_update_scale(self, device="cuda", dtype=torch.float):
growth = 2.0
backoff = 0.25
growth_interval = 2
scale = torch.full((1,), 4.0, dtype=dtype, device=device)
growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device)
found_inf = torch.full((1,), 0.0, dtype=torch.float, device="cuda:0")
# Simulates 2 consecutive unskipped iterations
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
self.assertEqual(growth_tracker, 1)
self.assertEqual(scale, 4.0)
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
self.assertEqual(growth_tracker, 0)
self.assertEqual(scale, 8.0)
# Simulates a skipped iteration
found_inf.fill_(1.0)
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
self.assertEqual(growth_tracker, 0)
self.assertEqual(scale, 2.0)
def test_grad_scaling_unscale_sparse(self, device="cuda", dtype=torch.float):
scaler = torch.cuda.amp.GradScaler()

View File

@ -5568,6 +5568,81 @@ else:
self._test_multinomial_empty(device, False, 1)
self._test_multinomial_empty(device, False, 2)
@onlyCPU
@dtypes(torch.float, torch.double)
def test_grad_scaling_unscale(self, device, dtype):
inv_scale = torch.full((1,), 0.25, dtype=torch.float, device=device)
found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device)
size = 20
g = torch.full((size, size), 4.0, dtype=dtype, device=device)
ginf = g.clone()
ginf[2, 2] = float('inf')
gnan = g.clone()
gnan[2, 2] = float('nan')
# Tries selected combinations of
# - contiguous grads
# - g.clone().t() which is not contiguous but still non overlapping and dense
# - variants of g.clone()[:, :5] which are not non overlapping and dense
# Non overlapping and dense grads route into a multi tensor apply kernel,
# others use a fallback per-tensor kernel, so we should try both.
cases = (
([g.clone(), g.clone()], False),
([g.clone(), g.clone().t()], False),
([g.clone(), g.clone()[:, :5]], False),
([g.clone()[:, :5], g.clone()[:, :5]], False),
([g.clone(), ginf.clone()], True),
([g.clone(), gnan.clone()], True),
([g.clone(), ginf.clone()[:, :5]], True),
([g.clone(), gnan.clone()[:, :5]], True),
([ginf.clone(), g.clone()[:, :5]], True),
([ginf.clone()[:, :5], g.clone()[:, :5]], True),
)
for grads, has_inf in cases:
found_inf.zero_()
torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
if has_inf:
self.assertEqual(found_inf, 1.0)
else:
self.assertEqual(found_inf, 0.0)
for grad in grads:
self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
# When passing lists with mismatched dtypes to a raw
# _amp_foreach_non_finite_check_and_unscale_ call,
# it's expected to fall back to single-tensor TensorIterator kernel.
grads = [g.clone(), g.to(dtype=torch.float16)]
torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
for grad in grads:
self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
@skipMeta
@onlyNativeDeviceTypes
@dtypes(torch.float)
def test_grad_scaling_update_scale(self, device, dtype):
growth = 2.0
backoff = 0.25
growth_interval = 2
scale = torch.full((1,), 4.0, dtype=dtype, device=device)
growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device)
found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device)
# Simulates 2 consecutive unskipped iterations
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
self.assertEqual(growth_tracker, 1)
self.assertEqual(scale, 4.0)
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
self.assertEqual(growth_tracker, 0)
self.assertEqual(scale, 8.0)
# Simulates a skipped iteration
found_inf.fill_(1.0)
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
self.assertEqual(growth_tracker, 0)
self.assertEqual(scale, 2.0)
@dtypesIfCUDA(torch.float, torch.double, torch.half)
@dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half)
@dtypes(torch.float, torch.double)