mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make Tensor.set_ validate storage_offset when sizes/strides are unchanged (#147354)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147354 Approved by: https://github.com/albanD ghstack dependencies: #147352
This commit is contained in:
parent
e64441915f
commit
536bce5a04
|
|
@ -6,6 +6,7 @@
|
|||
#include <ATen/TensorUtils.h>
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
|
||||
#include <utility>
|
||||
|
||||
|
|
@ -85,17 +86,28 @@ inline void checkInBoundsForStorage(
|
|||
T storage_offset,
|
||||
const caffe2::TypeMeta& data_type,
|
||||
const Storage& new_storage) {
|
||||
T storage_size_bytes =
|
||||
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
||||
if (storage_size_bytes == 0) {
|
||||
T storage_size_bytes, storage_size_plus_offset_bytes;
|
||||
if (stride.data()) {
|
||||
storage_size_bytes =
|
||||
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
||||
storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
|
||||
size, stride, data_type.itemsize(), storage_offset);
|
||||
} else {
|
||||
storage_size_bytes =
|
||||
at::detail::computeStorageNbytesContiguous(size, data_type.itemsize());
|
||||
storage_size_plus_offset_bytes = at::detail::computeStorageNbytesContiguous(
|
||||
size, data_type.itemsize(), storage_offset);
|
||||
}
|
||||
// It's ok to always evaluate to False for this early return for SymInts because
|
||||
// (1) maybe_convert_symint below only installs guard for int64_t case
|
||||
// (2) we check for this condition in the TORCH_MAYBE_SYM_CHECK below
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(storage_size_bytes, 0))) {
|
||||
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
||||
return;
|
||||
}
|
||||
T storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
|
||||
size, stride, data_type.itemsize(), storage_offset);
|
||||
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
||||
TORCH_CHECK(
|
||||
storage_size_plus_offset_bytes <= new_storage_size_bytes,
|
||||
TORCH_MAYBE_SYM_CHECK(
|
||||
sym_eq(storage_size_bytes, 0) || sym_le(storage_size_plus_offset_bytes, new_storage_size_bytes),
|
||||
"setStorage: sizes ",
|
||||
size,
|
||||
", strides ",
|
||||
|
|
@ -113,7 +125,7 @@ inline void checkInBoundsForStorage(
|
|||
|
||||
template <typename T>
|
||||
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
||||
ArrayRef<T> size, ArrayRef<T> stride) {
|
||||
ArrayRef<T> size, ArrayRef<T> stride, bool check_offset_in_bounds = true) {
|
||||
// FIXME: stride should be optional
|
||||
if (stride.data()) {
|
||||
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
|
||||
|
|
@ -124,6 +136,28 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
|||
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
|
||||
#endif
|
||||
|
||||
// storageOffset
|
||||
TORCH_CHECK(
|
||||
storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
||||
|
||||
// set_storage_{device} (except set_storage_meta__symint)
|
||||
// will (unsafely) set the storage offset and then call resize_impl that
|
||||
// handles resizing the storage However, resize_impl will only resize the
|
||||
// storage if the sizes/strides changed. For the case that the sizes/strides
|
||||
// remain unchanged, the storage offset is not properly validated, so we do
|
||||
// that here.
|
||||
if (check_offset_in_bounds) {
|
||||
auto result_tensor_impl = result.unsafeGetTensorImpl();
|
||||
bool size_unchanged = result_tensor_impl->generic_sizes<T>() == size;
|
||||
bool stride_unchanged = stride.data()
|
||||
? result_tensor_impl->generic_strides<T>() == stride
|
||||
: true;
|
||||
if (size_unchanged && stride_unchanged) {
|
||||
checkInBoundsForStorage(
|
||||
size, stride, storage_offset, result.dtype(), storage);
|
||||
}
|
||||
}
|
||||
|
||||
// storage: note this can't be replaced with result.set_(storage) as the semantics of that
|
||||
// function is to set the tensor size to be equal to the size of the storage.
|
||||
if (!result.storage().is_alias_of(storage)) {
|
||||
|
|
@ -140,9 +174,6 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
|||
"\". This is no longer allowed; the devices must match.");
|
||||
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
|
||||
}
|
||||
|
||||
// storageOffset
|
||||
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -400,7 +400,13 @@ Tensor& set_storage_meta__symint(
|
|||
c10::SymInt storage_offset,
|
||||
c10::SymIntArrayRef size,
|
||||
c10::SymIntArrayRef stride) {
|
||||
checkSetStorage(result, storage, storage_offset, size, stride);
|
||||
checkSetStorage(
|
||||
result,
|
||||
storage,
|
||||
storage_offset,
|
||||
size,
|
||||
stride,
|
||||
/*check_offset_in_bounds=*/false);
|
||||
|
||||
c10::SymDimVector contiguous_strides;
|
||||
if (stride.data() == nullptr) {
|
||||
|
|
|
|||
|
|
@ -49,6 +49,9 @@ class C10_API SymBool {
|
|||
SymBool operator|(const SymBool& other) const {
|
||||
return sym_or(other);
|
||||
}
|
||||
SymBool operator||(const SymBool& other) const {
|
||||
return sym_or(other);
|
||||
}
|
||||
SymBool operator~() const {
|
||||
return sym_not();
|
||||
}
|
||||
|
|
@ -89,6 +92,12 @@ C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
|
|||
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
|
||||
#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
|
||||
TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
|
||||
#define TORCH_MAYBE_SYM_CHECK(cond, ...) \
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(cond)>, SymBool>) { \
|
||||
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) \
|
||||
} else { \
|
||||
TORCH_CHECK((cond), __VA_ARGS__) \
|
||||
}
|
||||
|
||||
inline bool guard_size_oblivious(
|
||||
bool b,
|
||||
|
|
|
|||
|
|
@ -1240,6 +1240,11 @@ def forward(self, primals_1):
|
|||
# return [sin, copy]""",
|
||||
# )
|
||||
|
||||
# skipped after confirming with @yf225 and @bdhirsh
|
||||
@unittest.skipIf(
|
||||
True,
|
||||
"using set_ unsafely and PT2 FSDP2 no longer uses set_ as used in this test",
|
||||
)
|
||||
def test_input_mutation_storage_resize_down_and_set_(self):
|
||||
# Meant to mimic ppFSDP
|
||||
class TracableCreateParameter(torch.autograd.Function):
|
||||
|
|
|
|||
|
|
@ -7166,6 +7166,18 @@ class TestTorch(TestCase):
|
|||
f_cpu = torch.randn((2, 3), dtype=torch.float32)
|
||||
d_cpu = torch.randn((2, 3), dtype=torch.float64)
|
||||
|
||||
storage_offset = 0x41414141
|
||||
with self.assertRaisesRegex(RuntimeError, "out of bounds for storage of size"):
|
||||
t = torch.randn(1)
|
||||
t.set_(t.untyped_storage(), storage_offset, t.size())
|
||||
|
||||
# if size changes, set_ will resize the storage inplace
|
||||
t = torch.randn(1)
|
||||
size = torch.Size([2, 3])
|
||||
t.set_(t.untyped_storage(), storage_offset, size)
|
||||
self.assertEqual(t.storage_offset(), storage_offset)
|
||||
self.assertEqual(t.untyped_storage().nbytes(), (storage_offset + size[0] * size[1]) * 4)
|
||||
|
||||
# change dtype
|
||||
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
|
||||
self.assertRaises(RuntimeError,
|
||||
|
|
|
|||
|
|
@ -734,7 +734,8 @@ void initTorchFunctions(PyObject* module) {
|
|||
src.storage(),
|
||||
dst.sym_storage_offset(),
|
||||
dst.sym_sizes(),
|
||||
dst.sym_strides());
|
||||
dst.sym_strides(),
|
||||
/*check_offset_in_bounds=*/false);
|
||||
});
|
||||
py_module.def("_is_functional_tensor", [](const at::Tensor& t) {
|
||||
return at::functionalization::impl::isFunctionalTensor(t);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user