mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit 3f655277d4.
Reverted https://github.com/pytorch/pytorch/pull/107063 on behalf of https://github.com/ZainRizvi due to Diff train weirdness. Need to temporarily revert this PR and will right land it soon afterwards ([comment](https://github.com/pytorch/pytorch/pull/107063#issuecomment-1690799057))
39 lines
1.2 KiB
C++
39 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
struct PyFunctionTensorPreHook : public FunctionPreHook {
|
|
PyFunctionTensorPreHook(PyObject* dict, int value_idx);
|
|
~PyFunctionTensorPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
int value_idx;
|
|
};
|
|
|
|
struct PyFunctionPreHook : public FunctionPreHook {
|
|
PyFunctionPreHook(PyObject* dict);
|
|
~PyFunctionPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
struct PyFunctionPostHook : public FunctionPostHook {
|
|
PyFunctionPostHook(PyObject* dict);
|
|
~PyFunctionPostHook() override;
|
|
variable_list operator()(
|
|
const variable_list& outputs,
|
|
const variable_list& inputs) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|