Enable torch.empty for float8 dtypes + deterministic mode + cpu (#128744)

Summary:

Enables creating empty float8 tensors for:
* cuda when `torch.use_deterministic_algorithms` is set to True
* cpu for all settings of `torch.use_deterministic_algorithms`

Context for NaN values of float8_e4m3fn and float8_e5m2: https://arxiv.org/pdf/2209.05433, Section 3, Table 1

Context for NaN values of float8_e4m3fnuz and float8_e5m2fnuz: https://arxiv.org/pdf/2206.02915, Section 3.2, "instead of reserving one exponent field to represent Inf and NaN, we reserve only a single codeword (corresponding to negative zero)"

Test Plan:

```
python test/test_quantization.py -k test_empty
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes https://github.com/pytorch/pytorch/issues/128733

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128744
Approved by: https://github.com/malfet, https://github.com/drisspg
This commit is contained in:
vasiliy 2024-06-15 02:05:27 +00:00 committed by PyTorch MergeBot
parent 846bb30e13
commit 2d01f87737
5 changed files with 29 additions and 4 deletions

View File

@ -103,10 +103,10 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
// with max value if it is integer type // with max value if it is integer type
inline Tensor& fill_empty_deterministic_(Tensor& tensor) { inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
if (tensor.is_floating_point() || tensor.is_complex()) { if (tensor.is_floating_point() || tensor.is_complex()) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( AT_DISPATCH_V2(
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() { tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN()); tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
}); }), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf);
} else { } else {
AT_DISPATCH_V2( AT_DISPATCH_V2(
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() { tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {

View File

@ -43,6 +43,14 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
fill_non_native_type<at::BFloat16>(iter, value_scalar); fill_non_native_type<at::BFloat16>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::ComplexHalf) { } else if (iter.dtype() == ScalarType::ComplexHalf) {
fill_non_native_type<c10::complex<at::Half>>(iter, value_scalar); fill_non_native_type<c10::complex<at::Half>>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e4m3fn) {
fill_non_native_type<at::Float8_e4m3fn>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e5m2) {
fill_non_native_type<at::Float8_e5m2>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e4m3fnuz) {
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
} else { } else {
AT_DISPATCH_V2( AT_DISPATCH_V2(
iter.dtype(), "fill_cpu", AT_WRAP([&]() { iter.dtype(), "fill_cpu", AT_WRAP([&]() {

View File

@ -235,7 +235,7 @@ class numeric_limits<c10::Float8_e5m2> {
static constexpr bool is_specialized = true; static constexpr bool is_specialized = true;
static constexpr bool is_exact = false; static constexpr bool is_exact = false;
static constexpr bool has_infinity = true; static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = false; static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false; static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true; static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true; static constexpr auto has_denorm_loss = true;
@ -273,6 +273,9 @@ class numeric_limits<c10::Float8_e5m2> {
static constexpr c10::Float8_e5m2 infinity() { static constexpr c10::Float8_e5m2 infinity() {
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits()); return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
} }
static constexpr c10::Float8_e5m2 quiet_NaN() {
return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 denorm_min() { static constexpr c10::Float8_e5m2 denorm_min() {
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits()); return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
} }

View File

@ -270,6 +270,11 @@ class numeric_limits<c10::Float8_e5m2fnuz> {
static constexpr c10::Float8_e5m2fnuz infinity() { static constexpr c10::Float8_e5m2fnuz infinity() {
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
} }
// TODO(future): we are mapping neg_zero to both inf and NaN, this is
// surprising and we should figure out what to do about it.
static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
}
static constexpr c10::Float8_e5m2fnuz denorm_min() { static constexpr c10::Float8_e5m2fnuz denorm_min() {
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits()); return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
} }

View File

@ -9,6 +9,7 @@ from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, instantiate_device_type_tests,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
DeterministicGuard,
IS_WINDOWS, IS_WINDOWS,
parametrize, parametrize,
run_tests, run_tests,
@ -259,6 +260,14 @@ class TestFloat8Dtype(TestCase):
): ):
x + y x + y
@dtypes(*FLOAT8_DTYPES)
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
def test_empty(self, dtype, device):
with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
for use_deterministic in (True, False):
torch.use_deterministic_algorithms(use_deterministic)
x = torch.empty(4, 4, device=device, dtype=dtype)
instantiate_device_type_tests(TestFloat8Dtype, globals()) instantiate_device_type_tests(TestFloat8Dtype, globals())