[static runtime] add more _out variants (#48260)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48260

supporting a couple more operators

Test Plan:
use Ansha's test framework for e2e test

```
numactl -m 0 -C 3 ./buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench --pred_net=/home/bwasti/adindexer/precomputation_merge_net.pb --c2_inputs=/home/bwasti/adindexer/c2_inputs_precomputation_bs1.pb --c2_weights=/home/bwasti/adindexer/c2_weights_precomputation.pb --scripted_model=/home/bwasti/adindexer/traced_precomputation_partial_dper_fixes.pt --pt_inputs=/home/bwasti/adindexer/container_precomputation_bs1.pt --iters=30000 --warmup_iters=10000 --num_threads=1 --pt_enable_static_runtime=true --pt_cleanup_activations=true --pt_enable_out_variant=true --eps 1e-2
```

Reviewed By: hlu1

Differential Revision: D24767322

fbshipit-source-id: dce7f9bc0427632129f263bad509f0f00a21ccf3
This commit is contained in:
Bram Wasti 2020-11-20 17:01:02 -08:00 committed by Facebook GitHub Bot
parent 87bfb2ff08
commit 0984d3123a
2 changed files with 58 additions and 2 deletions

View File

@ -736,6 +736,10 @@ ProcessedNode::ProcessedNode(
std::ostringstream ss;
node->print(ss, 0, nullptr, false);
VLOG(1) << "Switch to native impl for node: " << ss.str();
} else {
std::ostringstream ss;
node->print(ss, 0, nullptr, false);
VLOG(1) << "Fallback interpreter for node: " << ss.str();
}
}

View File

@ -19,8 +19,10 @@ bool canRunOutOfPlace(Node* n) {
"aten::bmm",
"aten::cat",
"aten::clamp",
"aten::clone",
"aten::leaky_relu",
"aten::logit",
"aten::logit",
"aten::mul",
"aten::nan_to_num",
"aten::relu",
@ -37,12 +39,22 @@ bool canRunNatively(Node* n) {
// In alphabetical order
const static std::unordered_set<std::string> native_nodes{
"aten::flatten",
"aten::permute",
"aten::reshape",
"aten::slice",
"aten::transpose",
"aten::to",
"prim::ListConstruct",
"prim::ListUnpack",
"prim::TupleConstruct"};
auto str = std::string(n->kind().toQualString());
return native_nodes.count(str) > 0;
if (!native_nodes.count(str)) {
return false;
}
if (str == "aten::to") {
return n->inputs().size() == 5;
}
return true;
}
std::function<void(const ProcessedNode*, std::vector<IValue>&)>
@ -217,7 +229,9 @@ getOutOfPlaceOperation(Node* n) {
} else if (n->kind() == c10::Symbol::fromQualString("aten::logit")) {
return [](const ProcessedNode* p_node, std::vector<IValue>& reg) {
auto in0_t = p_node->Input(0, reg).toTensor();
auto in1_d = p_node->Input(1, reg).toDouble();
double in1_d = p_node->input_regs().size() > 1
? p_node->Input(1, reg).toDouble()
: -1.0;
if (p_node->Output(0, reg).isNone()) {
p_node->Output(0, reg) = create_empty_from(in0_t);
}
@ -310,6 +324,44 @@ getNativeOperation(Node* n) {
p_node->Output(i, reg) = std::move(stack[i]);
}
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) {
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
auto in0_t = p_node->Input(0, reg).toTensor();
auto in1_iv = p_node->Input(1, reg).toIntVector();
p_node->Output(0, reg) = at::native::permute(in0_t, in1_iv);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) {
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
auto in0_t = p_node->Input(0, reg).toTensor();
auto in1_iv = p_node->Input(1, reg).toIntVector();
p_node->Output(0, reg) = at::native::reshape(in0_t, in1_iv);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) {
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
auto in0_t = p_node->Input(0, reg).toTensor();
auto in1_i = p_node->Input(1, reg).toInt();
auto in2_i = p_node->Input(2, reg).toInt();
auto in3_i = p_node->Input(3, reg).toInt();
auto in4_i = p_node->Input(4, reg).toInt();
p_node->Output(0, reg) =
at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
};
} else if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
return [=](const ProcessedNode* p_node, std::vector<IValue>& reg) {
DCHECK(p_node->input_regs().size() == 5);
auto in0_t = p_node->Input(0, reg).toTensor();
auto in1_i = p_node->Input(1, reg).toScalarType();
auto in2_i = p_node->Input(2, reg).toBool();
auto in3_i = p_node->Input(3, reg).toBool();
if (p_node->Input(4, reg).isNone()) {
p_node->Output(0, reg) =
at::native::to(in0_t, in1_i, in2_i, in3_i, c10::nullopt);
} else {
auto in4_o = p_node->Input(4, reg).toMemoryFormat();
p_node->Output(0, reg) =
at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
}
};
}
return [](const ProcessedNode*, std::vector<IValue>&) { TORCH_CHECK(0); };
}