#include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #endif #include #include namespace torch::jit { struct VarWithType; struct ParsedLiteral; class IRParser { friend void parseIR( const std::string& str, torch::jit::Graph* graph, std::unordered_map& vmap, bool parse_tensor_constants); IRParser( const std::string& str, torch::jit::Graph* graph, std::unordered_map& vmap, bool parse_tensor_constants) : L(std::make_shared(str)), g(graph), vmap(vmap), type_parser(L, /*parse_complete_tensor_types*/ true), parse_tensor_constants_(parse_tensor_constants) {} std::string parseVar(); VarWithType parseVarWithType(bool allow_optional = false); ParsedLiteral parseScalarLiteral(Node* n); void parse(); void parseGraphInputs(); void parseReturnOperator(); void parseBlocks(Node* parentNode); void parseBlock(Node* parentNode); void parseBlockInputs(Block* b); void parseBlockOutputs(Block* b); void parseOperatorsList(Block* b); void parseOperator(Block* b); void parseOperatorOutputs(std::vector* outs); std::string parseOperatorName(); void parseOperatorInputs(Node* n); void parseAttrs(Node* n); void parseAttr(Node* n); void parseList( int begin, int sep, int end, const std::function& callback); void bypassTypeAnnotationList(); Value* findValueInVMap(const std::string& name); torch::jit::Lexer L; torch::jit::Graph* g = nullptr; std::unordered_map& vmap; SchemaTypeParser type_parser; bool parse_tensor_constants_; std::vector deferred_tensor_value_initializations_; std::vector deferred_empty_container_initializations_; }; struct ParsedLiteral { ParsedLiteral() = default; AttributeKind k = AttributeKind::t; int64_t i = 0; std::string s = ""; double f = 0.0; c10::complex c = c10::complex(0, 0); TypePtr ty; std::vector is; std::vector ss; std::vector fs; std::vector> cs; std::vector tys; }; struct VarWithType { VarWithType() = default; std::string name; TypePtr type; }; void parseIR( const std::string& str, torch::jit::Graph* graph, std::unordered_map& vmap, bool parse_tensor_constants) { torch::jit::IRParser p(str, graph, vmap, parse_tensor_constants); p.parse(); } void parseIR( const std::string& str, torch::jit::Graph* graph, bool parse_tensor_constants) { std::unordered_map vmap; parseIR(str, graph, vmap, parse_tensor_constants); } VarWithType IRParser::parseVarWithType(bool allow_optional) { VarWithType r; r.name = parseVar(); if (allow_optional) { r.type = nullptr; } else { r.type = TensorType::get(); } if (L.nextIf(':')) { auto type_alias = type_parser.parseType(); AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled"); r.type = type_alias.first; } return r; } std::string IRParser::parseVar() { L.expect('%'); std::string name; bool continue_parsing; do { if (L.cur().kind == TK_IDENT) { name += L.expect(TK_IDENT).text(); } else { name += L.expect(TK_NUMBER).text(); } continue_parsing = false; if (L.nextIf('.')) { continue_parsing = true; name += '.'; } else if (L.cur().kind == TK_NUMBER && L.cur().text()[0] == '.') { continue_parsing = true; } } while (continue_parsing); return name; } void IRParser::parseOperatorOutputs(std::vector* outs) { if (L.cur().kind != '%') { return; } parseList(TK_NOTHING, ',', TK_NOTHING, [&] { outs->push_back(parseVarWithType(true)); }); L.expect('='); } // Parse string or numeric literal and return it along with its type. ParsedLiteral IRParser::parseScalarLiteral(Node* n) { auto token = L.cur(); std::string str; std::pair> type_alias; ParsedLiteral r; switch (token.kind) { case TK_STRINGLITERAL: r.k = AttributeKind::s; r.s = parseStringLiteral(token.range, token.text()); L.next(); return r; case '-': str = "-"; L.next(); if (L.cur().kind != TK_NUMBER) { throw ErrorReport(token.range) << "Expected a number after '-' but got:" << token.text(); } [[fallthrough]]; case TK_NUMBER: str += L.cur().text(); if (str.find('j') != std::string::npos) { r.k = AttributeKind::c; double imag = 0.0f; try { imag = std::stod(str.substr(0, str.size() - 1)); } catch (const std::invalid_argument& e) { throw ErrorReport(token.range) << "Number cannot be converted to double"; } catch (const std::out_of_range& e) { throw ErrorReport(token.range) << "Number is too long to be represented in type double"; } r.c = c10::complex(0, imag); } else if ( str.find('.') != std::string::npos || str.find('e') != std::string::npos) { r.k = AttributeKind::f; try { r.f = std::stod(str); } catch (const std::invalid_argument& e) { throw ErrorReport(token.range) << "Number cannot be converted to double"; } catch (const std::out_of_range& e) { throw ErrorReport(token.range) << "Number is too long to be represented in type double"; } } else { r.k = AttributeKind::i; try { r.i = std::stoll(str); } catch (const std::invalid_argument& e) { throw ErrorReport(token.range) << "Number cannot be converted to integer"; } catch (const std::out_of_range& e) { throw ErrorReport(token.range) << "Number is too big"; } } L.next(); return r; case TK_IDENT: // Type literal r.k = AttributeKind::ty; type_alias = type_parser.parseType(); AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled"); r.ty = type_alias.first; return r; case '<': { L.next(); auto text = L.expect(TK_IDENT); if (text.text() != "Tensor") { throw ErrorReport(token.range) << "Could not parse literal" << token.text(); } if (!parse_tensor_constants_) { throw ErrorReport(token.range) << "Tensor constant encountered but `parse_tensor_constants` set to false" << token.text(); } L.expect('>'); // these values will be set with randomly initialized data in // a post processing pass; deferred_tensor_value_initializations_.push_back(n); r.k = AttributeKind::t; return r; } case '{': { L.next(); if (L.cur().kind == '-') { L.next(); } auto text = L.expect(TK_NUMBER); if (!parse_tensor_constants_) { throw ErrorReport(token.range) << "Single-element tensor constant encountered but " << "`parse_tensor_constants` is set to false " << token.text(); } L.expect('}'); deferred_tensor_value_initializations_.push_back(n); r.k = AttributeKind::t; return r; } default: throw ErrorReport(token.range) << "Could not parse literal" << token.text(); } } void IRParser::bypassTypeAnnotationList() { int depth = 0; bool bypassed_list = false; while (depth != 0 || !bypassed_list) { if (L.cur().kind == '[') { bypassed_list = true; depth++; } else if (L.cur().kind == ']') { depth--; } L.next(); } } /** \brief Parse attribute and add it to the node N. * * The function determines the attribute type (string, int, float, complex, list * of strings, list of ints, list of floats, list of complex, and a list of * tensors (currently only for empty lists)). An attribute looks like the * following: AttrName=AttrValue Where AttrValue can be a list or a scalar * literal, e.g.: size = 27 name = "Bob" coefs = [1.2, 3.4, 0.6] */ void IRParser::parseAttr(Node* n) { std::string attrname = L.expect(TK_IDENT).text(); L.expect('='); if (L.cur().kind == '[') { // list AttributeKind k = AttributeKind::ts; c10::List is; c10::List ss; c10::List fs; c10::List> cs; std::vector tys; int elem_num = 0; parseList('[', ',', ']', [&] { ParsedLiteral r = parseScalarLiteral(n); switch (r.k) { case AttributeKind::s: ss.push_back(r.s); AT_ASSERT(!elem_num++ || k == AttributeKind::ss); k = AttributeKind::ss; break; case AttributeKind::i: is.push_back(r.i); AT_ASSERT(!elem_num++ || k == AttributeKind::is); k = AttributeKind::is; break; case AttributeKind::f: fs.push_back(r.f); AT_ASSERT(!elem_num++ || k == AttributeKind::fs); k = AttributeKind::fs; break; case AttributeKind::c: cs.push_back(r.c); AT_ASSERT(!elem_num++ || k == AttributeKind::cs); k = AttributeKind::cs; break; case AttributeKind::ty: tys.push_back(r.ty); AT_ASSERT(!elem_num++ || k == AttributeKind::tys); k = AttributeKind::tys; break; default: throw ErrorReport(L.cur().range) << "Unexpected attr type"; } }); switch (k) { case AttributeKind::ts: n->ival_(Symbol::attr(attrname), IValue()); break; case AttributeKind::ss: n->ival_(Symbol::attr(attrname), IValue(ss)); break; case AttributeKind::fs: n->ival_(Symbol::attr(attrname), IValue(fs)); break; case AttributeKind::cs: n->ival_(Symbol::attr(attrname), IValue(cs)); break; case AttributeKind::is: n->ival_(Symbol::attr(attrname), IValue(is)); break; case AttributeKind::tys: n->tys_(Symbol::attr(attrname), tys); break; default: throw ErrorReport(L.cur().range) << "Unexpected attr type"; } } else if (L.cur().text() == "annotate") { L.next(); L.expect('('); auto type = L.cur().text(); if (type != "List" && type != "Dict") { throw ErrorReport(L.cur().range) << "Unexpected annotation (only List and Dict can be parsed)"; } L.next(); // ignore the annotations on the IValue constants, and instead recover // type from the Node output // Note: we could also use script_type_parser bypassTypeAnnotationList(); L.expect(','); // expect an empty definition (note - this isn't always true) if (type == "Dict") { L.expect('{'); L.expect('}'); } else if (type == "List") { L.expect('['); L.expect(']'); } L.expect(')'); deferred_empty_container_initializations_.push_back(n); } else { // scalar ParsedLiteral r = parseScalarLiteral(n); switch (r.k) { case AttributeKind::s: n->s_(Symbol::attr(attrname), r.s); break; case AttributeKind::i: n->i_(Symbol::attr(attrname), r.i); break; case AttributeKind::f: n->f_(Symbol::attr(attrname), r.f); break; case AttributeKind::c: n->c_(Symbol::attr(attrname), r.c); break; case AttributeKind::ty: n->ty_(Symbol::attr(attrname), r.ty); break; case AttributeKind::t: // initialized with random data later break; default: throw ErrorReport(L.cur().range) << "Unexpected attr type"; } return; } } void IRParser::parseAttrs(Node* n) { parseList('[', ',', ']', [&] { parseAttr(n); }); } void IRParser::parseOperatorInputs(Node* n) { if (L.cur().kind == '[') { parseAttrs(n); } parseList('(', ',', ')', [&] { std::string var_name = parseVar(); n->addInput(findValueInVMap(var_name)); }); } void IRParser::parseBlocks(Node* parentNode) { L.expect(TK_INDENT); while (L.cur().kind != TK_DEDENT) { parseBlock(parentNode); } L.expect(TK_DEDENT); } void IRParser::parseBlockInputs(Block* b) { parseList('(', ',', ')', [&] { VarWithType v = parseVarWithType(); // If the name isn't valid, don't use it std::string uniq_name = Value::isValidName(v.name) ? v.name : ""; vmap[v.name] = b->addInput(uniq_name); vmap[v.name]->setType(v.type); }); } void IRParser::parseBlockOutputs(Block* b) { L.expect(TK_ARROW); parseList('(', ',', ')', [&] { std::string var_name = parseVar(); b->registerOutput(findValueInVMap(var_name)); }); L.expect(TK_NEWLINE); L.expect(TK_DEDENT); } /** \brief Parse a block. * * It should look like the following: * blockName(input1, input2, input3, ...): * op1 * op2 * ... * opN * -> (output1, output2, output3, ...) */ void IRParser::parseBlock(Node* parentNode) { Block* b = parentNode->addBlock(); L.expect(TK_IDENT).text(); // Block name is not used anywhere. parseBlockInputs(b); L.expect(':'); parseOperatorsList(b); parseBlockOutputs(b); } /** \brief Parse a list of statements. * * It is expected to be delimited by TK_NEWLINE and end with TK_RETURN or * TK_ARROW. */ void IRParser::parseOperatorsList(Block* b) { L.expect(TK_INDENT); while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) { parseOperator(b); } } std::string IRParser::parseOperatorName() { std::string name = L.expect(TK_IDENT).text(); L.expect(':'); L.expect(':'); name += "::" + L.expect(TK_IDENT).text(); return name; } /** \brief Parse a statement. * * It should look like the following: * = NodeName[]() * * Outputs, blocks and attributes are optional. */ void IRParser::parseOperator(Block* b) { // Parse lefthand side. std::vector outs; parseOperatorOutputs(&outs); // Parse the name and create the corresponding node in the graph. auto source_range = L.cur().range; std::string name = parseOperatorName(); Node* n = g->create(Symbol::fromQualString(name), {}, outs.size()) ->setSourceRange(source_range); // Parse attributes and inputs. parseOperatorInputs(n); const FunctionSchema* schema = n->maybeSchema(); // Register outputs. unsigned idx = 0; for (const VarWithType& v : outs) { vmap[v.name] = n->outputs()[idx]; if (schema && !schema->is_varret()) { TORCH_CHECK( schema->returns().size() > idx, "Operator parsing error: out of bounds access at ", idx, " to schema->returns() which size is ", schema->returns().size(), " in size"); auto schema_return_type = schema->returns().at(idx).type(); if (!v.type) { vmap[v.name]->setType(schema_return_type); } else { // Don't currently support checking against type variables // TODO: support? if (!schema_return_type->hasFreeVariables() && !v.type->isSubtypeOf(*schema_return_type)) { throw ErrorReport(source_range) << "Annotated type " << v.type->repr_str() << " does not match schema type " << schema_return_type->repr_str() << " for operator " << *schema; } vmap[v.name]->setType(v.type); } } else { vmap[v.name]->setType(v.type ? v.type : TensorType::get()); } idx++; } // Insert the new node into block B. b->appendNode(n); // If the statement has nested blocks, parse them: if (L.cur().kind == TK_INDENT) { parseBlocks(n); } L.nextIf(TK_NEWLINE); } void IRParser::parseGraphInputs() { parseList('(', ',', ')', [&] { VarWithType v = parseVarWithType(); // If the name isn't valid, don't use it std::string uniq_name = Value::isValidName(v.name) ? v.name : ""; vmap[v.name] = g->addInput(uniq_name); vmap[v.name]->setType(v.type); }); } /** \brief Parse return statement. * * It should look like the following: * return (x : TypeX, y : TypeY, z, ...) */ void IRParser::parseReturnOperator() { L.expect(TK_RETURN); // Parse output names and types parseList('(', ',', ')', [&] { std::string var_name = parseVar(); g->registerOutput(findValueInVMap(var_name)); }); // Consume ending tokens if (L.cur().kind != TK_EOF) { L.expect(TK_NEWLINE); L.expect(TK_DEDENT); } } /** \brief Parse entire graph. * * It should look like the following: * graphName (input1, input2, ... inputN): * op1 * op2 * ... * opN * return (output1, output2, ... outputN) */ void IRParser::parse() { // Parse graph definition, it should look like the following: // graphName (input1, input2, ... inputN): std::string graphName = L.expect(TK_IDENT).text(); parseGraphInputs(); L.expect(':'); // After the definition we should have a list of statements, parse it: parseOperatorsList(g->block()); // The last statement should be return, which specifies graph outputs parseReturnOperator(); for (Node* n : deferred_tensor_value_initializations_) { auto type = n->output()->type()->expect(); auto tt = n->output()->type()->cast(); TORCH_INTERNAL_ASSERT(tt, "expected tensor output ", *n); auto sizes = tt->sizes().concrete_sizes(); TORCH_INTERNAL_ASSERT(sizes); auto strides = tt->strides().concrete_sizes(); TORCH_INTERNAL_ASSERT(strides); auto device = tt->device(); TORCH_INTERNAL_ASSERT(device); auto dtype = tt->scalarType(); TORCH_INTERNAL_ASSERT(dtype); auto options = at::TensorOptions(*device).dtype(*dtype); auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options)); (void)t; } for (Node* n : deferred_empty_container_initializations_) { auto type = n->output()->type(); IValue val; if (type->kind() == TypeKind::ListType) { val = c10::impl::GenericList(type->containedType(0)); } else if (type->kind() == TypeKind::DictType) { val = c10::impl::GenericDict( type->containedType(0), type->containedType(1)); } n->ival_(attr::value, val); } } void IRParser::parseList( int begin, int sep, int end, const std::function& callback) { if (begin != TK_NOTHING) { L.expect(begin); } if (L.cur().kind != end) { do { callback(); } while (L.nextIf(sep)); } if (end != TK_NOTHING) { L.expect(end); } } Value* IRParser::findValueInVMap(const std::string& name) { if (!vmap.count(name)) { throw ErrorReport(L.cur().range) << "Cannot find a variable with name '" << name << "'"; } return vmap.at(name); } } // namespace torch::jit