#include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { using namespace torch::indexing; using namespace torch::jit::tensorexpr; void testKernel_1() { KernelScope kernel_scope; const auto graph_string = R"IR( graph(%0 : Float(5:3,3:1), %1 : Float(5:3,3:1)): %2 : Float(5:3,3:1) = aten::mul(%0, %1) %3 : Float(5:3,3:1) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); auto ref = a * (a * b); TensorExprKernel k(graph); std::vector inputs = {a, b}; Stmt* s = k.getCodeGenStmt(); // TODO: verify stmt std::vector stack = fmap(inputs); k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } void testKernel_2() { KernelScope kernel_scope; const auto graph_string = R"IR( graph(%0 : Float(5:3,3:1), %1 : Float(5:1,3:5)): %2 : Float(5:3,3:1) = aten::mul(%0, %1) %3 : Float(5:3,3:1) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); auto b = at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); auto ref = a * (a * b); TensorExprKernel k(graph); std::vector inputs = {a, b}; Stmt* s = k.getCodeGenStmt(); // TODO: verify stmt std::vector stack = fmap(inputs); k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } void testKernel_3() { KernelScope kernel_scope; const auto graph_string = R"IR( graph(%0 : Float(5:3,3:1), %1 : Float(5:12,3:2)): %2 : Float(5:3,3:1) = aten::mul(%0, %1) %3 : Float(5:3,3:1) = aten::mul(%0, %2) return (%3))IR"; auto graph = std::make_shared(); parseIR(graph_string, &*graph); auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) .index({Slice(None, None, 2), Slice(None, None, 2)}); auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); auto ref = a * (a * b); TensorExprKernel k(graph); std::vector inputs = {a, b}; Stmt* s = k.getCodeGenStmt(); // TODO: verify stmt std::vector stack = fmap(inputs); k.run(stack); o = stack[0].toTensor(); for (size_t i = 0; i < 5 * 3; i++) { CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); } } } // namespace jit } // namespace torch