pytorch/torch/csrc/jit/script/compilation_unit.h
Edward Yang 97e1f07ffc Replace AT_CHECK with TORCH_CHECK [shard 10/10]
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20436

Reviewed By: jerryzh168

Differential Revision: D15318926

fbshipit-source-id: 71a43070cc50cc174f703ebc595f1d87c6fc1e91
2019-05-15 07:35:37 -07:00

325 lines
9.2 KiB
C++

#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/source_range.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <functional>
#include <memory>
#include <mutex>
#include <ostream>
#include <string>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
namespace script {
struct Def;
struct SugaredValue;
struct Function;
struct Resolver;
using ResolverPtr = std::shared_ptr<Resolver>;
using Self = std::function<std::shared_ptr<SugaredValue>(Value*)>;
using Kwargs = std::unordered_map<std::string, IValue>;
// A Function is a pure Graph with no implicit `self` object bound.
// It contains schema information, and the executor that manages the
// execution of the function. script::Method is a wrapper around a
// underlying Function that also provides a `self` object.
struct TORCH_API Function {
Function(
std::string name,
bool optimize,
std::shared_ptr<Graph> graph,
std::function<void(Function&)> function_creator)
: name_(std::move(name)),
graph_(std::move(graph)),
optimize_(optimize),
function_creator_(std::move(function_creator)) {}
void run(Stack& stack) {
get_executor().run(stack);
}
void run(Stack&& stack) {
run(stack);
}
IValue operator()(
std::vector<IValue> stack,
const Kwargs& kwargs = Kwargs()) {
getSchema().checkAndNormalizeInputs(stack, kwargs);
run(stack);
return stack.front();
}
std::shared_ptr<Graph> graph() const {
return graph_;
}
const std::string& name() const {
return name_;
}
// if this isn't yet defined, run its method_creator function
void ensure_defined();
size_t num_inputs() const {
return graph()->inputs().size();
}
Function& setSchema(FunctionSchema schema) {
schema_ = make_unique<FunctionSchema>(std::move(schema));
return *this;
}
const FunctionSchema& getSchema() const {
if (schema_ == nullptr) {
schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
}
return *schema_;
}
std::string pretty_print_schema() const {
AT_ASSERT(schema_);
std::stringstream ss;
ss << *schema_;
return ss.str();
}
GraphExecutorState getDebugState() {
return get_executor().getDebugState();
}
bool is_optimized() const {
return optimize_;
}
void check_single_output() {
TORCH_CHECK(
graph()->outputs().size() == 1,
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
}
GraphExecutor& get_executor() {
std::call_once(executor_init_, [&] {
check_single_output();
executor_ = GraphExecutor(graph(), optimize_);
});
return executor_;
}
// returns nullptr and fills in failure_messages if the callee does not
// match the functions schema
// TODO: defined in module.cpp, move to compilation_unit.cpp
Value* try_emit_call(
Graph& graph,
const SourceRange& loc,
c10::optional<NamedValue> self,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs,
std::stringstream& failure_messages,
bool conv_tensors_to_nums);
Value* emit_call(
Graph& graph,
const SourceRange& loc,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs);
private:
static FunctionSchema defaultSchemaFor(const Function& function) {
std::vector<Argument> args;
std::vector<Argument> returns;
Graph& g = *function.graph();
size_t num_inputs = function.num_inputs();
for (size_t i = 0; i < num_inputs; ++i) {
const Value* v = g.inputs().at(i);
std::string name = v->hasUniqueName() ? v->uniqueNameBase()
: ("argument_" + std::to_string(i));
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
}
for (size_t i = 0; i < g.outputs().size(); ++i) {
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
}
return {function.name(), "", std::move(args), std::move(returns)};
}
std::string name_;
std::shared_ptr<Graph> graph_; // for debugging and for inlining
bool optimize_;
GraphExecutor executor_; // for execution
std::once_flag executor_init_;
// an optional function that actually creates the method when
// emit_call_to(this,...) is first called. this is used by the compiler so
// that it can construct methods out of order
std::function<void(Function&)> function_creator_;
// if absent, then we generate a default schema based on the graph
// mutable because getSchema caches the default schema if one is requested
// before a call to setSchema
mutable std::unique_ptr<FunctionSchema> schema_;
};
// A CompilationUnit is a list of named script::Functions
// with helper methods to iterate the list, or invoke the function.
// Classes have a CompilationUnit holding the class methods
// and Modules also have a CompilationUnit holding the Functions that
// are used to implement their Methods
struct TORCH_API CompilationUnit {
// constructor that takes a set of functions to compile using the native resolver
explicit CompilationUnit(const std::string& source);
CompilationUnit() = default;
std::shared_ptr<Function> find_function(const std::string& name) const {
auto it = dict_.find(name);
if (it == dict_.end())
return nullptr;
return functions_[it->second];
}
Function& get_function(const std::string& name) const {
if (auto r = find_function(name))
return *r;
AT_ERROR("attempted to get undefined function ", name);
}
void set_optimized(bool o) {
optimized_ = o;
}
bool is_optimized() const {
return optimized_;
}
// for historic reasons, these are defined in compiler.cpp
void define(
const std::vector<Def>& definitions,
const std::vector<ResolverPtr>&
resolvers, /* determines how we handle free
variables in each definition*/
// if non-null, the first argument to each def, is bound to this value
const Self& self);
// same as above but parse the definitions from source
void define(
const std::string& source,
const ResolverPtr& resolver,
const Self& self);
std::shared_ptr<Function> create_function(
std::string name,
std::shared_ptr<Graph> graph) {
auto fn = std::make_shared<Function>(
std::move(name), is_optimized(), std::move(graph), nullptr);
register_function(fn);
return fn;
}
const std::vector<std::shared_ptr<Function>>& get_functions() const {
return functions_;
}
/// Run a method from this compilation.
///
/// For example:
/// @code
/// IValue output = module->run("relu_script", a, b);
/// @endcode
///
/// To get a compile a module from a source string, see torch::jit::compile
///
/// @param method_name The name of the method to run
/// @param args Arguments to be passed to the method
/// @return An IValue containing the return value (or values if it is a tuple)
/// from the method
template <typename... Types>
IValue run_method(const std::string& method_name, Types&&... args) {
return get_function(method_name)({IValue(std::forward<Types>(args))...});
}
void drop_all_functions() {
dict_.clear();
functions_.clear();
}
/**
* Register a class as being owned by this compilation unit.
*/
void register_class(ClassTypePtr classType) {
classes_.push_back(std::move(classType));
};
ClassTypePtr get_class(const c10::QualifiedName& name) const {
for (const auto& cls : classes_) {
if (cls->qualname() == name.qualifiedName()) {
return cls;
}
}
return nullptr;
}
/**
* Python compilation unit methods
*
* Right now there is a single compilation unit that owns all ScriptClasses
* defined in Python. Below are accessors methods for it.
*/
static const CompilationUnit& _get_python_cu_const() {
return _get_python_cu();
}
static CompilationUnit& _get_python_cu() {
static CompilationUnit pyCu;
return pyCu;
}
// For testing: clear all Python-defined classes to ensure that unit tests
// have isolation.
static void _clear_python_cu() {
_get_python_cu().classes_.clear();
}
private:
Function& register_function(std::shared_ptr<Function> fn) {
TORCH_CHECK(
0 == dict_.count(fn->name()),
"method '",
fn->name(),
"' already defined.");
functions_.emplace_back(std::move(fn));
dict_[functions_.back()->name()] = functions_.size() - 1;
return *functions_.back();
}
std::vector<std::shared_ptr<Function>> functions_;
// for fast lookup
std::unordered_map<std::string, size_t> dict_;
bool optimized_ = true;
// [class owernship] Right now there aree two relationships between classes
// and compilation units:
// 1. Classes have compilation units internally that hold their methods.
// 2. On load, the TypePtrs of any imported classes are owned by the main
// module's compilation unit.
std::vector<ClassTypePtr> classes_;
};
} // namespace script
} // namespace jit
} // namespace torch