diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index 58cbbfc4df3..d7a4d6483f6 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -103,10 +103,10 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens // with max value if it is integer type inline Tensor& fill_empty_deterministic_(Tensor& tensor) { if (tensor.is_floating_point() || tensor.is_complex()) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() { + AT_DISPATCH_V2( + tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() { tensor.fill_(std::numeric_limits::quiet_NaN()); - }); + }), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf); } else { AT_DISPATCH_V2( tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() { diff --git a/aten/src/ATen/native/cpu/FillKernel.cpp b/aten/src/ATen/native/cpu/FillKernel.cpp index a24de0b48f9..43a562306e3 100644 --- a/aten/src/ATen/native/cpu/FillKernel.cpp +++ b/aten/src/ATen/native/cpu/FillKernel.cpp @@ -43,6 +43,14 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::ComplexHalf) { fill_non_native_type>(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e4m3fn) { + fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e5m2) { + fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e4m3fnuz) { + fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) { + fill_non_native_type(iter, value_scalar); } else { AT_DISPATCH_V2( iter.dtype(), "fill_cpu", AT_WRAP([&]() { diff --git a/c10/util/Float8_e5m2-inl.h b/c10/util/Float8_e5m2-inl.h index 7800ceb2992..5a5c1a5fc9b 100644 --- a/c10/util/Float8_e5m2-inl.h +++ b/c10/util/Float8_e5m2-inl.h @@ -235,7 +235,7 @@ class numeric_limits { static constexpr bool is_specialized = true; static constexpr bool is_exact = false; static constexpr bool has_infinity = true; - static constexpr bool has_quiet_NaN = false; + static constexpr bool has_quiet_NaN = true; static constexpr bool has_signaling_NaN = false; static constexpr auto has_denorm = true; static constexpr auto has_denorm_loss = true; @@ -273,6 +273,9 @@ class numeric_limits { static constexpr c10::Float8_e5m2 infinity() { return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits()); } + static constexpr c10::Float8_e5m2 quiet_NaN() { + return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits()); + } static constexpr c10::Float8_e5m2 denorm_min() { return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits()); } diff --git a/c10/util/Float8_e5m2fnuz-inl.h b/c10/util/Float8_e5m2fnuz-inl.h index 3af233a87b8..d81054cbee3 100644 --- a/c10/util/Float8_e5m2fnuz-inl.h +++ b/c10/util/Float8_e5m2fnuz-inl.h @@ -270,6 +270,11 @@ class numeric_limits { static constexpr c10::Float8_e5m2fnuz infinity() { return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); } + // TODO(future): we are mapping neg_zero to both inf and NaN, this is + // surprising and we should figure out what to do about it. + static constexpr c10::Float8_e5m2fnuz quiet_NaN() { + return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); + } static constexpr c10::Float8_e5m2fnuz denorm_min() { return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits()); } diff --git a/test/quantization/core/experimental/test_float8.py b/test/quantization/core/experimental/test_float8.py index 1f735f29e32..feb14e2cbad 100644 --- a/test/quantization/core/experimental/test_float8.py +++ b/test/quantization/core/experimental/test_float8.py @@ -9,6 +9,7 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ) from torch.testing._internal.common_utils import ( + DeterministicGuard, IS_WINDOWS, parametrize, run_tests, @@ -259,6 +260,14 @@ class TestFloat8Dtype(TestCase): ): x + y + @dtypes(*FLOAT8_DTYPES) + @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) + def test_empty(self, dtype, device): + with DeterministicGuard(torch.are_deterministic_algorithms_enabled()): + for use_deterministic in (True, False): + torch.use_deterministic_algorithms(use_deterministic) + x = torch.empty(4, 4, device=device, dtype=dtype) + instantiate_device_type_tests(TestFloat8Dtype, globals())