mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Now that there will be two types of Python function prehooks, I prefer have the PyFunction hook taking all grad_outputs and returning all grad_inputs as the more "canonical" one Pull Request resolved: https://github.com/pytorch/pytorch/pull/83225 Approved by: https://github.com/albanD
29 lines
744 B
C++
29 lines
744 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
struct PyFunctionTensorPreHook : public FunctionPreHook {
|
|
PyFunctionTensorPreHook(PyObject* dict, int value_idx);
|
|
~PyFunctionTensorPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
PyObject* dict;
|
|
int value_idx;
|
|
};
|
|
|
|
struct PyFunctionPostHook : public FunctionPostHook {
|
|
PyFunctionPostHook(PyObject* dict);
|
|
~PyFunctionPostHook() override;
|
|
variable_list operator()(
|
|
const variable_list& outputs,
|
|
const variable_list& inputs) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|