disable JIT optimizer in Android wrapper for mobile custom build (#30285)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30285

PR #30144 introduced custom build script to tailor build to specific
models. It requires a list of all potentially used ops at build time.

Some JIT optimization passes can transform the IR by replacing
operators, e.g. decompose pass can replace aten::addmm with aten::mm if
coefficients are 1s.

Disabling optimization pass can ensure that the list of ops we dump from
the model is the list of ops that are needed.

Test Plan: - rerun the test on PR #30144 to verify the raw list without aten::mm works.

Differential Revision: D18652777

Pulled By: ljk53

fbshipit-source-id: 084751cb9a9ee16d8df7e743e9e5782ffd8bc4e3
This commit is contained in:
Jiakai Liu 2019-11-22 00:23:06 -08:00 committed by Facebook Github Bot
parent 1690feba9f
commit f5ef3a6fb6
2 changed files with 16 additions and 2 deletions

View File

@ -13,6 +13,18 @@
namespace pytorch_jni {
namespace {
struct JITCallGuard {
// AutoGrad is disabled for mobile by default.
torch::autograd::AutoGradMode no_autograd_guard{false};
// Disable graph optimizer to ensure list of unused ops are not changed for
// custom mobile build.
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
@ -51,6 +63,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
/* need_inputs */ false,
/* sampled */ false);
#endif
JITCallGuard guard;
module_ = torch::jit::load(std::move(modelPath->toStdString()));
module_.eval();
}
@ -76,7 +89,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
inputs.push_back(std::move(atIValue));
}
auto output = [&]() {
torch::autograd::AutoGradMode guard(false);
JITCallGuard guard;
return module_.forward(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);
@ -98,7 +111,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
torch::autograd::AutoGradMode guard(false);
JITCallGuard guard;
return (*method)(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);

View File

@ -116,6 +116,7 @@ int main(int argc, char** argv) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
torch::autograd::AutoGradMode guard(false);
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
auto module = torch::jit::load(FLAGS_model);
module.eval();