Add torch.utils.deterministic.fill_uninitialized_memory flag (#111377)

Part of #109802

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111377
Approved by: https://github.com/albanD, https://github.com/aaronenyeshi
This commit is contained in:
Kurt Mohler 2023-11-01 16:10:09 +00:00 committed by PyTorch MergeBot
parent cce5016653
commit fd209543d5
21 changed files with 193 additions and 46 deletions

View File

@ -65,6 +65,14 @@ void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
_deterministic_algorithms_warn_only = warn_only;
}
bool Context::deterministicFillUninitializedMemory() const {
return _deterministic_fill_uninitialized_memory;
}
void Context::setDeterministicFillUninitializedMemory(bool b) {
_deterministic_fill_uninitialized_memory = b;
}
void Context::alertNotDeterministic(c10::string_view const& caller) {
if (globalContext().deterministicAlgorithms()) {
if (globalContext().deterministicAlgorithmsWarnOnly()) {

View File

@ -205,6 +205,8 @@ class TORCH_API Context {
bool deterministicAlgorithms() const;
bool deterministicAlgorithmsWarnOnly() const;
void setDeterministicAlgorithms(bool, bool);
bool deterministicFillUninitializedMemory() const;
void setDeterministicFillUninitializedMemory(bool);
// Note [Writing Nondeterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -301,6 +303,7 @@ class TORCH_API Context {
bool deterministic_cudnn = false;
bool _deterministic_algorithms = false;
bool _deterministic_algorithms_warn_only = false;
bool _deterministic_fill_uninitialized_memory = true;
bool enabled_flashSDP = true;
bool enabled_mem_efficientSDP = true;
bool enabled_mathSDP = true;

View File

@ -64,7 +64,7 @@ TensorBase empty_mps(
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
at::native::fill_empty_deterministic_(tensor);
}
return tensor;
@ -107,7 +107,7 @@ TensorBase empty_strided_mps(
Tensor result = at::detail::empty_strided_generic(
size, stride, allocator, mps_dks, dtype);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
at::native::fill_empty_deterministic_(result);
}
return result;

View File

@ -252,7 +252,7 @@ const Tensor& _resize_(
self_->empty_tensor_restride(memory_format);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
}
return self;

View File

@ -255,7 +255,7 @@ Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::opt
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
fill_empty_deterministic_(result);
}
return result;
@ -327,7 +327,7 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
fill_empty_deterministic_(result);
}
return result;
@ -348,7 +348,7 @@ Tensor& empty_out(IntArrayRef size,
result.resize_(size);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
fill_empty_deterministic_(result);
}
return result;

View File

@ -65,7 +65,7 @@ const Tensor& resize_cuda_(
self_->empty_tensor_restride(memory_format);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
}
return self;

View File

@ -54,7 +54,7 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
fill_empty_deterministic_(result);
}
return result;
@ -80,7 +80,7 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
fill_empty_deterministic_(result);
}
return result;

View File

@ -120,7 +120,7 @@ const Tensor& resize_mps_(
self_->empty_tensor_restride(memory_format);
}
// See Note [Enabling Deterministic Operations]
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
}
return self;

View File

@ -392,4 +392,4 @@ regardless of whether autocast is enabled.
.. py:module:: torch.cpu.amp.autocast_mode
.. py:module:: torch.cuda.amp.autocast_mode
.. py:module:: torch.cuda.amp.common
.. py:module:: torch.cuda.amp.grad_scaler
.. py:module:: torch.cuda.amp.grad_scaler

View File

@ -0,0 +1,28 @@
torch.utils.deterministic
=========================
.. py:module:: torch.utils.deterministic
.. currentmodule:: torch.utils.deterministic
.. attribute:: fill_uninitialized_memory
A :class:`bool` that, if True, causes uninitialized memory to be filled with
a known value when :meth:`torch.use_deterministic_algorithms()` is set to
``True``. Floating point and complex values are set to NaN, and integer
values are set to the maximum value.
Default: ``True``
Filling uninitialized memory is detrimental to performance. So if your
program is valid and does not use uninitialized memory as the input to an
operation, then this setting can be turned off for better performance and
still be deterministic.
The following operations will fill uninitialized memory when this setting is
turned on:
* :func:`torch.Tensor.resize_` when called with a tensor that is not
quantized
* :func:`torch.empty`
* :func:`torch.empty_strided`
* :func:`torch.empty_permuted`
* :func:`torch.empty_like`

View File

