Make Tensor.set_ validate storage_offset when sizes/strides are unchanged (#147354)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147354
Approved by: https://github.com/albanD
ghstack dependencies: #147352
This commit is contained in:
Mikayla Gawarecki 2025-02-26 19:37:28 -08:00 committed by PyTorch MergeBot
parent e64441915f
commit 536bce5a04
6 changed files with 77 additions and 13 deletions

View File

@ -6,6 +6,7 @@
#include <ATen/TensorUtils.h> #include <ATen/TensorUtils.h>
#include <c10/core/CPUAllocator.h> #include <c10/core/CPUAllocator.h>
#include <c10/core/SymBool.h>
#include <utility> #include <utility>
@ -85,17 +86,28 @@ inline void checkInBoundsForStorage(
T storage_offset, T storage_offset,
const caffe2::TypeMeta& data_type, const caffe2::TypeMeta& data_type,
const Storage& new_storage) { const Storage& new_storage) {
T storage_size_bytes = T storage_size_bytes, storage_size_plus_offset_bytes;
if (stride.data()) {
storage_size_bytes =
at::detail::computeStorageNbytes(size, stride, data_type.itemsize()); at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
if (storage_size_bytes == 0) { storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
size, stride, data_type.itemsize(), storage_offset);
} else {
storage_size_bytes =
at::detail::computeStorageNbytesContiguous(size, data_type.itemsize());
storage_size_plus_offset_bytes = at::detail::computeStorageNbytesContiguous(
size, data_type.itemsize(), storage_offset);
}
// It's ok to always evaluate to False for this early return for SymInts because
// (1) maybe_convert_symint below only installs guard for int64_t case
// (2) we check for this condition in the TORCH_MAYBE_SYM_CHECK below
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(storage_size_bytes, 0))) {
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel. // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
return; return;
} }
T storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
size, stride, data_type.itemsize(), storage_offset);
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes()); T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
TORCH_CHECK( TORCH_MAYBE_SYM_CHECK(
storage_size_plus_offset_bytes <= new_storage_size_bytes, sym_eq(storage_size_bytes, 0) || sym_le(storage_size_plus_offset_bytes, new_storage_size_bytes),
"setStorage: sizes ", "setStorage: sizes ",
size, size,
", strides ", ", strides ",
@ -113,7 +125,7 @@ inline void checkInBoundsForStorage(
template <typename T> template <typename T>
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
ArrayRef<T> size, ArrayRef<T> stride) { ArrayRef<T> size, ArrayRef<T> stride, bool check_offset_in_bounds = true) {
// FIXME: stride should be optional // FIXME: stride should be optional
if (stride.data()) { if (stride.data()) {
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(), TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
@ -124,6 +136,28 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX"); TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
#endif #endif
// storageOffset
TORCH_CHECK(
storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
// set_storage_{device} (except set_storage_meta__symint)
// will (unsafely) set the storage offset and then call resize_impl that
// handles resizing the storage However, resize_impl will only resize the
// storage if the sizes/strides changed. For the case that the sizes/strides
// remain unchanged, the storage offset is not properly validated, so we do
// that here.
if (check_offset_in_bounds) {
auto result_tensor_impl = result.unsafeGetTensorImpl();
bool size_unchanged = result_tensor_impl->generic_sizes<T>() == size;
bool stride_unchanged = stride.data()
? result_tensor_impl->generic_strides<T>() == stride
: true;
if (size_unchanged && stride_unchanged) {
checkInBoundsForStorage(
size, stride, storage_offset, result.dtype(), storage);
}
}
// storage: note this can't be replaced with result.set_(storage) as the semantics of that // storage: note this can't be replaced with result.set_(storage) as the semantics of that
// function is to set the tensor size to be equal to the size of the storage. // function is to set the tensor size to be equal to the size of the storage.
if (!result.storage().is_alias_of(storage)) { if (!result.storage().is_alias_of(storage)) {
@ -140,9 +174,6 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
"\". This is no longer allowed; the devices must match."); "\". This is no longer allowed; the devices must match.");
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage)); result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
} }
// storageOffset
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
} }
/** /**

View File

@ -400,7 +400,13 @@ Tensor& set_storage_meta__symint(
c10::SymInt storage_offset, c10::SymInt storage_offset,
c10::SymIntArrayRef size, c10::SymIntArrayRef size,
c10::SymIntArrayRef stride) { c10::SymIntArrayRef stride) {
checkSetStorage(result, storage, storage_offset, size, stride); checkSetStorage(
result,
storage,
storage_offset,
size,
stride,
/*check_offset_in_bounds=*/false);
c10::SymDimVector contiguous_strides; c10::SymDimVector contiguous_strides;
if (stride.data() == nullptr) { if (stride.data() == nullptr) {

View File

@ -49,6 +49,9 @@ class C10_API SymBool {
SymBool operator|(const SymBool& other) const { SymBool operator|(const SymBool& other) const {
return sym_or(other); return sym_or(other);
} }
SymBool operator||(const SymBool& other) const {
return sym_or(other);
}
SymBool operator~() const { SymBool operator~() const {
return sym_not(); return sym_not();
} }
@ -89,6 +92,12 @@ C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \ #define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
#define TORCH_MAYBE_SYM_CHECK(cond, ...) \
if constexpr (std::is_same_v<std::decay_t<decltype(cond)>, SymBool>) { \
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) \
} else { \
TORCH_CHECK((cond), __VA_ARGS__) \
}
inline bool guard_size_oblivious( inline bool guard_size_oblivious(
bool b, bool b,

View File

@ -1240,6 +1240,11 @@ def forward(self, primals_1):
# return [sin, copy]""", # return [sin, copy]""",
# ) # )
# skipped after confirming with @yf225 and @bdhirsh
@unittest.skipIf(
True,
"using set_ unsafely and PT2 FSDP2 no longer uses set_ as used in this test",
)
def test_input_mutation_storage_resize_down_and_set_(self): def test_input_mutation_storage_resize_down_and_set_(self):
# Meant to mimic ppFSDP # Meant to mimic ppFSDP
class TracableCreateParameter(torch.autograd.Function): class TracableCreateParameter(torch.autograd.Function):

View File

@ -7166,6 +7166,18 @@ class TestTorch(TestCase):
f_cpu = torch.randn((2, 3), dtype=torch.float32) f_cpu = torch.randn((2, 3), dtype=torch.float32)
d_cpu = torch.randn((2, 3), dtype=torch.float64) d_cpu = torch.randn((2, 3), dtype=torch.float64)
storage_offset = 0x41414141
with self.assertRaisesRegex(RuntimeError, "out of bounds for storage of size"):
t = torch.randn(1)
t.set_(t.untyped_storage(), storage_offset, t.size())
# if size changes, set_ will resize the storage inplace
t = torch.randn(1)
size = torch.Size([2, 3])
t.set_(t.untyped_storage(), storage_offset, size)
self.assertEqual(t.storage_offset(), storage_offset)
self.assertEqual(t.untyped_storage().nbytes(), (storage_offset + size[0] * size[1]) * 4)
# change dtype # change dtype
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage())) self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
self.assertRaises(RuntimeError, self.assertRaises(RuntimeError,

View File

@ -734,7 +734,8 @@ void initTorchFunctions(PyObject* module) {
src.storage(), src.storage(),
dst.sym_storage_offset(), dst.sym_storage_offset(),
dst.sym_sizes(), dst.sym_sizes(),
dst.sym_strides()); dst.sym_strides(),
/*check_offset_in_bounds=*/false);
}); });
py_module.def("_is_functional_tensor", [](const at::Tensor& t) { py_module.def("_is_functional_tensor", [](const at::Tensor& t) {
return at::functionalization::impl::isFunctionalTensor(t); return at::functionalization::impl::isFunctionalTensor(t);