pytorch/test/cpp/nativert/test_c10_kernel.cpp
Shangdi Yu e4c9f6d9a2 [nativert] Move c10_kernel (#156208)
Summary:
Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72

As part of the effort to open source TorchNativeRuntime (or what we call Sigmoid), we are moving the Pytree implementation to torch/:

fbcode/sigmoid/kernels -> fbcode/caffe2/torch/nativert/kernels

Test Plan:
```
buck run fbcode//mode/dev-nosan  //caffe2/test/cpp/nativert:c10_kernel_test
```

Differential Revision: D76825830

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156208
Approved by: https://github.com/zhxchen17
2025-06-19 17:36:23 +00:00

77 lines
1.9 KiB
C++

#include <ATen/core/op_registration/op_registration.h>
#include <gtest/gtest.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/kernels/C10Kernel.h>
#include <torch/torch.h>
namespace torch::nativert {
at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) {
return a + b;
}
TEST(C10KernelTest, computeInternal) {
auto registrar = c10::RegisterOperators().op(
"test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel);
static constexpr std::string_view source =
R"(graph(%a, %b):
%x = test.foo.default(a=%a, b=%b)
return (%x)
)";
auto graph = stringToGraph(source);
const auto& nodes = graph->nodes();
auto it = nodes.begin();
std::advance(it, 1);
const Node& node = *it;
c10::Device device = torch::Device(torch::kCPU, 0);
auto a = at::randn({6, 6, 6});
auto b = at::randn({6, 6, 6});
auto frame = ExecutionFrame(*graph);
frame.setIValue(graph->getValue("a")->id(), a);
frame.setIValue(graph->getValue("b")->id(), b);
auto kernel = C10Kernel(&node, device);
kernel.computeInternal(frame);
at::Tensor expected = a + b;
EXPECT_TRUE(
torch::equal(frame.getTensor(graph->getValue("x")->id()), expected));
}
TEST(ScalarBinaryOpKernelTest, computeInternal) {
static constexpr std::string_view source =
R"(graph(%a, %b):
%x = _operator.add(a=%a, b=%b)
return (%x)
)";
auto graph = stringToGraph(source);
const auto& nodes = graph->nodes();
auto it = nodes.begin();
std::advance(it, 1);
const Node& node = *it;
auto a = 1;
auto b = 2;
auto frame = ExecutionFrame(*graph);
frame.setIValue(graph->getValue("a")->id(), a);
frame.setIValue(graph->getValue("b")->id(), b);
auto kernel = ScalarBinaryOpKernel(&node);
kernel.computeInternal(frame);
auto expected = a + b;
EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected);
}
} // namespace torch::nativert