mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable torch.empty for float8 dtypes + deterministic mode + cpu (#128744)
Summary: Enables creating empty float8 tensors for: * cuda when `torch.use_deterministic_algorithms` is set to True * cpu for all settings of `torch.use_deterministic_algorithms` Context for NaN values of float8_e4m3fn and float8_e5m2: https://arxiv.org/pdf/2209.05433, Section 3, Table 1 Context for NaN values of float8_e4m3fnuz and float8_e5m2fnuz: https://arxiv.org/pdf/2206.02915, Section 3.2, "instead of reserving one exponent field to represent Inf and NaN, we reserve only a single codeword (corresponding to negative zero)" Test Plan: ``` python test/test_quantization.py -k test_empty ``` Reviewers: Subscribers: Tasks: Tags: Fixes https://github.com/pytorch/pytorch/issues/128733 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128744 Approved by: https://github.com/malfet, https://github.com/drisspg
This commit is contained in:
parent
846bb30e13
commit
2d01f87737
|
|
@ -103,10 +103,10 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
|
||||||
// with max value if it is integer type
|
// with max value if it is integer type
|
||||||
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
|
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
|
||||||
if (tensor.is_floating_point() || tensor.is_complex()) {
|
if (tensor.is_floating_point() || tensor.is_complex()) {
|
||||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
AT_DISPATCH_V2(
|
||||||
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
|
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
||||||
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
|
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
|
||||||
});
|
}), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf);
|
||||||
} else {
|
} else {
|
||||||
AT_DISPATCH_V2(
|
AT_DISPATCH_V2(
|
||||||
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,14 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
|
||||||
fill_non_native_type<at::BFloat16>(iter, value_scalar);
|
fill_non_native_type<at::BFloat16>(iter, value_scalar);
|
||||||
} else if (iter.dtype() == ScalarType::ComplexHalf) {
|
} else if (iter.dtype() == ScalarType::ComplexHalf) {
|
||||||
fill_non_native_type<c10::complex<at::Half>>(iter, value_scalar);
|
fill_non_native_type<c10::complex<at::Half>>(iter, value_scalar);
|
||||||
|
} else if (iter.dtype() == ScalarType::Float8_e4m3fn) {
|
||||||
|
fill_non_native_type<at::Float8_e4m3fn>(iter, value_scalar);
|
||||||
|
} else if (iter.dtype() == ScalarType::Float8_e5m2) {
|
||||||
|
fill_non_native_type<at::Float8_e5m2>(iter, value_scalar);
|
||||||
|
} else if (iter.dtype() == ScalarType::Float8_e4m3fnuz) {
|
||||||
|
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
|
||||||
|
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
|
||||||
|
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
|
||||||
} else {
|
} else {
|
||||||
AT_DISPATCH_V2(
|
AT_DISPATCH_V2(
|
||||||
iter.dtype(), "fill_cpu", AT_WRAP([&]() {
|
iter.dtype(), "fill_cpu", AT_WRAP([&]() {
|
||||||
|
|
|
||||||
|
|
@ -235,7 +235,7 @@ class numeric_limits<c10::Float8_e5m2> {
|
||||||
static constexpr bool is_specialized = true;
|
static constexpr bool is_specialized = true;
|
||||||
static constexpr bool is_exact = false;
|
static constexpr bool is_exact = false;
|
||||||
static constexpr bool has_infinity = true;
|
static constexpr bool has_infinity = true;
|
||||||
static constexpr bool has_quiet_NaN = false;
|
static constexpr bool has_quiet_NaN = true;
|
||||||
static constexpr bool has_signaling_NaN = false;
|
static constexpr bool has_signaling_NaN = false;
|
||||||
static constexpr auto has_denorm = true;
|
static constexpr auto has_denorm = true;
|
||||||
static constexpr auto has_denorm_loss = true;
|
static constexpr auto has_denorm_loss = true;
|
||||||
|
|
@ -273,6 +273,9 @@ class numeric_limits<c10::Float8_e5m2> {
|
||||||
static constexpr c10::Float8_e5m2 infinity() {
|
static constexpr c10::Float8_e5m2 infinity() {
|
||||||
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
|
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
|
||||||
}
|
}
|
||||||
|
static constexpr c10::Float8_e5m2 quiet_NaN() {
|
||||||
|
return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
|
||||||
|
}
|
||||||
static constexpr c10::Float8_e5m2 denorm_min() {
|
static constexpr c10::Float8_e5m2 denorm_min() {
|
||||||
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
|
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -270,6 +270,11 @@ class numeric_limits<c10::Float8_e5m2fnuz> {
|
||||||
static constexpr c10::Float8_e5m2fnuz infinity() {
|
static constexpr c10::Float8_e5m2fnuz infinity() {
|
||||||
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||||
}
|
}
|
||||||
|
// TODO(future): we are mapping neg_zero to both inf and NaN, this is
|
||||||
|
// surprising and we should figure out what to do about it.
|
||||||
|
static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
|
||||||
|
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||||
|
}
|
||||||
static constexpr c10::Float8_e5m2fnuz denorm_min() {
|
static constexpr c10::Float8_e5m2fnuz denorm_min() {
|
||||||
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
|
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
DeterministicGuard,
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
|
|
@ -259,6 +260,14 @@ class TestFloat8Dtype(TestCase):
|
||||||
):
|
):
|
||||||
x + y
|
x + y
|
||||||
|
|
||||||
|
@dtypes(*FLOAT8_DTYPES)
|
||||||
|
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||||
|
def test_empty(self, dtype, device):
|
||||||
|
with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
|
||||||
|
for use_deterministic in (True, False):
|
||||||
|
torch.use_deterministic_algorithms(use_deterministic)
|
||||||
|
x = torch.empty(4, 4, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestFloat8Dtype, globals())
|
instantiate_device_type_tests(TestFloat8Dtype, globals())
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user