Apply some more missing moves in aten native (#92983)

Add some additional missing moves to further improve vmap and related operators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92983
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan 2023-01-26 15:52:16 +00:00 committed by PyTorch MergeBot
parent 7e449e8ba7
commit 3888555fa1
11 changed files with 48 additions and 32 deletions

View File

@ -172,7 +172,7 @@ Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const
Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::SymInt index) {
// Pessimism: we can't reapply views for slice_scatter.
return base.select_scatter_symint(mutated_view, dim, index);
return base.select_scatter_symint(mutated_view, dim, std::move(index));
}
Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {

View File

@ -9,6 +9,8 @@
#include <ATen/Operators.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <utility>
namespace at { namespace functorch {
template <typename F, F Func, typename... ExtraArgs>
@ -306,7 +308,7 @@ std::tuple<Tensor, optional<int64_t>> log_sigmoid_backward_batch_rule(
}
Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous
return at::binomial(count, prob.contiguous(), std::move(gen)); // Bug in PyTorch, prob shouldn't need to be contiguous
}
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {

View File

@ -19,6 +19,8 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/VmapGeneratedPlumbing.h>
#include <utility>
// This file contains helper functions for batching rules.
namespace at { namespace functorch {
@ -339,7 +341,7 @@ inline void boxed_all_tensors_have_optional_bdim(
if (tensor_idx == contig_tensor_index) {
value_ = value_.contiguous();
}
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
continue;
}
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
@ -347,7 +349,7 @@ inline void boxed_all_tensors_have_optional_bdim(
if (tensor_idx == contig_tensor_index) {
value_ = value_.contiguous();
}
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
}
op.callBoxed(stack);

View File

@ -6,6 +6,7 @@
#include <ATen/functorch/BatchRulesHelper.h>
#include <iostream>
#include <utility>
#include <ATen/Operators.h>
#include <ATen/functorch/PlumbingHelper.h>
@ -236,7 +237,7 @@ std::tuple<Tensor, optional<int64_t>> squeeze_batch_rule(const Tensor& self, opt
}
auto result = self.view(squeezed_sizes);
return std::make_tuple(result, c10::optional<int64_t>(new_batch_idx));
return std::make_tuple(std::move(result), c10::optional<int64_t>(new_batch_idx));
}
std::tuple<Tensor, optional<int64_t>> squeeze_dims_batch_rule(
@ -284,13 +285,13 @@ std::tuple<std::vector<Tensor>, optional<int64_t>> chunk_batching_rule(const Ten
std::tuple<Tensor, optional<int64_t>> select_batching_rule(const Tensor& self, optional<int64_t> bdim, int64_t dim, c10::SymInt index) {
if (!bdim) {
return std::make_tuple(self.select_symint(dim, index), nullopt);
return std::make_tuple(self.select_symint(dim, std::move(index)), nullopt);
}
auto _self = moveBatchDimToFront(self, bdim);
auto dim_physical = getPhysicalDim(_self, true, dim);
auto result = _self.select_symint(dim_physical, index);
return std::make_tuple(result, 0);
auto result = _self.select_symint(dim_physical, std::move(index));
return std::make_tuple(std::move(result), 0);
}
std::tuple<Tensor, optional<int64_t>> _reshape_alias_batch_rule(const Tensor& self, optional<int64_t> bdim, const c10::SymIntArrayRef shape, const c10::SymIntArrayRef strides) {
@ -359,8 +360,8 @@ std::tuple<Tensor,optional<int64_t>> slice_batch_rule(
auto self_ = moveBatchDimToFront(self, self_bdim);
dim = getPhysicalDim(self, self_bdim.has_value(), dim);
auto result = self_.slice_symint(dim, start, end, step);
return std::make_tuple(result, 0);
auto result = self_.slice_symint(dim, std::move(start), std::move(end), std::move(step));
return std::make_tuple(std::move(result), 0);
}
static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
@ -386,7 +387,7 @@ transpose_int_batch_rule(
dim0 = getPhysicalDim(self, self_bdim.has_value(), dim0);
dim1 = getPhysicalDim(self, self_bdim.has_value(), dim1);
auto result = self_.transpose(dim0, dim1);
return std::make_tuple(result, 0);
return std::make_tuple(std::move(result), 0);
}
std::tuple<Tensor, optional<int64_t>> permute_batching_rule(
@ -416,7 +417,7 @@ std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
c10::SymDimVector input_sizes_(input_sizes.size() + 1);
input_sizes_[0] = grad_input_.sym_size(0);
std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, index);
auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, std::move(index));
return std::make_tuple(std::move(result), 0);
}
@ -429,7 +430,7 @@ std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
c10::SymDimVector input_sizes_(input_sizes.size() + 1);
input_sizes_[0] = grad_input_.size(0);
std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, start, end, step);
auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, std::move(start), std::move(end), std::move(step));
return std::make_tuple(std::move(result), 0);
}
@ -507,7 +508,7 @@ std::tuple<Tensor, optional<int64_t>> unfold_batch_rule(
if (logical_rank==0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, 0);
return std::make_tuple(std::move(result), 0);
}
std::tuple<Tensor, optional<int64_t>> narrow_copy_batch_rule(
@ -517,9 +518,9 @@ std::tuple<Tensor, optional<int64_t>> narrow_copy_batch_rule(
auto self_ = moveBatchDimToFront(self, self_bdim);
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
dim = maybe_wrap_dim(dim, logical_rank) + 1;
auto result = self_.narrow_copy_symint(dim, start, length);
auto result = self_.narrow_copy_symint(dim, std::move(start), std::move(length));
return std::make_tuple(result, 0);
return std::make_tuple(std::move(result), 0);
}
std::tuple<std::vector<Tensor>, optional<int64_t>> unsafe_split_batch_rule(
@ -531,8 +532,8 @@ std::tuple<std::vector<Tensor>, optional<int64_t>> unsafe_split_batch_rule(
auto self_ = moveBatchDimToFront(self, self_bdim);
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
dim = maybe_wrap_dim(dim, logical_rank) + 1;
auto result = self_.unsafe_split_symint(split_size, dim);
return std::make_tuple(result, 0);
auto result = self_.unsafe_split_symint(std::move(split_size), dim);
return std::make_tuple(std::move(result), 0);
}
std::tuple<Tensor, optional<int64_t>> movedim_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef source, IntArrayRef destination) {

View File

@ -6,6 +6,8 @@
#include <ATen/functorch/ADInterpreters.h>
#include <ATen/functorch/DynamicLayer.h>
#include <utility>
namespace at { namespace functorch {
static DispatchKeySet get_all_dynlayer_keyset() {
@ -92,7 +94,7 @@ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto result = unwrapIfDead(tensor);
auto* wrapper = maybeGetTensorWrapper(result);
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
auto* batched = maybeGetBatchedImpl(result);
auto* batched = maybeGetBatchedImpl(std::move(result));
TORCH_INTERNAL_ASSERT(batched == nullptr);
return tensor;
});

View File

@ -16,6 +16,8 @@
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/functorch/BatchRulesHelper.h>
#include <utility>
namespace at {
namespace functorch {
@ -476,7 +478,7 @@ Tensor as_strided_batching_rule(
optional<c10::SymInt> storage_offset) {
if (!participatesInCurrentLevel(tensor)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::as_strided_symint(tensor, sizes, strides, storage_offset);
return at::as_strided_symint(tensor, sizes, strides, std::move(storage_offset));
}
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor);
auto num_batch_dims = physical_view.numBatchDims();
@ -511,7 +513,7 @@ Tensor as_strided_batching_rule(
// and creates a tensor y such that each y[i] references the same memory
// locations as zi. See NOTE: [When will the as_strided batching rule fail?]
auto result = physical_view.tensor().as_strided_symint(
physical_sizes, physical_strides, storage_offset);
physical_sizes, physical_strides, std::move(storage_offset));
return physical_view.getPhysicalToLogicalMap().apply(result);
}

View File

@ -8,6 +8,8 @@
#else
#include <ATen/ops/view_as_real_native.h>
#include <ATen/ops/view_as_complex_native.h>
#include <utility>
#endif
// WARNING: this header contains non-inline functions and should be only
@ -47,7 +49,7 @@ Tensor _view_as_real_physical(const Tensor& self) {
auto new_strides = computeStrideForViewAsReal(self.sym_strides());
auto new_storage_offset = self.sym_storage_offset() * 2;
const auto float_type = c10::toRealValueType(self.scalar_type());
auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
return real_tensor;
}

View File

@ -48,8 +48,9 @@
#include <ATen/ops/sqrt.h>
#endif
#include <vector>
#include <c10/core/SymIntArrayRef.h>
#include <utility>
#include <vector>
static const int MIOPEN_DIM_MAX = 5;
@ -490,7 +491,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
auto options = input.options().dtype(
at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda()));
auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
auto save_invstd = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options);
// don't return view of input, don't return empty tensor because it will break gradient chain
auto out = input.clone();
@ -514,7 +515,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
check_dims_match_num_input_features("weight", num_features, weight.sym_numel());
}
if (bias.defined()) {
check_dims_match_num_input_features("bias", num_features, bias.sym_numel());
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
}
const bool use_cudnn = (
@ -672,7 +673,7 @@ Tensor instance_norm(
at::alias(running_mean).copy_(running_mean_.view_symint({ b, c }).mean(0, false));
}
if (running_var.defined()) {
at::alias(running_var).copy_(running_var_.view_symint({ b, c }).mean(0, false));
at::alias(running_var).copy_(running_var_.view_symint({ std::move(b), std::move(c) }).mean(0, false));
}
return out.view_symint(input.sym_sizes());

View File

@ -4,6 +4,8 @@
#include <ATen/native/DispatchStub.h>
#include <c10/util/irange.h>
#include <utility>
#pragma once
namespace at {
@ -93,7 +95,7 @@ inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
c10::SymInt inputSize, c10::SymInt kernelSize, int64_t stride, int64_t dilation) {
return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), stride, dilation);
}
// AveragePool2d/DilatedMaxPool2d (forward)

View File

@ -7,6 +7,8 @@
#include <c10/core/CPUAllocator.h>
#include <utility>
namespace at { namespace native {
@ -130,7 +132,7 @@ static inline void checkSetStorage(Tensor& result, Storage storage, T storage_of
"Attempted to set the storage of a tensor on device \"", result.storage().device(),
"\" to a storage on different device \"", storage.device(),
"\". This is no longer allowed; the devices must match.");
result.unsafeGetTensorImpl()->set_storage_keep_dtype(storage);
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
}
// storageOffset

View File

@ -1804,7 +1804,7 @@ Tensor select_symint(const Tensor& self, int64_t dim, c10::SymInt index) {
Tensor select_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
auto grad_input = at::zeros_symint(input_sizes, grad.options());
grad_input.select_symint(dim, index).copy_(grad);
grad_input.select_symint(dim, std::move(index)).copy_(grad);
return grad_input;
}
@ -3879,7 +3879,7 @@ at::Tensor clone_preserve_strides(const at::Tensor& self) {
auto nbytes = self.storage().sym_nbytes();
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
auto numel = nbytes / dtype_size;
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
auto self_full_size = self.as_strided_symint({std::move(numel)}, {1}, 0);
auto clone = self_full_size.clone();
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
return out;
@ -3896,7 +3896,7 @@ at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t
}
at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) {
auto output = clone_preserve_strides(self);
auto slice = output.select_symint(dim, index);
auto slice = output.select_symint(dim, std::move(index));
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
slice.copy_(src);
return output;
@ -4039,7 +4039,7 @@ at::Tensor& _reshape_alias_copy_out(const at::Tensor & self, at::IntArrayRef siz
at::Tensor& select_copy_symint_out(const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out) {
auto tmp = self.select_symint(dim, index);
auto tmp = self.select_symint(dim, std::move(index));
out.copy_(tmp);
return out;
}