#include #include #include #include #include #include #include #include namespace torch { namespace jit { namespace script { struct RecursiveMethodCallError : public std::exception {}; void placeholderCreator(Function&) { throw RecursiveMethodCallError(); } void Function::ensure_defined() { try { if (function_creator_) { auto creator = function_creator_; function_creator_ = placeholderCreator; creator(*this); function_creator_ = nullptr; } } catch (RecursiveMethodCallError&) { throw ErrorReport() // TODO: once lower_first_class methods is removed // re-establish callsite info for debugging << " method '" << name() << "' is called recursively. " << "Recursive calls are not supported"; } } Value* Function::try_emit_call( Graph& graph, const SourceRange& loc, c10::optional self, ArrayRef args, ArrayRef kwargs, std::stringstream& failure_messages, bool conv_tensors_to_nums) { ensure_defined(); auto fn = this->graph(); auto matched_schema = tryMatchSchema( getSchema(), loc, graph, std::move(self), args, kwargs, failure_messages, conv_tensors_to_nums); if (!matched_schema) return nullptr; check_single_output(); return inlineCallTo(graph, *fn, matched_schema->inputs).at(0); } Value* Function::emit_call( Graph& graph, const SourceRange& loc, ArrayRef args, ArrayRef kwargs) { std::stringstream failure_messages; if (auto result = try_emit_call( graph, loc, c10::nullopt, args, kwargs, failure_messages, /*conv_tensors_to_nums=*/true)) { return result; } throw ErrorReport(loc) << failure_messages.str(); } void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) { to_impl(device, dtype, non_blocking); } void Module::to(at::ScalarType dtype, bool non_blocking) { to_impl(/*device=*/c10::nullopt, dtype, non_blocking); } void Module::to(at::Device device, bool non_blocking) { to_impl(device, /*dtype=*/c10::nullopt, non_blocking); } void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) { ExportModule(*this, out, extra_files); } void Module::save( const std::string& filename, const ExtraFilesMap& extra_files) { ExportModule(*this, filename, extra_files); } void Module::to_impl( const c10::optional& device, const c10::optional& dtype, bool non_blocking) { // First call `to()` on every child module. for (auto& child : get_modules()) { child->to_impl(device, dtype, non_blocking); } // Then convert every of our parameters. for (auto& parameter : get_parameters()) { // Need to access the `at::Tensor` as a `Variable` here. autograd::Variable variable = parameter.value().toTensor(); at::Tensor data = variable.data(); // Use the data's original device or dtype if not supplied here. auto new_data = data.to( device.value_or(data.device()), dtype.value_or(data.scalar_type()), non_blocking); variable.set_data(new_data); } } // lower_first_class_method and lift_lowered_method are transitionary functions // used to translate between module-as-first-class code generation, // and module-as-special execution. Once module-as-first-class execution is // debugged, then we can remove both and remove the lowered_functions_ table. // remove the first module argument, replacing any access of its // parameters/attributes with extra_ivalue input Slots that hold what value to // pass into the graph std::pair, std::vector> lower_graph( const ModulePtr& self, Graph& g_, size_t self_offset = 0) { std::shared_ptr g = g_.copy(); std::vector extra_ivalues; std::unordered_map slot_to_offset; struct ToScan { ModulePtr mod; Node* n; size_t offset; }; std::vector to_scan; std::vector to_clean; // nodes that should be dead at the end auto getOrAddSlot = [&](const Slot& slot) -> Value* { auto it = slot_to_offset.find(slot); if (it != slot_to_offset.end()) { size_t ivalues_start = g->inputs().size() - extra_ivalues.size(); return g->inputs().at(ivalues_start + it->second); } extra_ivalues.emplace_back(slot); slot_to_offset[slot] = extra_ivalues.size() - 1; return g->addInput()->setType(slot.type()); }; auto self_value = g->inputs().at(self_offset); for (Use use : self_value->uses()) { to_scan.emplace_back(ToScan{self, use.user, use.offset}); } while (to_scan.size() > 0) { auto e = to_scan.back(); to_scan.pop_back(); // when we lambda lift forks, first-class modules may be passed across // forks. This code recursively lowers the module in the fork call. if (e.n->kind() == prim::fork) { auto subgraph = e.n->g(attr::Subgraph); std::vector new_slots; std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset); e.n->g_(attr::Subgraph, subgraph); for (const Slot& slot : new_slots) { e.n->addInput(getOrAddSlot(slot)); } e.n->removeInput(e.offset); continue; } if (e.n->kind() != prim::GetAttr) { throw ErrorReport(e.n->getSourceLocation()) << "temporary: the only valid use of a module is looking up an attribute"; } Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name))); if (ClassTypePtr c = e.n->output()->type()->cast()) { if (c->name() == "Module") { auto obj = slot.value().toObject(); for (Use use : e.n->output()->uses()) { to_scan.emplace_back(ToScan{obj, use.user, use.offset}); } to_clean.emplace_back(e.n); continue; } } e.n->output()->replaceAllUsesWith(getOrAddSlot(slot)); e.n->destroy(); } while (to_clean.size() > 0) { Node* n = to_clean.back(); AT_ASSERT(!n->hasUses()); n->destroy(); to_clean.pop_back(); } AT_ASSERT(!self_value->hasUses()); g->eraseInput(self_offset); return std::make_pair(std::move(g), std::move(extra_ivalues)); } Method& Module::lower_first_class_method(Function* fn) { fn->ensure_defined(); auto lowered = lower_graph(module_object(), *fn->graph()); Function& new_func = lowered_methods_.create_function(fn->name(), lowered.first); // generate the new schema // slice away the self argument std::vector args( fn->getSchema().arguments().begin() + 1, fn->getSchema().arguments().end()); size_t id = 0; for (const Slot& slot : lowered.second) { std::ostringstream ss; ss << "slot" << id++; args.emplace_back(ss.str(), slot.type()); } new_func.setSchema(fn->getSchema().cloneWithArguments(std::move(args))); return _create_lowered_method(&new_func, std::move(lowered.second)); } static void createFirstClassValues( Module* module, Value* self, std::unordered_map& result) { auto& g = *self->owningGraph(); std::vector created; struct ToScan { Module* mod; Value* v; // value representing module in the graph }; std::vector to_scan = {{module, self}}; while (!to_scan.empty()) { auto s = to_scan.back(); to_scan.pop_back(); size_t offset = 0; for (const std::string& name : s.mod->module_object()->type()->attributeNames()) { Value* v = g.insertGetAttr(s.v, name); result[Slot(s.mod->module_object(), offset++)] = v; if (std::shared_ptr sub = s.mod->find_module(name)) { to_scan.emplace_back(ToScan{sub.get(), v}); } } } } void Module::lift_lowered_method(Method& m) { auto graph = m.graph()->copy(); Value* self = graph->insertInput(0, "self")->setType(module_object()->type()); std::unordered_map slot_to_value; if (!m.initial_ivalues().empty()) { WithInsertPoint guard(*graph->nodes().begin()); createFirstClassValues(this, self, slot_to_value); } size_t orig_graph_inputs_size = graph->inputs().size(); for (size_t i = 0; i < m.initial_ivalues().size(); ++i) { size_t input_offset = orig_graph_inputs_size - i - 1; size_t ivalue_offset = m.initial_ivalues().size() - i - 1; graph->inputs() .at(input_offset) ->replaceAllUsesWith( slot_to_value.at(m.initial_ivalues().at(ivalue_offset))); graph->eraseInput(input_offset); } if (!m.initial_ivalues().empty()) { // we added _all_ the submodules as first-class values but maybe did not use // them. So remove any dead attribute lookups EliminateDeadCode(graph); } Function& new_fn = class_cu().create_function(m.name(), std::move(graph)); // created lifted schema // self argument is named '$self' to prevent accidental name collisions // with another input that the user named 'self' std::vector new_args = {Argument("$self", module_object()->type())}; const auto& lowered_args = m.function().getSchema().arguments(); new_args.insert( new_args.end(), lowered_args.begin(), lowered_args.begin() + m.num_inputs()); new_fn.setSchema(m.function().getSchema().cloneWithArguments(std::move(new_args))); } Method& Module::_create_lowered_method( Function* func, std::vector member_inputs) { std::unique_ptr m(new Method(this, func, std::move(member_inputs))); return *insert(func->name(), methods_, EntityType::METHOD, std::move(m)); } void Module::lift_lowered_methods(size_t start) { for (size_t i = start; i < lowered_methods_.get_functions().size(); ++i) { Method& m = _create_lowered_method( lowered_methods_.get_functions().at(i).get(), {}); lift_lowered_method(m); } } void Module::_define_lowered( const std::vector& definitions, const std::vector& resolvers) { size_t start = lowered_methods_.get_functions().size(); lowered_methods_.define(definitions, resolvers, nullptr); lift_lowered_methods(start); // call lift_lowered_method for each definition } void Module::_define_lowered( const std::string& src, const ResolverPtr& resolver) { size_t start = lowered_methods_.get_functions().size(); lowered_methods_.define(src, resolver, nullptr); lift_lowered_methods(start); } Method& Module::_define_lowered( std::string name, std::shared_ptr graph, std::vector slots) { Method& m = _create_lowered_method( &lowered_methods_.create_function(std::move(name), std::move(graph)), std::move(slots)); lift_lowered_method(m); return m; } void Module::define(const std::string& src, const ResolverPtr& resolver) { class_cu().define( src, resolver ? resolver : script::nativeResolver(), simpleSelf(module_object()->type())); } } // namespace script } // namespace jit } // namespace torch