#include #include #include namespace { using torch::autograd::Variable; void check_single_result (Variable value, Variable result, std::string hook_name) { if (!value.defined()) { throw std::runtime_error("can't replace a empty gradient with a non-empty value"); } torch::autograd::check_variable_result(value, result, hook_name); } } namespace torch { namespace autograd { CppFunctionPreHook::CppFunctionPreHook(const std::shared_ptr &hooks, int value_idx) : hooks_(hooks) , value_idx_(value_idx) {} variable_list CppFunctionPreHook::operator()(const variable_list& values) { auto value = values[value_idx_]; for (unsigned i = 0; i < hooks_->size(); ++i) { auto &hook = (*hooks_)[i]; if (!hook) { // hook was removed continue; } auto res = hook(value); if (!res.defined()) { // Don't change gradient continue; } check_single_result(value, res, std::to_string(i)); value = res; } variable_list results(values); results[value_idx_] = value; return results; } }} // namespace torch::autograd