mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72730
This diff contains changes from several PRs landed to lazy_tensor_staging branch.
- generating 'fallback' overrides for each codegenned op, useful for debugging
- supports operators which are missing aten:: symbols for op names, instead using their string counterpart
- makes the IR class a base class instead of hardcoding the assumption of TS
Test Plan: tested on lazy_tensor_staging branch
Reviewed By: desertfire
Differential Revision: D34178476
fbshipit-source-id: 7190b2e0d82b4eb1f4510c858c24446c6df3f9d0
(cherry picked from commit 6713d3f0ef)
64 lines
1.9 KiB
C++
64 lines
1.9 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/lazy/core/cache.h>
|
|
#include <torch/csrc/lazy/core/hash.h>
|
|
#include <torch/csrc/lazy/core/ir.h>
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
class CacheNode : public Node {
|
|
public:
|
|
explicit CacheNode(const std::string& str)
|
|
: Node(OpKind(), /* num_outputs */ 1, /* hash_func */ [&](bool bakeInSizes) -> hash_t { return Hash(str); }),
|
|
str_(str) {}
|
|
~CacheNode() override = default;
|
|
|
|
const std::vector<Output>& operands() const override {
|
|
TORCH_INTERNAL_ASSERT(false, "Can't access operands of test node");
|
|
}
|
|
|
|
const Output& operand(size_t i) const override {
|
|
TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node");
|
|
}
|
|
|
|
private:
|
|
std::string str_;
|
|
};
|
|
|
|
TEST(CacheTest, BasicTest) {
|
|
std::shared_ptr<CacheNode> a = std::make_shared<CacheNode>("a");
|
|
std::shared_ptr<CacheNode> b = std::make_shared<CacheNode>("b");
|
|
std::shared_ptr<CacheNode> c = std::make_shared<CacheNode>("c");
|
|
Cache<hash_t, CacheNode, HashReducer> cache(2);
|
|
|
|
cache.Add(a->node_hash(), a);
|
|
EXPECT_EQ(cache.Get(a->node_hash()), a);
|
|
EXPECT_EQ(cache.Get(b->node_hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(c->node_hash()), nullptr);
|
|
|
|
cache.Add(b->node_hash(), b);
|
|
EXPECT_EQ(cache.Get(a->node_hash()), a);
|
|
EXPECT_EQ(cache.Get(b->node_hash()), b);
|
|
EXPECT_EQ(cache.Get(c->node_hash()), nullptr);
|
|
|
|
cache.Add(c->node_hash(), c);
|
|
EXPECT_EQ(cache.Get(a->node_hash()), nullptr); // a has been evicted
|
|
EXPECT_EQ(cache.Get(b->node_hash()), b);
|
|
EXPECT_EQ(cache.Get(c->node_hash()), c);
|
|
|
|
cache.Erase(c->node_hash());
|
|
EXPECT_EQ(cache.Get(a->node_hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(b->node_hash()), b);
|
|
EXPECT_EQ(cache.Get(c->node_hash()), nullptr); // c has been removed
|
|
|
|
cache.Clear();
|
|
EXPECT_EQ(cache.Get(a->node_hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(b->node_hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(c->node_hash()), nullptr);
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|