mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytorch] remove boilerplate setQEngine() from PyTorch mobile predictors (#34556)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34556 According to https://github.com/pytorch/pytorch/pull/34012#discussion_r388581548, this `at::globalContext().setQEngine(at::QEngine::QNNPACK);` call isn't really necessary for mobile. In Context.cpp it selects the last available QEngine if the engine isn't set explicitly. For OSS mobile prebuild it should only include QNNPACK engine so the default behavior should already be desired behavior. It makes difference only when USE_FBGEMM is set - but it should be off for both OSS mobile build and internal mobile build. Test Plan: Imported from OSS Differential Revision: D20374522 Pulled By: ljk53 fbshipit-source-id: d4e437a03c6d4f939edccb5c84f02609633a0698
This commit is contained in:
parent
2ce9513b0c
commit
7aca9afdfb
|
|
@ -104,11 +104,6 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||
}();
|
||||
((void)once);
|
||||
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
|
||||
qengines.end()) {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
#ifdef TRACE_ENABLED
|
||||
torch::autograd::profiler::pushCallback(
|
||||
&onFunctionEnter,
|
||||
|
|
|
|||
|
|
@ -31,11 +31,6 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||
}
|
||||
|
||||
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
|
||||
qengines.end()) {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
module_ = torch::jit::_load_for_mobile(std::move(modelPath->toStdString()));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -134,10 +134,6 @@ int main(int argc, char** argv) {
|
|||
inputs[0].push_back(stensor);
|
||||
}
|
||||
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
|
||||
auto module = torch::jit::load(FLAGS_model);
|
||||
|
|
|
|||
|
|
@ -65,10 +65,6 @@ static int iter = 10;
|
|||
}
|
||||
}
|
||||
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
torch::jit::GraphOptimizerEnabledGuard opguard(false);
|
||||
auto module = torch::jit::load(model);
|
||||
|
|
|
|||
|
|
@ -12,10 +12,6 @@
|
|||
|
||||
+ (void)setUp {
|
||||
[super setUp];
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
}
|
||||
|
||||
- (void)setUp {
|
||||
|
|
|
|||
|
|
@ -24,10 +24,6 @@ struct MobileCallGuard {
|
|||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
};
|
||||
|
||||
void init() {
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
|
||||
torch::jit::script::Module loadModel(const std::string& path) {
|
||||
MobileCallGuard guard;
|
||||
auto module = torch::jit::load(path);
|
||||
|
|
@ -42,7 +38,6 @@ int main(int argc, const char* argv[]) {
|
|||
std::cerr << "Usage: " << argv[0] << " <model_path>\n";
|
||||
return 1;
|
||||
}
|
||||
init();
|
||||
auto module = loadModel(argv[1]);
|
||||
auto input = torch::ones({1, 3, 224, 224});
|
||||
auto output = [&]() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user