pytorch/torch/csrc/autograd/function_hook.h
soulitzer 1bc60c6b31 [reland] Improve hooks ordering behavior (#92559)
This reverts commit e525f433e1.

Original PR:  #85849
Fixes #ISSUE_NUMBER

In addition to reverting the revert, this PR:
- defines the virtual destructor of FunctionPreHook in the header. Why? Presumably the internal build imports the header from somewhere, but does not have function_hooks.cpp (where the virtual destructor was previously defined) in the same compilation unit.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92559
Approved by: https://github.com/albanD
2023-01-19 08:17:32 +00:00

29 lines
664 B
C++

#pragma once
#include <ATen/Tensor.h>
#include <torch/csrc/Export.h>
#include <vector>
// A hook that's called on gradients
namespace torch {
namespace autograd {
using Variable = at::Tensor;
using variable_list = std::vector<Variable>;
struct TORCH_API FunctionPreHook {
virtual ~FunctionPreHook() = default;
virtual variable_list operator()(const variable_list& grads) = 0;
};
struct TORCH_API FunctionPostHook {
virtual ~FunctionPostHook() = default;
virtual variable_list operator()(
const variable_list& outputs /* grad_inputs */,
const variable_list& inputs /* grad_outputs */) = 0;
};
} // namespace autograd
} // namespace torch