mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Test Plan: revert-hammer
Differential Revision:
D25757721 (b7bfc723d3)
Original commit changeset: 44c396d8da9e
fbshipit-source-id: 58437d719285a4fecd8c05e487cc86fc2cebadff
162 lines
6.5 KiB
Python
162 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union
|
|
import textwrap
|
|
import torch
|
|
from torch._C import TupleType, OptionalType, ListType
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
MAX_RAW_TENSOR_SIZE = 16
|
|
|
|
|
|
class InflatableArg(NamedTuple):
|
|
value: Any
|
|
fmt: str
|
|
|
|
|
|
def augment_model_with_bundled_inputs(
|
|
model: torch.jit.ScriptModule,
|
|
inputs: Optional[List[Tuple[Any, ...]]] = None,
|
|
_receive_inflate_expr: Optional[List[str]] = None, # For debugging.
|
|
) -> None:
|
|
"""Add bundled sample inputs to a model.
|
|
|
|
Models with bundled inputs can be invoked in a uniform manner by
|
|
benchmarking and code coverage tools.
|
|
|
|
Augmented models will support the following methods:
|
|
|
|
`get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
|
|
Returns a list of tuples suitable for passing to the model like
|
|
`for inp in model.get_all_bundled_inputs(): model(*inp)`
|
|
|
|
`get_num_bundled_inputs() -> int`
|
|
Equivalent to `len(model.get_all_bundled_inputs())`,
|
|
but slightly easier to call from C++.
|
|
|
|
`run_on_bundled_input(idx: int) -> Any`
|
|
Run the model on bundled input number `idx`
|
|
|
|
Inputs can be specified in one of two ways:
|
|
|
|
- The model can define `_generate_bundled_inputs`
|
|
get_all_bundled_inputs will simply call this method
|
|
and cache the value.
|
|
- The `inputs` argument to this function can be a list of tuples,
|
|
of the same form that will be returned by get_all_bundled_inputs.
|
|
This function will attempt to optimize arguments so that (e.g.)
|
|
arguments like `torch.zeros(1000)` will be represented compactly.
|
|
Only top-level arguments will be optimized.
|
|
Tensors in lists or tuples will not.
|
|
"""
|
|
if not isinstance(model, torch.jit.ScriptModule):
|
|
raise Exception("Only ScriptModule is supported.")
|
|
|
|
forward_arg_types = [arg.type for arg in model.forward.schema.arguments[1:]]
|
|
deflated_inputs_type: ListType = ListType(TupleType(forward_arg_types))
|
|
inflated_inputs_type: OptionalType[ListType] = OptionalType(deflated_inputs_type)
|
|
model._c._register_attribute("_bundled_inputs_deflated", deflated_inputs_type, [])
|
|
model._c._register_attribute("_bundled_inputs_inflated", inflated_inputs_type, None)
|
|
|
|
if hasattr(model, "_generate_bundled_inputs"):
|
|
if inputs is not None:
|
|
raise Exception(
|
|
"inputs is not None, but _generate_bundled_inputs is already defined")
|
|
# Model author already defined _generate_bundled_inputs.
|
|
elif inputs is None:
|
|
raise Exception(
|
|
"inputs must be specified if _generate_bundled_inputs is not already defined")
|
|
else:
|
|
# Iterate over the inputs and args in each input.
|
|
# Accumulate `deflated_inputs` as (possibly) compressed values
|
|
# and `parts` to be joined into the expression that unpacks them.
|
|
deflated_inputs = []
|
|
parts = []
|
|
for inp_idx, args in enumerate(inputs):
|
|
deflated_args = []
|
|
parts.append("(")
|
|
for arg_idx, arg in enumerate(args):
|
|
deflated, inflater = _inflate_expr(arg, f"deflated[{inp_idx}][{arg_idx}]")
|
|
deflated_args.append(deflated)
|
|
parts.append(f" {inflater},")
|
|
deflated_inputs.append(tuple(deflated_args))
|
|
parts.append("),")
|
|
parts.append("")
|
|
expr = "\n".join(parts)
|
|
# Back-channel return this expr for debugging.
|
|
if _receive_inflate_expr is not None:
|
|
_receive_inflate_expr.append(expr)
|
|
model._bundled_inputs_deflated = deflated_inputs
|
|
definition = textwrap.dedent("""
|
|
def _generate_bundled_inputs(self):
|
|
deflated = self._bundled_inputs_deflated
|
|
return [
|
|
{}
|
|
]
|
|
""").format(expr)
|
|
model.define(definition)
|
|
|
|
# Define get_all_bundled_inputs that caches the generated inputs.
|
|
model.define(textwrap.dedent("""
|
|
def get_all_bundled_inputs(self):
|
|
if self._bundled_inputs_inflated is None:
|
|
self._bundled_inputs_inflated = self._generate_bundled_inputs()
|
|
all_inputs = self._bundled_inputs_inflated
|
|
assert all_inputs is not None
|
|
return all_inputs
|
|
"""))
|
|
|
|
# Define some helper methods.
|
|
model.define(textwrap.dedent("""
|
|
def get_num_bundled_inputs(self):
|
|
return len(self.get_all_bundled_inputs())
|
|
"""))
|
|
model.define(textwrap.dedent("""
|
|
def run_on_bundled_input(self, idx: int):
|
|
return self(*self.get_all_bundled_inputs()[idx])
|
|
"""))
|
|
|
|
|
|
def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]:
|
|
# Allow custom inflation expressions any object.
|
|
# For example, calling custom image-decoding ops.
|
|
# Or just use "{}" as the format string to ignore size limits.
|
|
if isinstance(arg, InflatableArg):
|
|
return arg.value, arg.fmt.format(ref)
|
|
|
|
if isinstance(arg, torch.Tensor):
|
|
# Small-storage tensors can just be saved directly.
|
|
if arg.storage().size() <= MAX_RAW_TENSOR_SIZE:
|
|
return arg, ref
|
|
# Small contiguous tensors can be cloned to have small storage.
|
|
# TODO: Should we do this even for non-contiguous tensors?
|
|
if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE:
|
|
return arg.clone(), ref
|
|
# Example inputs commonly come from torch.zeros, torch.ones, or torch.full.
|
|
# These can be represented compactly.
|
|
for fmt in [torch.contiguous_format, torch.channels_last]:
|
|
if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item():
|
|
return (torch.tensor([arg.flatten()[0]]).expand(*arg.size()),
|
|
f"{ref}.contiguous(memory_format={fmt})")
|
|
# Prevent big tensors from being bundled by default.
|
|
# TODO: Provide more useful diagnostics.
|
|
raise Exception(
|
|
f"Bundled input argument at position '{ref}' is "
|
|
f"a tensor with storage size {arg.storage().size()}. "
|
|
f"You probably don't want to bundle this as an input. "
|
|
)
|
|
else:
|
|
return arg, ref
|
|
|
|
|
|
def bundle_randn(*size, dtype=None):
|
|
"""Generate a tensor that will be inflated with torch.randn."""
|
|
stub = torch.zeros(1, dtype=dtype).expand(*size)
|
|
return InflatableArg(value=stub, fmt="torch.randn_like({})")
|
|
|
|
|
|
def bundle_large_tensor(t):
|
|
"""Wrap a tensor to allow bundling regardless of size."""
|
|
return InflatableArg(value=t, fmt="{}")
|