mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make Tensor.set_ validate storage_offset when sizes/strides are unchanged (#147354)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147354 Approved by: https://github.com/albanD ghstack dependencies: #147352
This commit is contained in:
parent
e64441915f
commit
536bce5a04
|
|
@ -6,6 +6,7 @@
|
||||||
#include <ATen/TensorUtils.h>
|
#include <ATen/TensorUtils.h>
|
||||||
|
|
||||||
#include <c10/core/CPUAllocator.h>
|
#include <c10/core/CPUAllocator.h>
|
||||||
|
#include <c10/core/SymBool.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
|
@ -85,17 +86,28 @@ inline void checkInBoundsForStorage(
|
||||||
T storage_offset,
|
T storage_offset,
|
||||||
const caffe2::TypeMeta& data_type,
|
const caffe2::TypeMeta& data_type,
|
||||||
const Storage& new_storage) {
|
const Storage& new_storage) {
|
||||||
T storage_size_bytes =
|
T storage_size_bytes, storage_size_plus_offset_bytes;
|
||||||
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
if (stride.data()) {
|
||||||
if (storage_size_bytes == 0) {
|
storage_size_bytes =
|
||||||
|
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
||||||
|
storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
|
||||||
|
size, stride, data_type.itemsize(), storage_offset);
|
||||||
|
} else {
|
||||||
|
storage_size_bytes =
|
||||||
|
at::detail::computeStorageNbytesContiguous(size, data_type.itemsize());
|
||||||
|
storage_size_plus_offset_bytes = at::detail::computeStorageNbytesContiguous(
|
||||||
|
size, data_type.itemsize(), storage_offset);
|
||||||
|
}
|
||||||
|
// It's ok to always evaluate to False for this early return for SymInts because
|
||||||
|
// (1) maybe_convert_symint below only installs guard for int64_t case
|
||||||
|
// (2) we check for this condition in the TORCH_MAYBE_SYM_CHECK below
|
||||||
|
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(storage_size_bytes, 0))) {
|
||||||
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
T storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
|
|
||||||
size, stride, data_type.itemsize(), storage_offset);
|
|
||||||
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
||||||
TORCH_CHECK(
|
TORCH_MAYBE_SYM_CHECK(
|
||||||
storage_size_plus_offset_bytes <= new_storage_size_bytes,
|
sym_eq(storage_size_bytes, 0) || sym_le(storage_size_plus_offset_bytes, new_storage_size_bytes),
|
||||||
"setStorage: sizes ",
|
"setStorage: sizes ",
|
||||||
size,
|
size,
|
||||||
", strides ",
|
", strides ",
|
||||||
|
|
@ -113,7 +125,7 @@ inline void checkInBoundsForStorage(
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
||||||
ArrayRef<T> size, ArrayRef<T> stride) {
|
ArrayRef<T> size, ArrayRef<T> stride, bool check_offset_in_bounds = true) {
|
||||||
// FIXME: stride should be optional
|
// FIXME: stride should be optional
|
||||||
if (stride.data()) {
|
if (stride.data()) {
|
||||||
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
|
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
|
||||||
|
|
@ -124,6 +136,28 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
||||||
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
|
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// storageOffset
|
||||||
|
TORCH_CHECK(
|
||||||
|
storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
||||||
|
|
||||||
|
// set_storage_{device} (except set_storage_meta__symint)
|
||||||
|
// will (unsafely) set the storage offset and then call resize_impl that
|
||||||
|
// handles resizing the storage However, resize_impl will only resize the
|
||||||
|
// storage if the sizes/strides changed. For the case that the sizes/strides
|
||||||
|
// remain unchanged, the storage offset is not properly validated, so we do
|
||||||
|
// that here.
|
||||||
|
if (check_offset_in_bounds) {
|
||||||
|
auto result_tensor_impl = result.unsafeGetTensorImpl();
|
||||||
|
bool size_unchanged = result_tensor_impl->generic_sizes<T>() == size;
|
||||||
|
bool stride_unchanged = stride.data()
|
||||||
|
? result_tensor_impl->generic_strides<T>() == stride
|
||||||
|
: true;
|
||||||
|
if (size_unchanged && stride_unchanged) {
|
||||||
|
checkInBoundsForStorage(
|
||||||
|
size, stride, storage_offset, result.dtype(), storage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// storage: note this can't be replaced with result.set_(storage) as the semantics of that
|
// storage: note this can't be replaced with result.set_(storage) as the semantics of that
|
||||||
// function is to set the tensor size to be equal to the size of the storage.
|
// function is to set the tensor size to be equal to the size of the storage.
|
||||||
if (!result.storage().is_alias_of(storage)) {
|
if (!result.storage().is_alias_of(storage)) {
|
||||||
|
|
@ -140,9 +174,6 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
||||||
"\". This is no longer allowed; the devices must match.");
|
"\". This is no longer allowed; the devices must match.");
|
||||||
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
|
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
|
||||||
}
|
}
|
||||||
|
|
||||||
// storageOffset
|
|
||||||
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -400,7 +400,13 @@ Tensor& set_storage_meta__symint(
|
||||||
c10::SymInt storage_offset,
|
c10::SymInt storage_offset,
|
||||||
c10::SymIntArrayRef size,
|
c10::SymIntArrayRef size,
|
||||||
c10::SymIntArrayRef stride) {
|
c10::SymIntArrayRef stride) {
|
||||||
checkSetStorage(result, storage, storage_offset, size, stride);
|
checkSetStorage(
|
||||||
|
result,
|
||||||
|
storage,
|
||||||
|
storage_offset,
|
||||||
|
size,
|
||||||
|
stride,
|
||||||
|
/*check_offset_in_bounds=*/false);
|
||||||
|
|
||||||
c10::SymDimVector contiguous_strides;
|
c10::SymDimVector contiguous_strides;
|
||||||
if (stride.data() == nullptr) {
|
if (stride.data() == nullptr) {
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,9 @@ class C10_API SymBool {
|
||||||
SymBool operator|(const SymBool& other) const {
|
SymBool operator|(const SymBool& other) const {
|
||||||
return sym_or(other);
|
return sym_or(other);
|
||||||
}
|
}
|
||||||
|
SymBool operator||(const SymBool& other) const {
|
||||||
|
return sym_or(other);
|
||||||
|
}
|
||||||
SymBool operator~() const {
|
SymBool operator~() const {
|
||||||
return sym_not();
|
return sym_not();
|
||||||
}
|
}
|
||||||
|
|
@ -89,6 +92,12 @@ C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
|
||||||
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
|
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
|
||||||
#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
|
#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
|
||||||
TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
|
TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
|
||||||
|
#define TORCH_MAYBE_SYM_CHECK(cond, ...) \
|
||||||
|
if constexpr (std::is_same_v<std::decay_t<decltype(cond)>, SymBool>) { \
|
||||||
|
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK((cond), __VA_ARGS__) \
|
||||||
|
}
|
||||||
|
|
||||||
inline bool guard_size_oblivious(
|
inline bool guard_size_oblivious(
|
||||||
bool b,
|
bool b,
|
||||||
|
|
|
||||||
|
|
@ -1240,6 +1240,11 @@ def forward(self, primals_1):
|
||||||
# return [sin, copy]""",
|
# return [sin, copy]""",
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# skipped after confirming with @yf225 and @bdhirsh
|
||||||
|
@unittest.skipIf(
|
||||||
|
True,
|
||||||
|
"using set_ unsafely and PT2 FSDP2 no longer uses set_ as used in this test",
|
||||||
|
)
|
||||||
def test_input_mutation_storage_resize_down_and_set_(self):
|
def test_input_mutation_storage_resize_down_and_set_(self):
|
||||||
# Meant to mimic ppFSDP
|
# Meant to mimic ppFSDP
|
||||||
class TracableCreateParameter(torch.autograd.Function):
|
class TracableCreateParameter(torch.autograd.Function):
|
||||||
|
|
|
||||||
|
|
@ -7166,6 +7166,18 @@ class TestTorch(TestCase):
|
||||||
f_cpu = torch.randn((2, 3), dtype=torch.float32)
|
f_cpu = torch.randn((2, 3), dtype=torch.float32)
|
||||||
d_cpu = torch.randn((2, 3), dtype=torch.float64)
|
d_cpu = torch.randn((2, 3), dtype=torch.float64)
|
||||||
|
|
||||||
|
storage_offset = 0x41414141
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "out of bounds for storage of size"):
|
||||||
|
t = torch.randn(1)
|
||||||
|
t.set_(t.untyped_storage(), storage_offset, t.size())
|
||||||
|
|
||||||
|
# if size changes, set_ will resize the storage inplace
|
||||||
|
t = torch.randn(1)
|
||||||
|
size = torch.Size([2, 3])
|
||||||
|
t.set_(t.untyped_storage(), storage_offset, size)
|
||||||
|
self.assertEqual(t.storage_offset(), storage_offset)
|
||||||
|
self.assertEqual(t.untyped_storage().nbytes(), (storage_offset + size[0] * size[1]) * 4)
|
||||||
|
|
||||||
# change dtype
|
# change dtype
|
||||||
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
|
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
|
||||||
self.assertRaises(RuntimeError,
|
self.assertRaises(RuntimeError,
|
||||||
|
|
|
||||||
|
|
@ -734,7 +734,8 @@ void initTorchFunctions(PyObject* module) {
|
||||||
src.storage(),
|
src.storage(),
|
||||||
dst.sym_storage_offset(),
|
dst.sym_storage_offset(),
|
||||||
dst.sym_sizes(),
|
dst.sym_sizes(),
|
||||||
dst.sym_strides());
|
dst.sym_strides(),
|
||||||
|
/*check_offset_in_bounds=*/false);
|
||||||
});
|
});
|
||||||
py_module.def("_is_functional_tensor", [](const at::Tensor& t) {
|
py_module.def("_is_functional_tensor", [](const at::Tensor& t) {
|
||||||
return at::functionalization::impl::isFunctionalTensor(t);
|
return at::functionalization::impl::isFunctionalTensor(t);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user