mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add torch.utils.deterministic.fill_uninitialized_memory flag (#111377)
Part of #109802 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111377 Approved by: https://github.com/albanD, https://github.com/aaronenyeshi
This commit is contained in:
parent
cce5016653
commit
fd209543d5
|
|
@ -65,6 +65,14 @@ void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
|
||||||
_deterministic_algorithms_warn_only = warn_only;
|
_deterministic_algorithms_warn_only = warn_only;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Context::deterministicFillUninitializedMemory() const {
|
||||||
|
return _deterministic_fill_uninitialized_memory;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::setDeterministicFillUninitializedMemory(bool b) {
|
||||||
|
_deterministic_fill_uninitialized_memory = b;
|
||||||
|
}
|
||||||
|
|
||||||
void Context::alertNotDeterministic(c10::string_view const& caller) {
|
void Context::alertNotDeterministic(c10::string_view const& caller) {
|
||||||
if (globalContext().deterministicAlgorithms()) {
|
if (globalContext().deterministicAlgorithms()) {
|
||||||
if (globalContext().deterministicAlgorithmsWarnOnly()) {
|
if (globalContext().deterministicAlgorithmsWarnOnly()) {
|
||||||
|
|
|
||||||
|
|
@ -205,6 +205,8 @@ class TORCH_API Context {
|
||||||
bool deterministicAlgorithms() const;
|
bool deterministicAlgorithms() const;
|
||||||
bool deterministicAlgorithmsWarnOnly() const;
|
bool deterministicAlgorithmsWarnOnly() const;
|
||||||
void setDeterministicAlgorithms(bool, bool);
|
void setDeterministicAlgorithms(bool, bool);
|
||||||
|
bool deterministicFillUninitializedMemory() const;
|
||||||
|
void setDeterministicFillUninitializedMemory(bool);
|
||||||
|
|
||||||
// Note [Writing Nondeterministic Operations]
|
// Note [Writing Nondeterministic Operations]
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
@ -301,6 +303,7 @@ class TORCH_API Context {
|
||||||
bool deterministic_cudnn = false;
|
bool deterministic_cudnn = false;
|
||||||
bool _deterministic_algorithms = false;
|
bool _deterministic_algorithms = false;
|
||||||
bool _deterministic_algorithms_warn_only = false;
|
bool _deterministic_algorithms_warn_only = false;
|
||||||
|
bool _deterministic_fill_uninitialized_memory = true;
|
||||||
bool enabled_flashSDP = true;
|
bool enabled_flashSDP = true;
|
||||||
bool enabled_mem_efficientSDP = true;
|
bool enabled_mem_efficientSDP = true;
|
||||||
bool enabled_mathSDP = true;
|
bool enabled_mathSDP = true;
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ TensorBase empty_mps(
|
||||||
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
|
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
|
||||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
at::native::fill_empty_deterministic_(tensor);
|
at::native::fill_empty_deterministic_(tensor);
|
||||||
}
|
}
|
||||||
return tensor;
|
return tensor;
|
||||||
|
|
@ -107,7 +107,7 @@ TensorBase empty_strided_mps(
|
||||||
Tensor result = at::detail::empty_strided_generic(
|
Tensor result = at::detail::empty_strided_generic(
|
||||||
size, stride, allocator, mps_dks, dtype);
|
size, stride, allocator, mps_dks, dtype);
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
at::native::fill_empty_deterministic_(result);
|
at::native::fill_empty_deterministic_(result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
||||||
|
|
@ -252,7 +252,7 @@ const Tensor& _resize_(
|
||||||
self_->empty_tensor_restride(memory_format);
|
self_->empty_tensor_restride(memory_format);
|
||||||
}
|
}
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
|
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
|
|
|
||||||
|
|
@ -255,7 +255,7 @@ Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::opt
|
||||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||||
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
fill_empty_deterministic_(result);
|
fill_empty_deterministic_(result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -327,7 +327,7 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
|
||||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||||
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
fill_empty_deterministic_(result);
|
fill_empty_deterministic_(result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -348,7 +348,7 @@ Tensor& empty_out(IntArrayRef size,
|
||||||
result.resize_(size);
|
result.resize_(size);
|
||||||
}
|
}
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
fill_empty_deterministic_(result);
|
fill_empty_deterministic_(result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ const Tensor& resize_cuda_(
|
||||||
self_->empty_tensor_restride(memory_format);
|
self_->empty_tensor_restride(memory_format);
|
||||||
}
|
}
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
|
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
|
||||||
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||||
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
fill_empty_deterministic_(result);
|
fill_empty_deterministic_(result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -80,7 +80,7 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,
|
||||||
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||||
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
fill_empty_deterministic_(result);
|
fill_empty_deterministic_(result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ const Tensor& resize_mps_(
|
||||||
self_->empty_tensor_restride(memory_format);
|
self_->empty_tensor_restride(memory_format);
|
||||||
}
|
}
|
||||||
// See Note [Enabling Deterministic Operations]
|
// See Note [Enabling Deterministic Operations]
|
||||||
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
|
||||||
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
|
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
|
|
|
||||||
|
|
@ -392,4 +392,4 @@ regardless of whether autocast is enabled.
|
||||||
.. py:module:: torch.cpu.amp.autocast_mode
|
.. py:module:: torch.cpu.amp.autocast_mode
|
||||||
.. py:module:: torch.cuda.amp.autocast_mode
|
.. py:module:: torch.cuda.amp.autocast_mode
|
||||||
.. py:module:: torch.cuda.amp.common
|
.. py:module:: torch.cuda.amp.common
|
||||||
.. py:module:: torch.cuda.amp.grad_scaler
|
.. py:module:: torch.cuda.amp.grad_scaler
|
||||||
|
|
|
||||||
28
docs/source/deterministic.rst
Normal file
28
docs/source/deterministic.rst
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
torch.utils.deterministic
|
||||||
|
=========================
|
||||||
|
.. py:module:: torch.utils.deterministic
|
||||||
|
.. currentmodule:: torch.utils.deterministic
|
||||||
|
|
||||||
|
.. attribute:: fill_uninitialized_memory
|
||||||
|
|
||||||
|
A :class:`bool` that, if True, causes uninitialized memory to be filled with
|
||||||
|
a known value when :meth:`torch.use_deterministic_algorithms()` is set to
|
||||||
|
``True``. Floating point and complex values are set to NaN, and integer
|
||||||
|
values are set to the maximum value.
|
||||||
|
|
||||||
|
Default: ``True``
|
||||||
|
|
||||||
|
Filling uninitialized memory is detrimental to performance. So if your
|
||||||
|
program is valid and does not use uninitialized memory as the input to an
|
||||||
|
operation, then this setting can be turned off for better performance and
|
||||||
|
still be deterministic.
|
||||||
|
|
||||||
|
The following operations will fill uninitialized memory when this setting is
|
||||||
|
turned on:
|
||||||
|
|
||||||
|
* :func:`torch.Tensor.resize_` when called with a tensor that is not
|
||||||
|
quantized
|
||||||
|
* :func:`torch.empty`
|
||||||
|
* :func:`torch.empty_strided`
|
||||||
|
* :func:`torch.empty_permuted`
|
||||||
|
* :func:`torch.empty_like`
|
||||||
|
|
@ -112,6 +112,7 @@ Features described in this documentation are classified by release status:
|
||||||
torch.utils.checkpoint <checkpoint>
|
torch.utils.checkpoint <checkpoint>
|
||||||
torch.utils.cpp_extension <cpp_extension>
|
torch.utils.cpp_extension <cpp_extension>
|
||||||
torch.utils.data <data>
|
torch.utils.data <data>
|
||||||
|
torch.utils.deterministic <deterministic>
|
||||||
torch.utils.jit <jit_utils>
|
torch.utils.jit <jit_utils>
|
||||||
torch.utils.dlpack <dlpack>
|
torch.utils.dlpack <dlpack>
|
||||||
torch.utils.mobile_optimizer <mobile_optimizer>
|
torch.utils.mobile_optimizer <mobile_optimizer>
|
||||||
|
|
|
||||||
|
|
@ -144,6 +144,22 @@ CUDA RNN and LSTM
|
||||||
In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
|
In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
|
||||||
See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds.
|
See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds.
|
||||||
|
|
||||||
|
Filling uninitialized memory
|
||||||
|
----------------------------
|
||||||
|
Operations like :meth:`torch.empty` and :meth:`torch.Tensor.resize_` can return
|
||||||
|
tensors with uninitialized memory that contain undefined values. Using such a
|
||||||
|
tensor as an input to another operation is invalid if determinism is required,
|
||||||
|
because the output will be nondeterministic. But there is nothing to actually
|
||||||
|
prevent such invalid code from being run. So for safety,
|
||||||
|
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is set to ``True``
|
||||||
|
by default, which will fill the uninitialized memory with a known value if
|
||||||
|
:code:`torch.use_deterministic_algorithms(True)` is set. This will to prevent
|
||||||
|
the possibility of this kind of nondeterministic behavior.
|
||||||
|
|
||||||
|
However, filling uninitialized memory is detrimental to performance. So if your
|
||||||
|
program is valid and does not use uninitialized memory as the input to an
|
||||||
|
operation, then this setting can be turned off for better performance.
|
||||||
|
|
||||||
DataLoader
|
DataLoader
|
||||||
..........
|
..........
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1292,7 +1292,7 @@ else:
|
||||||
else:
|
else:
|
||||||
a = torch.empty_strided(size, stride, dtype=dtype, device=device).fill_(0)
|
a = torch.empty_strided(size, stride, dtype=dtype, device=device).fill_(0)
|
||||||
old_storage = a.untyped_storage().clone()
|
old_storage = a.untyped_storage().clone()
|
||||||
with DeterministicGuard(True):
|
with DeterministicGuard(True, fill_uninitialized_memory=True):
|
||||||
a.resize_(resize_size)
|
a.resize_(resize_size)
|
||||||
|
|
||||||
new_storage = a.untyped_storage()
|
new_storage = a.untyped_storage()
|
||||||
|
|
@ -1336,7 +1336,7 @@ else:
|
||||||
]
|
]
|
||||||
|
|
||||||
for gen_fn in gen_fns:
|
for gen_fn in gen_fns:
|
||||||
with DeterministicGuard(True):
|
with DeterministicGuard(True, fill_uninitialized_memory=True):
|
||||||
res = gen_fn()
|
res = gen_fn()
|
||||||
|
|
||||||
if dtype.is_floating_point or dtype.is_complex:
|
if dtype.is_floating_point or dtype.is_complex:
|
||||||
|
|
@ -8689,6 +8689,38 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||||
r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
|
r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
|
||||||
torch.use_deterministic_algorithms(False, warn_only=1)
|
torch.use_deterministic_algorithms(False, warn_only=1)
|
||||||
|
|
||||||
|
# Tests that torch.utils.deterministic.fill_uninitialized_memory can be set as expected
|
||||||
|
def test_deterministic_fill_uninitialized_memory(self):
|
||||||
|
with DeterministicGuard(True, fill_uninitialized_memory=False):
|
||||||
|
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
|
||||||
|
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
|
||||||
|
|
||||||
|
with DeterministicGuard(True, fill_uninitialized_memory=True):
|
||||||
|
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
|
||||||
|
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
|
||||||
|
|
||||||
|
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
|
||||||
|
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
|
||||||
|
|
||||||
|
torch.utils.deterministic.fill_uninitialized_memory = False
|
||||||
|
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
|
||||||
|
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
|
||||||
|
|
||||||
|
torch.utils.deterministic.fill_uninitialized_memory = True
|
||||||
|
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
|
||||||
|
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
|
||||||
|
|
||||||
|
torch._C._set_deterministic_fill_uninitialized_memory(False)
|
||||||
|
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
|
||||||
|
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
|
||||||
|
|
||||||
|
torch._C._set_deterministic_fill_uninitialized_memory(True)
|
||||||
|
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
|
||||||
|
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, r"expected a bool, but got int"):
|
||||||
|
torch.utils.deterministic.fill_uninitialized_memory = 1
|
||||||
|
|
||||||
def test_type_conversion_via_dtype_name(self):
|
def test_type_conversion_via_dtype_name(self):
|
||||||
x = torch.tensor([1])
|
x = torch.tensor([1])
|
||||||
self.assertEqual(x.byte().dtype, torch.uint8)
|
self.assertEqual(x.byte().dtype, torch.uint8)
|
||||||
|
|
|
||||||
|
|
@ -1098,6 +1098,8 @@ def _set_deterministic_algorithms(
|
||||||
*,
|
*,
|
||||||
warn_only: _bool = ...,
|
warn_only: _bool = ...,
|
||||||
) -> None: ... # THPModule_setDeterministicAlgorithms
|
) -> None: ... # THPModule_setDeterministicAlgorithms
|
||||||
|
def _get_deterministic_fill_uninitialized_memory() -> _bool: ... # THPModule_deterministicFillUninitializedMemory
|
||||||
|
def _set_deterministic_fill_uninitialized_memory(arg: _bool) -> None: ... # THPModule_setDeterministicFillUninitializedMemory
|
||||||
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
|
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
|
||||||
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
|
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
|
||||||
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
|
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
|
||||||
|
|
|
||||||
|
|
@ -729,14 +729,6 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
|
||||||
* :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
|
* :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
|
||||||
* :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
|
* :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
|
||||||
* :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
|
* :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
|
||||||
* :func:`torch.Tensor.resize_`, when called with a tensor that is not
|
|
||||||
quantized, sets new elements to a known value. Floating point or
|
|
||||||
complex values are set to NaN. Integer values are set to the maximum
|
|
||||||
value.
|
|
||||||
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
|
|
||||||
and :func:`torch.empty_permuted` will fill the output tensor with a known
|
|
||||||
value. Floating point or complex dtype tensors are filled with NaN. Integer
|
|
||||||
dtype tensors are filled with the maximum value.
|
|
||||||
|
|
||||||
The following normally-nondeterministic operations will throw a
|
The following normally-nondeterministic operations will throw a
|
||||||
:class:`RuntimeError` when ``mode=True``:
|
:class:`RuntimeError` when ``mode=True``:
|
||||||
|
|
@ -781,6 +773,11 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
|
||||||
* :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
|
* :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
|
||||||
* :func:`torch.Tensor.resize_` when called with a quantized tensor
|
* :func:`torch.Tensor.resize_` when called with a quantized tensor
|
||||||
|
|
||||||
|
In addition, several operations fill uninitialized memory when this setting
|
||||||
|
is turned on and when
|
||||||
|
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
|
||||||
|
See the documentation for that attribute for more information.
|
||||||
|
|
||||||
A handful of CUDA operations are nondeterministic if the CUDA version is
|
A handful of CUDA operations are nondeterministic if the CUDA version is
|
||||||
10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
|
10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
|
||||||
or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more
|
or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more
|
||||||
|
|
|
||||||
|
|
@ -4217,10 +4217,12 @@ memory is uninitialized.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, new
|
If :func:`torch.use_deterministic_algorithms()` and
|
||||||
elements are initialized to prevent nondeterministic behavior from using
|
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
|
||||||
the result as an input to an operation. Floating point and complex values
|
``True``, new elements are initialized to prevent nondeterministic behavior
|
||||||
are set to NaN, and integer values are set to the maximum value.
|
from using the result as an input to an operation. Floating point and
|
||||||
|
complex values are set to NaN, and integer values are set to the maximum
|
||||||
|
value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sizes (torch.Size or int...): the desired size
|
sizes (torch.Size or int...): the desired size
|
||||||
|
|
|
||||||
|
|
@ -12335,11 +12335,12 @@ Returns a tensor filled with uninitialized data. The shape of the tensor is
|
||||||
defined by the variable argument :attr:`size`.
|
defined by the variable argument :attr:`size`.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
If :func:`torch.use_deterministic_algorithms()` and
|
||||||
output tensor is initialized to prevent any possible nondeterministic
|
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
|
||||||
behavior from using the data as an input to an operation. Floating point
|
``True``, the output tensor is initialized to prevent any possible
|
||||||
and complex tensors are filled with NaN, and integer tensors are filled
|
nondeterministic behavior from using the data as an input to an operation.
|
||||||
with the maximum value.
|
Floating point and complex tensors are filled with NaN, and integer tensors
|
||||||
|
are filled with the maximum value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
size (int...): a sequence of integers defining the shape of the output tensor.
|
size (int...): a sequence of integers defining the shape of the output tensor.
|
||||||
|
|
@ -12374,11 +12375,12 @@ Returns an uninitialized tensor with the same size as :attr:`input`.
|
||||||
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
|
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
If :func:`torch.use_deterministic_algorithms()` and
|
||||||
output tensor is initialized to prevent any possible nondeterministic
|
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
|
||||||
behavior from using the data as an input to an operation. Floating point
|
``True``, the output tensor is initialized to prevent any possible
|
||||||
and complex tensors are filled with NaN, and integer tensors are filled
|
nondeterministic behavior from using the data as an input to an operation.
|
||||||
with the maximum value.
|
Floating point and complex tensors are filled with NaN, and integer tensors
|
||||||
|
are filled with the maximum value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
{input}
|
{input}
|
||||||
|
|
@ -12413,11 +12415,12 @@ Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled w
|
||||||
in memory) its behavior is undefined.
|
in memory) its behavior is undefined.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
If :func:`torch.use_deterministic_algorithms()` and
|
||||||
output tensor is initialized to prevent any possible nondeterministic
|
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
|
||||||
behavior from using the data as an input to an operation. Floating point
|
``True``, the output tensor is initialized to prevent any possible
|
||||||
and complex tensors are filled with NaN, and integer tensors are filled
|
nondeterministic behavior from using the data as an input to an operation.
|
||||||
with the maximum value.
|
Floating point and complex tensors are filled with NaN, and integer tensors
|
||||||
|
are filled with the maximum value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
size (tuple of int): the shape of the output tensor
|
size (tuple of int): the shape of the output tensor
|
||||||
|
|
@ -12465,11 +12468,12 @@ tensor with no overlaps. If possible, prefer using this function over
|
||||||
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
|
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
|
If :func:`torch.use_deterministic_algorithms()` and
|
||||||
output tensor is initialized to prevent any possible nondeterministic
|
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
|
||||||
behavior from using the data as an input to an operation. Floating point
|
``True``, the output tensor is initialized to prevent any possible
|
||||||
and complex tensors are filled with NaN, and integer tensors are filled
|
nondeterministic behavior from using the data as an input to an operation.
|
||||||
with the maximum value.
|
Floating point and complex tensors are filled with NaN, and integer tensors
|
||||||
|
are filled with the maximum value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
size (tuple of int): the shape of the output tensor
|
size (tuple of int): the shape of the output tensor
|
||||||
|
|
|
||||||
|
|
@ -685,6 +685,26 @@ PyObject* THPModule_deterministicAlgorithmsWarnOnly(
|
||||||
Py_RETURN_FALSE;
|
Py_RETURN_FALSE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* THPModule_setDeterministicFillUninitializedMemory(
|
||||||
|
PyObject* _unused,
|
||||||
|
PyObject* arg) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
THPUtils_assert(
|
||||||
|
PyBool_Check(arg), "expected a bool, but got %s", THPUtils_typename(arg));
|
||||||
|
at::globalContext().setDeterministicFillUninitializedMemory(arg == Py_True);
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* THPModule_deterministicFillUninitializedMemory(
|
||||||
|
PyObject* _unused,
|
||||||
|
PyObject* noargs) {
|
||||||
|
if (at::globalContext().deterministicFillUninitializedMemory())
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
else
|
||||||
|
Py_RETURN_FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) {
|
PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) {
|
||||||
THPUtils_assert(
|
THPUtils_assert(
|
||||||
PyBool_Check(arg),
|
PyBool_Check(arg),
|
||||||
|
|
@ -1118,6 +1138,14 @@ static PyMethodDef TorchMethods[] = { // NOLINT
|
||||||
castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms),
|
castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms),
|
||||||
METH_VARARGS | METH_KEYWORDS,
|
METH_VARARGS | METH_KEYWORDS,
|
||||||
nullptr},
|
nullptr},
|
||||||
|
{"_get_deterministic_fill_uninitialized_memory",
|
||||||
|
THPModule_deterministicFillUninitializedMemory,
|
||||||
|
METH_NOARGS,
|
||||||
|
nullptr},
|
||||||
|
{"_set_deterministic_fill_uninitialized_memory",
|
||||||
|
THPModule_setDeterministicFillUninitializedMemory,
|
||||||
|
METH_O,
|
||||||
|
nullptr},
|
||||||
{"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
|
{"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
|
||||||
{"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
|
{"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
|
||||||
{"_warn", THPModule_warn, METH_NOARGS, nullptr},
|
{"_warn", THPModule_warn, METH_NOARGS, nullptr},
|
||||||
|
|
|
||||||
|
|
@ -1458,21 +1458,25 @@ def setLinalgBackendsToDefaultFinally(fn):
|
||||||
# Context manager for setting deterministic flag and automatically
|
# Context manager for setting deterministic flag and automatically
|
||||||
# resetting it to its original value
|
# resetting it to its original value
|
||||||
class DeterministicGuard:
|
class DeterministicGuard:
|
||||||
def __init__(self, deterministic, *, warn_only=False):
|
def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True):
|
||||||
self.deterministic = deterministic
|
self.deterministic = deterministic
|
||||||
self.warn_only = warn_only
|
self.warn_only = warn_only
|
||||||
|
self.fill_uninitialized_memory = fill_uninitialized_memory
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
|
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
|
||||||
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
|
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
|
||||||
|
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory
|
||||||
torch.use_deterministic_algorithms(
|
torch.use_deterministic_algorithms(
|
||||||
self.deterministic,
|
self.deterministic,
|
||||||
warn_only=self.warn_only)
|
warn_only=self.warn_only)
|
||||||
|
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory
|
||||||
|
|
||||||
def __exit__(self, exception_type, exception_value, traceback):
|
def __exit__(self, exception_type, exception_value, traceback):
|
||||||
torch.use_deterministic_algorithms(
|
torch.use_deterministic_algorithms(
|
||||||
self.deterministic_restore,
|
self.deterministic_restore,
|
||||||
warn_only=self.warn_only_restore)
|
warn_only=self.warn_only_restore)
|
||||||
|
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore
|
||||||
|
|
||||||
class AlwaysWarnTypedStorageRemoval:
|
class AlwaysWarnTypedStorageRemoval:
|
||||||
def __init__(self, always_warn):
|
def __init__(self, always_warn):
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
from .throughput_benchmark import ThroughputBenchmark
|
from .throughput_benchmark import ThroughputBenchmark
|
||||||
from .cpp_backtrace import get_cpp_backtrace
|
from .cpp_backtrace import get_cpp_backtrace
|
||||||
from .backend_registration import rename_privateuse1_backend, generate_methods_for_privateuse1_backend
|
from .backend_registration import rename_privateuse1_backend, generate_methods_for_privateuse1_backend
|
||||||
|
from . import deterministic
|
||||||
from . import collect_env
|
from . import collect_env
|
||||||
|
|
||||||
def set_module(obj, mod):
|
def set_module(obj, mod):
|
||||||
|
|
|
||||||
21
torch/utils/deterministic.py
Normal file
21
torch/utils/deterministic.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class _Deterministic(types.ModuleType):
|
||||||
|
@property
|
||||||
|
def fill_uninitialized_memory(self):
|
||||||
|
"""
|
||||||
|
Whether to fill uninitialized memory with a known value when
|
||||||
|
:meth:`torch.use_deterministic_algorithms()` is set to ``True``.
|
||||||
|
"""
|
||||||
|
return torch._C._get_deterministic_fill_uninitialized_memory()
|
||||||
|
|
||||||
|
@fill_uninitialized_memory.setter
|
||||||
|
def fill_uninitialized_memory(self, mode):
|
||||||
|
return torch._C._set_deterministic_fill_uninitialized_memory(mode)
|
||||||
|
|
||||||
|
|
||||||
|
sys.modules[__name__].__class__ = _Deterministic
|
||||||
Loading…
Reference in New Issue
Block a user