#pragma once #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/argument_spec.h" #include "torch/csrc/jit/function_schema.h" #include "torch/csrc/jit/assertions.h" #include "torch/csrc/jit/named_value.h" #include "torch/csrc/jit/source_range.h" #include #include #include #include #include #include #include #include #include #include // This file contains classes which assist in desugaring Python style // modules and their methods into flattened graphs which don't have any // function calls. namespace torch { namespace jit { namespace script { // A method in a module, e.g. f in: // // class M(ScriptModule): // @script_method // def f(self, x): // ... // Note: because Method/Module are exposed to python these // classes use python method naming conventions struct Method { Method(std::string name, bool optimize, std::shared_ptr graph, std::vector initial_members, std::function method_creator) : name_(std::move(name)) , graph_(std::move(graph)) , optimize(optimize) , member_inputs(std::move(initial_members)) , method_creator(method_creator) { JIT_ASSERT(graph_->inputs().size() >= member_inputs.size()); int i = graph_->inputs().size() - member_inputs.size(); for(at::Tensor* member : member_inputs) { member_input_index[member] = i++; } } void run(Stack & stack) { for(at::Tensor* tp : member_inputs) { stack.push_back(*tp); } get_executor().run(stack); } IValue operator()(std::vector stack) { checkInputsAgainstSchema(stack); run(stack); if (stack.size() != 1) { return Tuple::create(std::move(stack)); } return stack.front(); } std::shared_ptr graph_for(const Stack& inputs) { return get_executor().graphFor(inputs); } std::shared_ptr graph() const { return graph_; } const std::string & name() const { return name_; } // emit a function call by inlining the callees Graph into this one // adding any extra parameters necessary to do this call // defined here to keep details of member_input handling confined to this class std::vector emit_call_to(SourceRange loc, Method & callee, ArrayRef args, ArrayRef kwargs); // if this isn't yet defined, run its method_creator function void ensure_defined(); size_t num_inputs() const { return graph()->inputs().size() - member_inputs.size(); } Value * get_or_add_parameter(at::Tensor* slot) { auto it = member_input_index.find(slot); if(it != member_input_index.end()) { return graph()->inputs().at(it->second); } // add it as a new parameter member_inputs.push_back(slot); member_input_index[slot] = graph()->inputs().size(); return graph()->addInput(); } std::shared_ptr propagate_shapes(std::vector inputs, bool with_grad=false) { auto retval = graph_->copy(); Stack stack; stack.reserve(inputs.size() + member_inputs.size()); for (at::Tensor & i : inputs) { stack.emplace_back(std::move(i)); } for (at::Tensor* inp : member_inputs) { stack.push_back(*inp); } setInputTypes(*retval, ArgumentSpec(with_grad, std::move(stack), stack.size())); PropagateInputShapes(*retval); return retval; } std::shared_ptr propagate_and_assign_input_and_output_shapes(std::vector inputs, std::vector outputs, bool with_grad=false, bool propagate=true) { auto retval = graph_->copy(); for (auto inp : member_inputs) { inputs.push_back(*inp); } if (propagate) { setInputTypes(*retval, ArgumentSpec(with_grad, fmap(inputs), inputs.size())); PropagateInputShapes(*retval); } JIT_ASSERT(retval->inputs().size() == inputs.size()); for (size_t i=0; i < retval->inputs().size(); ++i) { auto scalar_type = inputs[i].type().scalarType(); auto sizes = inputs[i].sizes(); auto type = torch::jit::CompleteTensorType::create(scalar_type, -1, sizes); retval->inputs()[i]->setType(type); } JIT_ASSERT(retval->outputs().size() == outputs.size()); for (size_t i=0; i < retval->outputs().size(); ++i) { auto scalar_type = outputs[i].type().scalarType(); auto sizes = outputs[i].sizes(); auto type = torch::jit::CompleteTensorType::create(scalar_type, -1, sizes); retval->outputs()[i]->setType(type); } return retval; } std::vector params() { return member_inputs; } Method& setSchema(FunctionSchema schema_) { schema.reset(new FunctionSchema(std::move(schema_))); return *this; } const FunctionSchema& getSchema() const; std::string pretty_print_schema() const { JIT_ASSERT(schema); std::stringstream ss; ss << *schema; return ss.str(); } GraphExecutorState getDebugState() { return get_executor().getDebugState(); } void debugDisableAutodiffSubgraphInlining() { return get_executor().debugDisableAutodiffSubgraphInlining(); } bool is_optimized() { return optimize; } private: std::string name_; std::shared_ptr graph_; // for debugging and for inlining bool optimize; GraphExecutor& get_executor() { std::call_once(executor_init, [&]{ executor = GraphExecutor(graph(), optimize); }); return executor; } void checkInputsAgainstSchema(std::vector& inputs) { const auto& schema = getSchema(); // Do we have more inputs than the schema accepts? AT_CHECK( inputs.size() <= schema.arguments.size(), "Expected at most ", schema.arguments.size(), " argument(s) for operator '", schema.name, "', but received ", inputs.size(), " argument(s). Declaration: ", schema); for (size_t pos = 0; pos < schema.arguments.size(); ++pos) { const auto& argument = schema.arguments[pos]; if (pos < inputs.size()) { const TypePtr inputType = inferTypeFrom(inputs[pos]); AT_CHECK(inputType->isSubtypeOf(argument.type), "Expected value of type ", *argument.type, " for argument '", argument.name, "' in position ", pos, ", but instead got value of type ", *inputType, ". Declaration: ", schema); } else if (argument.default_value) { inputs.push_back(*argument.default_value); } else { AT_ERROR(schema.name, "() is missing value for argument '", argument.name, "'. Declaration: ", schema); } } } GraphExecutor executor; // for execution // member_inputs are a list of additional arguments appended to graph that are // inputs that come from the members of the Module or its submodules. // each is a pointer to a slot in the module that owns this parameter // parameters and submodules can only be _added_ to script Modules to ensure // these pointers always stay valid std::vector member_inputs; // map from a at::Tensor* in member_inputs to the offset it appears at // in graph. used to accelerate get_or_add_parameter std::unordered_map member_input_index; // TODO: support that case where we allow _writes_ to parameters from // compiled functions. // This requires more sophisticated tracking of ssa values in Graphs so that // stores to all modules can be lifted to the end of a graph execution. // It also adds more complexity to adding actual module invocations // to the executor, so currently it is not done. // std::vector member_outputs; std::once_flag executor_init; // an optional function that actually creates the method when emit_call_to(this,...) // is first called. // this is used by the compiler so that it can construct methods out of order std::function method_creator; // if absent, then we generate a default schema based on the graph // mutable because getSchema caches the default schema if one is requested // before a call to setSchema mutable std::unique_ptr schema; }; struct Module; struct NamedModule { std::string name; std::shared_ptr module; }; struct NamedParameter { NamedParameter(std::string name, at::Tensor tensor, bool is_buffer) : name(std::move(name)) , is_buffer(is_buffer) , parameter(new at::Tensor(std::move(tensor))) {} const std::string name; bool is_buffer; // buffers are part of the module state but // are not modified by optimizers during SGD at::Tensor* slot() const { return parameter.get(); } private: // the extra level of indirection allows Methods to safely store pointers // to the slots where parameters are kept while also allow parameters // to be reassigned std::unique_ptr parameter; }; struct Module { TH_DISALLOW_COPY_AND_ASSIGN(Module); Module() : modules("Module") , parameters("Parameter") , methods("Method") , optimize(true) {} // note this doesn't change the flags of existing methods just ones // added afterward. void set_optimized(bool o) { optimize = o; } IValue forward(std::vector inputs) { return get_method("forward")(inputs); } void register_parameter(const std::string & name, autograd::Variable v, bool is_buffer) { if(auto p = parameters.find(name)){ *p->slot() = v; p->is_buffer = is_buffer; return; } parameters.insert(name, NamedParameter(name, std::move(v), is_buffer)); } void register_module(const std::string& name, std::shared_ptr module) { modules.insert(name, {name, std::move(module)}); } Method& create_method(const std::string & name, std::shared_ptr graph, std::vector member_inputs) { JIT_ASSERT(graph); std::unique_ptr method(new Method(name, optimize, std::move(graph), std::move(member_inputs), nullptr)); return *methods.insert(name, std::move(method)); } Method& create_method(const std::string & name, std::function creator) { std::unique_ptr method(new Method(name, optimize, std::make_shared(), {}, creator)); return *methods.insert(name, std::move(method)); } at::Tensor* parameter_slot(const std::string & name) const { return parameters.get(name).slot(); } void set_parameter(const std::string & name, at::Tensor v) { *parameter_slot(name) = std::move(v); } autograd::Variable get_parameter(const std::string& name) const { return autograd::as_variable_ref(*parameter_slot(name)); } // each module owns its method. The reference returned here // is guarenteed to stay valid until this module has been destroyed Method& get_method(const std::string& name) const { return *methods.get(name); } std::shared_ptr get_module(const std::string& name) const { return modules.get(name).module; } const torch::detail::OrderedDict& get_modules() const { return modules; } const torch::detail::OrderedDict& get_parameters() const { return parameters; } const torch::detail::OrderedDict>& get_methods() const { return methods; } NamedParameter* find_parameter(const std::string& name) { return parameters.find(name); } NamedModule* find_module(const std::string& name) { return modules.find(name); } Method* find_method(const std::string& name) { if (auto* pm = methods.find(name)) { return pm->get(); } return nullptr; } /// Run a method from this module. /// /// For example: /// @code /// IValue output = module->run("relu_script", a, b); /// @endcode /// /// To get a compile a module from a source string, see torch::jit::compile /// /// @param method_name The name of the method to run /// @param args Arguments to be passed to the method /// @return An IValue containing the return value (or values if it is a tuple) /// from the method template IValue run_method(const std::string& method_name, Types&&... args) { return get_method(method_name)({IValue(std::forward(args))...}); } void save(std::ostream& out); void save(const std::string& filename); private: // invariant: to ensure member_inputs of Methods stay valid, // it is only legal to _add_ new modules and parameters. // removing them will allow member_inputs to point to invalid parameters // no such restriction exists for methods torch::detail::OrderedDict modules; torch::detail::OrderedDict parameters; torch::detail::OrderedDict> methods; bool optimize; }; }}}