[pytorch] remove boilerplate setQEngine() from PyTorch mobile predictors (#34556)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34556

According to
https://github.com/pytorch/pytorch/pull/34012#discussion_r388581548,
this `at::globalContext().setQEngine(at::QEngine::QNNPACK);` call isn't
really necessary for mobile.

In Context.cpp it selects the last available QEngine if the engine isn't
set explicitly. For OSS mobile prebuild it should only include QNNPACK
engine so the default behavior should already be desired behavior.

It makes difference only when USE_FBGEMM is set - but it should be off
for both OSS mobile build and internal mobile build.

Test Plan: Imported from OSS

Differential Revision: D20374522

Pulled By: ljk53

fbshipit-source-id: d4e437a03c6d4f939edccb5c84f02609633a0698
This commit is contained in:
Jiakai Liu 2020-03-11 00:52:10 -07:00 committed by Facebook Github Bot
parent 2ce9513b0c
commit 7aca9afdfb
6 changed files with 0 additions and 27 deletions

View File

@ -104,11 +104,6 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
}();
((void)once);
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
#ifdef TRACE_ENABLED
torch::autograd::profiler::pushCallback(
&onFunctionEnter,

View File

@ -31,11 +31,6 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
}
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
module_ = torch::jit::_load_for_mobile(std::move(modelPath->toStdString()));
}

View File

@ -134,10 +134,6 @@ int main(int argc, char** argv) {
inputs[0].push_back(stensor);
}
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
torch::autograd::AutoGradMode guard(false);
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
auto module = torch::jit::load(FLAGS_model);

View File

@ -65,10 +65,6 @@ static int iter = 10;
}
}
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
torch::autograd::AutoGradMode guard(false);
torch::jit::GraphOptimizerEnabledGuard opguard(false);
auto module = torch::jit::load(model);

View File

@ -12,10 +12,6 @@
+ (void)setUp {
[super setUp];
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
}
- (void)setUp {

View File

@ -24,10 +24,6 @@ struct MobileCallGuard {
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
void init() {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
torch::jit::script::Module loadModel(const std::string& path) {
MobileCallGuard guard;
auto module = torch::jit::load(path);
@ -42,7 +38,6 @@ int main(int argc, const char* argv[]) {
std::cerr << "Usage: " << argv[0] << " <model_path>\n";
return 1;
}
init();
auto module = loadModel(argv[1]);
auto input = torch::ones({1, 3, 224, 224});
auto output = [&]() {