mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Currently OpKind is stored as an object field called op_ for each IR node, and one usage of op_ is to avoid dynamic_cast in NodeCast when we need to downcast a base-node pointer into a concrete sub-node pointer. As a result, we need to construct and pass in an op when downcasting nodes, and this becomes quite anonnying when we start to implement the trie-based IR node reusing. More importantly, the op for each subclass should be unique for that subclass and thus making it a const static field is a more logical design. In this PR, we still keep the object-level op_ for easier XLA adoption. As furture work, we can come back to remove op_, make the op() method virtual, and get rid of OpKind in all the node constructors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76711 Approved by: https://github.com/wconstab, https://github.com/JackCaoG
138 lines
4.3 KiB
C++
138 lines
4.3 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/csrc/lazy/generated/LazyIr.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/lazy/core/config.h>
|
|
#include <torch/csrc/lazy/core/ir.h>
|
|
#include <torch/csrc/lazy/core/ir_builder.h>
|
|
#include <torch/csrc/lazy/core/debug_util.h>
|
|
#include <torch/csrc/lazy/core/ir_metadata.h>
|
|
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
|
#include <memory>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
class TestLeafNode : public Node {
|
|
public:
|
|
static const OpKind class_op_kind;
|
|
|
|
explicit TestLeafNode(size_t param)
|
|
: Node(OpKind(), /* num_outputs */ 1),
|
|
hash_(Hash(param)),
|
|
param_(param) {}
|
|
~TestLeafNode() override = default;
|
|
|
|
const std::vector<Output>& operands() const override {
|
|
TORCH_INTERNAL_ASSERT(false, "Can't access operands of leaf node");
|
|
}
|
|
|
|
const Output& operand(size_t i) const override {
|
|
TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node");
|
|
}
|
|
|
|
hash_t hash() const override { return hash_; }
|
|
hash_t shapeHash() const override { return hash_; }
|
|
private:
|
|
hash_t hash_;
|
|
size_t param_;
|
|
};
|
|
|
|
const OpKind TestLeafNode::class_op_kind = OpKind();
|
|
|
|
TEST(IrTest, BasicTest) {
|
|
NodePtr node1 = MakeNode<TestLeafNode>(1);
|
|
NodePtr node2 = MakeNode<TestLeafNode>(2);
|
|
EXPECT_NE(node1->hash(), node2->hash());
|
|
|
|
EXPECT_EQ(node1->num_outputs(), 1);
|
|
|
|
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
|
|
EXPECT_TRUE(leafptr != nullptr);
|
|
}
|
|
|
|
TEST(IrTest, MetaDataTest) {
|
|
bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug;
|
|
FLAGS_torch_lazy_ir_debug = false;
|
|
NodePtr node = MakeNode<TestLeafNode>(1);
|
|
auto metaWithoutDebug = node->metadata();
|
|
EXPECT_EQ(metaWithoutDebug.scope.size(), 0);
|
|
EXPECT_EQ(metaWithoutDebug.frame_info.size(), 0);
|
|
|
|
FLAGS_torch_lazy_ir_debug = true;
|
|
node = MakeNode<TestLeafNode>(1);
|
|
auto metaWithEmptyDebug = node->metadata();
|
|
EXPECT_EQ(metaWithEmptyDebug.scope.size(), 0);
|
|
EXPECT_EQ(metaWithEmptyDebug.frame_info.size(), 1);
|
|
|
|
{
|
|
ScopePusher scope("TestScope");
|
|
node = MakeNode<TestLeafNode>(1);
|
|
auto metaWithScope = node->metadata();
|
|
EXPECT_EQ(metaWithScope.scope, "TestScope.1");
|
|
EXPECT_EQ(metaWithScope.frame_info.size(), 1);
|
|
}
|
|
|
|
SourceLocation dummySourceLocation;
|
|
dummySourceLocation.file = "file";
|
|
dummySourceLocation.function = "function";
|
|
dummySourceLocation.line = 10;
|
|
GetPythonFramesFunction() =
|
|
[&]() -> std::vector<SourceLocation> { return {dummySourceLocation}; };
|
|
node = MakeNode<TestLeafNode>(1);
|
|
auto metaWithSourceLoc = node->metadata();
|
|
EXPECT_EQ(metaWithSourceLoc.scope.size(), 0);
|
|
EXPECT_EQ(metaWithSourceLoc.frame_info.size(), 1);
|
|
EXPECT_EQ(metaWithSourceLoc.frame_info[0].file, "file");
|
|
EXPECT_EQ(metaWithSourceLoc.frame_info[0].function, "function");
|
|
EXPECT_EQ(metaWithSourceLoc.frame_info[0].line, 10);
|
|
FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug;
|
|
}
|
|
|
|
TEST(IrTest, TsNodeTest) {
|
|
NodePtr node1 = MakeNode<TsNode>(
|
|
OpKind(at::aten::view),
|
|
Shape(),
|
|
/*num_outputs*/ 1,
|
|
/*hash_seed*/ kHashSeed);
|
|
NodePtr node2 = MakeNode<TsNode>(
|
|
OpKind(at::aten::view),
|
|
Shape(),
|
|
/*num_outputs*/ 1,
|
|
/*hash_seed*/ kHashSeed);
|
|
EXPECT_EQ(node1->hash(), node2->hash());
|
|
|
|
EXPECT_EQ(node1->num_outputs(), 1);
|
|
|
|
const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
|
|
EXPECT_TRUE(leafptr != nullptr);
|
|
}
|
|
|
|
TEST(IrTest, DimensionNodeTest) {
|
|
|
|
const size_t DIM0 = 5;
|
|
const size_t DIM1 = 8;
|
|
NodePtr node1 = MakeNode<TsNode>(
|
|
OpKind(at::aten::view),
|
|
Shape(c10::kFloat, {DIM0, DIM1}),
|
|
/*num_outputs*/ 1,
|
|
/*hash_seed*/ kHashSeed);
|
|
|
|
auto size0 = std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 0));
|
|
auto size1 = std::dynamic_pointer_cast<SizeNode>(MakeNode<SizeNode>(Value{node1}, 1));
|
|
|
|
ASSERT_EQ(DIM0, size0->getStaticValue());
|
|
ASSERT_EQ(DIM1, size1->getStaticValue());
|
|
|
|
auto add_dim = std::dynamic_pointer_cast<SizeAdd>(MakeNode<SizeAdd>(Value{size0}, Value{size1}));
|
|
ASSERT_EQ(DIM0 + DIM1, add_dim->getStaticValue());
|
|
|
|
auto mul_dim = std::dynamic_pointer_cast<SizeMul>(MakeNode<SizeMul>(Value{size0}, Value{size1}));
|
|
ASSERT_EQ(DIM0 * DIM1, mul_dim->getStaticValue());
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|