mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Next stage of breaking up https://github.com/pytorch/pytorch/pull/74710 Move shape cache implementation to the backend interface. Also, clean up some of the hashing logic in the base node class. CC: wconstab JackCaoG henrytwo Partially Fixes https://github.com/pytorch/pytorch/issues/74628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75324 Reviewed By: anjali411 Differential Revision: D35730823 Pulled By: wconstab fbshipit-source-id: cf6fa326319b9324e5f422a78817b6fb5bf7e9b8 (cherry picked from commit faec5043df56639e2fd23de2d91ae796e4f3df70)
97 lines
2.7 KiB
C++
97 lines
2.7 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/lazy/core/cache.h>
|
|
#include <torch/csrc/lazy/core/hash.h>
|
|
#include <torch/csrc/lazy/core/ir.h>
|
|
#include <torch/csrc/lazy/core/shape.h>
|
|
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
class CacheNode : public Node {
|
|
public:
|
|
explicit CacheNode(const std::string& str)
|
|
: Node(OpKind(), /* num_outputs */ 1),
|
|
hash_(Hash(str)),
|
|
str_(str) {}
|
|
~CacheNode() override = default;
|
|
|
|
const std::vector<Output>& operands() const override {
|
|
TORCH_INTERNAL_ASSERT(false, "Can't access operands of test node");
|
|
}
|
|
|
|
const Output& operand(size_t i) const override {
|
|
TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node");
|
|
}
|
|
|
|
hash_t hash() const override { return hash_; }
|
|
hash_t shapeHash() const override { return hash_; }
|
|
private:
|
|
hash_t hash_;
|
|
std::string str_;
|
|
};
|
|
|
|
TEST(CacheTest, BasicTest) {
|
|
std::shared_ptr<CacheNode> a = std::make_shared<CacheNode>("a");
|
|
std::shared_ptr<CacheNode> b = std::make_shared<CacheNode>("b");
|
|
std::shared_ptr<CacheNode> c = std::make_shared<CacheNode>("c");
|
|
Cache<hash_t, CacheNode, HashReducer> cache(2);
|
|
|
|
cache.Add(a->hash(), a);
|
|
EXPECT_EQ(cache.Get(a->hash()), a);
|
|
EXPECT_EQ(cache.Get(b->hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(c->hash()), nullptr);
|
|
|
|
cache.Add(b->hash(), b);
|
|
EXPECT_EQ(cache.Get(a->hash()), a);
|
|
EXPECT_EQ(cache.Get(b->hash()), b);
|
|
EXPECT_EQ(cache.Get(c->hash()), nullptr);
|
|
|
|
cache.Add(c->hash(), c);
|
|
EXPECT_EQ(cache.Get(a->hash()), nullptr); // a has been evicted
|
|
EXPECT_EQ(cache.Get(b->hash()), b);
|
|
EXPECT_EQ(cache.Get(c->hash()), c);
|
|
|
|
cache.Erase(c->hash());
|
|
EXPECT_EQ(cache.Get(a->hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(b->hash()), b);
|
|
EXPECT_EQ(cache.Get(c->hash()), nullptr); // c has been removed
|
|
|
|
cache.Clear();
|
|
EXPECT_EQ(cache.Get(a->hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(b->hash()), nullptr);
|
|
EXPECT_EQ(cache.Get(c->hash()), nullptr);
|
|
}
|
|
|
|
class CacheNodeWithShape : public TsNode {
|
|
public:
|
|
explicit CacheNodeWithShape(const Shape& shape)
|
|
: TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0){}
|
|
};
|
|
|
|
TEST(CacheTest, ShapeCacheTestForDynamicShape) {
|
|
// enable dynamic shape
|
|
FLAGS_ltc_enable_dynamic_shapes = true;
|
|
|
|
CacheNodeWithShape nodes[] = {
|
|
CacheNodeWithShape(Shape(c10::kFloat, {2, 4})),
|
|
CacheNodeWithShape(Shape(c10::kFloat, {4, 2})) };
|
|
|
|
/*
|
|
* Make sure the cached shape for node (2, 4) is not used for node (4, 2)
|
|
*/
|
|
for (auto& node : nodes) {
|
|
EXPECT_EQ(node.shape(), node.computeShape([&]() {
|
|
return node.shape();
|
|
}));
|
|
}
|
|
|
|
// reset the flag
|
|
FLAGS_ltc_enable_dynamic_shapes = false;
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|