mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19921 ghimport-source-id: 12a4553a4a081e8a41f4ed432b4ce3dc14e4699f Differential Revision: D15125017 Pulled By: ZolotukhinM fbshipit-source-id: f7285bd1e0745dadb9cd353a5fa8a09728012a59
21 lines
601 B
C++
21 lines
601 B
C++
#pragma once
|
|
// This file is temporary until native_functions.yaml and derivatives.yaml are
|
|
// merged. Ideally this should all go into native_functions.yaml
|
|
|
|
#include <c10/util/Optional.h>
|
|
#include <c10/util/StringUtil.h>
|
|
#include <torch/csrc/jit/script/module.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
struct GradientPair {
|
|
std::shared_ptr<Graph> forward;
|
|
std::shared_ptr<Graph> backward;
|
|
};
|
|
|
|
TORCH_API c10::optional<GradientPair> gradientInfoForSchema(
|
|
const FunctionSchema& schema);
|
|
TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
|
|
} // namespace jit
|
|
} // namespace torch
|