mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128301 Approved by: https://github.com/ezyang, https://github.com/r-barnes
19 lines
562 B
C++
19 lines
562 B
C++
#pragma once
|
|
// This file is temporary until native_functions.yaml and derivatives.yaml are
|
|
// merged. Ideally this should all go into native_functions.yaml
|
|
|
|
#include <c10/util/StringUtil.h>
|
|
#include <torch/csrc/jit/api/module.h>
|
|
#include <optional>
|
|
|
|
namespace torch::jit {
|
|
struct GradientPair {
|
|
std::shared_ptr<Graph> forward;
|
|
std::shared_ptr<Graph> backward;
|
|
};
|
|
|
|
TORCH_API std::optional<GradientPair> gradientInfoForSchema(
|
|
const FunctionSchema& schema);
|
|
TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
|
|
} // namespace torch::jit
|