mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Check all `.cpp` files except `jit` files for readability thoroughly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164561 Approved by: https://github.com/Skylion007
321 lines
11 KiB
C++
321 lines
11 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/core/Tensor.h>
|
|
#include <ATen/native/Resize.h>
|
|
#include <ATen/native/ResizeCommon.h>
|
|
#include <ATen/NamedTensorUtils.h>
|
|
#include <ATen/TensorSubclassLikeUtils.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/NativeFunctions.h>
|
|
#else
|
|
#include <ATen/ops/resize_as_native.h>
|
|
#include <ATen/ops/resize_as_sparse_native.h>
|
|
#include <ATen/ops/resize_native.h>
|
|
#include <ATen/ops/resize.h>
|
|
#include <ATen/ops/_resize_output.h>
|
|
#include <ATen/ops/_resize_output_native.h>
|
|
#endif
|
|
|
|
#include <c10/util/overflows.h>
|
|
|
|
namespace at::native {
|
|
|
|
// Returns true if resize is necessary
|
|
template <typename T>
|
|
static bool _resize_output_check(const Tensor& output, ArrayRef<T> shape) {
|
|
// Tests for resizing of tensors with one or more elements
|
|
if (at::symint::sizes<T>(output).equals(shape)) {
|
|
return false;
|
|
}
|
|
if (at::symint::numel<T>(output) != 0) {
|
|
TORCH_WARN(
|
|
"An output with one or more elements was resized since it had ",
|
|
"shape ", at::symint::sizes<T>(output), ", which does not match the required ",
|
|
"output shape ", shape, ". ",
|
|
"This behavior is deprecated, and in a future PyTorch release outputs ",
|
|
"will not be resized unless they have zero elements. You can explicitly ",
|
|
"reuse an out tensor t by resizing it, inplace, to zero elements with ",
|
|
"t.resize_(0).");
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
|
|
return _resize_output_check(output, shape);
|
|
}
|
|
|
|
bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape) {
|
|
return _resize_output_check(output, shape);
|
|
}
|
|
|
|
static void native_resize_(const Tensor& output, IntArrayRef shape) {
|
|
native::resize_(output, shape);
|
|
}
|
|
|
|
static void native_resize_(const Tensor& output, SymIntArrayRef shape) {
|
|
native::resize__symint(output, shape);
|
|
}
|
|
|
|
template <typename T>
|
|
static bool _resize_output(const Tensor& output, ArrayRef<T> shape) {
|
|
if (_resize_output_check<T>(output, shape)) {
|
|
// avoid a redispatch for cpu and cuda.
|
|
// TODO: when resize_cuda_ is re-written to be unified with resize_,
|
|
// we can provide the same benefit for cuda.
|
|
//
|
|
// TODO(#61485): functorch wrapped tensors should not go through the
|
|
// fast path. This is a hack, longer term solutions are in the issue
|
|
if (output.is_cpu() && !isTensorSubclassLike(output)) {
|
|
native_resize_(output, shape);
|
|
} else {
|
|
at::symint::resize_<T>(output, shape);
|
|
}
|
|
return true;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool resize_output(const Tensor& output, IntArrayRef shape) {
|
|
return _resize_output(output, shape);
|
|
}
|
|
|
|
bool resize_output_symint(const Tensor& output, SymIntArrayRef shape) {
|
|
return _resize_output(output, shape);
|
|
}
|
|
|
|
const Tensor& _resize_output_(const Tensor& self, IntArrayRef shape, c10::Device device) {
|
|
TORCH_CHECK(self.device() == device, "out Tensor doesn't have the correct device set");
|
|
at::native::resize_output(self, shape);
|
|
return self;
|
|
}
|
|
|
|
void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes) {
|
|
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
|
|
|
|
at::DataPtr new_data;
|
|
if (size_bytes != 0) {
|
|
new_data = storage->allocator()->allocate(size_bytes);
|
|
}
|
|
const at::DataPtr& old_data = storage->data_ptr();
|
|
const auto old_capacity = storage->nbytes();
|
|
const auto copy_capacity = std::min(size_bytes, old_capacity);
|
|
if (old_data != nullptr && copy_capacity > 0) {
|
|
memcpy(new_data.get(), old_data.get(), copy_capacity);
|
|
}
|
|
storage->set_data_ptr_noswap(std::move(new_data));
|
|
storage->set_nbytes(size_bytes);
|
|
}
|
|
|
|
// TODO(VitalyFedyunin): Move it to HTML docs.
|
|
//
|
|
// Strides of the output tensor of `resize_as_` operator is defined by input
|
|
// tensor strides and the value of memory_format argument.
|
|
//
|
|
// If memory_format is omitted and input tensor have the same shape as output
|
|
// tensor, strides of the output will remain unchanged. Strides going to be
|
|
// set to contiguous if shapes are different.
|
|
//
|
|
// If memory_format is equals to MemoryFormat::Contiguous (torch.contiguous_format)
|
|
// output tensor will have contiguous strides.
|
|
//
|
|
// If memory_format is equal to MemoryFormat::ChannelsLast (torch.channels_last)
|
|
// and input tensor is 4D, output tensor will have channels last memory layout.
|
|
//
|
|
// If memory_format is equal to MemoryFormat::Preserve (torch.preserve_format)
|
|
// output tensor will be defined by strides of the input tensor, following
|
|
// memory format preservation rule:
|
|
//
|
|
// - If input tensor strides are in channels last format, output tensor will
|
|
// have channels last memory layout.
|
|
//
|
|
// - Otherwise, output tensor will have contiguous memory layout.
|
|
//
|
|
const Tensor& resize_as_(
|
|
const Tensor& self,
|
|
const Tensor& the_template,
|
|
std::optional<MemoryFormat> optional_memory_format) {
|
|
if (self.is_sparse() && the_template.is_sparse()) {
|
|
TORCH_CHECK(
|
|
!optional_memory_format.has_value(),
|
|
"Unsupported memory format for sparse tensor resize_as_ :",
|
|
optional_memory_format.value());
|
|
return at::native::resize_as_sparse_(self, the_template);
|
|
}
|
|
const Tensor& result = self.resize_(the_template.sizes());
|
|
if (optional_memory_format.has_value()) {
|
|
auto memory_format = optional_memory_format.value();
|
|
if (memory_format == MemoryFormat::Preserve) {
|
|
memory_format = the_template.suggest_memory_format();
|
|
}
|
|
self.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
|
}
|
|
namedinference::propagate_names(result, the_template);
|
|
return result;
|
|
}
|
|
|
|
|
|
void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) {
|
|
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
|
|
storage->set_nbytes(std::move(size_bytes));
|
|
}
|
|
|
|
static void maybe_resize_storage_meta(TensorImpl* self, c10::SymInt new_size_bytes) {
|
|
// It does not make sense to try to resize a storage
|
|
// to hold 0 elements, and this can break
|
|
// if storage_offset is positive but
|
|
// new_size is 0, so just bail in that case
|
|
// (same comment is in Resize.h)
|
|
if (self->sym_numel() == 0) {
|
|
return;
|
|
}
|
|
|
|
const Storage& storage = self->unsafe_storage();
|
|
if (!storage) {
|
|
TORCH_INTERNAL_ASSERT(0, "NYI, this should only be Caffe2");
|
|
} else if (new_size_bytes > storage.sym_nbytes()) {
|
|
resize_bytes_meta(storage.unsafeGetStorageImpl(), std::move(new_size_bytes));
|
|
}
|
|
}
|
|
|
|
static void _maybe_resize_storage(TensorImpl* self, int64_t new_size_bytes) {
|
|
maybe_resize_storage_cpu(self, new_size_bytes);
|
|
}
|
|
|
|
static void _maybe_resize_storage(TensorImpl* self, c10::SymInt new_size_bytes) {
|
|
if (self->is_cpu()) {
|
|
maybe_resize_storage_cpu(self, new_size_bytes.expect_int());
|
|
return;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(self->is_meta());
|
|
maybe_resize_storage_meta(self, std::move(new_size_bytes));
|
|
}
|
|
|
|
template <typename T>
|
|
static TensorImpl* _resize_impl_(
|
|
TensorImpl* self,
|
|
ArrayRef<T> size,
|
|
at::OptionalArrayRef<T> stride,
|
|
bool resize_storage) {
|
|
if (self->generic_sizes<T>() == size && (!stride || self->generic_strides<T>() == stride.value())) {
|
|
return self;
|
|
}
|
|
|
|
const auto itemsize = self->dtype().itemsize();
|
|
const auto storage_offset = self->generic_storage_offset<T>();
|
|
T storage_size = T(1);
|
|
if (stride) {
|
|
self->set_sizes_and_strides(size, *stride);
|
|
storage_size = at::detail::computeStorageNbytes(
|
|
size, *stride, itemsize, storage_offset);
|
|
} else {
|
|
self->generic_set_sizes_contiguous(size);
|
|
storage_size = at::detail::computeStorageNbytesContiguous(
|
|
size, itemsize, storage_offset);
|
|
}
|
|
|
|
if (resize_storage) {
|
|
_maybe_resize_storage(self, std::move(storage_size));
|
|
}
|
|
|
|
return self;
|
|
}
|
|
|
|
TensorImpl* resize_impl_cpu_(
|
|
TensorImpl* self,
|
|
IntArrayRef size,
|
|
at::OptionalIntArrayRef stride,
|
|
bool resize_storage) {
|
|
return _resize_impl_(self, size, stride, resize_storage);
|
|
}
|
|
|
|
template <typename T>
|
|
static const Tensor& _resize_(
|
|
const Tensor& self,
|
|
ArrayRef<T> size,
|
|
std::optional<MemoryFormat> optional_memory_format) {
|
|
auto* self_ = self.unsafeGetTensorImpl();
|
|
int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().sym_nbytes().maybe_as_int().value_or(-1) : 0;
|
|
_resize_impl_<T>(self_, size, /*stride=*/std::nullopt, true);
|
|
if (optional_memory_format.has_value()) {
|
|
auto memory_format =
|
|
optional_memory_format.value();
|
|
TORCH_CHECK(
|
|
memory_format != MemoryFormat::Preserve,
|
|
"Unsupported memory format",
|
|
memory_format);
|
|
self_->empty_tensor_restride(memory_format);
|
|
}
|
|
// See Note [Enabling Deterministic Operations]
|
|
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory() && old_storage_nbytes != -1)) {
|
|
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
|
|
}
|
|
return self;
|
|
}
|
|
|
|
const Tensor& resize_(
|
|
const Tensor& self,
|
|
IntArrayRef size,
|
|
std::optional<MemoryFormat> optional_memory_format) {
|
|
if (self.has_names()) {
|
|
return resize_named_tensor_(self, size, optional_memory_format);
|
|
}
|
|
return _resize_(self, size, optional_memory_format);
|
|
}
|
|
|
|
const Tensor& resize__symint(
|
|
const Tensor& self,
|
|
c10::SymIntArrayRef size,
|
|
std::optional<MemoryFormat> optional_memory_format) {
|
|
TORCH_INTERNAL_ASSERT(!self.has_names())
|
|
return _resize_(self, size, optional_memory_format);
|
|
}
|
|
|
|
void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& newsize) {
|
|
// handles all devices except cuda (which needs to be in a different .so)
|
|
c10::DeviceType device_type = storage.device_type();
|
|
if (device_type == at::kCPU) {
|
|
at::native::resize_bytes_cpu(storage.unsafeGetStorageImpl(), newsize.expect_int());
|
|
} else if (device_type == at::kMeta) {
|
|
at::native::resize_bytes_meta(storage.unsafeGetStorageImpl(), newsize);
|
|
} else if (device_type == at::kPrivateUse1) {
|
|
at::detail::getPrivateUse1Hooks().resizePrivateUse1Bytes(
|
|
storage, newsize.expect_int());
|
|
} else if (device_type == at::kXPU || device_type == at::kHPU || device_type == at::kMTIA) {
|
|
ptrdiff_t size_bytes_i = newsize.expect_int();
|
|
TORCH_CHECK(
|
|
!c10::overflows<int64_t>(size_bytes_i),
|
|
"Requested storage size (",
|
|
size_bytes_i,
|
|
") cannot be represented as a int64_t");
|
|
const auto size_bytes = static_cast<int64_t>(size_bytes_i);
|
|
void* original_data_ptr = storage.data_ptr().get();
|
|
|
|
auto src_option =
|
|
c10::TensorOptions().device(storage.device()).dtype(at::kByte);
|
|
auto src_tensor = at::empty({0}, src_option).set_(storage);
|
|
src_tensor.resize_({size_bytes});
|
|
|
|
// When using resize_ to replace resize_bytes_xxx, in some cases
|
|
// the original data_ptr is still returned, which is an inconsistent
|
|
// behavior when compared to resize_bytes_xxx. For these cases,
|
|
// an additional memory copy and update for storage are required.
|
|
if (original_data_ptr == src_tensor.storage().data_ptr().get()) {
|
|
auto new_tensor = at::empty(src_tensor.sizes(), src_tensor.options());
|
|
new_tensor.copy_(src_tensor);
|
|
storage.set_data_ptr_noswap(
|
|
std::move(new_tensor.storage().mutable_data_ptr()));
|
|
storage.unsafeGetStorageImpl()->set_allocator(
|
|
new_tensor.storage().unsafeGetStorageImpl()->allocator());
|
|
storage.set_nbytes(new_tensor.storage().nbytes());
|
|
}
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"UntypedStorage.resize_: got unexpected device type ",
|
|
device_type);
|
|
}
|
|
}
|
|
|
|
} // namespace at::native
|