mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Added an extra step to **always** preserve the bundled inputs methods if they are present in the input module. Also added a check to see if all the methods in the `preseved_methods` exist. If not, we will now throw an exception. This can hopefully stop hard-to-debug inputs from getting into downstream functions. ~~Add an optional argument `preserve_bundled_inputs_methods=False` to the `optimize_for_mobile` function. If set to be True, the function will now add three additional functions related with bundled inputs to be preserved: `get_all_bundled_inputs`, `get_num_bundled_inputs` and `run_on_bundled_input`.~~ Test Plan: `buck test mode/dev //caffe2/test:mobile -- 'test_preserve_bundled_inputs_methods \(test_mobile_optimizer\.TestOptimizer\)'` or `buck test caffe2/test:mobile` to run some other related tests as well. Reviewed By: dhruvbird Differential Revision: D25433268 fbshipit-source-id: 0bf9b4afe64b79ed1684a3db4c0baea40ed3cdd5
103 lines
4.5 KiB
Python
103 lines
4.5 KiB
Python
"""
|
|
This module contains utility method for mobile model optimization and lint.
|
|
"""
|
|
|
|
import torch
|
|
from enum import Enum
|
|
from torch._C import MobileOptimizerType
|
|
from typing import Set, List, AnyStr
|
|
|
|
class LintCode(Enum):
|
|
BUNDLED_INPUT = 1
|
|
REQUIRES_GRAD = 2
|
|
DROPOUT = 3
|
|
BATCHNORM = 4
|
|
|
|
def optimize_for_mobile(
|
|
script_module,
|
|
optimization_blocklist: Set[MobileOptimizerType] = None,
|
|
preserved_methods: List[AnyStr] = None,
|
|
backend: str = 'CPU'):
|
|
"""
|
|
Args:
|
|
script_module: An instance of torch script module with type of ScriptModule.
|
|
optimization_blocklist: A set with type of MobileOptimizerType. When set is not passed,
|
|
optimization method will run all the optimizer pass; otherwise, optimizer
|
|
method will run the optimization pass that is not included inside optimization_blocklist.
|
|
perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked
|
|
backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal').
|
|
Returns:
|
|
A new optimized torch script module
|
|
"""
|
|
if not isinstance(script_module, torch.jit.ScriptModule):
|
|
raise TypeError(
|
|
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
|
|
|
|
if optimization_blocklist is None:
|
|
optimization_blocklist = set()
|
|
|
|
if preserved_methods is None:
|
|
preserved_methods = []
|
|
|
|
bundled_inputs_methods = ['get_all_bundled_inputs', 'get_num_bundled_inputs', 'run_on_bundled_input']
|
|
if all([hasattr(script_module, method) for method in bundled_inputs_methods]):
|
|
preserved_methods = list(set(preserved_methods + bundled_inputs_methods))
|
|
|
|
non_exist_methods = []
|
|
for method in preserved_methods:
|
|
if not hasattr(script_module, method):
|
|
non_exist_methods.append(method)
|
|
if non_exist_methods:
|
|
raise AttributeError(
|
|
'The following methods to preserve do not exist in script_module: {}'.format(', '.join(non_exist_methods)))
|
|
|
|
backend = backend.lower()
|
|
if backend == 'cpu':
|
|
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods)
|
|
elif backend == 'vulkan':
|
|
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods)
|
|
elif backend == 'metal':
|
|
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods)
|
|
else:
|
|
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'")
|
|
|
|
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
|
|
|
|
|
|
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
|
|
"""
|
|
Args:
|
|
script_module: An instance of torch script module with type of ScriptModule
|
|
|
|
Returns:
|
|
lint_map: A list of dictionary that contains modules lints
|
|
"""
|
|
if not isinstance(script_module, torch.jit.ScriptModule):
|
|
raise TypeError(
|
|
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
|
|
|
|
lint_list = []
|
|
|
|
if not hasattr(script_module, "_generate_bundled_inputs"):
|
|
lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before "
|
|
"saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."})
|
|
|
|
for name, param in script_module.named_parameters():
|
|
if param.requires_grad:
|
|
lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, "
|
|
"please set torch.no_grad() to reduce memory usage and improve computation speed during "
|
|
"inference phase.".format(name)})
|
|
|
|
op_names = torch.jit.export_opnames(script_module)
|
|
for op_name in op_names:
|
|
if "dropout" in op_name:
|
|
lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before "
|
|
"saving the module.and call torch.utils.mobile_optimizer.optimize_for_mobile to drop dropout "
|
|
"operator.".format(op_name)})
|
|
if "batch_norm" in op_name:
|
|
lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before "
|
|
"saving the module and call torch.utils.mobile_optimizer.optimize_for_mobile to drop batch_norm "
|
|
"operator.".format(op_name)})
|
|
|
|
return lint_list
|