mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make torch.empty* deterministic by filling with NaN or max int value (#101849)
Part of #82004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101849 Approved by: https://github.com/lezcano, https://github.com/albanD, https://github.com/kulinseth
This commit is contained in:
parent
d8352312f9
commit
2642f31e4c
|
|
@ -7,6 +7,7 @@
|
|||
#include <torch/library.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/native/mps/Copy.h>
|
||||
|
||||
#define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled"
|
||||
|
|
@ -63,6 +64,10 @@ TensorBase empty_mps(
|
|||
|
||||
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
|
||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
at::native::fill_empty_deterministic_(tensor);
|
||||
}
|
||||
return tensor;
|
||||
} else {
|
||||
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
|
||||
|
|
@ -100,8 +105,13 @@ TensorBase empty_strided_mps(
|
|||
const DeviceGuard device_guard(device);
|
||||
auto* allocator = at::mps::GetMPSAllocator();
|
||||
constexpr c10::DispatchKeySet mps_dks(c10::DispatchKey::MPS);
|
||||
return at::detail::empty_strided_generic(
|
||||
Tensor result = at::detail::empty_strided_generic(
|
||||
size, stride, allocator, mps_dks, dtype);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
at::native::fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
TORCH_CHECK(false, MPS_ERROR_RUNTIME_TOO_LOW)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -253,7 +253,12 @@ Tensor polar(const Tensor& abs, const Tensor& angle) {
|
|||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor empty_names(
|
||||
|
|
@ -320,7 +325,12 @@ Tensor empty_permuted_symint(SymIntArrayRef size, IntArrayRef physical_layout, c
|
|||
|
||||
Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& empty_out(IntArrayRef size,
|
||||
|
|
@ -337,6 +347,10 @@ Tensor& empty_out(IntArrayRef size,
|
|||
} else {
|
||||
result.resize_(size);
|
||||
}
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
|
@ -96,6 +97,24 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
|
|||
}
|
||||
}
|
||||
|
||||
// Called by `empty*` functions when deterministic algorithms are enabled to
|
||||
// fill the tensor with NaN if it is floating point or complex type, or fill
|
||||
// with max value if it is integer type
|
||||
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
|
||||
if (tensor.is_floating_point() || tensor.is_complex()) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
|
||||
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_INTEGRAL_TYPES_AND(
|
||||
kBool, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
|
||||
tensor.fill_(std::numeric_limits<scalar_t>::max());
|
||||
});
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// The ZeroTensor allocator ignores whatever allocation is requested and always
|
||||
// gives you nullptr
|
||||
struct ZeroTensorAllocator final : public at::Allocator {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <ATen/native/TensorFactories.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -51,7 +52,12 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
|
|||
}
|
||||
|
||||
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor _efficientzerotensor_cuda(IntArrayRef size,
|
||||
|
|
@ -72,7 +78,12 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,
|
|||
|
||||
|
||||
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
||||
fill_empty_deterministic_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
|
|
@ -1276,6 +1276,34 @@ else:
|
|||
f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n'
|
||||
+ e.output.decode("utf-8")))
|
||||
|
||||
# When deterministic algorithms are enabled, `torch.empty` should fill floating
|
||||
# point tensors with NaN and integer tensors with MAX_INT
|
||||
@skipXLA
|
||||
@skipIfTorchInductor("aot-autograd issue")
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
def test_deterministic_empty(self, device, dtype):
|
||||
gen_fns = [
|
||||
lambda: torch.empty(10, 9, device=device, dtype=dtype),
|
||||
lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)),
|
||||
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)),
|
||||
lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format),
|
||||
lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype),
|
||||
lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype),
|
||||
]
|
||||
|
||||
for gen_fn in gen_fns:
|
||||
with DeterministicGuard(True):
|
||||
res = gen_fn()
|
||||
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
self.assertTrue(res.isnan().all())
|
||||
else:
|
||||
if dtype == torch.bool:
|
||||
max_val = True
|
||||
else:
|
||||
max_val = torch.iinfo(dtype).max
|
||||
self.assertTrue(res.eq(max_val).all())
|
||||
|
||||
# FIXME: update OpInfos to support "nondeterministic samples" and port these tests
|
||||
# to that architecture
|
||||
@skipIfMps
|
||||
|
|
|
|||
|
|
@ -700,6 +700,10 @@ def use_deterministic_algorithms(mode, *, warn_only=False):
|
|||
* :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
|
||||
* :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
|
||||
* :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
|
||||
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
|
||||
and :func:`torch.empty_permuted` will fill the output tensor with a known
|
||||
value. Floating point or complex dtype tensors are filled with NaN. Integer
|
||||
dtype tensors are filled with the maximum value.
|
||||
|
||||
The following normally-nondeterministic operations will throw a
|
||||
:class:`RuntimeError` when ``mode=True``:
|
||||
|
|
|
|||
|
|
@ -12274,6 +12274,13 @@ memory_format=torch.contiguous_format) -> Tensor
|
|||
Returns a tensor filled with uninitialized data. The shape of the tensor is
|
||||
defined by the variable argument :attr:`size`.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
size (int...): a sequence of integers defining the shape of the output tensor.
|
||||
Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
|
@ -12306,6 +12313,13 @@ Returns an uninitialized tensor with the same size as :attr:`input`.
|
|||
``torch.empty_like(input)`` is equivalent to
|
||||
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
|
|
@ -12338,6 +12352,13 @@ Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled w
|
|||
If the constructed tensor is "overlapped" (with multiple indices referring to the same element
|
||||
in memory) its behavior is undefined.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
size (tuple of int): the shape of the output tensor
|
||||
stride (tuple of int): the strides of the output tensor
|
||||
|
|
@ -12383,6 +12404,13 @@ Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense
|
|||
tensor with no overlaps. If possible, prefer using this function over
|
||||
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
|
||||
|
||||
.. note::
|
||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
||||
output tensor is initialized to prevent any possible nondeterministic
|
||||
behavior from using the data as an input to an operation. Floating point
|
||||
and complex tensors are filled with NaN, and integer tensors are filled
|
||||
with the maximum value.
|
||||
|
||||
Args:
|
||||
size (tuple of int): the shape of the output tensor
|
||||
physical_layout (tuple of int): the ordering of dimensions physically in memory
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user