pytorch/torch/csrc/jit/graph_executor.h
Adam Paszke 00df09b65d Change specialization rules in GraphExecutors (#10977)
Summary:
**Review last commit only.** Stacked on top of #10949.

This commit fixes a number of issues connected to caching
differentiability status of graphs inside graph executors,
and changes the rules for optimization of differentiable subgraphs.
Previously every one of those was instantiated as a separate graph
executor, but now they are simply heavier-optimized graph regions,
and graph executors are only instantiated for their backward.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10977

Differential Revision: D9600626

Pulled By: apaszke

fbshipit-source-id: dad09a0f586e396afbd5406319c1cd54fbb8a3d3
2018-08-30 22:11:01 -07:00

55 lines
1.5 KiB
C++

#pragma once
#include <memory>
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/variable_tensor_list.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/autodiff.h"
#include "torch/csrc/jit/argument_spec.h"
namespace torch { namespace jit {
struct GraphExecutorState;
// Notice that those structs don't manage lifetime of their members.
// They is only valid only right after you call getDebugState() and should never
// be used again once another GraphExecutor function is called.
struct ExecutionPlanState {
Code* code = nullptr;
const Graph* graph = nullptr;
};
struct GraphExecutorState {
const Graph* graph;
ExecutionPlanState fallback; // XXX: members of this field are optional
std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
};
struct GraphExecutorImpl;
struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
void run(Stack & inputs);
explicit operator bool() const {
return pImpl != nullptr;
}
std::shared_ptr<Graph> graph() const;
std::shared_ptr<Graph> graphFor(const Stack& inputs) const;
GraphExecutorState getDebugState();
private:
std::shared_ptr<GraphExecutorImpl> pImpl;
};
// These passes need to run before it is valid to pass to the interpreter
// regardless of whether sizes have been specialized or not.
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
namespace detail {
GraphExecutor* getGradExecutor(Operation& op);
} // namespace detail
}}