mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add _amp_foreach_non_finite_check_and_unscale_cpu_ and _amp_update_scale_cpu_ kernels on CPU (#109281)
Step1 of https://github.com/pytorch/pytorch/issues/111559. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109281 Approved by: https://github.com/jgong5, https://github.com/ezyang
This commit is contained in:
parent
0fa6ee44d9
commit
29516bd2a0
|
|
@ -100,6 +100,10 @@ public:
|
|||
Vectorized<double> isnan() const {
|
||||
return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
__m256d self_sub = _mm256_sub_pd(values, values);
|
||||
return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != 0;
|
||||
}
|
||||
Vectorized<double> map(double (*const f)(double)) const {
|
||||
__at_align__ double tmp[size()];
|
||||
store(tmp);
|
||||
|
|
|
|||
|
|
@ -106,6 +106,12 @@ public:
|
|||
Vectorized<float> isnan() const {
|
||||
return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
|
||||
}
|
||||
|
||||
bool has_inf_nan() const {
|
||||
__m256 self_sub = _mm256_sub_ps(values, values);
|
||||
return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != 0;
|
||||
}
|
||||
|
||||
Vectorized<float> map(float (*const f)(float)) const {
|
||||
__at_align__ float tmp[size()];
|
||||
store(tmp);
|
||||
|
|
|
|||
|
|
@ -307,6 +307,16 @@ public:
|
|||
}
|
||||
return loadu(res);
|
||||
};
|
||||
bool has_inf_nan() const {
|
||||
__at_align__ float tmp[size()];
|
||||
store(tmp);
|
||||
for (const auto i : c10::irange(size())) {
|
||||
if(_isnan(tmp[i]) || _isinf(tmp[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
Vectorized<float> map(float (*const f)(float)) const {
|
||||
__at_align__ float tmp[size()];
|
||||
store(tmp);
|
||||
|
|
|
|||
|
|
@ -383,6 +383,19 @@ class Vectorized<double> {
|
|||
auto ret = (x == x);
|
||||
return ret._nor();
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
for (const auto i : c10::irange(size()/2)) {
|
||||
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(size()/2)) {
|
||||
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
DEFINE_MEMBER_OP(operator==, double, vec_cmpeq)
|
||||
DEFINE_MEMBER_OP(operator!=, double, vec_cmpne)
|
||||
|
|
|
|||
|
|
@ -239,6 +239,20 @@ class Vectorized<float> {
|
|||
return (x == v_inf) | (x == v_minus_inf);
|
||||
}
|
||||
|
||||
bool has_inf_nan() const {
|
||||
for (const auto i : c10::irange(size()/2)) {
|
||||
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(size()/2)) {
|
||||
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int zero_mask() const {
|
||||
// returns an integer mask where all zero elements are translated to 1-bit
|
||||
// and others are translated to 0-bit
|
||||
|
|
|
|||
|
|
@ -875,6 +875,20 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
|
|||
return ret._not();
|
||||
}
|
||||
|
||||
bool has_inf_nan() const {
|
||||
for (const auto i : c10::irange(size()/2)) {
|
||||
if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(size()/2)) {
|
||||
if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <
|
||||
typename U = T,
|
||||
std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
|
||||
|
|
|
|||
|
|
@ -106,6 +106,10 @@ public:
|
|||
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
|
||||
0xFFFFFFFFFFFFFFFF));
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
__m512d self_sub = _mm512_sub_pd(values, values);
|
||||
return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & 0x7777777777777777) != 0;
|
||||
}
|
||||
Vectorized<double> map(double (*const f)(double)) const {
|
||||
__at_align__ double tmp[size()];
|
||||
store(tmp);
|
||||
|
|
|
|||
|
|
@ -125,6 +125,10 @@ public:
|
|||
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
|
||||
0xFFFFFFFF));
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
__m512 self_sub = _mm512_sub_ps(values, values);
|
||||
return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & 0x7777777777777777) != 0;
|
||||
}
|
||||
Vectorized<float> map(float (*const f)(float)) const {
|
||||
__at_align__ float tmp[size()];
|
||||
store(tmp);
|
||||
|
|
|
|||
|
|
@ -255,6 +255,14 @@ public:
|
|||
}
|
||||
return vector;
|
||||
}
|
||||
bool has_inf_nan() const {
|
||||
for (int64_t i = 0; i != size(); i++) {
|
||||
if(_isnan(values[i]) || _isinf(values[i])) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
Vectorized<T> map(T (*const f)(T)) const {
|
||||
Vectorized<T> ret;
|
||||
for (int64_t i = 0; i != size(); i++) {
|
||||
|
|
|
|||
41
aten/src/ATen/native/AmpKernels.cpp
Normal file
41
aten/src/ATen/native/AmpKernels.cpp
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/AmpKernels.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
|
||||
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h>
|
||||
#include <ATen/ops/_amp_update_scale.h>
|
||||
#include <ATen/ops/_amp_update_scale_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void _amp_foreach_non_finite_check_and_unscale_cpu_(
|
||||
TensorList scaled_grads,
|
||||
at::Tensor& found_inf,
|
||||
const at::Tensor& inv_scale) {
|
||||
_amp_foreach_non_finite_check_and_unscale_cpu_stub(
|
||||
found_inf.device().type(), scaled_grads, found_inf, inv_scale);
|
||||
}
|
||||
|
||||
at::Tensor& _amp_update_scale_cpu_ (
|
||||
at::Tensor& current_scale,
|
||||
at::Tensor& growth_tracker,
|
||||
const at::Tensor& found_inf,
|
||||
double growth_factor,
|
||||
double backoff_factor,
|
||||
int64_t growth_interval) {
|
||||
return _amp_update_scale_cpu_stub(
|
||||
growth_tracker.device().type(), current_scale, growth_tracker,
|
||||
found_inf, growth_factor, backoff_factor, growth_interval);
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub);
|
||||
DEFINE_DISPATCH(_amp_update_scale_cpu_stub);
|
||||
|
||||
} // namespace at::native
|
||||
28
aten/src/ATen/native/AmpKernels.h
Normal file
28
aten/src/ATen/native/AmpKernels.h
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
|
||||
namespace at {
|
||||
class Tensor;
|
||||
|
||||
namespace native {
|
||||
|
||||
using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
|
||||
TensorList,
|
||||
Tensor&,
|
||||
const Tensor&);
|
||||
|
||||
using _amp_update_scale_cpu__fn = Tensor& (*)(
|
||||
Tensor&,
|
||||
Tensor&,
|
||||
const Tensor&,
|
||||
double,
|
||||
double,
|
||||
int64_t);
|
||||
|
||||
DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
|
||||
DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
199
aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp
Normal file
199
aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
|
||||
#include <ATen/native/AmpKernels.h>
|
||||
#include <math.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
// Follow the implementations of CUDA.
|
||||
// Multiplies each tensor in scaled_grads by inv_scale in-place.
|
||||
// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf
|
||||
// to 1.0.
|
||||
//
|
||||
// Args:
|
||||
// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or
|
||||
// NaNs. found_inf: A single-element float tensor to which 1.0 will be written
|
||||
// if any gradient contain infs/nans.
|
||||
// Pre-zeroing found_inf, if appropriate, is the responsibility of
|
||||
// the caller.
|
||||
// inv_scale: The inverse of the scale factor by which scaled_grads are
|
||||
// currently multiplied.
|
||||
void _amp_foreach_non_finite_check_and_unscale_cpu_kernel(
|
||||
TensorList scaled_grads,
|
||||
at::Tensor& found_inf,
|
||||
const at::Tensor& inv_scale) {
|
||||
if (scaled_grads.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(inv_scale.is_cpu(), "inv_scale must be a CPU tensor.");
|
||||
TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
|
||||
TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
|
||||
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
|
||||
TORCH_CHECK(
|
||||
inv_scale.scalar_type() == at::ScalarType::Float,
|
||||
"inv_scale must be a float tensor.");
|
||||
TORCH_CHECK(
|
||||
found_inf.scalar_type() == at::ScalarType::Float,
|
||||
"found_inf must be a float tensor.");
|
||||
|
||||
// Ensures client code (GradScaler) filtered scaled_grads by dtype.
|
||||
at::native::check_foreach_api_restrictions(scaled_grads);
|
||||
for (const at::Tensor& t : scaled_grads) {
|
||||
TORCH_CHECK(t.is_cpu(), "one of scaled_grads was not a CPU tensor.");
|
||||
TORCH_CHECK(
|
||||
t.layout() == at::kStrided,
|
||||
"one of scaled_grads was not a strided tensor.");
|
||||
auto iter = at::TensorIterator::unary_op(
|
||||
const_cast<at::Tensor&>(t), const_cast<at::Tensor&>(t));
|
||||
if (at::isReducedFloatingType(iter.dtype())) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(
|
||||
iter.dtype(),
|
||||
"_amp_foreach_non_finite_check_and_unscale_cpu",
|
||||
[&iter, &found_inf, &inv_scale] {
|
||||
auto* found_inf_ptr = found_inf.data_ptr<float>();
|
||||
auto* inv_scale_ptr = inv_scale.data_ptr<float>();
|
||||
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
at::native::cpu_kernel_vec(
|
||||
iter,
|
||||
[found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
|
||||
auto val = static_cast<opmath_t>(val_in);
|
||||
if (!std::isfinite(val)) {
|
||||
*found_inf_ptr = 1.f;
|
||||
}
|
||||
// Every thread accesses inv_scale, but it will hit in cache.
|
||||
const auto inv_scale_val = *inv_scale_ptr;
|
||||
return static_cast<scalar_t>(
|
||||
inv_scale_val == 1.f ? val : val * inv_scale_val);
|
||||
},
|
||||
[found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
|
||||
Vectorized<opmath_t> val_vec0, val_vec1;
|
||||
std::tie(val_vec0, val_vec1) = convert_to_float<scalar_t>(val_vec);
|
||||
if (val_vec0.has_inf_nan() || val_vec1.has_inf_nan()) {
|
||||
*found_inf_ptr = 1.f;
|
||||
}
|
||||
// Every thread accesses inv_scale, but it will hit in cache.
|
||||
const auto inv_scale_val = *inv_scale_ptr;
|
||||
val_vec0 = inv_scale_val == 1.f ? val_vec0 : val_vec0 * Vectorized<opmath_t>(inv_scale_val);
|
||||
val_vec1 = inv_scale_val == 1.f ? val_vec1 : val_vec1 * Vectorized<opmath_t>(inv_scale_val);
|
||||
return convert_from_float<scalar_t>(val_vec0, val_vec1);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
iter.dtype(),
|
||||
"_amp_foreach_non_finite_check_and_unscale_cpu",
|
||||
[&iter, &found_inf, &inv_scale] {
|
||||
auto* found_inf_ptr = found_inf.data_ptr<float>();
|
||||
auto* inv_scale_ptr = inv_scale.data_ptr<float>();
|
||||
at::native::cpu_kernel_vec(
|
||||
iter,
|
||||
[found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t {
|
||||
if (!std::isfinite(val_in)) {
|
||||
*found_inf_ptr = 1.f;
|
||||
}
|
||||
// Every thread accesses inv_scale, but it will hit in cache.
|
||||
const auto inv_scale_val = *inv_scale_ptr;
|
||||
return static_cast<scalar_t>(
|
||||
inv_scale_val == 1.f ? val_in : val_in * inv_scale_val);
|
||||
},
|
||||
[found_inf_ptr, inv_scale_ptr](Vectorized<scalar_t> val_vec) -> Vectorized<scalar_t>{
|
||||
if (val_vec.has_inf_nan()) {
|
||||
*found_inf_ptr = 1.f;
|
||||
}
|
||||
// Every thread accesses inv_scale, but it will hit in cache.
|
||||
const auto inv_scale_val = *inv_scale_ptr;
|
||||
return inv_scale_val == 1.f ? val_vec : val_vec * Vectorized<scalar_t>(inv_scale_val);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// _amp_update_scale_cpu updates the scale tensor in place.
|
||||
//
|
||||
// Args:
|
||||
// current_scale: A one-element float tensor containing the scale value.
|
||||
// growth_tracker: A one-element IntTensor containing the number of recent
|
||||
// consecutive unskipped steps. found_inf: A one-element float tensor. If > 0,
|
||||
// indicates that infs/nans were found by the relevant
|
||||
// prior _amp_non_finite_check_and_unscale_cpu call, and 0 if no
|
||||
// infs/nans were found.
|
||||
// growth_factor: Multiplier if no infs/NaNs were found (typically slightly >
|
||||
// 1). backoff_factor: Multiplier if infs/NaNs were found (typically 0.5).
|
||||
// growth_interval: Number of consecutive unskipped steps that must occur for
|
||||
// current_scale to be multiplied by
|
||||
// growth_factor.
|
||||
//
|
||||
// Returns:
|
||||
// current_scale
|
||||
at::Tensor& _amp_update_scale_cpu_kernel(
|
||||
at::Tensor& current_scale,
|
||||
at::Tensor& growth_tracker,
|
||||
const at::Tensor& found_inf,
|
||||
double growth_factor,
|
||||
double backoff_factor,
|
||||
int64_t growth_interval) {
|
||||
TORCH_CHECK(growth_tracker.is_cpu(), "growth_tracker must be a CPU tensor.");
|
||||
TORCH_CHECK(current_scale.is_cpu(), "current_scale must be a CPU tensor.");
|
||||
TORCH_CHECK(found_inf.is_cpu(), "found_inf must be a CPU tensor.");
|
||||
TORCH_CHECK(
|
||||
growth_tracker.numel() == 1,
|
||||
"growth_tracker must be a 1-element tensor.");
|
||||
TORCH_CHECK(
|
||||
current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
|
||||
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
|
||||
TORCH_CHECK(
|
||||
growth_tracker.scalar_type() == at::ScalarType::Int,
|
||||
"growth_tracker must be an int tensor.");
|
||||
TORCH_CHECK(
|
||||
current_scale.scalar_type() == at::ScalarType::Float,
|
||||
"current_scale must be a float tensor.");
|
||||
TORCH_CHECK(
|
||||
found_inf.scalar_type() == at::ScalarType::Float,
|
||||
"found_inf must be a float tensor.");
|
||||
|
||||
float* current_scale_ptr = current_scale.data_ptr<float>();
|
||||
int* growth_tracker_ptr = growth_tracker.data_ptr<int>();
|
||||
float* found_inf_ptr = found_inf.data_ptr<float>();
|
||||
|
||||
if (*found_inf_ptr) {
|
||||
*current_scale_ptr = (*current_scale_ptr) * backoff_factor;
|
||||
*growth_tracker_ptr = 0;
|
||||
} else {
|
||||
// Entering this branch means we just carried out a successful step,
|
||||
// so growth_tracker is incremented before comparing to growth_interval.
|
||||
auto successful = (*growth_tracker_ptr) + 1;
|
||||
if (successful == growth_interval) {
|
||||
auto new_scale = static_cast<float>((*current_scale_ptr) * growth_factor);
|
||||
// Do not grow the scale past fp32 bounds to inf.
|
||||
if (std::isfinite(new_scale)) {
|
||||
*current_scale_ptr = new_scale;
|
||||
}
|
||||
*growth_tracker_ptr = 0;
|
||||
} else {
|
||||
*growth_tracker_ptr = successful;
|
||||
}
|
||||
}
|
||||
|
||||
return current_scale;
|
||||
}
|
||||
|
||||
} // namepace
|
||||
|
||||
REGISTER_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu_stub, &_amp_foreach_non_finite_check_and_unscale_cpu_kernel);
|
||||
REGISTER_DISPATCH(_amp_update_scale_cpu_stub, &_amp_update_scale_cpu_kernel);
|
||||
|
||||
} // namespace at::native
|
||||
|
|
@ -10138,12 +10138,14 @@
|
|||
variants: function
|
||||
dispatch:
|
||||
CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_
|
||||
CPU: _amp_foreach_non_finite_check_and_unscale_cpu_
|
||||
autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out
|
||||
|
||||
- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _amp_update_scale_cuda_
|
||||
CPU: _amp_update_scale_cpu_
|
||||
autogen: _amp_update_scale, _amp_update_scale.out
|
||||
|
||||
#- func: _cat(Tensor[] tensors, int dim=0) -> Tensor
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ namespace {
|
|||
class FunctionalTests : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class FunctionalTestsReducedFloat : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class InfiniteTests : public ::testing::Test {};
|
||||
using RealFloatTestedTypes = ::testing::Types<vfloat, vdouble>;
|
||||
using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
|
||||
using ALLTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vlong, vint, vshort, vqint8, vquint8, vqint>;
|
||||
|
|
@ -106,6 +108,7 @@ namespace {
|
|||
TYPED_TEST_SUITE(BitwiseFloatsAdditional, RealFloatTestedTypes);
|
||||
TYPED_TEST_SUITE(BitwiseFloatsAdditional2, FloatTestedTypes);
|
||||
TYPED_TEST_SUITE(QuantizationTests, QuantTestedTypes);
|
||||
TYPED_TEST_SUITE(InfiniteTests, RealFloatTestedTypes);
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||
TYPED_TEST_SUITE(
|
||||
Quantization8BitWithTailTests,
|
||||
|
|
@ -1587,6 +1590,36 @@ namespace {
|
|||
<< "Test failed for uint16 to float " << u16 << "\n";
|
||||
}
|
||||
}
|
||||
TYPED_TEST(InfiniteTests, HasInfNan) {
|
||||
using vec = TypeParam;
|
||||
using VT = UholdType<TypeParam>;
|
||||
auto vec_size = vec::size();
|
||||
VT values[20];
|
||||
for (const auto i : c10::irange(20)) {
|
||||
values[i] = i + 0.3;
|
||||
}
|
||||
auto vec_val = vec::loadu(values);
|
||||
auto seed = TestSeed();
|
||||
ValueGen<int> generator(int(0), int(vec_size - 1), seed);
|
||||
int index = generator.get();
|
||||
int nanBits = 0x7FC00000;
|
||||
VT v_nan = static_cast<VT>(*(float *)&nanBits);
|
||||
values[index] = v_nan;
|
||||
auto vec_nan = vec::loadu(values);
|
||||
int infBits = 0x7F800000;
|
||||
VT v_pinf = static_cast<VT>(*(float *)&infBits);
|
||||
values[index] = v_pinf;
|
||||
auto vec_pinf = vec::loadu(values);
|
||||
int negInfBits = 0xFF800000;
|
||||
VT v_ninf = static_cast<VT>(*(float *)&negInfBits);
|
||||
values[index] = v_ninf;
|
||||
auto vec_ninf = vec::loadu(values);
|
||||
|
||||
ASSERT_TRUE(!(vec_val.has_inf_nan())) << "Test failed for normal value\n";
|
||||
ASSERT_TRUE(vec_nan.has_inf_nan()) << "Test failed for NAN\n";
|
||||
ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n";
|
||||
ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n";
|
||||
}
|
||||
|
||||
#else
|
||||
#error GTEST does not have TYPED_TEST
|
||||
|
|
|
|||
|
|
@ -1142,6 +1142,7 @@ aten_native_source_codegen_list = [
|
|||
"aten/src/ATen/native/cpu/batch_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/group_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/layer_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp",
|
||||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
|
||||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
|
||||
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
|
||||
|
|
@ -1342,6 +1343,7 @@ aten_native_source_non_codegen_list = [
|
|||
"aten/src/ATen/native/WeightNorm.cpp",
|
||||
"aten/src/ATen/native/group_norm.cpp",
|
||||
"aten/src/ATen/native/layer_norm.cpp",
|
||||
"aten/src/ATen/native/AmpKernels.cpp",
|
||||
"aten/src/ATen/native/mkl/LinearAlgebra.cpp",
|
||||
"aten/src/ATen/native/mkl/SparseBlasImpl.cpp",
|
||||
"aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp",
|
||||
|
|
|
|||
|
|
@ -1098,28 +1098,6 @@ torch.cuda.synchronize()
|
|||
self.assertTrue(r != 0)
|
||||
|
||||
|
||||
def test_grad_scaling_update_scale(self, device="cuda", dtype=torch.float):
|
||||
growth = 2.0
|
||||
backoff = 0.25
|
||||
growth_interval = 2
|
||||
scale = torch.full((1,), 4.0, dtype=dtype, device=device)
|
||||
growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device)
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float, device="cuda:0")
|
||||
|
||||
# Simulates 2 consecutive unskipped iterations
|
||||
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
|
||||
self.assertEqual(growth_tracker, 1)
|
||||
self.assertEqual(scale, 4.0)
|
||||
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
|
||||
self.assertEqual(growth_tracker, 0)
|
||||
self.assertEqual(scale, 8.0)
|
||||
|
||||
# Simulates a skipped iteration
|
||||
found_inf.fill_(1.0)
|
||||
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
|
||||
self.assertEqual(growth_tracker, 0)
|
||||
self.assertEqual(scale, 2.0)
|
||||
|
||||
def test_grad_scaling_unscale_sparse(self, device="cuda", dtype=torch.float):
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
|
|
|
|||
|
|
@ -5568,6 +5568,81 @@ else:
|
|||
self._test_multinomial_empty(device, False, 1)
|
||||
self._test_multinomial_empty(device, False, 2)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_grad_scaling_unscale(self, device, dtype):
|
||||
inv_scale = torch.full((1,), 0.25, dtype=torch.float, device=device)
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device)
|
||||
|
||||
size = 20
|
||||
g = torch.full((size, size), 4.0, dtype=dtype, device=device)
|
||||
ginf = g.clone()
|
||||
ginf[2, 2] = float('inf')
|
||||
gnan = g.clone()
|
||||
gnan[2, 2] = float('nan')
|
||||
|
||||
# Tries selected combinations of
|
||||
# - contiguous grads
|
||||
# - g.clone().t() which is not contiguous but still non overlapping and dense
|
||||
# - variants of g.clone()[:, :5] which are not non overlapping and dense
|
||||
# Non overlapping and dense grads route into a multi tensor apply kernel,
|
||||
# others use a fallback per-tensor kernel, so we should try both.
|
||||
cases = (
|
||||
([g.clone(), g.clone()], False),
|
||||
([g.clone(), g.clone().t()], False),
|
||||
([g.clone(), g.clone()[:, :5]], False),
|
||||
([g.clone()[:, :5], g.clone()[:, :5]], False),
|
||||
([g.clone(), ginf.clone()], True),
|
||||
([g.clone(), gnan.clone()], True),
|
||||
([g.clone(), ginf.clone()[:, :5]], True),
|
||||
([g.clone(), gnan.clone()[:, :5]], True),
|
||||
([ginf.clone(), g.clone()[:, :5]], True),
|
||||
([ginf.clone()[:, :5], g.clone()[:, :5]], True),
|
||||
)
|
||||
|
||||
for grads, has_inf in cases:
|
||||
found_inf.zero_()
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
|
||||
if has_inf:
|
||||
self.assertEqual(found_inf, 1.0)
|
||||
else:
|
||||
self.assertEqual(found_inf, 0.0)
|
||||
for grad in grads:
|
||||
self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
|
||||
|
||||
# When passing lists with mismatched dtypes to a raw
|
||||
# _amp_foreach_non_finite_check_and_unscale_ call,
|
||||
# it's expected to fall back to single-tensor TensorIterator kernel.
|
||||
grads = [g.clone(), g.to(dtype=torch.float16)]
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale)
|
||||
for grad in grads:
|
||||
self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7)
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float)
|
||||
def test_grad_scaling_update_scale(self, device, dtype):
|
||||
growth = 2.0
|
||||
backoff = 0.25
|
||||
growth_interval = 2
|
||||
scale = torch.full((1,), 4.0, dtype=dtype, device=device)
|
||||
growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device)
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device)
|
||||
|
||||
# Simulates 2 consecutive unskipped iterations
|
||||
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
|
||||
self.assertEqual(growth_tracker, 1)
|
||||
self.assertEqual(scale, 4.0)
|
||||
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
|
||||
self.assertEqual(growth_tracker, 0)
|
||||
self.assertEqual(scale, 8.0)
|
||||
|
||||
# Simulates a skipped iteration
|
||||
found_inf.fill_(1.0)
|
||||
torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval)
|
||||
self.assertEqual(growth_tracker, 0)
|
||||
self.assertEqual(scale, 2.0)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.double, torch.half)
|
||||
@dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half)
|
||||
@dtypes(torch.float, torch.double)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user