[fix] torch.frac : Handle inf correctly (#52678)

Summary:
Fixes : https://github.com/pytorch/pytorch/issues/51948
Fixes : https://github.com/pytorch/pytorch/issues/52027

Depends On: https://github.com/pytorch/pytorch/issues/52660

TODO
* [x] Benchmark

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52678

Reviewed By: anjali411

Differential Revision: D27566407

Pulled By: heitorschueroff

fbshipit-source-id: 92c7309558ee41f8b9c641f791e8f84819c333e2
This commit is contained in:
kshitij12345 2021-04-07 02:25:42 -07:00 committed by Facebook GitHub Bot
parent bc05867618
commit ece075195d
7 changed files with 75 additions and 16 deletions

View File

@ -332,7 +332,12 @@ public:
return map(std::expm1);
}
Vec256<T> frac() const {
return *this - this->trunc();
return map([](const T& x) -> T {
if (std::isinf(x)) {
return std::copysign(static_cast<T>(0), x);
}
return x - at::native::trunc_impl(x);
});
}
template <
typename U = T,

View File

@ -574,7 +574,24 @@ Vec256<BFloat16> Vec256<BFloat16>::le(const Vec256<BFloat16>& other) const {
// frac. Implement this here so we can use subtraction
Vec256<BFloat16> Vec256<BFloat16>::frac() const {
return *this - this->trunc();
__m256 lo, hi;
cvtbf16_fp32(values, lo, hi);
auto frac_lambda = [](__m256 values){
const auto pos_inf_vec = _mm256_set1_ps(INFINITY);
const auto pos_zero_vec = _mm256_set1_ps(0.f);
const auto neg_inf_vec = _mm256_set1_ps(-INFINITY);
const auto neg_zero_vec = _mm256_set1_ps(-0.f);
const auto pos_inf_mask = _mm256_cmp_ps(values, pos_inf_vec, _CMP_EQ_OQ);
const auto neg_inf_mask = _mm256_cmp_ps(values, neg_inf_vec, _CMP_EQ_OQ);
const auto trunc = _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
auto frac = _mm256_sub_ps(values, trunc);
frac = _mm256_blendv_ps(frac, pos_zero_vec, pos_inf_mask);
frac = _mm256_blendv_ps(frac, neg_zero_vec, neg_inf_mask);
return frac;
};
auto o1 = frac_lambda(lo);
auto o2 = frac_lambda(hi);
return cvtfp32_bf16(o1, o2);
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if

View File

@ -305,7 +305,16 @@ Vec256<double> inline operator/(const Vec256<double>& a, const Vec256<double>& b
// frac. Implement this here so we can use subtraction.
Vec256<double> Vec256<double>::frac() const {
return *this - this->trunc();
const auto pos_inf_vec = _mm256_set1_pd(INFINITY);
const auto pos_zero_vec = _mm256_set1_pd(0.f);
const auto neg_inf_vec = _mm256_set1_pd(-INFINITY);
const auto neg_zero_vec = _mm256_set1_pd(-0.f);
const auto pos_inf_mask = _mm256_cmp_pd(values, pos_inf_vec, _CMP_EQ_OQ);
const auto neg_inf_mask = _mm256_cmp_pd(values, neg_inf_vec, _CMP_EQ_OQ);
auto frac = *this - this->trunc();
frac = _mm256_blendv_pd(frac, pos_zero_vec, pos_inf_mask);
frac = _mm256_blendv_pd(frac, neg_zero_vec, neg_inf_mask);
return frac;
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if

View File

@ -312,7 +312,16 @@ Vec256<float> inline operator/(const Vec256<float>& a, const Vec256<float>& b) {
// frac. Implement this here so we can use subtraction
Vec256<float> Vec256<float>::frac() const {
return *this - this->trunc();
const auto pos_inf_vec = _mm256_set1_ps(INFINITY);
const auto pos_zero_vec = _mm256_set1_ps(0.f);
const auto neg_inf_vec = _mm256_set1_ps(-INFINITY);
const auto neg_zero_vec = _mm256_set1_ps(-0.f);
const auto pos_inf_mask = _mm256_cmp_ps(values, pos_inf_vec, _CMP_EQ_OQ);
const auto neg_inf_mask = _mm256_cmp_ps(values, neg_inf_vec, _CMP_EQ_OQ);
auto frac = *this - this->trunc();
frac = _mm256_blendv_ps(frac, pos_zero_vec, pos_inf_mask);
frac = _mm256_blendv_ps(frac, neg_zero_vec, neg_inf_mask);
return frac;
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if

View File

@ -227,10 +227,16 @@ static void bitwise_not_kernel(TensorIterator& iter) {
}
static void frac_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "frac_cpu", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.dtype(), "frac_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a - std::trunc(a); },
[=](scalar_t a) -> scalar_t {
if (std::isinf(a)) {
return std::copysign(static_cast<scalar_t>(0), a);
}
return a - ::trunc(a);
},
[=](Vec256<scalar_t> a) { return a.frac(); });
});
}

View File

@ -29,10 +29,26 @@ void ceil_kernel_cuda(TensorIterator& iter) {
});
}
template <typename T>
__host__ __device__ static inline T frac_wrapper(T a) {
if (::isinf(a)) {
return ::copysign(static_cast<T>(0), a);
}
return a - ::trunc(a);
}
template <>
__host__ __device__ inline c10::Half frac_wrapper(c10::Half a) {
if (::isinf(static_cast<float>(a))) {
return ::copysign(static_cast<c10::Half>(0), a);
}
return a - ::trunc(a);
}
void frac_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "frac_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
return a - ::trunc(a);
return frac_wrapper(a);
});
});
}

View File

@ -2772,10 +2772,7 @@ op_db: List[OpInfo] = [
dtypes=floating_types_and(torch.bfloat16, torch.float16),
dtypesIfCPU=floating_types_and(torch.bfloat16, torch.float16),
dtypesIfCUDA=floating_types_and(torch.float16),
assert_autodiffed=True,
# Reference for disabling extremals
# https://github.com/pytorch/pytorch/issues/51948
handles_extremals=False),
assert_autodiffed=True),
SpectralFuncInfo('fft.fft',
aten_name='fft_fft',
ref=np.fft.fft,