mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126 Approved by: https://github.com/kit1980 ghstack dependencies: #127122, #127123, #127124, #127125
57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
from torchvision import models
|
|
|
|
import torch
|
|
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
|
|
from torch.utils.mobile_optimizer import optimize_for_mobile
|
|
|
|
|
|
class MobileNetV2Module:
|
|
def getModule(self):
|
|
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
|
model.eval()
|
|
example = torch.zeros(1, 3, 224, 224)
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
optimized_module = optimize_for_mobile(traced_script_module)
|
|
augment_model_with_bundled_inputs(
|
|
optimized_module,
|
|
[
|
|
(example,),
|
|
],
|
|
)
|
|
optimized_module(example)
|
|
return optimized_module
|
|
|
|
|
|
class MobileNetV2VulkanModule:
|
|
def getModule(self):
|
|
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
|
model.eval()
|
|
example = torch.zeros(1, 3, 224, 224)
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan")
|
|
augment_model_with_bundled_inputs(
|
|
optimized_module,
|
|
[
|
|
(example,),
|
|
],
|
|
)
|
|
optimized_module(example)
|
|
return optimized_module
|
|
|
|
|
|
class Resnet18Module:
|
|
def getModule(self):
|
|
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
|
model.eval()
|
|
example = torch.zeros(1, 3, 224, 224)
|
|
traced_script_module = torch.jit.trace(model, example)
|
|
optimized_module = optimize_for_mobile(traced_script_module)
|
|
augment_model_with_bundled_inputs(
|
|
optimized_module,
|
|
[
|
|
(example,),
|
|
],
|
|
)
|
|
optimized_module(example)
|
|
return optimized_module
|