#include #include #include #include #include #include #include namespace torch { namespace lazy { class CacheNode : public Node { public: explicit CacheNode(const std::string& str) : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(str)), str_(str) {} ~CacheNode() override = default; const std::vector& operands() const override { TORCH_INTERNAL_ASSERT(false, "Can't access operands of test node"); } const Output& operand(size_t i) const override { TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node"); } hash_t hash() const override { return hash_; } hash_t shapeHash() const override { return hash_; } private: hash_t hash_; std::string str_; }; TEST(CacheTest, BasicTest) { std::shared_ptr a = std::make_shared("a"); std::shared_ptr b = std::make_shared("b"); std::shared_ptr c = std::make_shared("c"); Cache cache(2); cache.Add(a->hash(), a); EXPECT_EQ(cache.Get(a->hash()), a); EXPECT_EQ(cache.Get(b->hash()), nullptr); EXPECT_EQ(cache.Get(c->hash()), nullptr); cache.Add(b->hash(), b); EXPECT_EQ(cache.Get(a->hash()), a); EXPECT_EQ(cache.Get(b->hash()), b); EXPECT_EQ(cache.Get(c->hash()), nullptr); cache.Add(c->hash(), c); EXPECT_EQ(cache.Get(a->hash()), nullptr); // a has been evicted EXPECT_EQ(cache.Get(b->hash()), b); EXPECT_EQ(cache.Get(c->hash()), c); cache.Erase(c->hash()); EXPECT_EQ(cache.Get(a->hash()), nullptr); EXPECT_EQ(cache.Get(b->hash()), b); EXPECT_EQ(cache.Get(c->hash()), nullptr); // c has been removed cache.Clear(); EXPECT_EQ(cache.Get(a->hash()), nullptr); EXPECT_EQ(cache.Get(b->hash()), nullptr); EXPECT_EQ(cache.Get(c->hash()), nullptr); } class CacheNodeWithShape : public TsNode { public: explicit CacheNodeWithShape(const Shape& shape) : TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0) {} }; TEST(CacheTest, ShapeCacheTestForDynamicShape) { // enable dynamic shape FLAGS_ltc_enable_dynamic_shapes = true; CacheNodeWithShape nodes[] = { CacheNodeWithShape(Shape(c10::kFloat, {2, 4})), CacheNodeWithShape(Shape(c10::kFloat, {4, 2}))}; /* * Make sure the cached shape for node (2, 4) is not used for node (4, 2) */ for (auto& node : nodes) { EXPECT_EQ(node.shape(), node.computeShape([&]() { return node.shape(); })); } // reset the flag FLAGS_ltc_enable_dynamic_shapes = false; } } // namespace lazy } // namespace torch