#include #include #include #include #include #include #include #include #include namespace torch { namespace jit { namespace { bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) { return lhs.options().type_equal(rhs.options()) && lhs.equal(rhs); } bool tensorListEqual( const std::vector& lhs, const std::vector& rhs) { if (lhs.size() != rhs.size()) return false; return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual); } bool typeListEqual( const std::vector& lhs, const std::vector& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { if (*lhs[i] != *rhs[i]) { return false; } } return true; } // Check whether two nodes have the same attributes in CSE. // This function may be too conservative for general use. // Do NOT support g/gs attributes. bool attributesEqualCSE(const Node* lhs, const Node* rhs) { AT_ASSERT(lhs != nullptr); AT_ASSERT(rhs != nullptr); // One has attributes, the other does not. if (lhs->hasAttributes() != rhs->hasAttributes()) return false; // Neither has attributes. if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true; auto lnames = lhs->attributeNames(); auto rnames = rhs->attributeNames(); std::sort(lnames.begin(), lnames.end()); std::sort(rnames.begin(), rnames.end()); if (lnames != rnames) return false; for (auto name : lnames) { if (lhs->kindOf(name) != rhs->kindOf(name)) return false; #define COMPARE_ATTRIBUTEVALUE(selector) \ case AttributeKind::selector: { \ if (lhs->selector(name) != rhs->selector(name)) \ return false; \ } break; switch (lhs->kindOf(name)) { COMPARE_ATTRIBUTEVALUE(f) COMPARE_ATTRIBUTEVALUE(fs) COMPARE_ATTRIBUTEVALUE(i) COMPARE_ATTRIBUTEVALUE(is) COMPARE_ATTRIBUTEVALUE(s) COMPARE_ATTRIBUTEVALUE(ss) case AttributeKind::t: { if (!tensorEqual(lhs->t(name), rhs->t(name))) return false; break; } case AttributeKind::ts: { if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false; break; } case AttributeKind::ty: if (*lhs->ty(name) != *rhs->ty(name)) { return false; } break; case AttributeKind::tys: if (!typeListEqual(lhs->tys(name), rhs->tys(name))) { return false; } break; case AttributeKind::g: case AttributeKind::gs: return false; } #undef COMPARE_ATTRIBUTEVALUE } return true; } } // anonymous namespace size_t HashNode::operator()(const Node* k) const { AT_ASSERT(k != nullptr); size_t constant_hash = 0; if (k->kind() == prim::Constant) { TypePtr type = k->output()->type(); if (type->isSubtypeOf(NumberType::get()) && k->kindOf(attr::value) == AttributeKind::i) { constant_hash = std::hash{}(k->i(attr::value)); } else if (type->isSubtypeOf(NumberType::get()) && k->kindOf(attr::value) == AttributeKind::f) { constant_hash = std::hash{}(k->f(attr::value)); } else if (type->isSubtypeOf(BoolType::get())) { constant_hash = std::hash{}(k->i(attr::value)); } } return get_hash( k->kind(), fmap(k->outputs(), [](const Value* v) { return v->type()->kind(); }), fmap(k->inputs(), [](const Value* v) { return v->unique(); }), constant_hash); }; bool EqualNode::operator()(const Node* lhs, const Node* rhs) const { if (lhs == nullptr && rhs == nullptr) return true; if (lhs == nullptr || rhs == nullptr) return false; if (lhs->kind() != rhs->kind()) return false; // Check whether the output types are the same. auto lhs_outputs = lhs->outputs(); auto rhs_outputs = rhs->outputs(); if (lhs_outputs.size() != rhs_outputs.size()) return false; for (size_t i = 0; i < lhs_outputs.size(); ++i) { if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type()) return false; if (lhs_outputs[i]->type() == CapsuleType::get()) return false; } // Check whether the inputs are the same. auto lhs_inputs = lhs->inputs(); auto rhs_inputs = rhs->inputs(); if (lhs_inputs.size() != rhs_inputs.size()) return false; if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false; if (!attributesEqualCSE(lhs, rhs)) return false; return true; }; } // namespace jit } // namespace torch