#include #include #include #include #include #include namespace torch::nativert { at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) { return a + b; } TEST(C10KernelTest, computeInternal) { auto registrar = c10::RegisterOperators().op( "test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel); static constexpr std::string_view source = R"(graph(%a, %b): %x = test.foo.default(a=%a, b=%b) return (%x) )"; auto graph = stringToGraph(source); const auto& nodes = graph->nodes(); auto it = nodes.begin(); std::advance(it, 1); const Node& node = *it; c10::Device device = torch::Device(torch::kCPU, 0); auto a = at::randn({6, 6, 6}); auto b = at::randn({6, 6, 6}); auto frame = ExecutionFrame(*graph); frame.setIValue(graph->getValue("a")->id(), a); frame.setIValue(graph->getValue("b")->id(), b); auto kernel = C10Kernel(&node, device); kernel.computeInternal(frame); at::Tensor expected = a + b; EXPECT_TRUE( torch::equal(frame.getTensor(graph->getValue("x")->id()), expected)); } TEST(ScalarBinaryOpKernelTest, computeInternal) { static constexpr std::string_view source = R"(graph(%a, %b): %x = _operator.add(a=%a, b=%b) return (%x) )"; auto graph = stringToGraph(source); const auto& nodes = graph->nodes(); auto it = nodes.begin(); std::advance(it, 1); const Node& node = *it; auto a = 1; auto b = 2; auto frame = ExecutionFrame(*graph); frame.setIValue(graph->getValue("a")->id(), a); frame.setIValue(graph->getValue("b")->id(), b); auto kernel = ScalarBinaryOpKernel(&node); kernel.computeInternal(frame); auto expected = a + b; EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected); } } // namespace torch::nativert