mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: TrieCache provides a way to look up an IR node before we actually create it. If the lookup hits in TrieCache, we reuse the existing node and move the current pointer in TrieCache to point to that node; if the lookup misses, we create a new node and insert it into TrieCache. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76542 Approved by: https://github.com/wconstab, https://github.com/JackCaoG
89 lines
2.6 KiB
C++
89 lines
2.6 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/lazy/core/config.h>
|
|
#include <torch/csrc/lazy/core/ir.h>
|
|
#include <torch/csrc/lazy/core/ir_builder.h>
|
|
#include <torch/csrc/lazy/core/ir_metadata.h>
|
|
#include <torch/csrc/lazy/core/ir_util.h>
|
|
#include <memory>
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
class TrieCacheNode : public Node {
|
|
public:
|
|
explicit TrieCacheNode(size_t id)
|
|
: Node(OpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
|
|
~TrieCacheNode() override = default;
|
|
|
|
bool Equal(size_t id) const {
|
|
return (id_ == id);
|
|
}
|
|
|
|
void AddOperand(Value v) {
|
|
if (!v.node) {
|
|
return;
|
|
}
|
|
operands_as_outputs_.emplace_back(v.node.get(), v.index);
|
|
operands_.push_back(std::move(v.node));
|
|
}
|
|
|
|
hash_t hash() const override { return hash_; }
|
|
hash_t shapeHash() const override { return hash_; }
|
|
private:
|
|
size_t id_;
|
|
hash_t hash_;
|
|
};
|
|
|
|
TEST(TrieCacheTest, TestSinglePath) {
|
|
FLAGS_torch_lazy_reuse_ir = true;
|
|
TrieCache::Get()->Clear();
|
|
|
|
NodePtr a = MakeNode<TrieCacheNode>(0);
|
|
NodePtr b = MakeNode<TrieCacheNode>(1);
|
|
NodePtr c = MakeNode<TrieCacheNode>(2);
|
|
TrieCache::Get()->ResetCurrent(); // MarkStep
|
|
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
|
|
TrieCache::Get()->ResetCurrent(); // MarkStep
|
|
}
|
|
|
|
/*
|
|
* 0
|
|
* |
|
|
* 1
|
|
* / \
|
|
* 2 3
|
|
*/
|
|
TEST(TrieCacheTest, TestTwoPaths) {
|
|
FLAGS_torch_lazy_reuse_ir = true;
|
|
TrieCache::Get()->Clear();
|
|
|
|
NodePtr a = MakeNode<TrieCacheNode>(0);
|
|
NodePtr b = MakeNode<TrieCacheNode>(1);
|
|
NodePtr c = MakeNode<TrieCacheNode>(2);
|
|
TrieCache::Get()->ResetCurrent(); // MarkStep
|
|
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
|
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3);
|
|
EXPECT_NE(d.get(), c.get());
|
|
TrieCache::Get()->ResetCurrent(); // MarkStep
|
|
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3).get(), d.get());
|
|
TrieCache::Get()->ResetCurrent(); // MarkStep
|
|
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
|
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
|
|
TrieCache::Get()->ResetCurrent(); // MarkStep
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|