mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35631 Bundling sample inputs with our models with a standardized interface will make it possible to write benchmarking and code-coverage tools that call all models in a uniform way. The intent is to make this a standard for mobile models within Facebook. Putting it in torch/utils so tests can run on GitHub and because it might be useful for others as well. `augment_model_with_bundled_inputs` is the primary entry point. See its docstring for usage information and the test for some example uses. One design question I had was how much power should be available for automatic deflating and inflating of inputs. The current scheme gives some automatic handling and a reasonable escape hatch ("_bundled_input_inflate_format") for top-level tensor arguments, but no automatic support for (e.g.) tensors in tuples or long strings. For more complex cases, we have the ultimate escape hatch of just defining _generate_bundled_inputs in the model. Another design question was whether to add the inputs to the model or wrap the model in a wrapper module that had these methods and delegated calls to `forward`. Because models can have other exposed methods and attributes, the wrapped seemed too onerous. Test Plan: Unit test. Differential Revision: D20925013 Pulled By: dreiss fbshipit-source-id: 4dbbb4cce41e5752133b4ecdb05e1c92bac6b2d5
137 lines
4.9 KiB
Python
137 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
import io
|
|
import torch
|
|
import torch.utils.bundled_inputs
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
|
|
def model_size(sm):
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(sm, buffer)
|
|
return len(buffer.getvalue())
|
|
|
|
|
|
def save_and_load(sm):
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(sm, buffer)
|
|
buffer.seek(0)
|
|
return torch.jit.load(buffer)
|
|
|
|
|
|
class TestBundledInputs(TestCase):
|
|
|
|
def test_single_tensors(self):
|
|
class SingleTensorModel(torch.nn.Module):
|
|
def forward(self, arg):
|
|
return arg
|
|
|
|
sm = torch.jit.script(SingleTensorModel())
|
|
original_size = model_size(sm)
|
|
get_expr = []
|
|
samples = [
|
|
# Tensor with small numel and small storage.
|
|
(torch.tensor([1]),),
|
|
# Tensor with large numel and small storage.
|
|
(torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
|
|
# Tensor with small numel and large storage.
|
|
(torch.tensor(range(1 << 16))[-8:],),
|
|
# Large zero tensor.
|
|
(torch.zeros(1 << 16),),
|
|
# Special encoding of random tensor.
|
|
(torch.utils.bundled_inputs.bundle_randn(1 << 16),),
|
|
]
|
|
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
|
sm, samples, get_expr)
|
|
# print(get_expr[0])
|
|
# print(sm._generate_bundled_inputs.code)
|
|
|
|
# Make sure the model only grew a little bit,
|
|
# despite having nominally large bundled inputs.
|
|
augmented_size = model_size(sm)
|
|
self.assertLess(augmented_size, original_size + (1 << 12))
|
|
|
|
loaded = save_and_load(sm)
|
|
inflated = loaded.get_all_bundled_inputs()
|
|
self.assertEqual(loaded.get_num_bundled_inputs(), 5)
|
|
self.assertEqual(len(inflated), 5)
|
|
self.assertTrue(loaded.run_on_bundled_input(0) is inflated[0][0])
|
|
|
|
for idx, inp in enumerate(inflated):
|
|
self.assertIsInstance(inp, tuple)
|
|
self.assertEqual(len(inp), 1)
|
|
self.assertIsInstance(inp[0], torch.Tensor)
|
|
if idx != 4:
|
|
# Strides might be important for benchmarking.
|
|
self.assertEqual(inp[0].stride(), samples[idx][0].stride())
|
|
self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
|
|
|
|
# This tensor is random, but with 100,000 trials,
|
|
# mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
|
|
self.assertEqual(inflated[4][0].shape, (1 << 16,))
|
|
self.assertAlmostEqual(inflated[4][0].mean().item(), 0, delta=0.025)
|
|
self.assertAlmostEqual(inflated[4][0].std().item(), 1, delta=0.02)
|
|
|
|
|
|
def test_large_tensor_with_inflation(self):
|
|
class SingleTensorModel(torch.nn.Module):
|
|
def forward(self, arg):
|
|
return arg
|
|
sm = torch.jit.script(SingleTensorModel())
|
|
sample_tensor = torch.randn(1 << 16)
|
|
# We can store tensors with custom inflation functions regardless
|
|
# of size, even if inflation is just the identity.
|
|
sample = torch.utils.bundled_inputs.InflatableArg(
|
|
value=sample_tensor, fmt="{}")
|
|
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
|
sm, [(sample,)])
|
|
|
|
loaded = save_and_load(sm)
|
|
inflated = loaded.get_all_bundled_inputs()
|
|
self.assertEqual(len(inflated), 1)
|
|
|
|
self.assertEqual(inflated[0][0], sample_tensor)
|
|
|
|
|
|
def test_rejected_tensors(self):
|
|
def check_tensor(sample):
|
|
# Need to define the class in this scope to get a fresh type for each run.
|
|
class SingleTensorModel(torch.nn.Module):
|
|
def forward(self, arg):
|
|
return arg
|
|
sm = torch.jit.script(SingleTensorModel())
|
|
with self.assertRaisesRegex(Exception, "Bundled input argument"):
|
|
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
|
sm, [(sample,)])
|
|
|
|
# Plain old big tensor.
|
|
check_tensor(torch.randn(1 << 16))
|
|
# This tensor has two elements, but they're far apart in memory.
|
|
# We currently cannot represent this compactly while preserving
|
|
# the strides.
|
|
small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
|
|
self.assertEqual(small_sparse.numel(), 2)
|
|
check_tensor(small_sparse)
|
|
|
|
|
|
def test_non_tensors(self):
|
|
class StringAndIntModel(torch.nn.Module):
|
|
def forward(self, fmt: str, num: int):
|
|
return fmt.format(num)
|
|
|
|
sm = torch.jit.script(StringAndIntModel())
|
|
samples = [
|
|
("first {}", 1),
|
|
("second {}", 2),
|
|
]
|
|
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
|
sm, samples)
|
|
|
|
loaded = save_and_load(sm)
|
|
inflated = loaded.get_all_bundled_inputs()
|
|
self.assertEqual(inflated, samples)
|
|
self.assertTrue(loaded.run_on_bundled_input(0) == "first 1")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|