mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Reland: Make torch.empty* deterministic by filling with NaN or max int (#104995)
Relands #101849 after #104302 reverted it. torchrec PR https://github.com/pytorch/torchrec/pull/1269 fixes the torchrec failure that caused #101849 to be reverted Part of #82004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104995 Approved by: https://github.com/albanD
This commit is contained in:
parent
42530c17fc
commit
f987d11fa7
|
|
@ -7,6 +7,7 @@
|
|||
#include <torch/library.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/native/mps/Copy.h>
|
||||
|
||||
#define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled"
|
||||
|
|
@ -63,6 +64,10 @@ TensorBase empty_mps(
|
|||
|
||||
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
|
||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
at::native::fill_empty_deterministic_(tensor);
|
||||
}
|
||||
return tensor;
|
||||
} else {
|
||||
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
|
||||
|
|
@ -100,8 +105,13 @@ TensorBase empty_strided_mps(
|
|||
const DeviceGuard device_guard(device);
|
||||
auto* allocator = at::mps::GetMPSAllocator();
|
||||
constexpr c10::DispatchKeySet mps_dks(c10::DispatchKey::MPS);
|
||||
return at::detail::empty_strided_generic(
|
||||
Tensor result = at::detail::empty_strided_generic(
|
||||
size, stride, allocator, mps_dks, dtype);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
at::native::fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -253,7 +253,12 @@ Tensor polar(const Tensor& abs, const Tensor& angle) {
|
|||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor empty_names(
|
||||
|
|
@ -320,7 +325,12 @@ Tensor empty_permuted_symint(SymIntArrayRef size, IntArrayRef physical_layout, c
|
|||
|
||||
Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& empty_out(IntArrayRef size,
|
||||
|
|
@ -337,6 +347,10 @@ Tensor& empty_out(IntArrayRef size,
|
|||
} else {
|
||||
result.resize_(size);
|
||||
}
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <ATen/native/TensorFactories.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -51,7 +52,12 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
|
|||
}
|
||||
|
||||
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor _efficientzerotensor_cuda(IntArrayRef size,
|
||||
|
|
@ -72,7 +78,12 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,
|
|||
|
||||
|
||||
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
|
|
@ -1342,6 +1342,34 @@ else:
|
|||
else:
|
||||
self.assertEqual(old_tensor, new_tensor)
|
||||
|
||||
# When deterministic algorithms are enabled, `torch.empty` should fill floating
|
||||
# point tensors with NaN and integer tensors with MAX_INT
|
||||
@skipXLA
|
||||
@skipIfTorchInductor("aot-autograd issue")
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
def test_deterministic_empty(self, device, dtype):
|
||||
gen_fns = [
|
||||
lambda: torch.empty(10, 9, device=device, dtype=dtype),
|
||||
lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)),
|
||||
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)),
|
||||
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format),
|
||||
lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype),
|
||||
lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype),
|
||||
]
|
||||
|
||||
for gen_fn in gen_fns:
|
||||
with DeterministicGuard(True):
|
||||
res = gen_fn()
|
||||
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
self.assertTrue(res.isnan().all())
|
||||
else:
|
||||
if dtype == torch.bool:
|
||||
max_val = True
|
||||
else:
|
||||
max_val = torch.iinfo(dtype).max
|
||||
self.assertTrue(res.eq(max_val).all())
|
||||
|
||||
# FIXME: update OpInfos to support "nondeterministic samples" and port these tests
|
||||
# to that architecture
|
||||
@skipIfMps
|
||||
|
|
|
|||
|
|
@ -704,6 +704,10 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
|
|||
quantized, sets new elements to a known value. Floating point or
|
||||
complex values are set to NaN. Integer values are set to the maximum
|
||||
value.
|
||||
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
|
||||
and :func:`torch.empty_permuted` will fill the output tensor with a known
|
||||
value. Floating point or complex dtype tensors are filled with NaN. Integer
|
||||
dtype tensors are filled with the maximum value.
|
||||
|
||||
The following normally-nondeterministic operations will throw a
|
||||
:class:`RuntimeError` when ``mode=True``:
|
||||
|
|
|
|||
|
|
@ -12277,6 +12277,13 @@ memory_format=torch.contiguous_format) -> Tensor
|
|||
Returns a tensor filled with uninitialized data. The shape of the tensor is
|
||||
defined by the variable argument :attr:`size`.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
size (int...): a sequence of integers defining the shape of the output tensor.
|
||||
Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
|
@ -12309,6 +12316,13 @@ Returns an uninitialized tensor with the same size as :attr:`input`.
|
|||
``torch.empty_like(input)`` is equivalent to
|
||||
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
|
|
@ -12341,6 +12355,13 @@ Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled w
|
|||
If the constructed tensor is "overlapped" (with multiple indices referring to the same element
|
||||
in memory) its behavior is undefined.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
size (tuple of int): the shape of the output tensor
|
||||
stride (tuple of int): the strides of the output tensor
|
||||
|
|
@ -12386,6 +12407,13 @@ Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense
|
|||
tensor with no overlaps. If possible, prefer using this function over
|
||||
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
size (tuple of int): the shape of the output tensor
|
||||
physical_layout (tuple of int): the ordering of dimensions physically in memory
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user