diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index a9990e23e6c..9bbd387c1bb 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -2532,6 +2532,87 @@ TEST(StaticRuntime, Tensor_Split) { testStaticRuntime(tensor_split_str3, args3); } +TEST(StaticRuntime, JIT_Aten_Cpu) { + const std::string script = R"IR( + graph(%a: Tensor): + %1 : int = prim::Constant[value=0]() + %aa: Tensor = aten::add(%a, %a, %1) + %ret: Tensor = aten::cpu(%aa) + return (%ret) + )IR"; + + auto graph = std::make_shared(); + std::unordered_map vmap; + vmap.reserve(0); + parseIR(script, graph.get(), vmap); + torch::jit::StaticModule smodule(graph); + + auto a = at::randn({2, 4}); + std::vector args0{a}; + + testStaticRuntime(script, args0); +} + +TEST(StaticRuntime, JIT_Aten_Numel) { + const std::string script = R"IR( + graph(%a: Tensor): + %1 : int = prim::Constant[value=0]() + %aa: Tensor = aten::add(%a, %a, %1) + %ret: int = aten::numel(%aa) + return (%ret) + )IR"; + + auto graph = std::make_shared(); + std::unordered_map vmap; + vmap.reserve(0); + parseIR(script, graph.get(), vmap); + torch::jit::StaticModule smodule(graph); + + auto a = at::randn({2, 4}); + std::vector args0{a}; + + testStaticRuntime(script, args0); +} + +TEST(StaticRuntime, JIT_Aten_List) { + const std::string script = R"IR( + graph(%a: str): + %1 : int = prim::Constant[value=0]() + %ret: str[] = aten::list(%a) + return (%ret) + )IR"; + + auto graph = std::make_shared(); + std::unordered_map vmap; + vmap.reserve(0); + parseIR(script, graph.get(), vmap); + torch::jit::StaticModule smodule(graph); + + string a = "abcd"; + std::vector args0{a}; + + testStaticRuntime(script, args0); +} + +TEST(StaticRuntime, JIT_Aten_Range_Length) { + const std::string script = R"IR( + graph(%lo: int, %hi: int, %step: int): + %1 : int = prim::Constant[value=0]() + %ret: int = aten::__range_length(%lo, %hi, %step) + return (%ret) + )IR"; + + auto graph = std::make_shared(); + std::unordered_map vmap; + vmap.reserve(0); + parseIR(script, graph.get(), vmap); + torch::jit::StaticModule smodule(graph); + + std::vector args0{0, 10, 2}; + + testStaticRuntime(script, args0); +} + TEST(StaticRuntime, Cat) { const std::string cat_script = R"IR( graph(%a: Tensor, %b: Tensor, %dim: int): diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 9f91eb505a7..5442d481d78 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -200,6 +200,63 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( }; }); +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::list, + aten_list, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + const auto str = p_node->Input(0).toStringRef(); + c10::List chars; + chars.reserve(str.size()); + for (auto c : str) { + chars.emplace_back(1, c); + } + p_node->Output(0) = std::move(chars); + }; + }); + +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::numel, + aten_numel, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + const auto& arg = p_node->Input(0).toTensor(); + p_node->Output(0) = arg.numel(); + }; + }); + +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::cpu, + aten_cpu, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + const auto& arg = p_node->Input(0).toTensor(); + p_node->Output(0) = arg.cpu(); + }; + }); + +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::__range_length, + aten_range_length, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + auto lo = p_node->Input(0).toInt(); + auto hi = p_node->Input(1).toInt(); + auto step = p_node->Input(2).toInt(); + // error handling when step_val == 0 during runtime + if (step == 0) { + throw std::runtime_error("range() arg 3 must not be zero"); + } + if (step > 0 && lo < hi) { + p_node->Output(0) = 1 + (hi - 1 - lo) / step; + } else if (step < 0 && lo > hi) { + p_node->Output(0) = 1 + (lo - 1 - hi) / (0 - step); + } else { + p_node->Output(0) = 0; + } + }; + }); + REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::index_put, aten_index_put,