mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
disable JIT optimizer in Android wrapper for mobile custom build (#30285)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30285 PR #30144 introduced custom build script to tailor build to specific models. It requires a list of all potentially used ops at build time. Some JIT optimization passes can transform the IR by replacing operators, e.g. decompose pass can replace aten::addmm with aten::mm if coefficients are 1s. Disabling optimization pass can ensure that the list of ops we dump from the model is the list of ops that are needed. Test Plan: - rerun the test on PR #30144 to verify the raw list without aten::mm works. Differential Revision: D18652777 Pulled By: ljk53 fbshipit-source-id: 084751cb9a9ee16d8df7e743e9e5782ffd8bc4e3
This commit is contained in:
parent
1690feba9f
commit
f5ef3a6fb6
|
|
@ -13,6 +13,18 @@
|
|||
|
||||
namespace pytorch_jni {
|
||||
|
||||
namespace {
|
||||
|
||||
struct JITCallGuard {
|
||||
// AutoGrad is disabled for mobile by default.
|
||||
torch::autograd::AutoGradMode no_autograd_guard{false};
|
||||
// Disable graph optimizer to ensure list of unused ops are not changed for
|
||||
// custom mobile build.
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
private:
|
||||
friend HybridBase;
|
||||
|
|
@ -51,6 +63,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||
/* need_inputs */ false,
|
||||
/* sampled */ false);
|
||||
#endif
|
||||
JITCallGuard guard;
|
||||
module_ = torch::jit::load(std::move(modelPath->toStdString()));
|
||||
module_.eval();
|
||||
}
|
||||
|
|
@ -76,7 +89,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
auto output = [&]() {
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
JITCallGuard guard;
|
||||
return module_.forward(std::move(inputs));
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
|
|
@ -98,7 +111,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
|||
}
|
||||
if (auto method = module_.find_method(methodName)) {
|
||||
auto output = [&]() {
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
JITCallGuard guard;
|
||||
return (*method)(std::move(inputs));
|
||||
}();
|
||||
return JIValue::newJIValueFromAtIValue(output);
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ int main(int argc, char** argv) {
|
|||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
|
||||
auto module = torch::jit::load(FLAGS_model);
|
||||
|
||||
module.eval();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user