pytorch/torch/utils/mobile_optimizer.py
Jacob Szwejbka 0118dec2e3 [Pytorch] Expanded Bundled Inputs To Any Public Function (#51153)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51153

Enabled bundled inputs for all public functions that the user wants in a torchscript module. An important caveat here is that you cant add bundled inputs to functions that were in the nn.module but weren't caught in the scripting/tracing process that brought the model to torchscript.

Old Api is exactly the same. Still only works on forward, return types the same, etc.

-----------New API-------------

Attachment of inputs:

***augment_model_with_bundled_inputs*** : works the same as before but added the option to specify an info dictionary.

***augment_many_model_functions_with_bundled_inputs*** : Similar to the above function but allows the user to specify a Dict[Callable, List[<inputs>]] (mapping function references to the bundled inputs for that function) to attach bundled inputs to many functions

Consumption of inputs:

***get_all_bundled_inputs_for_<function_name>()*** : Works exactly like get_all_bundled_inputs does, but can be used for functions other then forward if you know ahead of time what they are called, and if they have bundled inputs.

***get_bundled_inputs_functions_and_info()*** : This is easily the hackiest function. Returns a Dict['str', 'str'] mapping function_names to get_all_bundled_inputs_for_<function_name>. A user can then execute the functions specified in the values with something like
    all_info = model.get_bundled_inputs_functions_and_info()
    for func_name in all_info.keys():
        input_func_name = all_info[func_name]['get_inputs_function_name'][0]
        func_to_run = getattr(loaded, input_func_name)
The reason its done this way is because torchscript doesn't support 'Any' type yet meaning I can't return the bundled inputs directly because they could be different types for each function. Torchscript also doesn't support callable so I can't return a function reference directly either.
ghstack-source-id: 120768561

Test Plan:
Got a model into torchscript using the available methods that I'm aware of (tracing, scripting, old scripting method). Not really sure how tracing brings in functions that arent in the forward call path though. Attached bundled inputs and info to them successfully. Changes to TorchTest.py on all but the last version of this diff (where it will be/is removed for land) illustrate what I did to test.

Created and ran unit test

Reviewed By: dreiss

Differential Revision: D25931961

fbshipit-source-id: 36e87c9a585554a83a932e4dcf07d1f91a32f046
2021-02-02 10:33:59 -08:00

112 lines
4.9 KiB
Python

"""
This module contains utility method for mobile model optimization and lint.
"""
import torch
from enum import Enum
from torch._C import MobileOptimizerType
from typing import Set, List, AnyStr
class LintCode(Enum):
BUNDLED_INPUT = 1
REQUIRES_GRAD = 2
DROPOUT = 3
BATCHNORM = 4
def optimize_for_mobile(
script_module,
optimization_blocklist: Set[MobileOptimizerType] = None,
preserved_methods: List[AnyStr] = None,
backend: str = 'CPU'):
"""
Args:
script_module: An instance of torch script module with type of ScriptModule.
optimization_blocklist: A set with type of MobileOptimizerType. When set is not passed,
optimization method will run all the optimizer pass; otherwise, optimizer
method will run the optimization pass that is not included inside optimization_blocklist.
perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked
backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal').
Returns:
A new optimized torch script module
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
if optimization_blocklist is None:
optimization_blocklist = set()
if preserved_methods is None:
preserved_methods = []
# Convert potential byte arrays into strings (if there is any) to pass type checking
# Here we use a new name as assigning it back to preserved_methods will invoke
# mypy errors (i.e. List[AnyStr] = List[str])
preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
bundled_inputs_methods = ['get_all_bundled_inputs', 'get_num_bundled_inputs', 'run_on_bundled_input']
if all([hasattr(script_module, method) for method in bundled_inputs_methods]):
preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_methods))
non_exist_methods = []
for method in preserved_methods_str:
if not hasattr(script_module, method):
non_exist_methods.append(method)
if non_exist_methods:
raise AttributeError(
'The following methods to preserve do not exist in script_module: {}'
.format(', '.join(non_exist_methods)))
backend = backend.lower()
if backend == 'cpu':
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(
script_module._c,
optimization_blocklist,
preserved_methods_str)
elif backend == 'vulkan':
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str)
elif backend == 'metal':
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str)
else:
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'")
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
"""
Args:
script_module: An instance of torch script module with type of ScriptModule
Returns:
lint_map: A list of dictionary that contains modules lints
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
lint_list = []
if not hasattr(script_module, "_generate_bundled_inputs_for_forward"):
lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before "
"saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."})
for name, param in script_module.named_parameters():
if param.requires_grad:
lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, "
"please set torch.no_grad() to reduce memory usage and improve computation speed during "
"inference phase.".format(name)})
op_names = torch.jit.export_opnames(script_module)
for op_name in op_names:
if "dropout" in op_name:
lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before "
"saving the module.and call torch.utils.mobile_optimizer.optimize_for_mobile to drop dropout "
"operator.".format(op_name)})
if "batch_norm" in op_name:
lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before "
"saving the module and call torch.utils.mobile_optimizer.optimize_for_mobile to drop batch_norm "
"operator.".format(op_name)})
return lint_list