mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147891 Approved by: https://github.com/jansel ghstack dependencies: #147242, #147796, #147804
30 lines
923 B
C++
30 lines
923 B
C++
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/dynamo/compiled_autograd.h>
|
|
|
|
namespace torch::dynamo::autograd {
|
|
|
|
std::unique_ptr<PyCompilerInterface> kActivePyCompilerInterface;
|
|
|
|
const std::unique_ptr<PyCompilerInterface>& getPyCompilerInterface() {
|
|
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
|
|
return kActivePyCompilerInterface;
|
|
}
|
|
|
|
PyCompilerGuard::PyCompilerGuard(std::unique_ptr<PyCompilerInterface>&& impl) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
kActivePyCompilerInterface == nullptr && impl != nullptr);
|
|
kActivePyCompilerInterface = std::move(impl);
|
|
}
|
|
|
|
PyCompilerGuard::~PyCompilerGuard() {
|
|
TORCH_INTERNAL_ASSERT(kActivePyCompilerInterface != nullptr);
|
|
kActivePyCompilerInterface.reset();
|
|
}
|
|
|
|
std::vector<std::optional<InputMetadata>> get_input_metadata(
|
|
const edge_list& edges) {
|
|
return torch::autograd::collect_input_metadata(edges);
|
|
}
|
|
|
|
} // namespace torch::dynamo::autograd
|