pytorch/torch/csrc/dynamo/compiled_autograd.cpp

30 lines
923 B
C++

#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
namespace torch::dynamo::autograd {
std::unique_ptr<PyCompilerInterface> kActivePyCompilerInterface;
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
return kActivePyCompilerInterface;
}
PyCompilerGuard::PyCompilerGuard(std::unique_ptr<PyCompilerInterface>&& impl) {
TORCH_INTERNAL_ASSERT(
kActivePyCompilerInterface == nullptr && impl != nullptr);
kActivePyCompilerInterface = std::move(impl);
}
PyCompilerGuard::~PyCompilerGuard() {
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
kActivePyCompilerInterface.reset();
}
std::vector<std::optional<InputMetadata>> get_input_metadata(
const edge_list& edges) {
return torch::autograd::collect_input_metadata(edges);
}
} // namespace torch::dynamo::autograd