mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
As this is the oldest gcc that is fully compatible with C++17 standard. - Replace number of conditional version with simpler `if(CMAKE_COMPILER_IS_GNUCXX)` or `append_cxx_flag_if_supported`. - As `-Wsuggest-override` condition was hidden before incorrect guard, add missing `override` keywords to `torch::autograd::PyFunctionTensorPostAccGradHooks::apply_with_saved` , `caffe2::python::TensorFeeder::Feed` and `cafee2::NetObserverReporterPrint::report``` Fixes https://github.com/pytorch/pytorch/issues/101839 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112858 Approved by: https://github.com/Skylion007, https://github.com/albanD
58 lines
2.0 KiB
C++
58 lines
2.0 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
namespace torch::dynamo::autograd {
|
|
class SwapSavedVariables;
|
|
} // namespace torch::dynamo::autograd
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
struct PyFunctionTensorPreHook : public FunctionPreHook {
|
|
PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
|
|
~PyFunctionTensorPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
size_t value_idx;
|
|
};
|
|
|
|
struct PyFunctionPreHook : public FunctionPreHook {
|
|
PyFunctionPreHook(PyObject* dict);
|
|
~PyFunctionPreHook() override;
|
|
variable_list operator()(const variable_list& values) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
struct PyFunctionPostHook : public FunctionPostHook {
|
|
PyFunctionPostHook(PyObject* dict);
|
|
~PyFunctionPostHook() override;
|
|
variable_list operator()(
|
|
const variable_list& outputs,
|
|
const variable_list& inputs) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
// PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks,
|
|
// and it is understandable if you are confused by why it's a subclass. We are
|
|
// simply following the precedent of PyFunctionPreHook and PyFunctionPostHook
|
|
// above to easily enroll into existing infrastructure.
|
|
struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
|
|
PyFunctionTensorPostAccGradHooks(PyObject* dict);
|
|
~PyFunctionTensorPostAccGradHooks() override;
|
|
void operator()(const Variable& tensor) override;
|
|
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
|
void apply_with_saved(
|
|
Variable& tensor,
|
|
torch::dynamo::autograd::SwapSavedVariables& saved) override;
|
|
PyObject* dict;
|
|
};
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|