Make torch.empty* deterministic by filling with NaN or max int value (#101849)

Part of #82004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101849
Approved by: https://github.com/lezcano, https://github.com/albanD, https://github.com/kulinseth
This commit is contained in:
Kurt Mohler 2023-06-21 02:53:18 +00:00 committed by PyTorch MergeBot
parent d8352312f9
commit 2642f31e4c
7 changed files with 119 additions and 5 deletions

View File

@ -7,6 +7,7 @@
#include <torch/library.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/native/mps/Copy.h>
#define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled"
@ -63,6 +64,10 @@ TensorBase empty_mps(
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_empty_deterministic_(tensor);
}
return tensor;
} else {
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
@ -100,8 +105,13 @@ TensorBase empty_strided_mps(
const DeviceGuard device_guard(device);
auto* allocator = at::mps::GetMPSAllocator();
constexpr c10::DispatchKeySet mps_dks(c10::DispatchKey::MPS);
return at::detail::empty_strided_generic(
Tensor result = at::detail::empty_strided_generic(
size, stride, allocator, mps_dks, dtype);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
at::native::fill_empty_deterministic_(result);
}
return result;
} else {
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
}

View File

@ -253,7 +253,12 @@ Tensor polar(const Tensor& abs, const Tensor& angle) {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}
Tensor empty_names(
@ -320,7 +325,12 @@ Tensor empty_permuted_symint(SymIntArrayRef size, IntArrayRef physical_layout, c
Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}
Tensor& empty_out(IntArrayRef size,
@ -337,6 +347,10 @@ Tensor& empty_out(IntArrayRef size,
} else {
result.resize_(size);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}

View File

@ -3,6 +3,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/EmptyTensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -96,6 +97,24 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
}
}
// Called by `empty*` functions when deterministic algorithms are enabled to
// fill the tensor with NaN if it is floating point or complex type, or fill
// with max value if it is integer type
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
if (tensor.is_floating_point() || tensor.is_complex()) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
});
} else {
AT_DISPATCH_INTEGRAL_TYPES_AND(
kBool, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
tensor.fill_(std::numeric_limits<scalar_t>::max());
});
}
return tensor;
}
// The ZeroTensor allocator ignores whatever allocation is requested and always
// gives you nullptr
struct ZeroTensorAllocator final : public at::Allocator {

View File

@ -9,6 +9,7 @@
#include <ATen/native/TensorFactories.h>
#include <c10/util/accumulate.h>
#include <c10/util/Exception.h>
#include <ATen/native/cuda/Loops.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -51,7 +52,12 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
}
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}
Tensor _efficientzerotensor_cuda(IntArrayRef size,
@ -72,7 +78,12 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
fill_empty_deterministic_(result);
}
return result;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1276,6 +1276,34 @@ else:
f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n'
+ e.output.decode("utf-8")))
# When deterministic algorithms are enabled, `torch.empty` should fill floating
# point tensors with NaN and integer tensors with MAX_INT
@skipXLA
@skipIfTorchInductor("aot-autograd issue")
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_deterministic_empty(self, device, dtype):
gen_fns = [
lambda: torch.empty(10, 9, device=device, dtype=dtype),
lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)),
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)),
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format),
lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype),
lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype),
]
for gen_fn in gen_fns:
with DeterministicGuard(True):
res = gen_fn()
if dtype.is_floating_point or dtype.is_complex:
self.assertTrue(res.isnan().all())
else:
if dtype == torch.bool:
max_val = True
else:
max_val = torch.iinfo(dtype).max
self.assertTrue(res.eq(max_val).all())
# FIXME: update OpInfos to support "nondeterministic samples" and port these tests
# to that architecture
@skipIfMps

View File

@ -700,6 +700,10 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
* :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
* :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
* :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
and :func:`torch.empty_permuted` will fill the output tensor with a known
value. Floating point or complex dtype tensors are filled with NaN. Integer
dtype tensors are filled with the maximum value.
The following normally-nondeterministic operations will throw a
:class:`RuntimeError` when ``mode=True``:

View File

@ -12274,6 +12274,13 @@ memory_format=torch.contiguous_format) -> Tensor
Returns a tensor filled with uninitialized data. The shape of the tensor is
defined by the variable argument :attr:`size`.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
size (int...): a sequence of integers defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple.
@ -12306,6 +12313,13 @@ Returns an uninitialized tensor with the same size as :attr:`input`.
``torch.empty_like(input)`` is equivalent to
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
{input}
@ -12338,6 +12352,13 @@ Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled w
If the constructed tensor is "overlapped" (with multiple indices referring to the same element
in memory) its behavior is undefined.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
size (tuple of int): the shape of the output tensor
stride (tuple of int): the strides of the output tensor
@ -12383,6 +12404,13 @@ Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense
tensor with no overlaps. If possible, prefer using this function over
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
Args:
size (tuple of int): the shape of the output tensor
physical_layout (tuple of int): the ordering of dimensions physically in memory