@ -112,6 +112,7 @@ Features described in this documentation are classified by release status:
torch.utils.checkpoint <checkpoint>
torch.utils.cpp_extension <cpp_extension>
torch.utils.data <data>
torch.utils.deterministic <deterministic>
torch.utils.jit <jit_utils>
torch.utils.dlpack <dlpack>
torch.utils.mobile_optimizer <mobile_optimizer>

View File

@ -144,6 +144,22 @@ CUDA RNN and LSTM
In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds.
Filling uninitialized memory
----------------------------
Operations like :meth:`torch.empty` and :meth:`torch.Tensor.resize_` can return
tensors with uninitialized memory that contain undefined values. Using such a
tensor as an input to another operation is invalid if determinism is required,
because the output will be nondeterministic. But there is nothing to actually
prevent such invalid code from being run. So for safety,
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is set to ``True``
by default, which will fill the uninitialized memory with a known value if
:code:`torch.use_deterministic_algorithms(True)` is set. This will to prevent
the possibility of this kind of nondeterministic behavior.
However, filling uninitialized memory is detrimental to performance. So if your
program is valid and does not use uninitialized memory as the input to an
operation, then this setting can be turned off for better performance.
DataLoader
..........

View File

@ -1292,7 +1292,7 @@ else:
else:
a = torch.empty_strided(size, stride, dtype=dtype, device=device).fill_(0)
old_storage = a.untyped_storage().clone()
with DeterministicGuard(True):
with DeterministicGuard(True, fill_uninitialized_memory=True):
a.resize_(resize_size)
new_storage = a.untyped_storage()
@ -1336,7 +1336,7 @@ else:
]
for gen_fn in gen_fns:
with DeterministicGuard(True):
with DeterministicGuard(True, fill_uninitialized_memory=True):
res = gen_fn()
if dtype.is_floating_point or dtype.is_complex:
@ -8689,6 +8689,38 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
torch.use_deterministic_algorithms(False, warn_only=1)
# Tests that torch.utils.deterministic.fill_uninitialized_memory can be set as expected
def test_deterministic_fill_uninitialized_memory(self):
with DeterministicGuard(True, fill_uninitialized_memory=False):
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
with DeterministicGuard(True, fill_uninitialized_memory=True):
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
torch.utils.deterministic.fill_uninitialized_memory = False
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
torch.utils.deterministic.fill_uninitialized_memory = True
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
torch._C._set_deterministic_fill_uninitialized_memory(False)
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
torch._C._set_deterministic_fill_uninitialized_memory(True)
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
with self.assertRaisesRegex(RuntimeError, r"expected a bool, but got int"):
torch.utils.deterministic.fill_uninitialized_memory = 1
def test_type_conversion_via_dtype_name(self):
x = torch.tensor([1])
self.assertEqual(x.byte().dtype, torch.uint8)

View File

@ -1098,6 +1098,8 @@ def _set_deterministic_algorithms(
*,
warn_only: _bool = ...,
) -> None: ... # THPModule_setDeterministicAlgorithms
def _get_deterministic_fill_uninitialized_memory() -> _bool: ... # THPModule_deterministicFillUninitializedMemory
def _set_deterministic_fill_uninitialized_memory(arg: _bool) -> None: ... # THPModule_setDeterministicFillUninitializedMemory
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN

View File

@ -729,14 +729,6 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
* :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
* :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
* :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
* :func:`torch.Tensor.resize_`, when called with a tensor that is not
quantized, sets new elements to a known value. Floating point or
complex values are set to NaN. Integer values are set to the maximum
value.
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
and :func:`torch.empty_permuted` will fill the output tensor with a known
value. Floating point or complex dtype tensors are filled with NaN. Integer
dtype tensors are filled with the maximum value.
The following normally-nondeterministic operations will throw a
:class:`RuntimeError` when ``mode=True``:
@ -781,6 +773,11 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
* :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
* :func:`torch.Tensor.resize_` when called with a quantized tensor
In addition, several operations fill uninitialized memory when this setting
is turned on and when
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
See the documentation for that attribute for more information.
A handful of CUDA operations are nondeterministic if the CUDA version is
10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more

View File

@ -4217,10 +4217,12 @@ memory is uninitialized.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, new
elements are initialized to prevent nondeterministic behavior from using
the result as an input to an operation. Floating point and complex values
are set to NaN, and integer values are set to the maximum value.
If :func:`torch.use_deterministic_algorithms()` and
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
``True``, new elements are initialized to prevent nondeterministic behavior
from using the result as an input to an operation. Floating point and
complex values are set to NaN, and integer values are set to the maximum
value.
Args:
sizes (torch.Size or int...): the desired size

