mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fix] torch.frac : Handle inf correctly (#52678)
Summary: Fixes : https://github.com/pytorch/pytorch/issues/51948 Fixes : https://github.com/pytorch/pytorch/issues/52027 Depends On: https://github.com/pytorch/pytorch/issues/52660 TODO * [x] Benchmark Pull Request resolved: https://github.com/pytorch/pytorch/pull/52678 Reviewed By: anjali411 Differential Revision: D27566407 Pulled By: heitorschueroff fbshipit-source-id: 92c7309558ee41f8b9c641f791e8f84819c333e2
This commit is contained in:
parent
bc05867618
commit
ece075195d
|
|
@ -332,7 +332,12 @@ public:
|
|||
return map(std::expm1);
|
||||
}
|
||||
Vec256<T> frac() const {
|
||||
return *this - this->trunc();
|
||||
return map([](const T& x) -> T {
|
||||
if (std::isinf(x)) {
|
||||
return std::copysign(static_cast<T>(0), x);
|
||||
}
|
||||
return x - at::native::trunc_impl(x);
|
||||
});
|
||||
}
|
||||
template <
|
||||
typename U = T,
|
||||
|
|
|
|||
|
|
@ -574,7 +574,24 @@ Vec256<BFloat16> Vec256<BFloat16>::le(const Vec256<BFloat16>& other) const {
|
|||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
Vec256<BFloat16> Vec256<BFloat16>::frac() const {
|
||||
return *this - this->trunc();
|
||||
__m256 lo, hi;
|
||||
cvtbf16_fp32(values, lo, hi);
|
||||
auto frac_lambda = [](__m256 values){
|
||||
const auto pos_inf_vec = _mm256_set1_ps(INFINITY);
|
||||
const auto pos_zero_vec = _mm256_set1_ps(0.f);
|
||||
const auto neg_inf_vec = _mm256_set1_ps(-INFINITY);
|
||||
const auto neg_zero_vec = _mm256_set1_ps(-0.f);
|
||||
const auto pos_inf_mask = _mm256_cmp_ps(values, pos_inf_vec, _CMP_EQ_OQ);
|
||||
const auto neg_inf_mask = _mm256_cmp_ps(values, neg_inf_vec, _CMP_EQ_OQ);
|
||||
const auto trunc = _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
|
||||
auto frac = _mm256_sub_ps(values, trunc);
|
||||
frac = _mm256_blendv_ps(frac, pos_zero_vec, pos_inf_mask);
|
||||
frac = _mm256_blendv_ps(frac, neg_zero_vec, neg_inf_mask);
|
||||
return frac;
|
||||
};
|
||||
auto o1 = frac_lambda(lo);
|
||||
auto o2 = frac_lambda(hi);
|
||||
return cvtfp32_bf16(o1, o2);
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||
|
|
|
|||
|
|
@ -305,7 +305,16 @@ Vec256<double> inline operator/(const Vec256<double>& a, const Vec256<double>& b
|
|||
|
||||
// frac. Implement this here so we can use subtraction.
|
||||
Vec256<double> Vec256<double>::frac() const {
|
||||
return *this - this->trunc();
|
||||
const auto pos_inf_vec = _mm256_set1_pd(INFINITY);
|
||||
const auto pos_zero_vec = _mm256_set1_pd(0.f);
|
||||
const auto neg_inf_vec = _mm256_set1_pd(-INFINITY);
|
||||
const auto neg_zero_vec = _mm256_set1_pd(-0.f);
|
||||
const auto pos_inf_mask = _mm256_cmp_pd(values, pos_inf_vec, _CMP_EQ_OQ);
|
||||
const auto neg_inf_mask = _mm256_cmp_pd(values, neg_inf_vec, _CMP_EQ_OQ);
|
||||
auto frac = *this - this->trunc();
|
||||
frac = _mm256_blendv_pd(frac, pos_zero_vec, pos_inf_mask);
|
||||
frac = _mm256_blendv_pd(frac, neg_zero_vec, neg_inf_mask);
|
||||
return frac;
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||
|
|
|
|||
|
|
@ -312,7 +312,16 @@ Vec256<float> inline operator/(const Vec256<float>& a, const Vec256<float>& b) {
|
|||
|
||||
// frac. Implement this here so we can use subtraction
|
||||
Vec256<float> Vec256<float>::frac() const {
|
||||
return *this - this->trunc();
|
||||
const auto pos_inf_vec = _mm256_set1_ps(INFINITY);
|
||||
const auto pos_zero_vec = _mm256_set1_ps(0.f);
|
||||
const auto neg_inf_vec = _mm256_set1_ps(-INFINITY);
|
||||
const auto neg_zero_vec = _mm256_set1_ps(-0.f);
|
||||
const auto pos_inf_mask = _mm256_cmp_ps(values, pos_inf_vec, _CMP_EQ_OQ);
|
||||
const auto neg_inf_mask = _mm256_cmp_ps(values, neg_inf_vec, _CMP_EQ_OQ);
|
||||
auto frac = *this - this->trunc();
|
||||
frac = _mm256_blendv_ps(frac, pos_zero_vec, pos_inf_mask);
|
||||
frac = _mm256_blendv_ps(frac, neg_zero_vec, neg_inf_mask);
|
||||
return frac;
|
||||
}
|
||||
|
||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
||||
|
|
|
|||
|
|
@ -227,12 +227,18 @@ static void bitwise_not_kernel(TensorIterator& iter) {
|
|||
}
|
||||
|
||||
static void frac_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "frac_cpu", [&]() {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t { return a - std::trunc(a); },
|
||||
[=](Vec256<scalar_t> a) { return a.frac(); });
|
||||
});
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, iter.dtype(), "frac_cpu", [&]() {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t {
|
||||
if (std::isinf(a)) {
|
||||
return std::copysign(static_cast<scalar_t>(0), a);
|
||||
}
|
||||
return a - ::trunc(a);
|
||||
},
|
||||
[=](Vec256<scalar_t> a) { return a.frac(); });
|
||||
});
|
||||
}
|
||||
|
||||
static void logical_not_kernel(TensorIterator& iter) {
|
||||
|
|
|
|||
|
|
@ -29,10 +29,26 @@ void ceil_kernel_cuda(TensorIterator& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ static inline T frac_wrapper(T a) {
|
||||
if (::isinf(a)) {
|
||||
return ::copysign(static_cast<T>(0), a);
|
||||
}
|
||||
return a - ::trunc(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ inline c10::Half frac_wrapper(c10::Half a) {
|
||||
if (::isinf(static_cast<float>(a))) {
|
||||
return ::copysign(static_cast<c10::Half>(0), a);
|
||||
}
|
||||
return a - ::trunc(a);
|
||||
}
|
||||
|
||||
void frac_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "frac_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return a - ::trunc(a);
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
return frac_wrapper(a);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2772,10 +2772,7 @@ op_db: List[OpInfo] = [
|
|||
dtypes=floating_types_and(torch.bfloat16, torch.float16),
|
||||
dtypesIfCPU=floating_types_and(torch.bfloat16, torch.float16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16),
|
||||
assert_autodiffed=True,
|
||||
# Reference for disabling extremals
|
||||
# https://github.com/pytorch/pytorch/issues/51948
|
||||
handles_extremals=False),
|
||||
assert_autodiffed=True),
|
||||
SpectralFuncInfo('fft.fft',
|
||||
aten_name='fft_fft',
|
||||
ref=np.fft.fft,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user