mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Mostly in torch/csrc/jit/runtime and in `ATen/cuda/` Pull Request resolved: https://github.com/pytorch/pytorch/pull/110314 Approved by: https://github.com/seemethere
19 lines
573 B
C++
19 lines
573 B
C++
#pragma once
|
|
// This file is temporary until native_functions.yaml and derivatives.yaml are
|
|
// merged. Ideally this should all go into native_functions.yaml
|
|
|
|
#include <c10/util/Optional.h>
|
|
#include <c10/util/StringUtil.h>
|
|
#include <torch/csrc/jit/api/module.h>
|
|
|
|
namespace torch::jit {
|
|
struct GradientPair {
|
|
std::shared_ptr<Graph> forward;
|
|
std::shared_ptr<Graph> backward;
|
|
};
|
|
|
|
TORCH_API c10::optional<GradientPair> gradientInfoForSchema(
|
|
const FunctionSchema& schema);
|
|
TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
|
|
} // namespace torch::jit
|