pytorch/torch/csrc/jit/interpreter.h
Adam Paszke 00df09b65d Change specialization rules in GraphExecutors (#10977)
Summary:
**Review last commit only.** Stacked on top of #10949.

This commit fixes a number of issues connected to caching
differentiability status of graphs inside graph executors,
and changes the rules for optimization of differentiable subgraphs.
Previously every one of those was instantiated as a separate graph
executor, but now they are simply heavier-optimized graph regions,
and graph executors are only instantiated for their backward.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10977

Differential Revision: D9600626

Pulled By: apaszke

fbshipit-source-id: dad09a0f586e396afbd5406319c1cd54fbb8a3d3
2018-08-30 22:11:01 -07:00

60 lines
1.5 KiB
C++

#pragma once
#include <memory>
#include <vector>
#include "ATen/core/optional.h"
#include "torch/csrc/WindowsTorchApiMacro.h"
namespace at {
struct Tensor;
}
namespace torch { namespace jit {
// The interpreter run Graphs with Tensor inputs and Tensor outputs
// a separate component in the autograd handles unwrapping and wrapping
// variable objects for use in the interpreter.
struct Node;
struct GraphExecutor;
struct CodeImpl;
struct InterpreterStateImpl;
struct Graph;
struct Node;
struct IValue;
using Stack = std::vector<IValue>;
struct TORCH_API Code {
Code()
: pImpl(nullptr) {}
Code(std::shared_ptr<Graph>& graph);
~Code();
const std::vector<GraphExecutor*>& grad_executors();
explicit operator bool() const {
return pImpl != nullptr;
}
private:
std::shared_ptr<CodeImpl> pImpl;
friend struct InterpreterStateImpl;
friend std::ostream & operator<<(std::ostream & out, const Code & code);
};
struct InterpreterState {
InterpreterState(const Code & code);
// advance the interpreter state by running one stage. Returning the
// outputs for that stage, suspending the computation.
// Call this function again continues computation where it left off.
void runOneStage(Stack & stack);
~InterpreterState();
// create a copy of InterpreterState with its current state
// used when retain_graph=True so that stages can be re-run
InterpreterState clone() const;
private:
InterpreterState(InterpreterStateImpl * pImpl);
std::shared_ptr<InterpreterStateImpl> pImpl;
};
}}