mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[static runtime] Add JIT prim ops: aten::cpu, aten::list, aten::numel, aten::__range_length (#79111)
Summary: This adds the missing jit prim ops appear in the non ads models for c2->pt mitigation: aten::cpu, aten::list, aten::numel, aten::__range_length Test Plan: static runtime unit tests Differential Revision: D36984960 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79111 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
7d17e3b884
commit
0545c85f74
|
|
@ -2532,6 +2532,87 @@ TEST(StaticRuntime, Tensor_Split) {
|
|||
testStaticRuntime(tensor_split_str3, args3);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_Cpu) {
|
||||
const std::string script = R"IR(
|
||||
graph(%a: Tensor):
|
||||
%1 : int = prim::Constant[value=0]()
|
||||
%aa: Tensor = aten::add(%a, %a, %1)
|
||||
%ret: Tensor = aten::cpu(%aa)
|
||||
return (%ret)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
vmap.reserve(0);
|
||||
parseIR(script, graph.get(), vmap);
|
||||
torch::jit::StaticModule smodule(graph);
|
||||
|
||||
auto a = at::randn({2, 4});
|
||||
std::vector<IValue> args0{a};
|
||||
|
||||
testStaticRuntime(script, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_Numel) {
|
||||
const std::string script = R"IR(
|
||||
graph(%a: Tensor):
|
||||
%1 : int = prim::Constant[value=0]()
|
||||
%aa: Tensor = aten::add(%a, %a, %1)
|
||||
%ret: int = aten::numel(%aa)
|
||||
return (%ret)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
vmap.reserve(0);
|
||||
parseIR(script, graph.get(), vmap);
|
||||
torch::jit::StaticModule smodule(graph);
|
||||
|
||||
auto a = at::randn({2, 4});
|
||||
std::vector<IValue> args0{a};
|
||||
|
||||
testStaticRuntime(script, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_List) {
|
||||
const std::string script = R"IR(
|
||||
graph(%a: str):
|
||||
%1 : int = prim::Constant[value=0]()
|
||||
%ret: str[] = aten::list(%a)
|
||||
return (%ret)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
vmap.reserve(0);
|
||||
parseIR(script, graph.get(), vmap);
|
||||
torch::jit::StaticModule smodule(graph);
|
||||
|
||||
string a = "abcd";
|
||||
std::vector<IValue> args0{a};
|
||||
|
||||
testStaticRuntime(script, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, JIT_Aten_Range_Length) {
|
||||
const std::string script = R"IR(
|
||||
graph(%lo: int, %hi: int, %step: int):
|
||||
%1 : int = prim::Constant[value=0]()
|
||||
%ret: int = aten::__range_length(%lo, %hi, %step)
|
||||
return (%ret)
|
||||
)IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
std::unordered_map<std::string, Value*> vmap;
|
||||
vmap.reserve(0);
|
||||
parseIR(script, graph.get(), vmap);
|
||||
torch::jit::StaticModule smodule(graph);
|
||||
|
||||
std::vector<IValue> args0{0, 10, 2};
|
||||
|
||||
testStaticRuntime(script, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Cat) {
|
||||
const std::string cat_script = R"IR(
|
||||
graph(%a: Tensor, %b: Tensor, %dim: int):
|
||||
|
|
|
|||
|
|
@ -200,6 +200,63 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::list,
|
||||
aten_list,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto str = p_node->Input(0).toStringRef();
|
||||
c10::List<std::string> chars;
|
||||
chars.reserve(str.size());
|
||||
for (auto c : str) {
|
||||
chars.emplace_back(1, c);
|
||||
}
|
||||
p_node->Output(0) = std::move(chars);
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::numel,
|
||||
aten_numel,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& arg = p_node->Input(0).toTensor();
|
||||
p_node->Output(0) = arg.numel();
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::cpu,
|
||||
aten_cpu,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& arg = p_node->Input(0).toTensor();
|
||||
p_node->Output(0) = arg.cpu();
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::__range_length,
|
||||
aten_range_length,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
auto lo = p_node->Input(0).toInt();
|
||||
auto hi = p_node->Input(1).toInt();
|
||||
auto step = p_node->Input(2).toInt();
|
||||
// error handling when step_val == 0 during runtime
|
||||
if (step == 0) {
|
||||
throw std::runtime_error("range() arg 3 must not be zero");
|
||||
}
|
||||
if (step > 0 && lo < hi) {
|
||||
p_node->Output(0) = 1 + (hi - 1 - lo) / step;
|
||||
} else if (step < 0 && lo > hi) {
|
||||
p_node->Output(0) = 1 + (lo - 1 - hi) / (0 - step);
|
||||
} else {
|
||||
p_node->Output(0) = 0;
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::index_put,
|
||||
aten_index_put,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user