#include "ir.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/autograd/function.h" #include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/assertions.h" #include "torch/csrc/jit/script/compiler.h" #include #include #include #include #include #include #include #include namespace torch { namespace jit { // Sigh, see https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char constexpr Symbol PythonOp::Kind; constexpr int max_tensor_display_size = 10; void printValueRef(std::ostream & out, const Value * n) { out << "%" << n->uniqueName(); } template std::ostream& operator<<(std::ostream & out, const std::vector & nodes) { out << at::ArrayRef{nodes}; return out; } template std::ostream& operator<<(std::ostream & out, const at::ArrayRef & nodes) { size_t i = 0; for(auto n : nodes) { if(i++ > 0) out << ", "; printValueRef(out, n); } return out; } struct const_value_list_with_types { const ArrayRef values; bool use_newlines; const_value_list_with_types(ArrayRef values, bool use_newlines = false) : values(values), use_newlines(use_newlines) {} }; std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) { size_t i = 0; size_t prev_stage = 0; for(auto n : l.values) { if(i++ > 0) { if (l.use_newlines) { // TODO: Indent here is hard-coded for "graph(": un-hard-code it out << "\n "; if (n->stage() != prev_stage) { out << "-------- stage " << n->stage() << " --------\n "; prev_stage = n->stage(); } } else { out << ", "; } } printValueRef(out, n); out << " : "; out << *n->type(); } return out; } template void printPrimList(std::ostream & out, const std::vector & items) { out << "["; int i = 0; for(auto & item : items) { if(i++ > 0) out << ", "; out << item; } out << "]"; } std::string escapeString(std::string s) { std::vector search = {'\n', '\t', '\v'}; std::vector replace = {"\\n", "\\t", "\\v"}; for (size_t i = 0; i < search.size(); i++) { size_t pos = s.find(search[i]); while(pos != std::string::npos) { s.replace(pos, 1, replace[i]); pos = s.find(search[i], pos + 1); } } return s; } void printAttributes(std::ostream & out, const Node * n, bool ignore_subgraph=false) { out << "["; auto names = n->attributeNames(); int i = 0; for(auto name : names) { if (ignore_subgraph && name == attr::Subgraph) continue; if(i++ > 0) out << ", "; // TODO: debugging mode to see the qualifier. We definitely // don't want to print the qualifier since it should always // be attribute, but you might be able to track down a weird // bug by printing it out. out << name.toUnqualString() <<"="; switch(n->kindOf(name)) { case AttributeKind::f: out << n->f(name); break; case AttributeKind::fs: printPrimList(out,n->fs(name)); break; case AttributeKind::i: out << n->i(name); break; case AttributeKind::is: printPrimList(out,n->is(name)); break; case AttributeKind::s: out << escapeString(n->s(name)); break; case AttributeKind::ss: printPrimList(out,n->ss(name)); break; case AttributeKind::t: { at::Tensor t = n->t(name); // 1-elem tensors are usually boxed scalars, so print them like it if (t.numel() == 1) { auto scalar_tensor = t.view({})._local_scalar(); out << "{"; if (scalar_tensor.isFloatingPoint()) { out << scalar_tensor.toDouble(); } else { out << scalar_tensor.toLong(); } out << "}"; } else if (t.numel() <= max_tensor_display_size) { // TODO: This is awful code. Also it doesn't work on Windows. std::ostringstream tensor_ss; tensor_ss << t; std::string tensor_s{tensor_ss.str()}; // Remove newlines std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' '); out << tensor_s; } else { out << ""; } break; } case AttributeKind::ts: out << "[]"; break; case AttributeKind::g: out << ""; break; case AttributeKind::gs: out << "[]"; break; } } out << "]"; } static std::ostream & indent(std::ostream & out, size_t level) { for(size_t i = 0; i < level; ++i) out << " "; return out; } std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::vector * groups) { auto outputs = n->outputs(); indent(out, level) << const_value_list_with_types(outputs); out << " = "; IR_IFM_CONST(n,PythonOp) out << "^" << value->name(); value->writeScalars(out); IR_ELSE() if(n->hasAttribute(attr::Subgraph) && groups) { out << n->kind().toQualString() << "_" << groups->size(); if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) { printAttributes(out, n, /*ignore_subgraph=*/true); } groups->push_back(n); } else { out << n->kind().toQualString(); if(n->hasAttributes()) { printAttributes(out,n); } } IR_END() out << "(" << n->inputs() << ")"; std::string scopeName = n->scopeName(); if (scopeName.empty()) { out << "\n"; } else { out << ", "; out << "scope: " << scopeName << "\n"; } for(size_t i = 0; i < n->blocks().size(); ++i) { auto b = n->blocks()[i]; indent(out, level + 1) << "block" << i << "(" << const_value_list_with_types(b->inputs(), false) << ") {\n"; for(auto n : b->nodes()) { printNode(out, level + 2, n, groups); } indent(out, level + 2) << "-> (" << b->outputs() << ")\n"; indent(out, level + 1) << "}\n"; } return out; } std::ostream& operator<<(std::ostream & out, const Node & n) { return printNode(out, 0, &n, nullptr); } std::ostream& operator<<(std::ostream & out, const Graph & g) { out << "graph(" << const_value_list_with_types(g.inputs(), true) << ") {\n"; std::vector groups; size_t prev_stage = 0; for(auto n : g.nodes()) { if (n->stage() != prev_stage) { out << " ---------------- stage " << n->stage() << " ----------------\n"; prev_stage = n->stage(); } printNode(out, 1, n, &groups); } out << " return (" << g.outputs() << ");\n}\n"; size_t i = 0; for(auto fg : groups) { out << "with " << fg->kind().toQualString() << "_" <(n), nullptr); } } */ return out; } static void checkSameDevice(const Node* node) { bool has_device = false; int device; auto checkValue = [&](const Value* v) { if(CompleteTensorTypePtr type = v->type()->cast()) { if(!has_device) { has_device = true; device = type->device(); } else { JIT_ASSERT(device == type->device()); } } }; for(auto input : node->inputs()) { checkValue(input); } for(auto output : node->outputs()) { checkValue(output); } } using node_set = std::set; #define ALL_OF(container) container.begin(), container.end() // These functions purposely operate on the internal members directly, to force // you to think about how the invariants change if you change the data // representation (even if the external API does not change.) // NB: This assert is written to assume you don't have any unattached // nodes. Unattached nodes can occur while manipulations to the // graph are occurring. void Node::lint() const { // Node invariants // - if node should live in list, nodes_iter is consistent // - Inputs are all marked as a use by the nodes they refer to // - Stage is consistent (stage is >= all input stages) // - Owning graph is non-null and consistent // - The "Select" invariant, when the node is MultiReturn // // The handle invariant: // If a node takes a handle as an input, it is always the // LAST input of the node. There is at most one handle input. { size_t i = 0; for (auto input : inputs_) { // WARNING: O(n^2) JIT_ASSERT(std::find(ALL_OF(input->uses_), Use(const_cast(this), i)) != input->uses_.end()); JIT_ASSERT(stage_ >= input->stage_); JIT_ASSERT(graph_->all_nodes.count(this) == 1); i++; } } for(auto o : outputs()) { size_t i = 0; for (auto use : o->uses()) { // Use invariants // - Use is consistent with inputs // - Every user node is live (checked in Graph) JIT_ASSERT(use.user->inputs_[use.offset] == o); i++; } } // Node subclass invariants // - Return uses is zero // - Param inputs is zero // - Select inputs is one // - Python operator cconv is correct IR_IF(this,Constant) JIT_ASSERT(inputs_.size() == 0); IR_ELSEIF(Return) JIT_ASSERT(outputs().size() == 0); IR_ELSEIF(Param) JIT_ASSERT(inputs_.size() == 0); IR_ELSEIFM_CONST(PythonOp) size_t n_scalars = 0, n_tensors = 0; for (auto c : value->cconv) { if (c == 's') { n_scalars++; } else if (c == 't') { n_tensors++; } else { JIT_ASSERT(0); } JIT_ASSERT(static_cast(value->pyobj)); } JIT_ASSERT(n_scalars == value->scalar_args.size()); JIT_ASSERT(n_tensors == inputs_.size()); IR_ELSEIF(Eval) // TODO: add invariants // TODO: It's not good for these ops to be top-level, it makes cases longer. IR_ELSEIF(FusionGroup) checkSameDevice(value); // TODO: Typecheck the parameters value->g(attr::Subgraph)->lint(); IR_END() } // TODO: When lint fails, give better indication about which // instruction triggered the failure. void Graph::lint() const { // Graph invariants // Uncomment the following to see the graph // std::cout << *const_cast(this); // nodes // - nodes_ is a valid topological ordering for inputs // - No repeated nodes // - Params and return do NOT occur in nodes // - next_unique_ is greater than all uniques in graph // - uniques in all_nodes are unique // - every use will occur later in the topsort struct LintScope { LintScope() = default; LintScope(std::unique_ptr parent) : parent(std::move(parent)) {} bool contains(const Value * v) { return values.count(v) > 0 || (parent && parent->contains(v)); } bool contains(const Node * n) { return nodes.count(n) > 0 || (parent && parent->contains(n)); } void insert(const Value * v) { JIT_ASSERT(!contains(v)); values.insert(v); } void insert(const Node * n) { JIT_ASSERT(!contains(n)); nodes.insert(n); } std::unique_ptr parent; private: std::unordered_set values; std::unordered_set nodes; }; // Struct enables mutual recursion in linting methods. // Putting it inside Graph::lint enables access to private Graph members struct LintImpl { LintImpl(const Graph & g) : g(g) , scope(new LintScope()) , all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered* const Graph & g; std::unique_ptr scope; std::unordered_set seen_uniques; std::unordered_map anticipated_uses; node_set all_nodes_set; node_set sum_set; void check_value(const Value* v) { scope->insert(v); auto b2 = seen_uniques.insert(v->unique()); JIT_ASSERT(b2.second); // insertion took place JIT_ASSERT(v->unique() < g.next_unique_); for (auto use : v->uses()) { JIT_ASSERT(!scope->contains(use.user)); JIT_ASSERT(g.all_nodes.count(use.user) == 1); anticipated_uses[use.user]++; // int default constructs to 0 } } void check_node(const Node* n) { for (auto input : n->inputs_) { if (!scope->contains(input)) { JIT_ASSERTM(0, input->unique(), " not in scope"); } } JIT_ASSERT(anticipated_uses[n] == static_cast(n->inputs_.size())); anticipated_uses[n] = -1; // we saw the anticipated user! scope->insert(n); for(auto block : n->blocks()) { std::unique_ptr new_scope(new LintScope(std::move(scope))); scope = std::move(new_scope); check_block(block); scope = std::move(scope->parent); } size_t i = 0; for(auto o : n->outputs()) { JIT_ASSERT(o->node() == n); JIT_ASSERT(i++ == o->offset_); check_value(o); } n->lint(); } void check_block(const Block *b) { for (auto input : b->inputs()) { check_value(input); JIT_ASSERT(input->node()->kind_ == prim::Param); } for (auto n : b->nodes()) { JIT_ASSERT(n->kind_ != prim::Param); JIT_ASSERT(n->kind_ != prim::Return); check_node(n); } JIT_ASSERT(b->output_->kind() == prim::Return); check_node(b->output_); // all_nodes // - inputs_, output_ and nodes_ are all included in all_nodes // - all_nodes does not contain dead nodes??? (likely to be temporarily // suspended). Weaker: all_nodes contains all inputs and returns // - only one return node??? node_set nodes_set(ALL_OF(b->nodes())); node_set inputs_set {b->input_}; node_set output_set {b->output_}; // TODO: Make a more type safe std::includes wrapper which disallows use on // non-ordered containers JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set))); JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set))); JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set))); sum_set.insert(ALL_OF(nodes_set)); sum_set.insert(ALL_OF(inputs_set)); sum_set.insert(ALL_OF(output_set)); } void check_graph() { node_set all_nodes_set(ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered* check_block(g.block_); for (auto kv : anticipated_uses) { JIT_ASSERT(kv.second == -1); } // graph->stage() should be equal to max(node.stage for node in graph->nodes()) if (g.nodes().begin() == g.nodes().end()) { JIT_ASSERT(g.stage() == 0); } else { JIT_ASSERT(g.stage() == g.nodes().rbegin()->stage()); } JIT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set))); } }; LintImpl(*this).check_graph(); } void Graph::dump() const { std::cout << *this << "\n"; } void LintGraph(std::shared_ptr& graph) { graph->lint(); } void Block::cloneFrom(Block * src, std::function value_map) { std::unordered_map local_map; auto env = [&](Value * v) { auto it = local_map.find(v); if(it != local_map.end()) return it->second; return value_map(v); }; auto graph = owningGraph(); for(auto input : src->inputs()) { local_map[input] = this->addInput()->copyMetadata(input)->setStage(input->stage()); graph->setStage(std::max(graph->stage(), input->stage())); } for(auto node : src->nodes()) { auto new_node = this->appendNode(graph->createClone(node, env)); new_node->setStage(node->stage()); graph->setStage(std::max(graph->stage(), node->stage())); for(size_t i = 0; i < node->outputs().size(); ++i) { auto oo = node->outputs()[i]; auto no = new_node->outputs()[i]; local_map[oo] = no; no->copyMetadata(oo); no->setStage(oo->stage()); } } for(auto output : src->outputs()) { this->registerOutput(env(output)); } } std::shared_ptr Graph::copy() { auto new_g = std::make_shared(); auto env = [](Value *) -> Value* { AT_ERROR("Graph::copy() encountered a use of a value not in scope. Run lint!"); }; new_g->block()->cloneFrom(this->block(), env); return new_g; } Value* Value::setUniqueName(const std::string & name) { if (name.size() > 0 && name.find_first_not_of("0123456789") == std::string::npos) { throw std::runtime_error("names may not be integers: " + name); } auto & names = node()->owningGraph()->unique_names_; // clear any old name from the map if(hasUniqueName()) { names.erase(unique_name_); unique_name_ = ""; } // allow "" to clear the uniquename if(name == "") return this; // if someone else has this name, then rename the other value auto old_owner_of_name = names.find(name); if(old_owner_of_name != names.end()) { size_t suffix = 1; std::string name_base = name; auto last_dot_pos = name.find_last_of('.'); if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) { if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) { suffix = std::stoll(name.substr(last_dot_pos + 1)); name_base = name.substr(0, last_dot_pos); } } std::string replacement_name; do { std::stringstream ss; ss << name_base << "." << suffix++; replacement_name = ss.str(); } while(names.count(replacement_name) > 0); old_owner_of_name->second->setUniqueName(replacement_name); } names[name] = this; unique_name_ = name; return this; } size_t findArgument(const FunctionSchema& the_schema, Symbol name) { auto name_str = name.toUnqualString(); for (size_t i = 0; i < the_schema.arguments.size(); ++i) { const Argument* arg = &the_schema.arguments[i]; if (arg->name == name_str) { return i; } } throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString()); } at::optional Node::get(Symbol name) const { return toIValue(namedInput(name)); } Value* Node::namedInput(Symbol name) const { return input(findArgument(schema(), name)); } bool Node::matches(const char *signature_literal, at::ArrayRef const_inputs) const { if (!sig(signature_literal).matches(this)) return false; for (Symbol s : const_inputs) { if (!is_constant(s)) return false; } return true; } void Node::dump() const { std::cout << *this << "\n"; } void Node::findSchema() const { schema_ = &getOperatorFor(this).schema(); } namespace { const OperatorSet& nondeterminstic_aten_ops() { static OperatorSet nondeterministic_ops = { "aten::dropout(Tensor input, float p, int train) -> Tensor", "aten::_fused_dropout(Tensor self, float p, Generator generator) -> (Tensor, Tensor)", "aten::_standard_gamma(Tensor self, Generator generator) -> Tensor", "aten::_th_bernoulli(Tensor self, *, Generator generator) -> Tensor", "aten::bernoulli(Tensor self) -> Tensor", "aten::bernoulli(Tensor self, Tensor p, Generator generator) -> Tensor", "aten::bernoulli(Tensor self, float p, Generator generator) -> Tensor", "aten::multinomial(Tensor self, int num_samples, int replacement, *, Generator generator) -> Tensor", "aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor", "aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor", "aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor", "aten::poisson(Tensor self, Generator generator) -> Tensor", "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, int training, Generator generator) -> Tensor", "aten::rand(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::rand_like(Tensor self) -> Tensor", "aten::rand_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint(int high, int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint(int low, int high, int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint_like(Tensor self, int high) -> Tensor", "aten::randint_like(Tensor self, int low, int high) -> Tensor", "aten::randint_like(Tensor self, int high, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, int[] device) -> Tensor", "aten::randn(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::randn_like(Tensor self) -> Tensor", "aten::randn_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", "aten::randperm(int n, *, int dtype, int layout, int[] device) -> Tensor" }; return nondeterministic_ops; } } // namespace bool Node::isNondeterministic() const { if (nondeterminstic_aten_ops().find(this) == nullptr) { return false; } // Dropout with train = False is deterministic if (matches("aten::dropout(Tensor input, float p, int train) -> Tensor") && is_constant(attr::train) && !get(attr::train).value()) { return false; } return true; } inline const SourceRange& fakeRange() { static SourceRange range(std::make_shared(""), 0, 1); return range; } Value* Graph::insert(Symbol opname, at::ArrayRef args, at::ArrayRef kwargs) { return script::emitBuiltinCall(fakeRange(), *this, opname, args, kwargs, /*required=*/true); } PythonOp* defaultAllocPythonOp(Graph*g) { throw std::runtime_error("Trying to allocate a Python object without python bindings loaded"); } std::atomic alloc_python_op; // patched in when python bindings are loaded PythonOp* allocPythonOp(Graph* g) { return alloc_python_op.load()(g); } void setAllocPythonOp(PythonOp* (*v)(Graph* g)) { alloc_python_op.store(v); } }} // namespace torch::jit