mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: **Review last commit only.** Stacked on top of #10949. This commit fixes a number of issues connected to caching differentiability status of graphs inside graph executors, and changes the rules for optimization of differentiable subgraphs. Previously every one of those was instantiated as a separate graph executor, but now they are simply heavier-optimized graph regions, and graph executors are only instantiated for their backward. zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/10977 Differential Revision: D9600626 Pulled By: apaszke fbshipit-source-id: dad09a0f586e396afbd5406319c1cd54fbb8a3d3
60 lines
1.5 KiB
C++
60 lines
1.5 KiB
C++
#pragma once
|
|
#include <memory>
|
|
#include <vector>
|
|
#include "ATen/core/optional.h"
|
|
|
|
#include "torch/csrc/WindowsTorchApiMacro.h"
|
|
|
|
namespace at {
|
|
struct Tensor;
|
|
}
|
|
namespace torch { namespace jit {
|
|
|
|
// The interpreter run Graphs with Tensor inputs and Tensor outputs
|
|
// a separate component in the autograd handles unwrapping and wrapping
|
|
// variable objects for use in the interpreter.
|
|
|
|
struct Node;
|
|
struct GraphExecutor;
|
|
struct CodeImpl;
|
|
struct InterpreterStateImpl;
|
|
struct Graph;
|
|
struct Node;
|
|
struct IValue;
|
|
using Stack = std::vector<IValue>;
|
|
|
|
struct TORCH_API Code {
|
|
Code()
|
|
: pImpl(nullptr) {}
|
|
Code(std::shared_ptr<Graph>& graph);
|
|
~Code();
|
|
|
|
const std::vector<GraphExecutor*>& grad_executors();
|
|
|
|
explicit operator bool() const {
|
|
return pImpl != nullptr;
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<CodeImpl> pImpl;
|
|
friend struct InterpreterStateImpl;
|
|
friend std::ostream & operator<<(std::ostream & out, const Code & code);
|
|
};
|
|
|
|
struct InterpreterState {
|
|
InterpreterState(const Code & code);
|
|
// advance the interpreter state by running one stage. Returning the
|
|
// outputs for that stage, suspending the computation.
|
|
// Call this function again continues computation where it left off.
|
|
void runOneStage(Stack & stack);
|
|
~InterpreterState();
|
|
// create a copy of InterpreterState with its current state
|
|
// used when retain_graph=True so that stages can be re-run
|
|
InterpreterState clone() const;
|
|
private:
|
|
InterpreterState(InterpreterStateImpl * pImpl);
|
|
std::shared_ptr<InterpreterStateImpl> pImpl;
|
|
};
|
|
|
|
}}
|