View File

@ -12335,11 +12335,12 @@ Returns a tensor filled with uninitialized data. The shape of the tensor is
defined by the variable argument :attr:`size`.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
If :func:`torch.use_deterministic_algorithms()` and
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
``True``, the output tensor is initialized to prevent any possible
nondeterministic behavior from using the data as an input to an operation.
Floating point and complex tensors are filled with NaN, and integer tensors
are filled with the maximum value.
Args:
size (int...): a sequence of integers defining the shape of the output tensor.
@ -12374,11 +12375,12 @@ Returns an uninitialized tensor with the same size as :attr:`input`.
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
If :func:`torch.use_deterministic_algorithms()` and
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
``True``, the output tensor is initialized to prevent any possible
nondeterministic behavior from using the data as an input to an operation.
Floating point and complex tensors are filled with NaN, and integer tensors
are filled with the maximum value.
Args:
{input}
@ -12413,11 +12415,12 @@ Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled w
in memory) its behavior is undefined.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
If :func:`torch.use_deterministic_algorithms()` and
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
``True``, the output tensor is initialized to prevent any possible
nondeterministic behavior from using the data as an input to an operation.
Floating point and complex tensors are filled with NaN, and integer tensors
are filled with the maximum value.
Args:
size (tuple of int): the shape of the output tensor
@ -12465,11 +12468,12 @@ tensor with no overlaps. If possible, prefer using this function over
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
.. note::
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
output tensor is initialized to prevent any possible nondeterministic
behavior from using the data as an input to an operation. Floating point
and complex tensors are filled with NaN, and integer tensors are filled
with the maximum value.
If :func:`torch.use_deterministic_algorithms()` and
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
``True``, the output tensor is initialized to prevent any possible
nondeterministic behavior from using the data as an input to an operation.
Floating point and complex tensors are filled with NaN, and integer tensors
are filled with the maximum value.
Args:
size (tuple of int): the shape of the output tensor

View File

@ -685,6 +685,26 @@ PyObject* THPModule_deterministicAlgorithmsWarnOnly(
Py_RETURN_FALSE;
}
PyObject* THPModule_setDeterministicFillUninitializedMemory(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
THPUtils_assert(
PyBool_Check(arg), "expected a bool, but got %s", THPUtils_typename(arg));
at::globalContext().setDeterministicFillUninitializedMemory(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPModule_deterministicFillUninitializedMemory(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().deterministicFillUninitializedMemory())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
@ -1118,6 +1138,14 @@ static PyMethodDef TorchMethods[] = { // NOLINT
castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_get_deterministic_fill_uninitialized_memory",
THPModule_deterministicFillUninitializedMemory,
METH_NOARGS,
nullptr},
{"_set_deterministic_fill_uninitialized_memory",
THPModule_setDeterministicFillUninitializedMemory,
METH_O,
nullptr},
{"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
{"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
{"_warn", THPModule_warn, METH_NOARGS, nullptr},

View File

@ -1458,21 +1458,25 @@ def setLinalgBackendsToDefaultFinally(fn):
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
def __init__(self, deterministic, *, warn_only=False):
def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True):
self.deterministic = deterministic
self.warn_only = warn_only
self.fill_uninitialized_memory = fill_uninitialized_memory
def __enter__(self):
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory
torch.use_deterministic_algorithms(
self.deterministic,
warn_only=self.warn_only)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory
def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(
self.deterministic_restore,
warn_only=self.warn_only_restore)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore
class AlwaysWarnTypedStorageRemoval:
def __init__(self, always_warn):

View File

@ -4,6 +4,7 @@ import torch
from .throughput_benchmark import ThroughputBenchmark
from .cpp_backtrace import get_cpp_backtrace
from .backend_registration import rename_privateuse1_backend, generate_methods_for_privateuse1_backend
from . import deterministic
from . import collect_env
def set_module(obj, mod):

View File

@ -0,0 +1,21 @@
import sys
import types
import torch
class _Deterministic(types.ModuleType):
@property
def fill_uninitialized_memory(self):
"""
Whether to fill uninitialized memory with a known value when
:meth:`torch.use_deterministic_algorithms()` is set to ``True``.
"""
return torch._C._get_deterministic_fill_uninitialized_memory()
@fill_uninitialized_memory.setter
def fill_uninitialized_memory(self, mode):
return torch._C._set_deterministic_fill_uninitialized_memory(mode)
sys.modules[__name__].__class__ = _Deterministic