mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit e525f433e1.
Original PR: #85849
Fixes #ISSUE_NUMBER
In addition to reverting the revert, this PR:
- defines the virtual destructor of FunctionPreHook in the header. Why? Presumably the internal build imports the header from somewhere, but does not have function_hooks.cpp (where the virtual destructor was previously defined) in the same compilation unit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92559
Approved by: https://github.com/albanD
29 lines
664 B
C++
29 lines
664 B
C++
#pragma once
|
|
|
|
#include <ATen/Tensor.h>
|
|
#include <torch/csrc/Export.h>
|
|
#include <vector>
|
|
|
|
// A hook that's called on gradients
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
using Variable = at::Tensor;
|
|
using variable_list = std::vector<Variable>;
|
|
|
|
struct TORCH_API FunctionPreHook {
|
|
virtual ~FunctionPreHook() = default;
|
|
virtual variable_list operator()(const variable_list& grads) = 0;
|
|
};
|
|
|
|
struct TORCH_API FunctionPostHook {
|
|
virtual ~FunctionPostHook() = default;
|
|
virtual variable_list operator()(
|
|
const variable_list& outputs /* grad_inputs */,
|
|
const variable_list& inputs /* grad_outputs */) = 0;
|
|
};
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|