From 3888555fa1affd22eb82220c44d22ec296054ded Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Thu, 26 Jan 2023 15:52:16 +0000 Subject: [PATCH] Apply some more missing moves in aten native (#92983) Add some additional missing moves to further improve vmap and related operators. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92983 Approved by: https://github.com/ezyang --- aten/src/ATen/FunctionalInverses.cpp | 2 +- .../ATen/functorch/BatchRulesBinaryOps.cpp | 4 ++- aten/src/ATen/functorch/BatchRulesHelper.h | 6 ++-- aten/src/ATen/functorch/BatchRulesViews.cpp | 29 ++++++++++--------- aten/src/ATen/functorch/Interpreter.cpp | 4 ++- .../functorch/LegacyBatchingRegistrations.cpp | 6 ++-- aten/src/ATen/native/ComplexHelper.h | 4 ++- aten/src/ATen/native/Normalization.cpp | 9 +++--- aten/src/ATen/native/Pool.h | 4 ++- aten/src/ATen/native/Resize.h | 4 ++- aten/src/ATen/native/TensorShape.cpp | 8 ++--- 11 files changed, 48 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 8a68503df32..8f99a5df73c 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -172,7 +172,7 @@ Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::SymInt index) { // Pessimism: we can't reapply views for slice_scatter. - return base.select_scatter_symint(mutated_view, dim, index); + return base.select_scatter_symint(mutated_view, dim, std::move(index)); } Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index 1c0f98949a5..5a00f7d466c 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -9,6 +9,8 @@ #include #include +#include + namespace at { namespace functorch { template @@ -306,7 +308,7 @@ std::tuple> log_sigmoid_backward_batch_rule( } Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional gen) { - return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous + return at::binomial(count, prob.contiguous(), std::move(gen)); // Bug in PyTorch, prob shouldn't need to be contiguous } TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { diff --git a/aten/src/ATen/functorch/BatchRulesHelper.h b/aten/src/ATen/functorch/BatchRulesHelper.h index 8e78ba71029..9db1543fd37 100644 --- a/aten/src/ATen/functorch/BatchRulesHelper.h +++ b/aten/src/ATen/functorch/BatchRulesHelper.h @@ -19,6 +19,8 @@ #include #include +#include + // This file contains helper functions for batching rules. namespace at { namespace functorch { @@ -339,7 +341,7 @@ inline void boxed_all_tensors_have_optional_bdim( if (tensor_idx == contig_tensor_index) { value_ = value_.contiguous(); } - (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; + (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_); continue; } TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1); @@ -347,7 +349,7 @@ inline void boxed_all_tensors_have_optional_bdim( if (tensor_idx == contig_tensor_index) { value_ = value_.contiguous(); } - (*stack)[args_begin + tensor_pos[tensor_idx]] = value_; + (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_); } op.callBoxed(stack); diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp index 9bc67cbe881..5ce01711cae 100644 --- a/aten/src/ATen/functorch/BatchRulesViews.cpp +++ b/aten/src/ATen/functorch/BatchRulesViews.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -236,7 +237,7 @@ std::tuple> squeeze_batch_rule(const Tensor& self, opt } auto result = self.view(squeezed_sizes); - return std::make_tuple(result, c10::optional(new_batch_idx)); + return std::make_tuple(std::move(result), c10::optional(new_batch_idx)); } std::tuple> squeeze_dims_batch_rule( @@ -284,13 +285,13 @@ std::tuple, optional> chunk_batching_rule(const Ten std::tuple> select_batching_rule(const Tensor& self, optional bdim, int64_t dim, c10::SymInt index) { if (!bdim) { - return std::make_tuple(self.select_symint(dim, index), nullopt); + return std::make_tuple(self.select_symint(dim, std::move(index)), nullopt); } auto _self = moveBatchDimToFront(self, bdim); auto dim_physical = getPhysicalDim(_self, true, dim); - auto result = _self.select_symint(dim_physical, index); - return std::make_tuple(result, 0); + auto result = _self.select_symint(dim_physical, std::move(index)); + return std::make_tuple(std::move(result), 0); } std::tuple> _reshape_alias_batch_rule(const Tensor& self, optional bdim, const c10::SymIntArrayRef shape, const c10::SymIntArrayRef strides) { @@ -359,8 +360,8 @@ std::tuple> slice_batch_rule( auto self_ = moveBatchDimToFront(self, self_bdim); dim = getPhysicalDim(self, self_bdim.has_value(), dim); - auto result = self_.slice_symint(dim, start, end, step); - return std::make_tuple(result, 0); + auto result = self_.slice_symint(dim, std::move(start), std::move(end), std::move(step)); + return std::make_tuple(std::move(result), 0); } static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { @@ -386,7 +387,7 @@ transpose_int_batch_rule( dim0 = getPhysicalDim(self, self_bdim.has_value(), dim0); dim1 = getPhysicalDim(self, self_bdim.has_value(), dim1); auto result = self_.transpose(dim0, dim1); - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } std::tuple> permute_batching_rule( @@ -416,7 +417,7 @@ std::tuple> select_backward_batch_rule( c10::SymDimVector input_sizes_(input_sizes.size() + 1); input_sizes_[0] = grad_input_.sym_size(0); std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1); - auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, index); + auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, std::move(index)); return std::make_tuple(std::move(result), 0); } @@ -429,7 +430,7 @@ std::tuple> slice_backward_batch_rule( c10::SymDimVector input_sizes_(input_sizes.size() + 1); input_sizes_[0] = grad_input_.size(0); std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1); - auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, start, end, step); + auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, std::move(start), std::move(end), std::move(step)); return std::make_tuple(std::move(result), 0); } @@ -507,7 +508,7 @@ std::tuple> unfold_batch_rule( if (logical_rank==0) { result = result.squeeze(-1); } - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } std::tuple> narrow_copy_batch_rule( @@ -517,9 +518,9 @@ std::tuple> narrow_copy_batch_rule( auto self_ = moveBatchDimToFront(self, self_bdim); auto logical_rank = rankWithoutBatchDim(self, self_bdim); dim = maybe_wrap_dim(dim, logical_rank) + 1; - auto result = self_.narrow_copy_symint(dim, start, length); + auto result = self_.narrow_copy_symint(dim, std::move(start), std::move(length)); - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } std::tuple, optional> unsafe_split_batch_rule( @@ -531,8 +532,8 @@ std::tuple, optional> unsafe_split_batch_rule( auto self_ = moveBatchDimToFront(self, self_bdim); auto logical_rank = rankWithoutBatchDim(self, self_bdim); dim = maybe_wrap_dim(dim, logical_rank) + 1; - auto result = self_.unsafe_split_symint(split_size, dim); - return std::make_tuple(result, 0); + auto result = self_.unsafe_split_symint(std::move(split_size), dim); + return std::make_tuple(std::move(result), 0); } std::tuple> movedim_batch_rule(const Tensor& self, optional self_bdim, IntArrayRef source, IntArrayRef destination) { diff --git a/aten/src/ATen/functorch/Interpreter.cpp b/aten/src/ATen/functorch/Interpreter.cpp index 6db36eb3303..b2c4dda1257 100644 --- a/aten/src/ATen/functorch/Interpreter.cpp +++ b/aten/src/ATen/functorch/Interpreter.cpp @@ -6,6 +6,8 @@ #include #include +#include + namespace at { namespace functorch { static DispatchKeySet get_all_dynlayer_keyset() { @@ -92,7 +94,7 @@ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) { auto result = unwrapIfDead(tensor); auto* wrapper = maybeGetTensorWrapper(result); TORCH_INTERNAL_ASSERT(wrapper == nullptr); - auto* batched = maybeGetBatchedImpl(result); + auto* batched = maybeGetBatchedImpl(std::move(result)); TORCH_INTERNAL_ASSERT(batched == nullptr); return tensor; }); diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp index d9f6ed21f13..547c945eda1 100644 --- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp @@ -16,6 +16,8 @@ #include #include +#include + namespace at { namespace functorch { @@ -476,7 +478,7 @@ Tensor as_strided_batching_rule( optional storage_offset) { if (!participatesInCurrentLevel(tensor)) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); - return at::as_strided_symint(tensor, sizes, strides, storage_offset); + return at::as_strided_symint(tensor, sizes, strides, std::move(storage_offset)); } auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor); auto num_batch_dims = physical_view.numBatchDims(); @@ -511,7 +513,7 @@ Tensor as_strided_batching_rule( // and creates a tensor y such that each y[i] references the same memory // locations as zi. See NOTE: [When will the as_strided batching rule fail?] auto result = physical_view.tensor().as_strided_symint( - physical_sizes, physical_strides, storage_offset); + physical_sizes, physical_strides, std::move(storage_offset)); return physical_view.getPhysicalToLogicalMap().apply(result); } diff --git a/aten/src/ATen/native/ComplexHelper.h b/aten/src/ATen/native/ComplexHelper.h index 9533115a706..ca5929fb5f4 100644 --- a/aten/src/ATen/native/ComplexHelper.h +++ b/aten/src/ATen/native/ComplexHelper.h @@ -8,6 +8,8 @@ #else #include #include + +#include #endif // WARNING: this header contains non-inline functions and should be only @@ -47,7 +49,7 @@ Tensor _view_as_real_physical(const Tensor& self) { auto new_strides = computeStrideForViewAsReal(self.sym_strides()); auto new_storage_offset = self.sym_storage_offset() * 2; const auto float_type = c10::toRealValueType(self.scalar_type()); - auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides); + auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides); return real_tensor; } diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index ab9094d9b59..a05d669d694 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -48,8 +48,9 @@ #include #endif -#include #include +#include +#include static const int MIOPEN_DIM_MAX = 5; @@ -490,7 +491,7 @@ std::tuple _batch_norm_impl_index( auto options = input.options().dtype( at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda())); auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options); - auto save_invstd = at::empty_symint(c10::SymIntArrayRef({num_features}), options); + auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options); // don't return view of input, don't return empty tensor because it will break gradient chain auto out = input.clone(); @@ -514,7 +515,7 @@ std::tuple _batch_norm_impl_index( check_dims_match_num_input_features("weight", num_features, weight.sym_numel()); } if (bias.defined()) { - check_dims_match_num_input_features("bias", num_features, bias.sym_numel()); + check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel()); } const bool use_cudnn = ( @@ -672,7 +673,7 @@ Tensor instance_norm( at::alias(running_mean).copy_(running_mean_.view_symint({ b, c }).mean(0, false)); } if (running_var.defined()) { - at::alias(running_var).copy_(running_var_.view_symint({ b, c }).mean(0, false)); + at::alias(running_var).copy_(running_var_.view_symint({ std::move(b), std::move(c) }).mean(0, false)); } return out.view_symint(input.sym_sizes()); diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 0ff4490086b..15c16d1d7ba 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -4,6 +4,8 @@ #include #include +#include + #pragma once namespace at { @@ -93,7 +95,7 @@ inline std::pair pooling_same_mode_padding_lr( inline std::pair pooling_same_mode_padding_lr( c10::SymInt inputSize, c10::SymInt kernelSize, int64_t stride, int64_t dilation) { - return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation); + return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), stride, dilation); } // AveragePool2d/DilatedMaxPool2d (forward) diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index c93e4cbe84b..c328afcfad9 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -7,6 +7,8 @@ #include +#include + namespace at { namespace native { @@ -130,7 +132,7 @@ static inline void checkSetStorage(Tensor& result, Storage storage, T storage_of "Attempted to set the storage of a tensor on device \"", result.storage().device(), "\" to a storage on different device \"", storage.device(), "\". This is no longer allowed; the devices must match."); - result.unsafeGetTensorImpl()->set_storage_keep_dtype(storage); + result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage)); } // storageOffset diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 1cd231b6719..2bbfd49128e 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1804,7 +1804,7 @@ Tensor select_symint(const Tensor& self, int64_t dim, c10::SymInt index) { Tensor select_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { auto grad_input = at::zeros_symint(input_sizes, grad.options()); - grad_input.select_symint(dim, index).copy_(grad); + grad_input.select_symint(dim, std::move(index)).copy_(grad); return grad_input; } @@ -3879,7 +3879,7 @@ at::Tensor clone_preserve_strides(const at::Tensor& self) { auto nbytes = self.storage().sym_nbytes(); TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0); auto numel = nbytes / dtype_size; - auto self_full_size = self.as_strided_symint({numel}, {1}, 0); + auto self_full_size = self.as_strided_symint({std::move(numel)}, {1}, 0); auto clone = self_full_size.clone(); auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset()); return out; @@ -3896,7 +3896,7 @@ at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t } at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) { auto output = clone_preserve_strides(self); - auto slice = output.select_symint(dim, index); + auto slice = output.select_symint(dim, std::move(index)); TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes()); slice.copy_(src); return output; @@ -4039,7 +4039,7 @@ at::Tensor& _reshape_alias_copy_out(const at::Tensor & self, at::IntArrayRef siz at::Tensor& select_copy_symint_out(const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out) { - auto tmp = self.select_symint(dim, index); + auto tmp = self.select_symint(dim, std::move(index)); out.copy_(tmp); return out; }