#include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace lazy { class TestLeafNode : public Node { public: static OpKind ClassOpKind() { return OpKind(); } explicit TestLeafNode(size_t param) : Node(ClassOpKind(), /* num_outputs */ 1), hash_(Hash(param)), param_(param) {} ~TestLeafNode() override = default; const std::vector& operands() const override { TORCH_INTERNAL_ASSERT(false, "Can't access operands of leaf node"); } const Output& operand(size_t i) const override { TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of leaf node"); } hash_t hash() const override { return hash_; } hash_t shapeHash() const override { return hash_; } private: hash_t hash_; size_t param_; }; TEST(IrTest, BasicTest) { NodePtr node1 = MakeNode(1); NodePtr node2 = MakeNode(2); EXPECT_NE(node1->hash(), node2->hash()); EXPECT_EQ(node1->num_outputs(), 1); const TestLeafNode* leafptr = NodeCast(node1.get()); EXPECT_TRUE(leafptr != nullptr); } TEST(IrTest, MetaDataTest) { bool restore_FLAGS_torch_lazy_ir_debug = FLAGS_torch_lazy_ir_debug; FLAGS_torch_lazy_ir_debug = false; NodePtr node = MakeNode(1); auto metaWithoutDebug = node->metadata(); EXPECT_EQ(metaWithoutDebug.scope.size(), 0); EXPECT_EQ(metaWithoutDebug.frame_info.size(), 0); FLAGS_torch_lazy_ir_debug = true; node = MakeNode(1); auto metaWithEmptyDebug = node->metadata(); EXPECT_EQ(metaWithEmptyDebug.scope.size(), 0); EXPECT_EQ(metaWithEmptyDebug.frame_info.size(), 1); { ScopePusher scope("TestScope"); node = MakeNode(1); auto metaWithScope = node->metadata(); EXPECT_EQ(metaWithScope.scope, "TestScope.1"); EXPECT_EQ(metaWithScope.frame_info.size(), 1); } SourceLocation dummySourceLocation; dummySourceLocation.file = "file"; dummySourceLocation.function = "function"; dummySourceLocation.line = 10; GetPythonFramesFunction() = [&]() -> std::vector { return {dummySourceLocation}; }; node = MakeNode(1); auto metaWithSourceLoc = node->metadata(); EXPECT_EQ(metaWithSourceLoc.scope.size(), 0); EXPECT_EQ(metaWithSourceLoc.frame_info.size(), 1); EXPECT_EQ(metaWithSourceLoc.frame_info[0].file, "file"); EXPECT_EQ(metaWithSourceLoc.frame_info[0].function, "function"); EXPECT_EQ(metaWithSourceLoc.frame_info[0].line, 10); FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug; } TEST(IrTest, TsNodeTest) { NodePtr node1 = MakeNode( OpKind(at::aten::view), Shape(), /*num_outputs*/ 1, /*hash_seed*/ kHashSeed); NodePtr node2 = MakeNode( OpKind(at::aten::view), Shape(), /*num_outputs*/ 1, /*hash_seed*/ kHashSeed); EXPECT_EQ(node1->hash(), node2->hash()); EXPECT_EQ(node1->num_outputs(), 1); const TsNode* leafptr = dynamic_cast(node1.get()); EXPECT_TRUE(leafptr != nullptr); } TEST(IrTest, DimensionNodeTest) { const size_t DIM0 = 5; const size_t DIM1 = 8; NodePtr node1 = MakeNode( OpKind(at::aten::view), Shape(c10::kFloat, {DIM0, DIM1}), /*num_outputs*/ 1, /*hash_seed*/ kHashSeed); auto size0 = std::dynamic_pointer_cast(MakeNode(Value{node1}, 0)); auto size1 = std::dynamic_pointer_cast(MakeNode(Value{node1}, 1)); ASSERT_EQ(DIM0, size0->getStaticValue()); ASSERT_EQ(DIM1, size1->getStaticValue()); NodePtr size0_np = size0; auto size0_dn = std::dynamic_pointer_cast(size0_np); ASSERT_EQ(DIM0, size0_dn->getStaticValue()); auto add_dim = std::dynamic_pointer_cast( MakeNode(Value{size0}, Value{size1})); ASSERT_EQ(DIM0 + DIM1, add_dim->getStaticValue()); auto mul_dim = std::dynamic_pointer_cast( MakeNode(Value{size0}, Value{size1})); ASSERT_EQ(DIM0 * DIM1, mul_dim->getStaticValue()); } TEST(IrTest, DimensionIsDynamicTest) { const size_t DIM0 = 5; const size_t DIM1 = 8; const auto shape = Shape(c10::kFloat, {DIM0, DIM1}); NodePtr node1 = MakeNode( OpKind(at::aten::view), shape.with_symbolic_dims(std::vector{true, false}), /*num_outputs*/ 1, /*hash_seed*/ kHashSeed); auto size0 = std::dynamic_pointer_cast(MakeNode(Value{node1}, 0)); auto size1 = std::dynamic_pointer_cast(MakeNode(Value{node1}, 1)); ASSERT_EQ(true, size0->isSymbolic()); ASSERT_EQ(false, size1->isSymbolic()); auto add_dim = std::dynamic_pointer_cast( MakeNode(Value{size0}, Value{size1})); ASSERT_EQ(true, add_dim->isSymbolic()); add_dim = std::dynamic_pointer_cast( MakeNode(Value{size1}, Value{size1})); ASSERT_EQ(false, add_dim->isSymbolic()); auto mul_dim = std::dynamic_pointer_cast( MakeNode(Value{size0}, Value{size0})); ASSERT_EQ(true, mul_dim->isSymbolic()); } } // namespace lazy } // namespace torch