#include #include #include #include #include #include #include #include #include #include namespace torch::nativert { namespace { // Workaround for MSVC bug: "std" ambiguous symbol. template constexpr bool is_same_v = std::is_same_v; bool isBlank(char n) { return std::isspace(n); } size_t consumeWhitespaceImpl(std::string_view source, size_t curPos) { while (isBlank(source.at(curPos))) { curPos++; } return curPos; } size_t expectImpl( std::string_view source, std::string_view expected, size_t curPos) { curPos = consumeWhitespaceImpl(source, curPos); const auto actual = source.substr(curPos, expected.size()); TORCH_CHECK( expected == actual, fmt::format( "Parser error: expected '{}' at position {}, but found '{}'.", expected, curPos, actual)); curPos += expected.size(); return curPos; } size_t expectImpl(std::string_view source, char expected, size_t curPos) { curPos = consumeWhitespaceImpl(source, curPos); while (isBlank(source.at(curPos))) { curPos++; } TORCH_CHECK( expected == source[curPos], "Parser error: expected '{}' at position {}, but found '{}'.", expected, curPos, source[curPos]); curPos++; return curPos; } } // namespace bool operator==(const Type& left, const Type& right) { if (left.kind() != right.kind()) { return false; } if (std::holds_alternative(left.kind_) && std::holds_alternative(right.kind_)) { return std::get(left.kind_).classFqn == std::get(right.kind_).classFqn; } return true; } Graph::Graph() : insertBefore_(nodes_.end()), inputNode_(insertNode("prim.Input", {})), outputNode_(insertNode("prim.Output", {})) { // Set the insertion point to append to the graph insertBefore_ = nodes_.iterator_to(*outputNode_); } std::string Graph::getUniqueValueName() { auto name = fmt::format("v{}", uniqueValueName_); while (values_.find(name) != values_.end()) { name = fmt::format("v{}", uniqueValueName_++); } return name; } // If `name` is null, create a unique value name Value* Graph::addValue( const std::optional& name, const Type& type, Node* node) { const auto valueName = name.value_or(getUniqueValueName()); ValueId valueId = getNextValueId(); const auto [it, success] = values_.insert( {valueName, std::make_unique(valueId, valueName, type, node)}); TORCH_CHECK( success, fmt::format( "Tried to create Value with name: '{}', but it already existed", valueName)); return it->second.get(); } Value* Graph::addInput(std::string_view name, const Type& type) { return inputNode_->addOutput(name, type); } void Graph::addInput() { inputNode_->addOutput(); } Value* Graph::addOutput(Value* v) { outputNode_->addInput({std::string(v->name()), v}); return v; } void Graph::addConstantOutput(Constant c) { constantOutputs_.push_back(std::move(c)); } // Create a node without inserting it into the execution graph. Node* Graph::createNode( std::string target, std::vector inputs, std::unordered_map metadata) { auto& node = nodesOwner_.emplace_back(std::make_unique( this, std::move(target), std::move(inputs), std::move(metadata))); return node.get(); } Node* Graph::insertBefore(Node* toInsert, Node* insertionPoint) { TORCH_CHECK(insertionPoint != inputNode_, "can't insert before prim.Input"); TORCH_CHECK( !toInsert->is_linked(), "expected node to be unlinked: ", *toInsert); TORCH_CHECK( insertionPoint->is_linked(), "expected node to be linked: ", *insertionPoint); auto it = nodes_.insert(nodes_.iterator_to(*insertionPoint), *toInsert); return &*it; } Node* Graph::insert(Node* toInsert) { TORCH_CHECK( !toInsert->is_linked(), "expected node to be unlinked: ", *toInsert); nodes_.insert(insertBefore_, *toInsert); return toInsert; } Node* Graph::insertAfter(Node* toInsert, Node* insertionPoint) { TORCH_CHECK(insertionPoint != outputNode_, "can't insert after prim.Output"); TORCH_CHECK( !toInsert->is_linked(), "expected node to be unlinked: ", *toInsert); TORCH_CHECK( insertionPoint->is_linked(), "expected node to be linked: ", *insertionPoint); auto insertIt = nodes_.iterator_to(*insertionPoint); // Increment once because we want to insert after the insertion point ++insertIt; auto it = nodes_.insert(insertIt, *toInsert); return &*it; } Node* Graph::insertNode( std::string target, std::vector inputs, std::unordered_map metadata) { auto node = createNode(std::move(target), std::move(inputs), std::move(metadata)); nodes_.insert(insertBefore_, *node); return node; } std::ostream& operator<<(std::ostream& out, const Type& ty) { std::visit( [&out](auto&& arg) { using T = std::decay_t; if constexpr (is_same_v) { switch (arg) { case Type::Kind::None: out << "None"; break; case Type::Kind::Tensor: out << "Tensor"; break; case Type::Kind::TensorList: out << "TensorList"; break; case Type::Kind::OptionalTensorList: out << "OptionalTensorList"; break; case Type::Kind::SymInt: out << "SymInt"; break; case Type::Kind::SymFloat: out << "SymFloat"; break; case Type::Kind::SymIntList: out << "SymIntList"; break; case Type::Kind::CustomObj: out << "CustomObj"; break; default: TORCH_CHECK(false, "Unhandled type"); } } else if constexpr (is_same_v) { out << "CustomObj: " << arg.classFqn; } }, ty.kind_); return out; } const NamedArgument* Node::tryGetInput(std::string_view name) const { // Just do a scan over the inputs. We expect there to always be a very small // number of elements, so it shouldn't be slow. This allows us to avoid a // second datastructure for lookups. // Drop a debug check here, just to make sure :) TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs_.size() < 1000); for (const auto& input : inputs_) { if (input.name == name) { return &input; } } return nullptr; } const NamedArgument& Node::getInput(std::string_view name) const { const auto ret = tryGetInput(name); if (ret == nullptr) { TORCH_CHECK( false, fmt::format( "Expected input '{}' on node: '{}' to exist, but it does not.", name, fmt::streamed(*this))); } return *ret; } const Attribute* Node::tryGetAttribute(std::string_view name) const { // Just do a scan over the inputs. We expect there to always be a very small // number of elements, so it shouldn't be slow. This allows us to avoid a // second datastructure for lookups. // Drop a debug check here, just to make sure :) TORCH_INTERNAL_ASSERT_DEBUG_ONLY(attributes_.size() < 1000); for (const auto& attribute : attributes_) { if (attribute.name == name) { return &attribute; } } return nullptr; } const Attribute& Node::getAttribute(std::string_view name) const { const auto ret = tryGetAttribute(name); if (ret == nullptr) { TORCH_CHECK( false, fmt::format( "Expected attribute '{}' on node: '{}' to exist, but it does not.", name, fmt::streamed(*this))); } return *ret; } void Node::applyDevicePlacement(const Placement& placement) { for (auto& attribute : attributes_) { if (std::holds_alternative(attribute.value)) { auto device = std::get(attribute.value); auto targetDevice = placement.getMappedDevice(std::get(attribute.value)); if (!isSameDevice(targetDevice, device)) { LOG(INFO) << "Overriding " << device.str() << " to " << targetDevice.str() << " for node " << *this; attribute.value = targetDevice; } } } } Node* Node::next() { return owningGraph()->nodeAfter(this); } const Node* Node::next() const { return owningGraph()->nodeAfter(this); } Node* Node::prev() { return owningGraph()->nodeBefore(this); } const Node* Node::prev() const { return owningGraph()->nodeBefore(this); } bool Node::isBefore(const Node* n) const { if (this == n) { return false; } for (const Node* cursor = this->next(); cursor != nullptr; cursor = cursor->next()) { if (cursor == n) { return true; } } // Reached the end without finding n return false; } std::vector Node::producers() const { std::vector ret; if (this->prev() == nullptr /* prim.Input */) { return ret; } if (this->next() == nullptr /* prim.Output */) { for (auto& node : owningGraph_->nodes()) { if (node.next() == nullptr /* prim.Output */ || node.prev() == nullptr /* prim.Input */) { continue; } for (auto* dep : node.users()) { if (dep == this /* prim.Output */) { ret.push_back(&node); } } } } else { std::unordered_set seen; for (const auto& input : inputs()) { auto* n = input.value->producer(); if (n == nullptr) { continue; } if (const auto [_, inserted] = seen.insert(n); inserted) { ret.push_back(n); } } if (ret.empty()) { ret.push_back(owningGraph_->inputNode()); } } return ret; } std::vector Node::users() const { std::vector ret; if (this->next() == nullptr /* prim.Output */) { return ret; } if (this->prev() == nullptr /* prim.Input */) { for (auto& node : owningGraph_->nodes()) { if (node.prev() == nullptr /* prim.Input */ || node.next() == nullptr /* prim.Output */) { continue; } for (auto* dep : node.producers()) { if (dep == this /* prim.Input */) { ret.push_back(&node); } } } } else { std::unordered_set seen; for (const auto* output : outputs()) { for (auto* n : output->users()) { if (const auto [_, inserted] = seen.insert(n); inserted) { ret.push_back(n); } } } if (ret.empty()) { ret.push_back(owningGraph_->outputNode()); } } return ret; } Node* Graph::createListPack(std::vector inputs, const Type& inputType) { std::vector nodeInputs; nodeInputs.reserve(inputs.size()); for (auto [i, input] : c10::enumerate(inputs)) { nodeInputs.push_back({fmt::format("l{}", i), input}); } // Create a new named value for this auto name = getUniqueValueName(); auto node = createNode("prim.ListPack", std::move(nodeInputs)); // Make sure all inputs are the same type for (auto& input : inputs) { TORCH_CHECK(input->type() == inputType); } if (inputType == Type::Kind::Tensor) { node->addOutput(name, Type::Kind::TensorList); } else if (inputType == Type::Kind::SymInt) { node->addOutput(name, Type::Kind::SymIntList); } return node; } Node* Graph::createOptionalListPack(std::vector inputs) { std::vector nodeInputs; nodeInputs.reserve(inputs.size()); for (auto [i, input] : c10::enumerate(inputs)) { nodeInputs.push_back({fmt::format("l{}", i), input}); } // Create a new named value for this auto name = getUniqueValueName(); auto node = createNode("prim.ListPack", std::move(nodeInputs)); // Make sure all inputs are either None or Tensor for (auto& input : inputs) { TORCH_CHECK( input->type() == Type::Kind::None || input->type() == Type::Kind::Tensor); } node->addOutput(name, Type::Kind::OptionalTensorList); return node; } Value* Graph::createConstantSymIntValue(int value) { auto valueName = getUniqueValueName(); ValueId valueId = getNextValueId(); const auto [it, success] = values_.insert( {valueName, std::make_unique( valueId, valueName, Type::Kind::SymInt, nullptr)}); TORCH_CHECK( success, fmt::format( "Tried to create constant SymInt Value with name: '{}', but it already existed", valueName)); constantSymIntValues_[valueId] = value; return it->second.get(); } Value* Graph::getValue(std::string_view name) const { // TODO: can eliminate this string copy by enabling heterogeneous lookup for // the container return values_.at(std::string(name)).get(); } Value* Graph::tryGetValue(std::string_view name) const { // TODO: can eliminate this string copy by enabling heterogeneous lookup for // the container const auto key = std::string(name); if (values_.find(key) != values_.end()) { return values_.at(key).get(); } return nullptr; } void Graph::renumberValues() { std::vector currentValues; currentValues.reserve(values_.size()); for (auto& kv : values_) { currentValues.push_back(kv.second.get()); } // Sort values in creation order (by value ids) std::sort(currentValues.begin(), currentValues.end(), [](Value* a, Value* b) { return a->id() < b->id(); }); // Build a new id map with all ids < values_.size() std::unordered_map oldToNew; oldToNew.reserve(currentValues.size()); ValueId newId = 0; for (Value* v : currentValues) { oldToNew[v->id()] = newId; v->setId(newId); newId++; } std::unordered_map newSymIntMap; for (auto& [oldId, symIntVal] : constantSymIntValues_) { auto it = oldToNew.find(oldId); if (it != oldToNew.end()) { ValueId updatedId = it->second; newSymIntMap[updatedId] = symIntVal; } } constantSymIntValues_ = std::move(newSymIntMap); uniqueValueId_ = newId; } bool Graph::cleanupDeadNodes() { std::unordered_set visited; std::vector visitStack; // Mark reachable nodes from output visitStack.push_back(outputNode_); visited.insert(outputNode_); while (!visitStack.empty()) { const Node* current = visitStack.back(); visitStack.pop_back(); for (auto& namedArg : current->inputs()) { Value* val = namedArg.value; Node* producer = val->producer(); if (!producer) { continue; } if (!visited.count(producer)) { visited.insert(producer); visitStack.push_back(producer); } } } // Remove all nodes not in visited (other than input/outputs) std::vector toRemove; for (auto& n : nodes()) { if (n.target() == "prim.Input" || n.target() == "prim.Output" || visited.count(&n)) { continue; } toRemove.push_back(&n); } const bool mutated = !toRemove.empty(); // Remove nodes in reverse order to handle input/output dependencies for (auto it = toRemove.rbegin(); it != toRemove.rend(); ++it) { removeNode(*it); } renumberValues(); lint(); return mutated; } void Graph::lint() const { // Check that every value has a producer marked. for (const auto& [name, value] : values_) { // Some constant symint and None don't have producer nodes if (value->type().kind() != Type::Kind::SymInt && value->type().kind() != Type::Kind::None) { TORCH_CHECK(value->isFolded() || value->producer() != nullptr); } } for (const auto& node : nodes()) { TORCH_CHECK(node.owningGraph() == this); } // Check that every list type is either produced by a prim.ListPack or // immediately consumed by a prim.ListUnpack. We make use of this invariant // to retrieve list elements in `getListElements`. for (const auto& [_, value] : values_) { if (value->type().kind() != Type::Kind::TensorList) { continue; } const bool producedByListPack = value->producer(/* resolve_folded = */ true)->target() == "prim.ListPack"; const bool consumedByListUnpack = value->users().size() == 1 && value->users()[0]->target() == "prim.ListUnpack"; TORCH_CHECK(producedByListPack || consumedByListUnpack); } auto getNames = [](const auto& values) { c10::FastSet names; for (const auto* value : values) { if (value) { names.emplace(value->name()); } } return names; }; signature_.lint(getNames(inputs()), getNames(outputs())); } void Graph::finalize() { // build userOutputs_ view userOutputs_.clear(); size_t constantIndex = 0; for (auto& outputName : signature_.userOutputs()) { if (outputName.has_value()) { userOutputs_.emplace_back(getValue(*outputName)); } else { if (constantIndex < constantOutputs_.size()) { userOutputs_.emplace_back(std::move(constantOutputs_[constantIndex])); constantIndex++; } else { TORCH_CHECK(false, "No more constant outputs available"); } } } } namespace { // Scan through a node's inputs, replacing ALL instances of `old` with // `replacement`. Returns true if a replacement occurred, otherwise false. bool replace(Node* node, Value* old, Value* replacement) { bool replacementOccurred = false; for (auto& input : node->inputs()) { if (input.value == old) { input.value = replacement; replacementOccurred = true; } } return replacementOccurred; } } // namespace void Graph::replaceAllUses(Value* old, Value* replacement) { for (auto user : old->users()) { // Find this use in the input list and replace it auto replaced = replace(user, old, replacement); TORCH_CHECK(replaced); replacement->addUser(user); } old->eraseAllUsers(); signature_.replaceAllUses(old->name(), replacement->name()); } void Graph::replaceAllUsesAfterNode( Value* old, Value* replacement, Node* afterThis) { auto it = nodes_.iterator_to(*afterThis); // Don't search `afterThis` ++it; // Scan through all node inputs linearly and replace uses for (; it != nodes_.end(); ++it) { Node* node = &*it; const bool replaced = replace(node, old, replacement); if (replaced) { old->eraseUser(node); replacement->addUser(node); } } signature_.replaceAllUses(old->name(), replacement->name()); } void Graph::applyDevicePlacement(const Placement& placement) { TORCH_CHECK( !placementApplied_, "placement has been applied to the graph! placement must be applied once and once only."); placementApplied_ = true; // inplace override node's device-typed attributes according to placement for (auto& node : nodes_) { node.applyDevicePlacement(placement); } // inplace override weightMeta_'s device according to placement for (auto& [_, weightMeta] : weightsMeta_) { weightMeta.applyDevicePlacement(placement); } // inplace override tensorValuesMeta_'s device according to placement for (auto& [_, tensorMeta] : tensorValuesMeta_) { tensorMeta.applyDevicePlacement(placement); } } Node* Graph::nodeAfter(Node* n) { TORCH_CHECK(n->owningGraph() == this); if (n == outputNode_) { return nullptr; } auto it = nodes_.iterator_to(*n); return &*(++it); } const Node* Graph::nodeAfter(const Node* n) const { TORCH_CHECK(n->owningGraph() == this); if (n == outputNode_) { return nullptr; } auto it = nodes_.iterator_to(*n); return &*(++it); } Node* Graph::nodeBefore(Node* n) { TORCH_CHECK(n->owningGraph() == this); if (n == inputNode_) { return nullptr; } auto it = nodes_.iterator_to(*n); return &*(--it); } const Node* Graph::nodeBefore(const Node* n) const { TORCH_CHECK(n->owningGraph() == this); if (n == inputNode_) { return nullptr; } auto it = nodes_.iterator_to(*n); return &*(--it); } void Graph::removeNode(Node* n) { TORCH_CHECK(n->owningGraph() == this, "Node does not belong to this graph!"); for (auto* outputVal : n->outputs()) { TORCH_CHECK( outputVal->users().empty(), "Trying to erase a node that still has users: ", outputVal->name()); outputVal->eraseAllUsers(); removeValue(outputVal); } for (const auto& input : n->inputs()) { input.value->eraseUser(n); } TORCH_CHECK(n->is_linked(), "Node is not linked to the graph!"); n->unlink(); auto it = std::find_if( nodesOwner_.begin(), nodesOwner_.end(), [n](const std::unique_ptr& ptr) { return ptr.get() == n; }); TORCH_CHECK(it != nodesOwner_.end(), "Node not found in nodesOwner_!"); nodesOwner_.erase(it); } void Graph::removeValue(Value* value) { // TODO: assuming not removing from constantSymIntValues_ TORCH_CHECK(value->users().empty(), "Cannot erase a value with users."); auto it = values_.find(std::string(value->name())); TORCH_CHECK( it != values_.end(), "Attempted to erase a value not in graph ", value->name()); values_.erase(it); } std::vector Graph::insertGraph( const Graph& subgraph, std::vector inputs, std::unordered_map& valueMap) { TORCH_CHECK(subgraph.inputs().size() == inputs.size(), "Input size mismatch"); for (auto i : c10::irange(subgraph.inputs().size())) { valueMap[subgraph.inputs()[i]] = inputs[i]; } // Clone each node from subgraph for (const auto& n : subgraph.nodes()) { if (n.target() == "prim.Input" || n.target() == "prim.Output") { continue; } std::vector clonedInputs; auto inputs = n.inputs(); clonedInputs.reserve(inputs.size()); for (auto& inp : inputs) { auto it = valueMap.find(inp.value); TORCH_CHECK(it != valueMap.end(), "Missing input value in subgraph"); clonedInputs.push_back({inp.name, it->second}); } Node* newNode = insertNode( std::string(n.target()), std::move(clonedInputs), n.metadata()); for (const auto& attr : n.attributes()) { Attribute newAttr; newAttr.name = attr.name; std::visit( [&](auto&& val) -> void { // Workaround for MSVC bug: "std" ambiguous symbol. using std::unique_ptr; using std::move; using T = std::decay_t; if constexpr (is_same_v>) { LOG(ERROR) << "Graph attributes are not supported yet. Skipping attribute: " << attr.name; } else { newAttr.value = val; #ifdef __clang__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunknown-warning-option" #pragma GCC diagnostic ignored "-Wunqualified-std-cast-call" #endif newNode->addAttribute(move(newAttr)); #ifdef __clang__ #pragma GCC diagnostic pop #endif } }, attr.value); } for (const auto* outVal : n.outputs()) { const auto& uniqueName = getUniqueValueName(); Value* newOut = newNode->addOutput(uniqueName, outVal->type()); valueMap[outVal] = newOut; } } auto subgraphOutputs = subgraph.outputs(); std::vector outputValues; outputValues.reserve(subgraphOutputs.size()); for (auto* outputValue : subgraphOutputs) { outputValues.emplace_back(valueMap[outputValue]); } lint(); return outputValues; } Node::Node( Graph* owningGraph, std::string target, std::vector inputs, std::unordered_map metadata) : owningGraph_(owningGraph), target_(std::move(target)), inputs_(std::move(inputs)), metadata_(std::move(metadata)) { for (const auto& input : inputs_) { input.value->addUser(this); } } Value* Node::addInput(NamedArgument input) { inputs_.push_back(std::move(input)); auto val = inputs_.back().value; val->addUser(this); return val; } void Node::addInputs(const std::vector& inputs) { for (const auto& input : inputs) { addInput(input); } } void Node::addAttribute(Attribute attr) { attributes_.push_back(std::move(attr)); } void Node::addOutput() { outputs_.push_back(nullptr); } Value* Node::addOutput(const Type& type) { TORCH_CHECK(type == Type::Kind::None); Value* v = owningGraph_->addValue(std::nullopt, type, this); outputs_.push_back(v); return v; } Value* Node::addOutput(std::string_view name, const Type& type) { Value* v = owningGraph_->addValue(std::string(name), type, this); outputs_.push_back(v); return v; } void Node::destroy() { owningGraph_->removeNode(this); } void Value::addUser(Node* node) { for (const auto* user : users_) { if (user == node) { return; } } users_.push_back(node); } void Value::eraseUser(Node* node) { users_.erase( std::remove_if( users_.begin(), users_.end(), [&](Node* el) { return el == node; }), users_.end()); } std::vector Value::getListElements() const { std::vector ret; if (auto p = producer(); p && p->target() == "prim.ListPack") { for (const auto& tv : p->inputs()) { ret.push_back(tv.value); } } else { TORCH_CHECK(users().size() == 1); const auto listUnpack = users()[0]; TORCH_CHECK(listUnpack->target() == "prim.ListUnpack"); for (const auto v : listUnpack->outputs()) { ret.push_back(v); } } return ret; } template [[maybe_unused]] inline constexpr bool AlwaysFalse = false; c10::IValue constantToIValue(const Constant& constant) { // Workaround for MSVC bug: "std" ambiguous symbol. using std::string; using std::unique_ptr; using std::vector; return std::visit( [](auto&& arg) -> c10::IValue { using T = std::decay_t; if constexpr (is_same_v) { return c10::IValue(); } else if constexpr (std::is_convertible_v) { return arg; } else if constexpr (is_same_v>) { TORCH_CHECK( false, "subgraph arguments cannot be turned into ivalues!"); } else { static_assert(AlwaysFalse, "non-exhaustive visitor!"); } }, constant); } namespace { template [[maybe_unused]] inline constexpr bool always_false_v = false; void printDouble(std::ostream& out, double arg) { fmt::print(out, "{}", arg); } template std::ostream& printList( std::ostream& out, bool encloseInSquareBrackets, const T& list, F formatter) { if (encloseInSquareBrackets) { out << '['; } for (const auto& [idx, el] : c10::enumerate(list)) { if (idx > 0) { out << ", "; } formatter(out, el); } if (encloseInSquareBrackets) { out << ']'; } return out; } std::ostream& operator<<(std::ostream& out, const Constant& constant) { // Workaround for MSVC bug: "std" ambiguous symbol. using std::quoted; using std::string; using std::unique_ptr; using std::vector; std::visit( [&](auto&& arg) { using T = std::decay_t; if constexpr (is_same_v) { out << "None"; } else if constexpr (is_same_v || is_same_v) { out << arg; } else if constexpr ( is_same_v> || is_same_v>) { out << fmt::format("{}", fmt::streamed(arg)); } else if constexpr (is_same_v) { printDouble(out, arg); } else if constexpr (is_same_v>) { printList(out, true, arg, printDouble); } else if constexpr (is_same_v) { out << quoted(arg); } else if constexpr (is_same_v) { out << kScalarTypePrefix << arg; } else if constexpr (is_same_v) { out << kMemoryFormatPrefix << arg; } else if constexpr (is_same_v) { out << kLayoutPrefix << arg; } else if constexpr (is_same_v) { out << kDevicePrefix << "{" << arg << "}"; } else if constexpr (is_same_v>) { out << fmt::format("[{}]", fmt::join(arg, ",")); } else if constexpr (is_same_v>) { out << fmt::format(""); VLOG(0) << "Subgraph pretty print is not implemented"; } else { static_assert(always_false_v, "non-exhaustive visitor!"); } }, constant); return out; } void printValue(std::ostream& out, const Value* v) { if (!v) { out << ""; return; } out << *v; } void printNamedArgument(std::ostream& out, const NamedArgument& nv) { out << nv.name << "=" << *nv.value; } void printAttribute(std::ostream& out, const Attribute& nv) { out << nv.name << "=" << nv.value; } } // namespace std::ostream& operator<<(std::ostream& out, const Value& v) { out << "%" << v.name(); // If a list, distinguish it by adding a [] // Looks like %my_list[] if (v.type() == Type::Kind::TensorList) { out << "[]"; } return out; } std::ostream& operator<<(std::ostream& out, const Node& node) { // special casing for inputs and outputs if (node.target() == "prim.Input") { out << "graph("; printList(out, false, node.outputs(), printValue); out << "):"; return out; } if (node.target() == "prim.Output") { out << "return("; printList(out, false, node.inputs(), [](std::ostream& out, const auto& nv) { out << *nv.value; }); out << ")"; return out; } printList(out, false, node.outputs_, printValue); out << " = "; out << node.target_ << "("; printList(out, false, node.inputs_, printNamedArgument); if (!node.inputs_.empty() && !node.attributes_.empty()) { // Emit a connective ',' between inputs and attributes. out << ", "; } printList(out, false, node.attributes_, printAttribute); out << ")"; return out; } std::ostream& operator<<(std::ostream& out, const Graph& graph) { for (const auto& node : graph.nodes_) { out << node << "\n"; } return out; } c10::Device convertDevice(std::string_view symbol) { // Symbol looks like `Device{cuda:1}` const auto typeStart = symbol.find('{') + 1; TORCH_CHECK(typeStart < symbol.size()); const auto typeEnd = symbol.find(':'); TORCH_CHECK(typeEnd != std::string_view::npos); const auto type = symbol.substr(typeStart, typeEnd - typeStart); const auto indexStart = typeEnd + 1; TORCH_CHECK(indexStart < symbol.size()); const auto indexEnd = symbol.find('}'); TORCH_CHECK(indexEnd != std::string_view::npos); const auto index = symbol.substr(indexStart, indexEnd - indexStart); c10::Device device((std::string(type))); auto indexValue = c10::tryToNumber(std::string{index}); TORCH_CHECK(indexValue.has_value(), "Invalid device index format"); int64_t deviceIndex = indexValue.value(); TORCH_CHECK( deviceIndex >= std::numeric_limits::min() && deviceIndex <= std::numeric_limits::max(), "Device index out of range for int8_t"); device.set_index(static_cast(deviceIndex)); return device; } Constant convertAtomicConstant(std::string_view symbol) { if (c10::starts_with(symbol, "\"")) { // chop off the outer quotes and return the string TORCH_CHECK(symbol.size() >= 2); symbol.remove_prefix(1); symbol.remove_suffix(1); return std::string(symbol); } else if (symbol == "None") { return None(); } else if (symbol == "true") { return true; } else if (symbol == "false") { return false; } else if (c10::starts_with(symbol, kMemoryFormatPrefix)) { torch::_export::MemoryFormat value = torch::_export::MemoryFormat::Unknown; symbol.remove_prefix(kMemoryFormatPrefix.length()); torch::_export::parseEnum(symbol, value); return convertJsonMemoryFormat(value); } else if (c10::starts_with(symbol, kLayoutPrefix)) { torch::_export::Layout value = torch::_export::Layout::Unknown; symbol.remove_prefix(kLayoutPrefix.length()); torch::_export::parseEnum(symbol, value); return convertJsonLayout(value); } else if (c10::starts_with(symbol, kDevicePrefix)) { return convertDevice(symbol); } else if (c10::starts_with(symbol, kScalarTypePrefix)) { torch::_export::ScalarType value = torch::_export::ScalarType::UNKNOWN; symbol.remove_prefix(kScalarTypePrefix.length()); torch::_export::parseEnum(symbol, value); return convertJsonScalarType(value); } // match number // We need to disambiguate between int and float constants const auto maybeInt = c10::tryToNumber(std::string{symbol}); // Libraries may happily convert "5.0" to an int 5, but we want that to // become a float. So add an extra check for whether a '.' is in the string // to guard against that. bool hasDecimalSeparator = symbol.find('.') != std::string_view::npos; if (maybeInt.has_value() && !hasDecimalSeparator) { return maybeInt.value(); } const auto maybeDouble = c10::tryToNumber(std::string{symbol}); if (maybeDouble.has_value()) { return maybeDouble.value(); } TORCH_CHECK(false, "unhandled symbol: ", symbol); } Constant convertListConstant(std::string_view source) { std::vector values; size_t curPos = 0; Constant type = None(); // This basically the same as parseValueList, it's probably better to refactor curPos = expectImpl(source, '[', curPos); while (true) { curPos = consumeWhitespaceImpl(source, curPos); size_t start = curPos; while (source.at(curPos) != ',' && source.at(curPos) != ']') { curPos++; } auto symbol = source.substr(start, curPos - start); auto val = convertAtomicConstant(symbol); if (std::holds_alternative(type)) { // First time around; initialize our type sentinel with the first value. // We will use this on subsequent iterations to check that all types are // the same. if (auto intPtr = std::get_if(&val)) { type = *intPtr; } else if (auto doublePtr = std::get_if(&val)) { type = *doublePtr; } else if (auto boolPtr = std::get_if(&val)) { type = *boolPtr; } else { TORCH_CHECK(false, "constant lists only support int, float, bool"); } } else { TORCH_CHECK( type.index() == val.index(), "lists must have all the same type"); } values.push_back(std::move(val)); if (source.at(curPos) == ']') { break; } curPos = expectImpl(source, ',', curPos); } expectImpl(source, ']', curPos); // Some annoying unwrapping // std::vector> --> // Constant> // Do it the dumb way. if (std::holds_alternative(type)) { std::vector inner; inner.reserve(values.size()); for (const auto& el : values) { inner.push_back(std::get(el)); } return inner; } else if (std::holds_alternative(type)) { std::vector inner; inner.reserve(values.size()); for (const auto& el : values) { inner.push_back(std::get(el)); } return inner; } else if (std::holds_alternative(type)) { std::vector inner; inner.reserve(values.size()); for (const auto& el : values) { inner.push_back(std::get(el)); } return inner; } TORCH_CHECK(false, "constant lists only support int, float, bool"); } namespace { /** * Deserialization for graphs: parse the output produced by operator<<(Graph). * This parser really only expects the exact output generated by well-formed * Graph objects, so it is not very permissive and does not give good error * messages. */ class Parser { public: explicit Parser(std::string_view source) : source_(source), graph_(Graph::createGraph()) {} std::unique_ptr parse(); private: template std::vector parseList( char open, char close, const std::function& parseFn); std::string_view parseUntil( const std::function& fn, bool includeEnd = false); void expect(std::string_view expected); void expect(char expected); bool nextEquals(std::string_view expected) const; bool nextIf(std::string_view expected); bool nextIf(char expected); void consumeWhitespace(); bool validIdent(char n); char cur(); void parseReturn(); void parseNode(); std::pair parseOutput(); void parseGraphInputs(); std::string_view parseString(); std::variant parseArgument(); std::variant parseNamedArgument(); Value* parseSymbolicArgument(); // Symbols look like %v109, with the same valid ident rules as Python // This returns the symbol *without* the % at the front. std::string_view parseAtomicSymbol(); size_t curPos_ = 0; std::string_view source_; std::unique_ptr graph_; torch::_export::GraphSignature signature_; }; std::unique_ptr Parser::parse() { parseGraphInputs(); while (true) { consumeWhitespace(); if (nextEquals("return")) { parseReturn(); break; } parseNode(); } // For graph textual format, it should be safe to assume all // inputs/outputs are from users. graph_->setSignature(GraphSignature{signature_}); graph_->finalize(); graph_->lint(); // TODO: Might have some source left over, should check it if so. return std::move(graph_); } bool Parser::nextIf(std::string_view expected) { if (nextEquals(expected)) { curPos_ += expected.size(); return true; } return false; } bool Parser::nextIf(char expected) { if (cur() == expected) { curPos_++; return true; } return false; } void Parser::parseGraphInputs() { TORCH_CHECK(curPos_ == 0); expect("graph"); const auto inputs = parseList( '(', ')', [&]() { return parseAtomicSymbol(); }); std::vector inputSpecs; inputSpecs.reserve(inputs.size()); for (const auto& input : inputs) { graph_->addInput(input, Type::Kind::Tensor); torch::_export::TensorArgument inputTensorArg; inputTensorArg.set_name(std::string{input}); torch::_export::Argument inputArg; inputArg.set_as_tensor(std::move(inputTensorArg)); torch::_export::UserInputSpec userInput; userInput.set_arg(std::move(inputArg)); torch::_export::InputSpec inputSpec; inputSpec.set_user_input(std::move(userInput)); inputSpecs.push_back(std::move(inputSpec)); } signature_.set_input_specs(std::move(inputSpecs)); // TODO populate graphinputs expect(":"); } template std::vector Parser::parseList( char open, char close, const std::function& parseFn) { std::vector ret; expect(open); // Handle empty list if (nextIf(close)) { return ret; } while (true) { ret.push_back(parseFn()); if (cur() == close) { break; } expect(','); } expect(close); return ret; } // Parse until `fn` returns true, returning the segment of the source that was // consumed. If `includeEnd` is true, the returned segment will also include // final character, which caused `fn` to return true. std::string_view Parser::parseUntil( const std::function& fn, bool includeEnd) { size_t start = curPos_; while (!fn()) { curPos_++; } if (includeEnd) { curPos_++; } return source_.substr(start, curPos_ - start); } // Parse a string, including the outer quotes std::string_view Parser::parseString() { size_t start = curPos_; expect('"'); while (cur() != '"') { // Handle escaped characters by skipping the next char when we see a // backslash if (cur() == '\\') { curPos_++; } curPos_++; } // Consume final quote curPos_++; auto ret = source_.substr(start, curPos_ - start); return ret; } bool Parser::validIdent(char n) { return std::isalpha(n) || n == '_' || std::isdigit(n); } // Symbols look like %v109, with the same valid ident rules as Python // This returns the symbol *without* the % at the front. std::string_view Parser::parseAtomicSymbol() { expect("%"); return parseUntil([&]() { return !validIdent(cur()); }); } char Parser::cur() { return source_.at(curPos_); } void Parser::consumeWhitespace() { while (isBlank(cur())) { curPos_++; } } void Parser::expect(std::string_view expected) { curPos_ = expectImpl(source_, expected, curPos_); } void Parser::expect(char expected) { curPos_ = expectImpl(source_, expected, curPos_); } bool Parser::nextEquals(std::string_view expected) const { const auto actual = source_.substr(curPos_, expected.size()); return expected == actual; } // %a, %b = aten.foo.default(input=%foo, foo=[7616], blah=%lol) void Parser::parseNode() { std::vector> outputs; outputs.push_back(parseOutput()); while (nextIf(",")) { outputs.push_back(parseOutput()); } expect("="); consumeWhitespace(); // parse target name const auto target = parseUntil([&]() { return cur() == '('; }); Node* node = graph_->insertNode(std::string(target)); for (auto& [name, var] : outputs) { node->addOutput(name, var); } auto arguments = parseList>( '(', ')', [&]() { return parseNamedArgument(); }); // Split the arguments into symbolic inputs and constant attributes for (auto& arg : arguments) { if (std::holds_alternative(arg)) { node->addInput(std::get(arg)); } else { node->addAttribute(std::get(std::move(arg))); } } } void Parser::parseReturn() { expect("return"); const auto returns = parseList('(', ')', [&]() { return parseSymbolicArgument(); }); std::vector outputSpecs; outputSpecs.reserve(returns.size()); for (const auto ret : returns) { graph_->addOutput(ret); torch::_export::TensorArgument retTensorArg; retTensorArg.set_name(std::string{ret->name()}); torch::_export::Argument retArg; retArg.set_as_tensor(std::move(retTensorArg)); torch::_export::UserOutputSpec userOutput; userOutput.set_arg(std::move(retArg)); torch::_export::OutputSpec outputSpec; outputSpec.set_user_output(std::move(userOutput)); outputSpecs.push_back(std::move(outputSpec)); } signature_.set_output_specs(std::move(outputSpecs)); } std::variant Parser::parseNamedArgument() { consumeWhitespace(); // Parse name const auto symbol = parseUntil([&]() { return cur() == '='; }); expect('='); // Parse value auto value = parseArgument(); if (std::holds_alternative(value)) { return NamedArgument{std::string(symbol), std::get(value)}; } else { return Attribute{std::string(symbol), std::get(std::move(value))}; } } std::pair Parser::parseOutput() { consumeWhitespace(); TORCH_CHECK(cur() == '%', fmt::format("expected % but got {}", cur())); auto symbol = parseAtomicSymbol(); if (nextIf('[')) { expect(']'); return {symbol, Type::Kind::TensorList}; } else { return {symbol, Type::Kind::Tensor}; } } Value* Parser::parseSymbolicArgument() { consumeWhitespace(); TORCH_CHECK(cur() == '%', fmt::format("expected % but got {}", cur())); auto symbol = parseAtomicSymbol(); std::vector listElements; if (cur() == '[') { listElements = parseList( '[', ']', [&]() { return graph_->getValue(parseAtomicSymbol()); }); } return graph_->getValue(symbol); } std::variant Parser::parseArgument() { consumeWhitespace(); // match symbol if (cur() == '%') { return parseSymbolicArgument(); } // match list if (cur() == '[') { const auto symbol = parseUntil([&]() { return cur() == ']'; }, /*includeEnd=*/true); return convertListConstant(symbol); } // match string if (cur() == '"') { return convertAtomicConstant(parseString()); } // otherwise parse this as a value const auto symbol = parseUntil([&]() { return cur() == ',' || cur() == ')'; }); return convertAtomicConstant(symbol); } } // namespace std::unique_ptr stringToGraph(std::string_view source) { return Parser(source).parse(); } std::string graphToString(const Graph& g, bool include_signature) { std::stringstream ss; ss << g; if (include_signature) { ss << "\nGraphSignature\n"; ss << g.signature(); } return ss.str(); } } // namespace torch::nativert