mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Enable torch.empty for float8 dtypes + deterministic mode + cpu (#128744)
Summary: Enables creating empty float8 tensors for: * cuda when `torch.use_deterministic_algorithms` is set to True * cpu for all settings of `torch.use_deterministic_algorithms` Context for NaN values of float8_e4m3fn and float8_e5m2: https://arxiv.org/pdf/2209.05433, Section 3, Table 1 Context for NaN values of float8_e4m3fnuz and float8_e5m2fnuz: https://arxiv.org/pdf/2206.02915, Section 3.2, "instead of reserving one exponent field to represent Inf and NaN, we reserve only a single codeword (corresponding to negative zero)" Test Plan: ``` python test/test_quantization.py -k test_empty ``` Reviewers: Subscribers: Tasks: Tags: Fixes https://github.com/pytorch/pytorch/issues/128733 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128744 Approved by: https://github.com/malfet, https://github.com/drisspg
This commit is contained in:
parent
846bb30e13
commit
2d01f87737
|
|
@ -103,10 +103,10 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
|
|||
// with max value if it is integer type
|
||||
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
|
||||
if (tensor.is_floating_point() || tensor.is_complex()) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
|
||||
AT_DISPATCH_V2(
|
||||
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
||||
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
|
||||
});
|
||||
}), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf);
|
||||
} else {
|
||||
AT_DISPATCH_V2(
|
||||
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
||||
|
|
|
|||
|
|
@ -43,6 +43,14 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
|
|||
fill_non_native_type<at::BFloat16>(iter, value_scalar);
|
||||
} else if (iter.dtype() == ScalarType::ComplexHalf) {
|
||||
fill_non_native_type<c10::complex<at::Half>>(iter, value_scalar);
|
||||
} else if (iter.dtype() == ScalarType::Float8_e4m3fn) {
|
||||
fill_non_native_type<at::Float8_e4m3fn>(iter, value_scalar);
|
||||
} else if (iter.dtype() == ScalarType::Float8_e5m2) {
|
||||
fill_non_native_type<at::Float8_e5m2>(iter, value_scalar);
|
||||
} else if (iter.dtype() == ScalarType::Float8_e4m3fnuz) {
|
||||
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
|
||||
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
|
||||
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
|
||||
} else {
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(), "fill_cpu", AT_WRAP([&]() {
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ class numeric_limits<c10::Float8_e5m2> {
|
|||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr bool has_quiet_NaN = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
|
|
@ -273,6 +273,9 @@ class numeric_limits<c10::Float8_e5m2> {
|
|||
static constexpr c10::Float8_e5m2 infinity() {
|
||||
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 quiet_NaN() {
|
||||
return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 denorm_min() {
|
||||
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -270,6 +270,11 @@ class numeric_limits<c10::Float8_e5m2fnuz> {
|
|||
static constexpr c10::Float8_e5m2fnuz infinity() {
|
||||
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
// TODO(future): we are mapping neg_zero to both inf and NaN, this is
|
||||
// surprising and we should figure out what to do about it.
|
||||
static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
|
||||
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz denorm_min() {
|
||||
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torch.testing._internal.common_device_type import (
|
|||
instantiate_device_type_tests,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
DeterministicGuard,
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
|
|
@ -259,6 +260,14 @@ class TestFloat8Dtype(TestCase):
|
|||
):
|
||||
x + y
|
||||
|
||||
@dtypes(*FLOAT8_DTYPES)
|
||||
@dtypesIfCUDA(*CUDA_FLOAT8_DTYPES)
|
||||
def test_empty(self, dtype, device):
|
||||
with DeterministicGuard(torch.are_deterministic_algorithms_enabled()):
|
||||
for use_deterministic in (True, False):
|
||||
torch.use_deterministic_algorithms(use_deterministic)
|
||||
x = torch.empty(4, 4, device=device, dtype=dtype)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFloat8Dtype, globals())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user