mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR introduces some modifications: 1. We find out some const function parameters that can be passed by reference and add the reference. 2. We find more opportunists of passing by value and change them accordingly. 3. Some use-after-move errors are fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95942 Approved by: https://github.com/Skylion007
32 lines
878 B
C++
32 lines
878 B
C++
#pragma once
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <functional>
|
|
#include <memory>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
using hooks_list =
|
|
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
|
|
|
|
struct CppFunctionTensorPreHook : public FunctionPreHook {
|
|
CppFunctionTensorPreHook(std::shared_ptr<hooks_list> hooks, int value_idx);
|
|
variable_list operator()(const variable_list& values) override;
|
|
|
|
std::shared_ptr<hooks_list> hooks_;
|
|
int value_idx_;
|
|
};
|
|
|
|
struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
|
|
CppFunctionSingleTensorPreHook(
|
|
std::function<at::TensorBase(const at::TensorBase&)> hook,
|
|
int value_idx);
|
|
variable_list operator()(const variable_list& values) override;
|
|
|
|
std::function<at::TensorBase(const at::TensorBase&)> hook_;
|
|
int value_idx_;
|
|
};
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|