#include "caffe2/core/net.h" #include "caffe2/utils/proto_utils.h" #include "compiler.h" #include "parser.h" namespace caffe2 { namespace script { namespace { static std::unordered_set ops_containing_nets = { "If", "While", "RecurrentNetwork", }; // record of defined function // NetDef + metadata struct FunctionDefinition { explicit FunctionDefinition(Def tree) : tree(new Def(tree)), net_def(new NetDef()) {} explicit FunctionDefinition(std::unique_ptr def) : tree(nullptr), net_def(std::move(def)) { // we coop extern_inputs/extern_outputs to be the inputs/outputs to // this net as a function // but we _dont_ set these when creating the net in the workspace // because they require the net to have valid inputs/outputs inputs.insert( inputs.begin(), net_def->external_input().begin(), net_def->external_input().end()); outputs.insert( outputs.begin(), net_def->external_output().begin(), net_def->external_output().end()); net_def->clear_external_output(); net_def->clear_external_input(); } bool isExtern() const { return tree == nullptr; } std::unique_ptr tree; std::unique_ptr net_def; std::vector inputs; std::vector outputs; }; } // namespace using SymbolTable = std::unordered_map; struct DefCompiler { DefCompiler(FunctionDefinition& def, SymbolTable& symbol_table) : def(def), net_def_stack({def.net_def.get()}), symbol_table(symbol_table) {} void run() { auto& tree = *def.tree; cur().set_name(tree.name().name()); for (auto input : tree.params()) { auto& name = input.ident().name(); map(name, name); def.inputs.push_back(name); } for (auto output : tree.returns()) { auto& name = output.ident().name(); map(name, name); def.outputs.push_back(name); } emitStatements(tree.statements()); } void emitExpressionStatement(TreeRef stmt) { // expression with no used outputs emit(stmt, {}); } void emitStatements(const ListView& statements) { for (auto stmt : statements) { switch (stmt->kind()) { case TK_IF: emitIf(If(stmt)); break; case TK_WHILE: emitWhile(While(stmt)); break; case TK_ASSIGN: emitAssignment(Assign(stmt)); break; case TK_GLOBAL: for (auto ident : stmt->trees()) { auto name = Ident(ident).name(); map(name, name); } break; default: emitExpressionStatement(stmt); break; } } } void map(const std::string& name, const std::string& value) { env[name] = value; } const std::string& lookup(const Ident& ident) { if (env.count(ident.name()) == 0) throw ErrorReport(ident) << "undefined value " << ident.name(); return env[ident.name()]; } void emitAssignment(const Assign& stmt) { std::vector outputs; for (auto lhs : stmt.lhs()) { std::string name = getLHS(lhs); // use of "_" gets renamed in Caffe2 graphs so that two uses // don't unintentionally interfere with each other if (name == "_") { name = fresh(); } outputs.push_back(name); } if (stmt.reduction() != '=') { if (stmt.lhs().size() != 1) { throw ErrorReport(stmt) << "reductions are only allow when there is a single variable " << "on the left-hand side."; } auto lhs = stmt.lhs()[0]; auto expr = Compound::create(stmt.reduction(), stmt.range(), {lhs, stmt.rhs()}); emit(expr, outputs); } else { emit(stmt.rhs(), outputs); } int i = 0; for (auto ident : stmt.lhs()) { if (ident->kind() == TK_IDENT) map(Ident(ident).name(), outputs.at(i)); i++; } } void emitIf(const If& stmt) { auto cond = getValue(stmt.cond()); auto op = cur().add_op(); op->set_type("If"); op->add_input(cond); auto true_branch = op->add_arg(); true_branch->set_name("then_net"); auto nd = true_branch->mutable_n(); net_def_stack.push_back(nd); emitStatements(stmt.trueBranch()); net_def_stack.pop_back(); if (stmt.falseBranch().size() > 0) { auto false_branch = op->add_arg(); false_branch->set_name("else_net"); auto nd = false_branch->mutable_n(); net_def_stack.push_back(nd); emitStatements(stmt.falseBranch()); net_def_stack.pop_back(); } } void emitWhile(const While& stmt) { std::string loop_var = fresh(); emitConst(0, loop_var, "i"); // it needs a definition before loop auto op = cur().add_op(); op->set_type("While"); auto cond = op->add_arg(); cond->set_name("cond_net"); auto cond_net = cond->mutable_n(); net_def_stack.push_back(cond_net); emit(stmt.cond(), {loop_var}); net_def_stack.pop_back(); op->add_input(loop_var); auto body = op->add_arg(); body->set_name("loop_net"); auto body_net = body->mutable_n(); net_def_stack.push_back(body_net); emitStatements(stmt.body()); net_def_stack.pop_back(); } std::string getLHS(const TreeRef& tree) { switch (tree->kind()) { case TK_IDENT: { return Ident(tree).name(); } break; case '.': { auto sel = Select(tree); std::string lhs = getValue(sel.value()); // TODO: check whether this subname exists in object lhs return lhs + "/" + sel.selector().name(); } break; default: { throw ErrorReport(tree) << "This expression cannot appear on the left-hand size of an assignment"; } break; } } std::string getValue(const TreeRef& tree) { switch (tree->kind()) { case TK_IDENT: { return lookup(Ident(tree)); } break; case '.': { auto sel = Select(tree); std::string lhs = getValue(sel.value()); // TODO: check whether this subname exists in object lhs return lhs + "/" + sel.selector().name(); } break; default: { std::string name = fresh(); emit(tree, {name}); return name; } break; } } std::string fresh(std::string prefix = "$t") { return std::string(prefix) + caffe2::to_string(next_fresh++); } const char* operatorName(int kind, int ninputs) { switch (kind) { case '+': return "Add"; case '-': if (ninputs == 1) return "Negative"; else return "Sub"; case '*': return "Mul"; case '/': return "Div"; case TK_NE: return "NE"; case TK_EQ: return "EQ"; case '<': return "LT"; case '>': return "GT"; case TK_LE: return "LE"; case TK_GE: return "GE"; case TK_IF_EXPR: return "Conditional"; case TK_AND: return "And"; case TK_OR: return "Or"; case TK_NOT: return "Not"; default: throw std::runtime_error("unknown kind " + caffe2::to_string(kind)); } } void fillArg(Argument* arg, const Attribute& attr) { std::string name = attr.name().name(); arg->set_name(name); auto value = attr.value(); // TODO: handle non-float attributes switch (value->kind()) { case TK_CONST: { auto v = value->tree(0)->doubleValue(); auto f = value->tree(1)->stringValue(); if (f == "f") arg->set_f(v); else arg->set_i(v); } break; case TK_LIST: for (auto t : value->trees()) { auto v = t->tree(0)->doubleValue(); auto f = t->tree(1)->stringValue(); if (f == "f") arg->add_floats(v); else arg->add_ints(v); } break; } } template std::vector getValues(const Trees& trees) { std::vector result; for (const auto& tree : trees) { result.push_back(getValue(tree)); } return result; } bool renameLookup( std::unordered_map& rename_map, const std::string& name, std::string& rename) { // first look for name in the map directly auto it = rename_map.find(name); if (it != rename_map.end()) { rename = it->second; return true; } // otherwise if we have a rename entry like a => b and a name "a/foo/bar" // then replace it with "b/foo/bar" auto p = name.find("/"); if (p == std::string::npos) return false; it = rename_map.find(name.substr(0, p)); if (it != rename_map.end()) { rename = it->second + name.substr(p); return true; } return false; } void renameOp( std::unordered_map& rename_map, const Apply& apply, const std::string& prefix, bool isExtern, OperatorDef* new_op) { for (size_t i = 0; i < new_op->input().size(); i++) { auto& name = new_op->input(i); std::string renamed; bool defined = renameLookup(rename_map, name, renamed); if (!isExtern && !defined) { throw ErrorReport(apply) << " unexpected undefined name '" << name << "' while attempting to inline '" << apply.name().name() << "'"; } else if (!defined) { // extern function using a global name, assign it an identity mapping rename_map[name] = name; } new_op->set_input(i, renamed); } for (size_t i = 0; i < new_op->output().size(); i++) { auto& name = new_op->output(i); std::string renamed; if (!renameLookup(rename_map, name, renamed)) { renamed = prefix + name; rename_map[name] = renamed; } new_op->set_output(i, renamed); } // handle control flow inside the op as well if (ops_containing_nets.count(new_op->type()) > 0) { for (size_t i = 0; i < new_op->arg_size(); i++) { auto* arg = new_op->mutable_arg(i); if (arg->has_n()) { auto* n = arg->mutable_n(); for (size_t j = 0; j < n->op_size(); j++) { renameOp(rename_map, apply, prefix, isExtern, n->mutable_op(j)); } } } } } bool hasBypassRename(const Apply& apply) { for (auto attr : apply.attributes()) { if (attr.name().name() == "rename") { if (attr.value()->kind() != TK_CONST) { throw ErrorReport(attr.value()) << "expected a single constant"; } return attr.value()->tree(0)->doubleValue() == 0; } } return false; } // emit a function call by inlining the function's NetDef into our // net def, renaming temporaries func_name/orig_name // renaming only happens for values defined by the function // that are not marked outputs // inputs/outputs are passed by reference void emitFunctionCall(Apply& apply, const std::vector& outputs) { std::string fname = apply.name().name(); std::string prefix = fresh(fname) + "/"; auto& fn = symbol_table.at(apply.name().name()); bool isExtern = fn.isExtern(); auto inputs = getValues(apply.inputs()); std::unordered_map rename_map; if (inputs.size() != fn.inputs.size()) { throw ErrorReport(apply) << fname << " expected " << fn.inputs.size() << " values but received " << inputs.size(); } for (size_t i = 0; i < inputs.size(); i++) { rename_map[fn.inputs[i]] = inputs[i]; } if (outputs.size() != fn.outputs.size()) { throw ErrorReport(apply) << fname << " expected " << fn.outputs.size() << " values but received " << outputs.size(); } for (size_t i = 0; i < outputs.size(); i++) { rename_map[fn.outputs[i]] = outputs[i]; } for (auto& op : fn.net_def->op()) { auto new_op = cur().add_op(); new_op->CopyFrom(op); if (hasBypassRename(apply)) { prefix = ""; } renameOp(rename_map, apply, prefix, isExtern, new_op); } } void expectOutputs( const TreeRef& tree, const std::vector& outputs, size_t size) { if (outputs.size() != size) { throw ErrorReport(tree) << "expected operator to produce " << outputs.size() << " outputs but it produced " << size; } } void appendOutputs( const TreeRef& tree, OperatorDef* op, const std::vector& outputs, size_t size) { expectOutputs(tree, outputs, size); for (size_t i = 0; i < size; i++) { op->add_output(outputs[i]); } } void emitOperator( const Apply& apply, const OpSchema* schema, const std::vector& outputs) { // must be before add_op auto values = getValues(apply.inputs()); if (values.size() < schema->min_input() || values.size() > schema->max_input()) { if (schema->min_input() == schema->max_input()) { throw ErrorReport(apply) << "operator expects " << schema->min_input() << " inputs but found " << values.size(); } else { throw ErrorReport(apply) << "operator takes between " << schema->min_input() << " and " << schema->max_input() << " inputs but found " << values.size() << "."; } } auto numActualOutputs = schema->CalculateOutput(values.size()); if (numActualOutputs != kCannotComputeNumOutputs && outputs.size() != numActualOutputs) { throw ErrorReport(apply) << "operator produces " << numActualOutputs << " outputs but matched to " << outputs.size() << " outputs"; } auto op = cur().add_op(); op->set_type(apply.name().name()); for (auto& v : values) { op->add_input(v); } // assume 1 output unless matched to more appendOutputs(apply, op, outputs, outputs.size()); for (auto attribute : apply.attributes()) { fillArg(op->add_arg(), attribute); } // Ok, we checked the stuff where we can easily give a friendly error // message, now verify against the schema and report the error at the line if (!schema->Verify(*op)) { throw ErrorReport(apply) << "failed schema checking"; } } // Emit an operation, writing results into 'outputs'. // This will _always_ compute something, unlike 'getValue' which simply // returns an already computed reference if possible. // So if 'tree' is an identifier or nested identifier (foo.bar) // this will cause it to be _copied_ into outputs. void emit(const TreeRef& tree, const std::vector& outputs) { switch (tree->kind()) { case TK_IDENT: case '.': { auto op = cur().add_op(); op->set_type("Copy"); op->add_input(getValue(tree)); appendOutputs(tree, op, outputs, 1); } break; case TK_NE: case TK_EQ: case '<': case '>': case TK_LE: case TK_GE: case '-': case '*': case '/': case '+': case TK_AND: case TK_OR: case TK_NOT: case TK_IF_EXPR: { // must be before add_op auto values = getValues(tree->trees()); auto op = cur().add_op(); op->set_type(operatorName(tree->kind(), tree->trees().size())); for (auto& v : values) { op->add_input(v); } appendOutputs(tree, op, outputs, 1); auto broadcast = op->add_arg(); broadcast->set_name("broadcast"); broadcast->set_i(1); } break; case TK_APPLY: { auto apply = Apply(tree); // Handle built-ins like zeros, ones, etc if (builtins.count(apply.name().name()) > 0) { builtins[apply.name().name()](this, apply, outputs); break; } if (symbol_table.count(apply.name().name()) > 0) { emitFunctionCall(apply, outputs); break; } auto schema = OpSchemaRegistry::Schema(apply.name().name()); if (schema) { emitOperator(apply, schema, outputs); break; } throw ErrorReport(apply) << "attempting to call unknown operation or function '" << apply.name().name() << "'"; } break; case TK_CAST: { auto cast = Cast(tree); auto c2type = getType(cast.type()); auto input = getValue(cast.input()); auto op = cur().add_op(); op->set_type("Cast"); op->add_input(input); appendOutputs(tree, op, outputs, 1); auto arg = op->add_arg(); arg->set_name("to"); arg->set_i(c2type); } break; case TK_CONST: { expectOutputs(tree, outputs, 1); emitConst( tree->tree(0)->doubleValue(), outputs[0], tree->tree(1)->stringValue()); } break; case TK_GATHER: { const auto gather = Gather(tree); desugarAndEmitOperator( "Gather", gather.range(), {gather.value(), gather.indices()}, outputs); break; } case TK_SLICE: { const auto slice = Slice(tree); desugarAndEmitOperator( "Slice", slice.range(), {slice.value(), slice.startOr(0), slice.endOr(-1)}, outputs); break; } default: throw ErrorReport(tree) << "NYI: " << tree; break; } } // Desugars constructs that are syntactic sugar and emits the corresponding // operator invocation, e.g. tensor[indices] -> tensor.Gather(indices). void desugarAndEmitOperator( const std::string& operatorName, const SourceRange& range, TreeList&& inputs, const std::vector& outputs) { const auto applyName = Ident::create(range, operatorName); const auto applyInputs = Compound::create(TK_LIST, range, std::move(inputs)); const auto applyAttributes = Compound::create(TK_LIST, range, {}); const auto apply = Apply::create(range, applyName, applyInputs, applyAttributes); const auto schema = OpSchemaRegistry::Schema(operatorName); assert(schema != nullptr); emitOperator(Apply(apply), schema, outputs); } TensorProto_DataType getType(int type) { switch (type) { case TK_INT: return TensorProto_DataType_INT32; case TK_FLOAT: return TensorProto_DataType_FLOAT; case TK_LONG: return TensorProto_DataType_INT64; case TK_BOOL: return TensorProto_DataType_BOOL; default: throw std::runtime_error( "expected type token: " + caffe2::to_string(type)); } } OperatorDef* emitConst( double v, const std::string& output, const std::string& type_ident) { auto op = cur().add_op(); op->set_type("ConstantFill"); auto dtype = op->add_arg(); dtype->set_name("dtype"); auto value = op->add_arg(); value->set_name("value"); if (type_ident == "f") { dtype->set_i(TensorProto_DataType_FLOAT); value->set_f(v); } else if (type_ident == "LL") { dtype->set_i(TensorProto_DataType_INT64); value->set_i(v); } else if (type_ident == "b") { dtype->set_i(TensorProto_DataType_BOOL); value->set_i(v != 0); } else if (type_ident == "i") { dtype->set_i(TensorProto_DataType_INT32); value->set_i(v); } else { throw std::runtime_error("unknown type_ident " + type_ident); } auto shape = op->add_arg(); shape->set_name("shape"); shape->add_ints(1); op->add_output(output); return op; } NetDef& cur() { return *net_def_stack.back(); } FunctionDefinition& def; // the def being constructed std::unordered_map env; // map from name in Def to name in NetDef std::vector net_def_stack; SymbolTable& symbol_table; int next_fresh = 0; private: void emitFillOp(const Apply& apply, const std::vector& outputs) { auto builtin_type = apply.name().name(); auto values = getValues(apply.inputs()); if (values.size() > 1) { throw ErrorReport(apply) << "Built-in " << builtin_type << " accepts 0 or 1 inputs."; } bool has_shape = false; for (const auto& attribute : apply.attributes()) { if (attribute.name().name() == "shape") { has_shape = true; } else { throw ErrorReport(apply) << "Unrecognized attribute " << attribute.name().name() << " for built-in " << builtin_type; } } if (builtin_type == "zeros" || builtin_type == "ones") { if ((values.size() != 1) && !has_shape) { throw ErrorReport(apply) << "Built-in " << builtin_type << " requires either 1 input or 1 shape attribute"; } } else { // zeros_like or ones_like if (values.size() != 1) { throw ErrorReport(apply) << "Built-in " << builtin_type << " requires 1 input"; } } auto op = cur().add_op(); op->set_type("ConstantFill"); if (values.size()) { op->add_input(values[0]); auto* input_as_shape = op->add_arg(); input_as_shape->set_name("input_as_shape"); if (builtin_type.find("_like") != std::string::npos) { // zeros_like, ones_like take the shape of the input as constant // tensor shape input_as_shape->set_i(0); } else { // zeros, ones take the values in the tensor as constant tensor // shape input_as_shape->set_i(1); } } else { fillArg(op->add_arg(), apply.attributes()[0]); } auto value = op->add_arg(); value->set_name("value"); if (builtin_type.find("ones") != std::string::npos) { value->set_f(1.0f); } else { value->set_f(0.0f); } appendOutputs(apply, op, outputs, 1); } // emitModule doesn't actually do anything except for allow // statements like a = Module() to register 'a' as a valid identifier // so that a.b = ... will work void emitModule(const Apply& apply, const std::vector& outputs) { expectOutputs(apply, outputs, 1); } std::unordered_map< std::string, std::function& outputs)>> builtins{{"zeros", &DefCompiler::emitFillOp}, {"zeros_like", &DefCompiler::emitFillOp}, {"ones", &DefCompiler::emitFillOp}, {"ones_like", &DefCompiler::emitFillOp}, {"Module", &DefCompiler::emitModule}}; }; struct CompilationUnitImpl { void defineFunction(const Def& def) { if (functions.count(def.name().name()) > 0) { throw ErrorReport(def) << def.name().name() << " already defined."; } DefCompiler c( functions.emplace(def.name().name(), FunctionDefinition(def)) .first->second, functions); c.run(); } void define(const std::string& str) { Parser p(str); while (p.lexer().cur().kind != TK_EOF) { defineFunction(Def(p.parseFunction())); } } std::unique_ptr createNet(Workspace* ws, const std::string& str) { if (functions.count(str) == 0) throw ErrorReport() << "undefined function: " << str << "\n"; auto& def = functions.at(str); return caffe2::CreateNet(*def.net_def, ws); } void defineExtern(const std::string& name, std::unique_ptr net_def) { // TODO: unify extern and function namespaces if (functions.count(name) > 0) { throw ErrorReport() << "function '" << name << "' already defined."; } functions.emplace(name, FunctionDefinition(std::move(net_def))); } std::string getProto(const std::string& functionName) { return functions.at(functionName).net_def->DebugString(); } private: friend struct DefCompiler; SymbolTable functions; }; CompilationUnit::CompilationUnit() : pImpl(new CompilationUnitImpl()) {} void CompilationUnit::define(const std::string& str) { return pImpl->define(str); } void CompilationUnit::defineExtern( const std::string& name, std::unique_ptr nd) { pImpl->defineExtern(name, std::move(nd)); } std::unique_ptr CompilationUnit::createNet( Workspace* ws, const std::string& str) { return pImpl->createNet(ws, str); } std::string CompilationUnit::getProto(const std::string& functionName) const { return pImpl->getProto(functionName); } CompilationUnit::~CompilationUnit() {} } // namespace script } // namespace caffe2