mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23170 Test Plan: Imported from OSS Differential Revision: D16441912 Pulled By: suo fbshipit-source-id: a33485178a329d54e41e364c4f14950f88481c55
141 lines
3.9 KiB
C++
141 lines
3.9 KiB
C++
#pragma once
|
|
#include <torch/csrc/jit/graph_executor.h>
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/utils/memory.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using Kwargs = std::unordered_map<std::string, IValue>;
|
|
|
|
// A Function is a pure Graph with no implicit `self` object bound.
|
|
// It contains schema information, and the executor that manages the
|
|
// execution of the function. script::Method is a wrapper around a
|
|
// underlying Function that also provides a `self` object.
|
|
struct TORCH_API Function {
|
|
Function(
|
|
c10::QualifiedName name,
|
|
std::shared_ptr<Graph> graph,
|
|
std::function<void(Function&)> function_creator)
|
|
: name_(std::move(name)),
|
|
graph_(std::move(graph)),
|
|
function_creator_(std::move(function_creator)) {}
|
|
|
|
void run(Stack& stack) {
|
|
get_executor().run(stack);
|
|
}
|
|
|
|
void run(Stack&& stack) {
|
|
run(stack);
|
|
}
|
|
|
|
IValue operator()(
|
|
std::vector<IValue> stack,
|
|
const Kwargs& kwargs = Kwargs()) {
|
|
getSchema().checkAndNormalizeInputs(stack, kwargs);
|
|
run(stack);
|
|
return stack.front();
|
|
}
|
|
|
|
std::shared_ptr<Graph> graph() const {
|
|
return graph_;
|
|
}
|
|
|
|
const c10::QualifiedName& qualname() const {
|
|
return name_;
|
|
}
|
|
|
|
const std::string& name() const {
|
|
return name_.name();
|
|
}
|
|
|
|
// if this isn't yet defined, run its method_creator function
|
|
void ensure_defined();
|
|
|
|
size_t num_inputs() const {
|
|
return graph()->inputs().size();
|
|
}
|
|
|
|
Function& setSchema(FunctionSchema schema) {
|
|
schema_ = make_unique<FunctionSchema>(std::move(schema));
|
|
return *this;
|
|
}
|
|
|
|
const FunctionSchema& getSchema() const {
|
|
if (schema_ == nullptr) {
|
|
schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
|
|
}
|
|
return *schema_;
|
|
}
|
|
|
|
std::string pretty_print_schema() const {
|
|
AT_ASSERT(schema_);
|
|
std::stringstream ss;
|
|
ss << *schema_;
|
|
return ss.str();
|
|
}
|
|
|
|
GraphExecutorState getDebugState() {
|
|
return get_executor().getDebugState();
|
|
}
|
|
|
|
bool is_optimized() const {
|
|
AT_WARN(
|
|
"Function::is_optimized() is deprecated and always returns true. "
|
|
"Please use getGraphExecutorOptimize()");
|
|
return true;
|
|
}
|
|
|
|
void check_single_output() {
|
|
TORCH_CHECK(
|
|
graph()->outputs().size() == 1,
|
|
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
|
|
}
|
|
|
|
GraphExecutor& get_executor() {
|
|
ensure_defined();
|
|
std::call_once(executor_init_, [&] {
|
|
check_single_output();
|
|
executor_ = GraphExecutor(graph());
|
|
});
|
|
return executor_;
|
|
}
|
|
|
|
private:
|
|
static FunctionSchema defaultSchemaFor(const Function& function) {
|
|
std::vector<Argument> args;
|
|
std::vector<Argument> returns;
|
|
Graph& g = *function.graph();
|
|
size_t num_inputs = function.num_inputs();
|
|
for (size_t i = 0; i < num_inputs; ++i) {
|
|
const Value* v = g.inputs().at(i);
|
|
std::string name = v->hasDebugName() ? v->debugNameBase()
|
|
: ("argument_" + std::to_string(i));
|
|
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
|
|
}
|
|
for (size_t i = 0; i < g.outputs().size(); ++i) {
|
|
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
|
|
}
|
|
return {function.name(), "", std::move(args), std::move(returns)};
|
|
}
|
|
|
|
c10::QualifiedName name_;
|
|
std::shared_ptr<Graph> graph_; // for debugging and for inlining
|
|
|
|
GraphExecutor executor_; // for execution
|
|
|
|
std::once_flag executor_init_;
|
|
|
|
// an optional function that actually creates the method when
|
|
// ensure_defined() is called. This is used by the compiler so
|
|
// that it can construct methods out of order
|
|
std::function<void(Function&)> function_creator_;
|
|
|
|
// if absent, then we generate a default schema based on the graph
|
|
// mutable because getSchema caches the default schema if one is requested
|
|
// before a call to setSchema
|
|
mutable std::unique_ptr<FunctionSchema> schema_;
|
|
};
|
|
} // namespace jit
|
|
} // namespace torch
|