mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[static runtime] add more _out variants (#48260)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48260 supporting a couple more operators Test Plan: use Ansha's test framework for e2e test ``` numactl -m 0 -C 3 ./buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench --pred_net=/home/bwasti/adindexer/precomputation_merge_net.pb --c2_inputs=/home/bwasti/adindexer/c2_inputs_precomputation_bs1.pb --c2_weights=/home/bwasti/adindexer/c2_weights_precomputation.pb --scripted_model=/home/bwasti/adindexer/traced_precomputation_partial_dper_fixes.pt --pt_inputs=/home/bwasti/adindexer/container_precomputation_bs1.pt --iters=30000 --warmup_iters=10000 --num_threads=1 --pt_enable_static_runtime=true --pt_cleanup_activations=true --pt_enable_out_variant=true --eps 1e-2 ``` Reviewed By: hlu1 Differential Revision: D24767322 fbshipit-source-id: dce7f9bc0427632129f263bad509f0f00a21ccf3
This commit is contained in:
parent
87bfb2ff08
commit
0984d3123a
|
|
@ -736,6 +736,10 @@ ProcessedNode::ProcessedNode(
|
|||
std::ostringstream ss;
|
||||
node->print(ss, 0, nullptr, false);
|
||||
VLOG(1) << "Switch to native impl for node: " << ss.str();
|
||||
} else {
|
||||
std::ostringstream ss;
|
||||
node->print(ss, 0, nullptr, false);
|
||||
VLOG(1) << "Fallback interpreter for node: " << ss.str();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@ bool canRunOutOfPlace(Node* n) {
|
|||
"aten::bmm",
|
||||
"aten::cat",
|
||||
"aten::clamp",
|
||||
"aten::clone",
|
||||
"aten::leaky_relu",
|
||||
"aten::logit",
|
||||
"aten::logit",
|
||||
"aten::mul",
|
||||
"aten::nan_to_num",
|
||||
"aten::relu",
|
||||
|
|
@ -37,12 +39,22 @@ bool canRunNatively(Node* n) {
|
|||
// In alphabetical order
|
||||
const static std::unordered_set<std::string> native_nodes{
|
||||
"aten::flatten",
|
||||
"aten::permute",
|
||||
"aten::reshape",
|
||||
"aten::slice",
|
||||
"aten::transpose",
|
||||
"aten::to",
|
||||
"prim::ListConstruct",
|
||||
"prim::ListUnpack",
|
||||
"prim::TupleConstruct"};
|
||||
auto str = std::string(n->kind().toQualString());
|
||||
return native_nodes.count(str) > 0;
|
||||
if (!native_nodes.count(str)) {
|
||||
return false;
|
||||
}
|
||||
if (str == "aten::to") {
|
||||
return n->inputs().size() == 5;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::function<void(const ProcessedNode*, std::vector<IValue>&)>
|
||||
|
|
@ -217,7 +229,9 @@ getOutOfPlaceOperation(Node* n) {
|
|||
} else if (n->kind() == c10::Symbol::fromQualString("aten::logit")) {
|
||||
return [](const ProcessedNode* p_node, std::vector<IValue>& reg) {
|
||||
auto in0_t = p_node->Input(0, reg).toTensor();
|
||||
auto in1_d = p_node->Input(1, reg).toDouble();
|
||||
double in1_d = p_node->input_regs().size() > 1
|
||||
? p_node->Input(1, reg).toDouble()
|
||||
: -1.0;
|
||||
if (p_node->Output(0, reg).isNone()) {
|
||||
p_node->Output(0, reg) = create_empty_from(in0_t);
|
||||
}
|
||||
|
|
@ -310,6 +324,44 @@ getNativeOperation(Node* n) {
|
|||
p_node->Output(i, reg) = std::move(stack[i]);
|
||||
}
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) {
|
||||
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
|
||||
auto in0_t = p_node->Input(0, reg).toTensor();
|
||||
auto in1_iv = p_node->Input(1, reg).toIntVector();
|
||||
p_node->Output(0, reg) = at::native::permute(in0_t, in1_iv);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) {
|
||||
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
|
||||
auto in0_t = p_node->Input(0, reg).toTensor();
|
||||
auto in1_iv = p_node->Input(1, reg).toIntVector();
|
||||
p_node->Output(0, reg) = at::native::reshape(in0_t, in1_iv);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) {
|
||||
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
|
||||
auto in0_t = p_node->Input(0, reg).toTensor();
|
||||
auto in1_i = p_node->Input(1, reg).toInt();
|
||||
auto in2_i = p_node->Input(2, reg).toInt();
|
||||
auto in3_i = p_node->Input(3, reg).toInt();
|
||||
auto in4_i = p_node->Input(4, reg).toInt();
|
||||
p_node->Output(0, reg) =
|
||||
at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
|
||||
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
|
||||
DCHECK(p_node->input_regs().size() == 5);
|
||||
auto in0_t = p_node->Input(0, reg).toTensor();
|
||||
auto in1_i = p_node->Input(1, reg).toScalarType();
|
||||
auto in2_i = p_node->Input(2, reg).toBool();
|
||||
auto in3_i = p_node->Input(3, reg).toBool();
|
||||
if (p_node->Input(4, reg).isNone()) {
|
||||
p_node->Output(0, reg) =
|
||||
at::native::to(in0_t, in1_i, in2_i, in3_i, c10::nullopt);
|
||||
} else {
|
||||
auto in4_o = p_node->Input(4, reg).toMemoryFormat();
|
||||
p_node->Output(0, reg) =
|
||||
at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
|
||||
}
|
||||
};
|
||||
}
|
||||
return [](const ProcessedNode*, std::vector<IValue>&) { TORCH_CHECK(0); };
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user