mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Set quantized engine backend for mobile in speed_benchmark_torch (#26911)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26911 Check if QNNPACK is present as a backend (should always be present on mobile). If it is present then set the backend to QNNPACK Test Plan: Test on mobile ./speed_benchmark_torch --model mobilenet_quantized_scripted.pt --input_dims="1,3,224,224" --input_type=float --warmup=5 --iter 20 --print_output True Imported from OSS Differential Revision: D17613908 fbshipit-source-id: af96722570a0111f13d69c38ccca52416ea5e460
This commit is contained in:
parent
638c4375de
commit
8d5c2aa71c
|
|
@ -82,10 +82,15 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto qengines = at::globalContext().supportedQEngines();
|
||||||
|
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
|
||||||
|
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||||
|
}
|
||||||
torch::autograd::AutoGradMode guard(false);
|
torch::autograd::AutoGradMode guard(false);
|
||||||
auto module = torch::jit::load(FLAGS_model);
|
auto module = torch::jit::load(FLAGS_model);
|
||||||
|
|
||||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||||
|
module.eval();
|
||||||
if (FLAGS_print_output) {
|
if (FLAGS_print_output) {
|
||||||
std::cout << module.forward(inputs) << std::endl;
|
std::cout << module.forward(inputs) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user