pytorch/test/cpp/lazy/test_trie_cache.cpp
Bin Bao f05710dd40 [LT] Add a trie data structure for caching IR nodes
Summary: TrieCache provides a way to look up an IR node before we
actually create it. If the lookup hits in TrieCache, we reuse the
existing node and move the current pointer in TrieCache to point to that
node; if the lookup misses, we create a new node and insert it into TrieCache.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76542

Approved by: https://github.com/wconstab, https://github.com/JackCaoG
2022-05-04 23:48:03 +00:00

89 lines
2.6 KiB
C++

#include <gtest/gtest.h>
#include <c10/util/Exception.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/ir_metadata.h>
#include <torch/csrc/lazy/core/ir_util.h>
#include <memory>
namespace torch {
namespace lazy {
class TrieCacheNode : public Node {
public:
explicit TrieCacheNode(size_t id)
: Node(OpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
~TrieCacheNode() override = default;
bool Equal(size_t id) const {
return (id_ == id);
}
void AddOperand(Value v) {
if (!v.node) {
return;
}
operands_as_outputs_.emplace_back(v.node.get(), v.index);
operands_.push_back(std::move(v.node));
}
hash_t hash() const override { return hash_; }
hash_t shapeHash() const override { return hash_; }
private:
size_t id_;
hash_t hash_;
};
TEST(TrieCacheTest, TestSinglePath) {
FLAGS_torch_lazy_reuse_ir = true;
TrieCache::Get()->Clear();
NodePtr a = MakeNode<TrieCacheNode>(0);
NodePtr b = MakeNode<TrieCacheNode>(1);
NodePtr c = MakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}
/*
* 0
* |
* 1
* / \
* 2 3
*/
TEST(TrieCacheTest, TestTwoPaths) {
FLAGS_torch_lazy_reuse_ir = true;
TrieCache::Get()->Clear();
NodePtr a = MakeNode<TrieCacheNode>(0);
NodePtr b = MakeNode<TrieCacheNode>(1);
NodePtr c = MakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3);
EXPECT_NE(d.get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3).get(), d.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}
} // namespace lazy
} // namespace torch