pytorch/test/mobile/model_test/torchvision_models.py
Xuehai Pan 7763c83af6 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
2024-05-27 04:22:18 +00:00

57 lines
1.8 KiB
Python

from torchvision import models
import torch
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
from torch.utils.mobile_optimizer import optimize_for_mobile
class MobileNetV2Module:
def getModule(self):
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module)
augment_model_with_bundled_inputs(
optimized_module,
[
(example,),
],
)
optimized_module(example)
return optimized_module
class MobileNetV2VulkanModule:
def getModule(self):
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan")
augment_model_with_bundled_inputs(
optimized_module,
[
(example,),
],
)
optimized_module(example)
return optimized_module
class Resnet18Module:
def getModule(self):
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.eval()
example = torch.zeros(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_module = optimize_for_mobile(traced_script_module)
augment_model_with_bundled_inputs(
optimized_module,
[
(example,),
],
)
optimized_module(example)
return optimized_module