#include #include #include #include #include namespace torch { namespace lazy { class CacheNode : public Node { public: explicit CacheNode(const std::string& str) : Node(OpKind(), /* num_outputs */ 1, /* hash_seed */ Hash(str)), str_(str) {} ~CacheNode() override = default; const std::vector& operands() const override { TORCH_INTERNAL_ASSERT(false, "Can't access operands of test node"); } const Output& operand(size_t i) const override { TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node"); } private: std::string str_; }; TEST(CacheTest, BasicTest) { std::shared_ptr a = std::make_shared("a"); std::shared_ptr b = std::make_shared("b"); std::shared_ptr c = std::make_shared("c"); Cache cache(2); cache.Add(a->node_hash(), a); EXPECT_EQ(cache.Get(a->node_hash()), a); EXPECT_EQ(cache.Get(b->node_hash()), nullptr); EXPECT_EQ(cache.Get(c->node_hash()), nullptr); cache.Add(b->node_hash(), b); EXPECT_EQ(cache.Get(a->node_hash()), a); EXPECT_EQ(cache.Get(b->node_hash()), b); EXPECT_EQ(cache.Get(c->node_hash()), nullptr); cache.Add(c->node_hash(), c); EXPECT_EQ(cache.Get(a->node_hash()), nullptr); // a has been evicted EXPECT_EQ(cache.Get(b->node_hash()), b); EXPECT_EQ(cache.Get(c->node_hash()), c); cache.Erase(c->node_hash()); EXPECT_EQ(cache.Get(a->node_hash()), nullptr); EXPECT_EQ(cache.Get(b->node_hash()), b); EXPECT_EQ(cache.Get(c->node_hash()), nullptr); // c has been removed cache.Clear(); EXPECT_EQ(cache.Get(a->node_hash()), nullptr); EXPECT_EQ(cache.Get(b->node_hash()), nullptr); EXPECT_EQ(cache.Get(c->node_hash()), nullptr); } } // namespace lazy } // namespace torch