mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40199 Mobile custom selective build has already been covered by `test/mobile/custom_build/build.sh`. It builds a CLI binary with host-toolchain and runs on host machine to check correctness of the result. But that custom build test doesn't cover the android/gradle build part. And we cannot use it to measure and track the in-APK size of custom build library. So this PR adds the selective build test coverage for android NDK build. Also integrate with the CI to upload the custom build size to scuba. TODO: Ideally it should build android/test_app and measure the in-APK size. But the test_app hasn't been covered by any CI yet and is currently broken, so build & measure AAR instead (which can be inaccurate as we plan to pack C++ header files into AAR soon). Sample result: https://fburl.com/scuba/pytorch_binary_size/skxwb1gh ``` +---------------------+-------------+-------------------+-----------+----------+ | build_mode | arch | lib | Build Num | Size | +---------------------+-------------+-------------------+-----------+----------+ | custom-build-single | armeabi-v7a | libpytorch_jni.so | 5901579 | 3.68 MiB | | prebuild | armeabi-v7a | libpytorch_jni.so | 5901014 | 6.23 MiB | | prebuild | x86_64 | libpytorch_jni.so | 5901014 | 7.67 MiB | +---------------------+-------------+-------------------+-----------+----------+ ``` Test Plan: Imported from OSS Differential Revision: D22111115 Pulled By: ljk53 fbshipit-source-id: 11d24efbc49a85f851ecd0e481d14123f405b3a9
26 lines
806 B
Python
26 lines
806 B
Python
"""
|
|
This is a script for PyTorch Android custom selective build test. It prepares
|
|
MobileNetV2 TorchScript model, and dumps root ops used by the model for custom
|
|
build script to create a tailored build which only contains these used ops.
|
|
"""
|
|
|
|
import torch
|
|
import torchvision
|
|
import yaml
|
|
|
|
# Download and trace the model.
|
|
model = torchvision.models.mobilenet_v2(pretrained=True)
|
|
model.eval()
|
|
example = torch.rand(1, 3, 224, 224)
|
|
# TODO: create script model with `torch.jit.script`
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
|
|
# Save traced TorchScript model.
|
|
traced_script_module.save("MobileNetV2.pt")
|
|
|
|
# Dump root ops used by the model (for custom build optimization).
|
|
ops = torch.jit.export_opnames(traced_script_module)
|
|
|
|
with open('MobileNetV2.yaml', 'w') as output:
|
|
yaml.dump(ops, output)
|