some reference and move fixes (#95942)

This PR introduces some modifications:
1. We find out some const function parameters that can be passed by reference and add the reference.
2. We find more opportunists of passing by value and change them accordingly.
3. Some use-after-move errors are fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95942
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy 2023-03-10 03:44:09 +00:00 committed by PyTorch MergeBot
parent 6e0359dd42
commit d0e4ca233e
30 changed files with 46 additions and 60 deletions

View File

@ -476,7 +476,7 @@ void sync(ITensorListRef t_list) {
sync(t);
}
}
void sync(const c10::List<c10::optional<Tensor>> t_list) {
void sync(const c10::List<c10::optional<Tensor>>& t_list) {
for (const auto i : c10::irange(t_list.size())) {
sync(t_list[i]);
}

View File

@ -213,7 +213,7 @@ TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
TORCH_API void sync(const at::Tensor& t);
TORCH_API void sync(const c10::optional<Tensor>& t);
TORCH_API void sync(const c10::List<c10::optional<Tensor>> t_list);
TORCH_API void sync(const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API void sync(ITensorListRef t_list);
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);

View File

@ -824,8 +824,8 @@ struct TORCH_API TensorType : public SharedType {
TensorType(
c10::optional<at::ScalarType> scalar_type,
c10::optional<Device> device,
const SymbolicShape& sizes,
const VaryingShape<Stride>& strides,
SymbolicShape sizes,
VaryingShape<Stride> strides,
c10::optional<bool> requires_grad,
c10::optional<bool> undefined = false);

View File

@ -424,16 +424,15 @@ VaryingShape<int64_t> TensorType::strides() const {
TensorType::TensorType(
c10::optional<at::ScalarType> scalar_type,
c10::optional<Device> device,
// NOLINTNEXTLINE(modernize-pass-by-value)
const SymbolicShape& sizes,
const VaryingShape<Stride>& strides,
SymbolicShape sizes,
VaryingShape<Stride> strides,
c10::optional<bool> requires_grad,
c10::optional<bool> undefined)
: SharedType(TypeKind::TensorType),
scalar_type_(scalar_type),
device_(device),
sizes_(sizes),
strides_(strides),
sizes_(std::move(sizes)),
strides_(std::move(strides)),
requires_grad_(requires_grad),
undefined_(undefined) {}

View File

@ -29,7 +29,7 @@ inline scalar_t vec_reduce_all(
template <typename scalar_t, typename Op>
struct VecReduceAllSIMD {
static inline scalar_t apply(const Op& vec_fun, Vectorized<scalar_t> acc_vec) {
static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
}
};
@ -38,7 +38,7 @@ struct VecReduceAllSIMD {
#if defined(CPU_CAPABILITY_AVX2)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, Vectorized<float> acc_vec) {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle
@ -57,7 +57,7 @@ struct VecReduceAllSIMD<float, Op> {
#if defined(CPU_CAPABILITY_AVX512)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, Vectorized<float> acc_vec) {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 256-bit shuffle
@ -79,7 +79,7 @@ struct VecReduceAllSIMD<float, Op> {
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
template <typename scalar_t, typename Op>
inline scalar_t vec_reduce_all(const Op& vec_fun, Vectorized<scalar_t> acc_vec) {
inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
}

View File

@ -71,7 +71,7 @@ bool isBatchedAtLevel(ITensorListRef tensors, int64_t level) {
return false;
}
bool isBatchedAtLevel(const c10::List<c10::optional<Tensor>> maybe_tensors, int64_t level) {
bool isBatchedAtLevel(const c10::List<c10::optional<Tensor>>& maybe_tensors, int64_t level) {
for (const auto idx : c10::irange(0, maybe_tensors.size())) {
const auto& maybe_tensor = maybe_tensors.get(idx);
if (isBatchedAtLevel(maybe_tensor, level)) {

View File

@ -42,7 +42,7 @@ TORCH_API std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tenso
// Returns True if ANY tensor in tensors is batched at level
TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);
TORCH_API bool isBatchedAtLevel(const c10::List<c10::optional<Tensor>> maybe_tensors, int64_t level);
TORCH_API bool isBatchedAtLevel(const c10::List<c10::optional<Tensor>>& maybe_tensors, int64_t level);
TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level);
TORCH_API bool isBatchedAtLevel(const c10::optional<Tensor>& maybe_tensor, int64_t level);

View File

@ -78,7 +78,7 @@ static inline void checkInBoundsForStorage(
ArrayRef<T> size,
ArrayRef<T> stride,
T storage_offset,
const caffe2::TypeMeta data_type,
const caffe2::TypeMeta& data_type,
const Storage& new_storage) {
T storage_size_bytes =
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());

View File

@ -133,7 +133,7 @@ void _sparse_binary_op_intersection_kernel_impl(
Tensor& res,
const Tensor& x_,
const Tensor& y_,
const std::vector<int64_t> broadcasted_shape,
const std::vector<int64_t>& broadcasted_shape,
const bool restrict_indices_to_rhs = false,
const bool distributive_with_sum = true
) {

View File

@ -6,20 +6,18 @@ QTensorImpl::QTensorImpl(
Storage&& storage,
DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
// NOLINTNEXTLINE(modernize-pass-by-value)
QuantizerPtr quantizer)
: TensorImpl(std::move(storage), key_set, data_type),
quantizer_(quantizer) {}
: TensorImpl(std::move(storage), std::move(key_set), data_type),
quantizer_(std::move(quantizer)) {}
QTensorImpl::QTensorImpl(
ImplType type,
Storage&& storage,
DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
// NOLINTNEXTLINE(modernize-pass-by-value)
QuantizerPtr quantizer)
: TensorImpl(type, std::move(storage), key_set, data_type),
quantizer_(quantizer) {}
: TensorImpl(type, std::move(storage), std::move(key_set), data_type),
quantizer_(std::move(quantizer)) {}
const char* QTensorImpl::tensorimpl_type_name() const {
return "QTensorImpl";

View File

@ -2105,7 +2105,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* If the existing data does not match the desired type, it will be deleted
* and a new storage will be created.
*/
inline void* raw_mutable_data(const caffe2::TypeMeta meta) {
inline void* raw_mutable_data(const caffe2::TypeMeta& meta) {
// For 0-size tensors it's fine to return any pointer (including nullptr)
if (data_type_ == meta && storage_initialized()) {
return static_cast<void*>(

View File

@ -613,7 +613,7 @@ inline TensorOptions dtype() {
return dtype(caffe2::TypeMeta::Make<T>());
}
inline std::string toString(const TensorOptions options) {
inline std::string toString(const TensorOptions& options) {
std::ostringstream stream;
stream << options;
return stream.str();
@ -763,7 +763,7 @@ inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) {
}
namespace detail {
inline bool backend_supports_empty_operator(const TensorOptions options) {
inline bool backend_supports_empty_operator(const TensorOptions& options) {
// Quantized backends don't support at::empty().
// They have separate operators like at::empty_quantized() that take in
// extra information about how to quantize the tensor.

View File

@ -422,7 +422,7 @@ class C10_API TypeMeta final {
return data().name_;
}
friend bool operator==(const TypeMeta lhs, const TypeMeta rhs) noexcept;
friend bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept;
template <typename T>
bool Match() const noexcept {
@ -597,10 +597,10 @@ C10_EXPORT constexpr uint16_t TypeMeta::_typeMetaData<
inline TypeMeta::TypeMeta() noexcept
: index_(_typeMetaData<detail::_Uninitialized>()) {}
inline bool operator==(const TypeMeta lhs, const TypeMeta rhs) noexcept {
inline bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept {
return (lhs.index_ == rhs.index_);
}
inline bool operator!=(const TypeMeta lhs, const TypeMeta rhs) noexcept {
inline bool operator!=(const TypeMeta& lhs, const TypeMeta& rhs) noexcept {
return !operator==(lhs, rhs);
}

View File

@ -624,7 +624,6 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) {
}
if (compiledStep->gotFailure) {
LOG(ERROR) << "One of the workers failed.";
// NOLINTNEXTLINE(bugprone-use-after-move)
if (first_exception) {
first_exception.rethrowException();
}

View File

@ -888,14 +888,14 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
VLOG(2) << "Adding extra init tensor: " << t.name();
TensorShape shape;
shape.mutable_dims()->CopyFrom(t.dims());
auto dims_size = shape.dims_size();
auto ret = shape_hints_onnx_.emplace(t.name(), std::move(shape));
shape_hints_max_bs->emplace(
std::piecewise_construct,
std::forward_as_tuple(ret.first->first),
std::forward_as_tuple(
std::vector<TensorBoundShape::DimType>(
// NOLINTNEXTLINE(bugprone-use-after-move)
shape.dims_size(), TensorBoundShape_DimType_CONSTANT),
dims_size, TensorBoundShape_DimType_CONSTANT),
ret.first->second));
// Feed into workspace as CPU Tensors

View File

@ -329,14 +329,13 @@ class ChunkDataset final
ChunkSampler chunk_sampler,
ExampleSampler example_sampler,
ChunkDatasetOptions options,
// NOLINTNEXTLINE(modernize-pass-by-value)
std::function<void(UnwrappedBatchType&)> preprocessing_policy =
std::function<void(UnwrappedBatchType&)>())
: chunk_reader_(std::move(chunk_reader)),
chunk_sampler_(std::move(chunk_sampler)),
example_sampler_(std::move(example_sampler)),
options_(std::move(options)),
preprocessing_policy_(preprocessing_policy),
preprocessing_policy_(std::move(preprocessing_policy)),
quit_worker_(false),
running_preloaders_(0),
load_checkpoint_(false) {}

View File

@ -22,11 +22,10 @@ void check_single_result(
namespace torch {
namespace autograd {
// NOLINTNEXTLINE(modernize-pass-by-value)
CppFunctionTensorPreHook::CppFunctionTensorPreHook(
const std::shared_ptr<hooks_list>& hooks,
std::shared_ptr<hooks_list> hooks,
int value_idx)
: hooks_(hooks), value_idx_(value_idx) {}
: hooks_(std::move(hooks)), value_idx_(value_idx) {}
variable_list CppFunctionTensorPreHook::operator()(
const variable_list& values) {

View File

@ -10,9 +10,7 @@ using hooks_list =
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
struct CppFunctionTensorPreHook : public FunctionPreHook {
CppFunctionTensorPreHook(
const std::shared_ptr<hooks_list>& hooks,
int value_idx);
CppFunctionTensorPreHook(std::shared_ptr<hooks_list> hooks, int value_idx);
variable_list operator()(const variable_list& values) override;
std::shared_ptr<hooks_list> hooks_;

View File

@ -930,7 +930,6 @@ static variable_list call_function(
});
if (has_post_hooks) {
// NOLINTNEXTLINE(bugprone-use-after-move)
return call_post_hooks(fn, std::move(outputs), inputs);
}
return outputs;

View File

@ -109,7 +109,7 @@ inline bool isFwGradDefinedTensorList(const at::ITensorListRef& variables) {
}
inline bool isFwGradDefinedTensorList(
const c10::List<c10::optional<at::Tensor>> li) {
const c10::List<c10::optional<at::Tensor>>& li) {
bool ret = false;
for (auto i : c10::irange(li.size())) {
auto t = li.get(i);

View File

@ -41,7 +41,7 @@ struct InputMetadata {
InputMetadata() = default;
InputMetadata(
const at::TensorOptions options,
const at::TensorOptions& options,
MetadataShape input_shape,
bool is_tensor_subclass)
: options_{options},

View File

@ -160,9 +160,8 @@ class TORCH_API DistEngine {
// Guard to clean up resources once the backward pass is done.
class BackwardPassCleanupGuard {
public:
// NOLINTNEXTLINE(modernize-pass-by-value)
explicit BackwardPassCleanupGuard(const ContextPtr& autogradContext)
: autogradContext_(autogradContext) {}
explicit BackwardPassCleanupGuard(ContextPtr autogradContext)
: autogradContext_(std::move(autogradContext)) {}
~BackwardPassCleanupGuard() {
DistEngine::getInstance().cleanupBackwardPass(autogradContext_);

View File

@ -82,11 +82,9 @@ TORCH_API extern mutexType currentStateStackEntryMutex;
class StateStackEntry {
public:
StateStackEntry(
// NOLINTNEXTLINE(modernize-pass-by-value)
std::shared_ptr<StateStackEntry> prevPtr,
// NOLINTNEXTLINE(modernize-pass-by-value)
std::shared_ptr<State> statePtr)
: prevPtr_(prevPtr), statePtr_(statePtr) {}
: prevPtr_(std::move(prevPtr)), statePtr_(std::move(statePtr)) {}
static void pushRange(std::shared_ptr<State> profilerProcessGlobalStatePtr);
static std::shared_ptr<State> popRange();

View File

@ -314,8 +314,8 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
// TODO: To achieve better performance we can have a pipe pool per
// client that can be configured using RpcBackendOptions.
struct ClientPipe {
// NOLINTNEXTLINE(modernize-pass-by-value)
explicit ClientPipe(std::shared_ptr<tensorpipe::Pipe> pipe) : pipe_(pipe) {}
explicit ClientPipe(std::shared_ptr<tensorpipe::Pipe> pipe)
: pipe_(std::move(pipe)) {}
std::shared_ptr<tensorpipe::Pipe> pipe_;
mutable std::mutex mutex_;
bool inError_{false};
@ -359,11 +359,10 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
struct TimeoutMessageMetadata {
TimeoutMessageMetadata(
uint64_t messageId_,
// NOLINTNEXTLINE(modernize-pass-by-value)
std::shared_ptr<AtomicJitFuture> responseFuture_,
std::chrono::milliseconds timeout_)
: messageId(messageId_),
responseFuture(responseFuture_),
responseFuture(std::move(responseFuture_)),
timeout(timeout_) {}
uint64_t messageId;
std::shared_ptr<AtomicJitFuture> responseFuture;

View File

@ -62,10 +62,9 @@ class CompilationUnit {
class TORCH_API Module {
public:
Module(
// NOLINTNEXTLINE(modernize-pass-by-value)
c10::intrusive_ptr<c10::ivalue::Object> object,
std::shared_ptr<CompilationUnit> cu)
: object_(object), cu_(std::move(cu)) {}
: object_(std::move(object)), cu_(std::move(cu)) {}
Module() = default;
Method get_method(const std::string& method_name) const;
template <typename... Types>

View File

@ -75,7 +75,7 @@ c10::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor) {
}
c10::optional<BackendDevice> GetBackendDevice(
const c10::optional<c10::Device> device) {
const c10::optional<c10::Device>& device) {
if (device) {
return c10::make_optional(atenDeviceToBackendDevice(*device));
}

View File

@ -560,7 +560,7 @@ void LazyGraphExecutor::Async::Wait() {
}
}
bool LazyGraphExecutor::ShouldSyncTensor(const LazyTensorPtr tensor) const {
bool LazyGraphExecutor::ShouldSyncTensor(const LazyTensorPtr& tensor) const {
return tensor->GetIrValue()->op() != ltc_not_supported;
}

View File

@ -348,7 +348,7 @@ class TORCH_API LazyGraphExecutor {
std::vector<BackendDataPtr> parameters_data;
};
virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const;
virtual bool ShouldSyncTensor(const LazyTensorPtr& tensor) const;
SyncTensorCollection CollectSyncTensors(
const std::vector<LazyTensorPtr>& tensors,

View File

@ -54,7 +54,7 @@ static const char* backend_to_string(const at::Backend& backend) {
}
}
std::string options_to_string(const at::TensorOptions options) {
std::string options_to_string(const at::TensorOptions& options) {
std::ostringstream ss;
ss << backend_to_string(options.backend()) << "."
<< toString(at::typeMetaToScalarType(options.dtype())) << "Tensor";

View File

@ -8,7 +8,7 @@
namespace torch {
namespace utils {
std::string options_to_string(const at::TensorOptions options);
std::string options_to_string(const at::TensorOptions& options);
std::string type_to_string(const at::DeprecatedTypeProperties& type);
at::TensorOptions options_from_string(const std::string& str);