#include #include #include #include #include #include #include #include namespace torch { namespace lazy { class TrieCacheNode : public Node { public: static OpKind ClassOpKind() { return OpKind(); } explicit TrieCacheNode(size_t id) : Node(ClassOpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {} ~TrieCacheNode() override = default; bool CanBeReused(size_t id) const { return (id_ == id); } void AddOperand(Value v) { if (!v.node) { return; } operands_as_outputs_.emplace_back(v.node.get(), v.index); operands_.push_back(std::move(v.node)); } hash_t hash() const override { return hash_; } hash_t shapeHash() const override { return hash_; } private: size_t id_; hash_t hash_; }; TEST(TrieCacheTest, TestSinglePath) { FLAGS_torch_lazy_reuse_ir = true; TrieCache::Get()->Clear(); NodePtr a = ReuseOrMakeNode(0); NodePtr b = ReuseOrMakeNode(1); NodePtr c = ReuseOrMakeNode(2); TrieCache::Get()->ResetCurrent(); // MarkStep EXPECT_EQ(ReuseOrMakeNode(0).get(), a.get()); EXPECT_EQ(ReuseOrMakeNode(1).get(), b.get()); EXPECT_EQ(ReuseOrMakeNode(2).get(), c.get()); TrieCache::Get()->ResetCurrent(); // MarkStep } /* * 0 * | * 1 * / \ * 2 3 */ TEST(TrieCacheTest, TestTwoPaths) { FLAGS_torch_lazy_reuse_ir = true; TrieCache::Get()->Clear(); NodePtr a = ReuseOrMakeNode(0); NodePtr b = ReuseOrMakeNode(1); NodePtr c = ReuseOrMakeNode(2); TrieCache::Get()->ResetCurrent(); // MarkStep EXPECT_EQ(ReuseOrMakeNode(0).get(), a.get()); EXPECT_EQ(ReuseOrMakeNode(1).get(), b.get()); NodePtr d = ReuseOrMakeNode(3); EXPECT_NE(d.get(), c.get()); TrieCache::Get()->ResetCurrent(); // MarkStep EXPECT_EQ(ReuseOrMakeNode(0).get(), a.get()); EXPECT_EQ(ReuseOrMakeNode(1).get(), b.get()); EXPECT_EQ(ReuseOrMakeNode(3).get(), d.get()); TrieCache::Get()->ResetCurrent(); // MarkStep EXPECT_EQ(ReuseOrMakeNode(0).get(), a.get()); EXPECT_EQ(ReuseOrMakeNode(1).get(), b.get()); EXPECT_EQ(ReuseOrMakeNode(2).get(), c.get()); TrieCache::Get()->ResetCurrent(); // MarkStep } } // namespace lazy } // namespace torch