mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72 As part of the effort to open source TorchNativeRuntime (or what we call Sigmoid), we are moving the Pytree implementation to torch/: fbcode/sigmoid/kernels -> fbcode/caffe2/torch/nativert/kernels Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/cpp/nativert:c10_kernel_test ``` Differential Revision: D76825830 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156208 Approved by: https://github.com/zhxchen17
77 lines
1.9 KiB
C++
77 lines
1.9 KiB
C++
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <gtest/gtest.h>
|
|
#include <torch/nativert/executor/ExecutionFrame.h>
|
|
#include <torch/nativert/graph/Graph.h>
|
|
#include <torch/nativert/kernels/C10Kernel.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) {
|
|
return a + b;
|
|
}
|
|
|
|
TEST(C10KernelTest, computeInternal) {
|
|
auto registrar = c10::RegisterOperators().op(
|
|
"test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel);
|
|
|
|
static constexpr std::string_view source =
|
|
R"(graph(%a, %b):
|
|
%x = test.foo.default(a=%a, b=%b)
|
|
return (%x)
|
|
)";
|
|
|
|
auto graph = stringToGraph(source);
|
|
const auto& nodes = graph->nodes();
|
|
auto it = nodes.begin();
|
|
std::advance(it, 1);
|
|
const Node& node = *it;
|
|
|
|
c10::Device device = torch::Device(torch::kCPU, 0);
|
|
|
|
auto a = at::randn({6, 6, 6});
|
|
auto b = at::randn({6, 6, 6});
|
|
|
|
auto frame = ExecutionFrame(*graph);
|
|
frame.setIValue(graph->getValue("a")->id(), a);
|
|
frame.setIValue(graph->getValue("b")->id(), b);
|
|
|
|
auto kernel = C10Kernel(&node, device);
|
|
|
|
kernel.computeInternal(frame);
|
|
|
|
at::Tensor expected = a + b;
|
|
EXPECT_TRUE(
|
|
torch::equal(frame.getTensor(graph->getValue("x")->id()), expected));
|
|
}
|
|
|
|
TEST(ScalarBinaryOpKernelTest, computeInternal) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%a, %b):
|
|
%x = _operator.add(a=%a, b=%b)
|
|
return (%x)
|
|
)";
|
|
|
|
auto graph = stringToGraph(source);
|
|
const auto& nodes = graph->nodes();
|
|
auto it = nodes.begin();
|
|
std::advance(it, 1);
|
|
const Node& node = *it;
|
|
|
|
auto a = 1;
|
|
auto b = 2;
|
|
|
|
auto frame = ExecutionFrame(*graph);
|
|
frame.setIValue(graph->getValue("a")->id(), a);
|
|
frame.setIValue(graph->getValue("b")->id(), b);
|
|
|
|
auto kernel = ScalarBinaryOpKernel(&node);
|
|
|
|
kernel.computeInternal(frame);
|
|
|
|
auto expected = a + b;
|
|
EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected);
|
|
}
|
|
|
|
} // namespace torch::nativert
|