pytorch/test/test_mobile_optimizer.py
Aaron Gokaslan 88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00

621 lines
26 KiB
Python

# Owner(s): ["oncall: mobile"]
import unittest
import torch
import torch.nn as nn
import torch.utils.bundled_inputs
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
from torch.utils.mobile_optimizer import (LintCode,
generate_mobile_module_lints,
optimize_for_mobile,
MobileOptimizerType)
from torch.nn import functional as F
from torch.testing._internal.common_quantized import override_quantized_engine
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
FileCheck = torch._C.FileCheck
class TestOptimizer(TestCase):
@skipIfNoXNNPACK
def test_optimize_for_mobile(self):
batch_size = 2
input_channels_per_group = 6
height = 16
width = 16
output_channels_per_group = 6
groups = 4
kernel_h = kernel_w = 3
stride_h = stride_w = 1
pad_h = pad_w = 1
dilation = 1
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
conv_bias_shape = (output_channels)
input_data = torch.rand((batch_size, input_channels, height, width))
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
conv_bias = torch.rand(output_channels)
result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
weight_output_dim = 24
linear_input_shape = result.shape[1]
linear_weight_shape = (weight_output_dim, linear_input_shape)
class MyTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape))
self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape))
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(x, self.conv_weight, self.conv_bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.relu(o)
x = o.permute([0, 2, 3, 1])
o = F.linear(x, self.linear_weight, self.linear_bias)
o = o + x
return F.relu(o)
@torch.jit.export
def foo(self, x):
o = F.conv2d(x, self.conv_weight, self.conv_bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.relu(o)
x = o.permute([0, 2, 3, 1])
o = F.linear(x, self.linear_weight, self.linear_bias)
o = o + x
return F.relu(o)
class BNTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
self.bn = torch.nn.BatchNorm2d(num_features=20)
self.bn.eps = 0.0023
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
data_shape = (batch_size, input_channels, height, width)
input_data = torch.normal(1, 20, size=data_shape)
scripted_model = torch.jit.script(MyTestModule())
scripted_model.eval()
initial_result = scripted_model(input_data)
initial_foo_result = scripted_model.foo(input_data)
optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
optimized_result = optimized_scripted_model(input_data)
optimized_foo_result = optimized_scripted_model.foo(input_data)
FileCheck().check_not("Tensor = aten::conv2d") \
.check_not("Tensor = prim::CallFunction") \
.check_not("prepacked::conv2d_clamp_prepack") \
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
.check_not("prepacked::linear_clamp_prepack") \
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
.check_not("aten::add(") \
.check_not("aten::relu(") \
.check_count("aten::_add_relu(", 1, exactly=True) \
.run(optimized_scripted_model.graph)
torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
FileCheck().check_not("Tensor = aten::conv2d") \
.check_not("Tensor = prim::CallFunction") \
.check_not("prepacked::conv2d_clamp_prepack") \
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
.check_not("prepacked::linear_clamp_prepack") \
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
.check_not("aten::add(") \
.check_not("aten::relu(") \
.check_count("aten::_add_relu(", 1, exactly=True) \
.run(optimized_scripted_model.foo.graph)
torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
.check_not("prepacked::linear_clamp_run") \
.check_not("prepacked::conv2d_clamp_run") \
.run(optimized_scripted_model_no_prepack.graph)
torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
bn_test_module = BNTestModule()
bn_scripted_module = torch.jit.script(bn_test_module)
bn_scripted_module.eval()
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(bn_scripted_module._c).graph))
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
.run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
class MyMobileOptimizedTagTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
def forward(self, x):
o = F.linear(x, self.linear_weight, self.linear_bias)
return F.relu(o)
mobile_optimized_tag_module = MyMobileOptimizedTagTest()
m = torch.jit.script(mobile_optimized_tag_module)
m.eval()
opt_m = optimize_for_mobile(m)
tag = getattr(opt_m, "mobile_optimized", None)
self.assertTrue(tag)
class MyPreserveMethodsTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
def forward(self, x):
o = F.linear(x, self.linear_weight, self.linear_bias)
return F.relu(o)
@torch.jit.export
def preserveThis(self):
pass
preserve_method_module = MyPreserveMethodsTest()
m = torch.jit.script(preserve_method_module)
m.eval()
opt_m = optimize_for_mobile(m)
no_preserveThis = getattr(opt_m, "preserveThis", None)
self.assertEqual(no_preserveThis, None)
opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
preserveThis = getattr(opt_m, "preserveThis", None)
self.assertNotEqual(preserveThis, None)
class OptimizeNoForwardTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = nn.Linear(10, 100)
self.l2 = nn.Linear(100, 1)
self.d = nn.Dropout(p=0.2)
@torch.jit.export
def foo(self, x):
x = self.d(F.relu(self.l(x)))
x = self.l2(x)
x = x + torch.ones(1, 100)
return F.relu(x)
input_data = torch.ones(1, 10)
m = torch.jit.script(OptimizeNoForwardTest())
m.eval()
initial_result = m.foo(input_data)
optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
optimized_result = optimized_scripted_model.foo(input_data)
FileCheck().check_not("dropout.__") \
.check_count("aten::_add_relu(", 1, exactly=True) \
.run(optimized_scripted_model.foo.graph)
torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
class BNTestNoForwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
self.bn = torch.nn.BatchNorm2d(num_features=20)
self.bn.eps = 0.0023
@torch.jit.export
def foo(self, x):
x = self.conv(x)
x = self.bn(x)
return x
bn_test_no_forward_module = BNTestNoForwardModule()
bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
bn_no_forward_scripted_module.eval()
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(bn_no_forward_scripted_module.foo.graph)
bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1)
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_close(
bn_no_forward_scripted_module.foo(bn_input),
bn_fold_no_forward_scripted_module.foo(bn_input),
rtol=1e-2,
atol=1e-3)
@skipIfNoXNNPACK
def test_quantized_conv_no_asan_failures(self):
# There were ASAN failures when fold_conv_bn was run on
# already quantized conv modules. Verifying that this does
# not happen again.
if 'qnnpack' not in torch.backends.quantized.supported_engines:
return
class Child(nn.Module):
def __init__(self):
super().__init__()
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv2(x)
return x
class Parent(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv1 = nn.Conv2d(1, 1, 1)
self.child = Child()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.child(x)
x = self.dequant(x)
return x
with override_quantized_engine('qnnpack'):
model = Parent()
model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
torch.ao.quantization.prepare(model, inplace=True)
model(torch.randn(4, 1, 4, 4))
torch.ao.quantization.convert(model, inplace=True)
model = torch.jit.script(model)
# this line should not have ASAN failures
model_optim = optimize_for_mobile(model)
def test_generate_mobile_module_lints(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(4, 4)
self.dropout = torch.nn.Dropout(p=0.5)
def forward(self, inputs):
out = self.fc(inputs)
out = self.dropout(out)
return out
class MyBNModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(4, affine=True)
def forward(self, inputs):
bn = self.bn(inputs)
return bn
class MyBundledInputModule(torch.nn.Module):
def forward(self, inputs):
return inputs
def get_lint_count_by_type(lint_type, module_lint_List):
return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
test_module = torch.jit.script(MyTestModule())
test_module_lint_list = generate_mobile_module_lints(test_module)
self.assertEqual(len(test_module_lint_list), 4)
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
bn_module = torch.jit.script(MyBNModule())
bn_module_lint_list = generate_mobile_module_lints(bn_module)
self.assertEqual(len(bn_module_lint_list), 4)
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
bi_module = torch.jit.script(MyBundledInputModule())
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
bi_module, [(torch.tensor([1]),)], [])
bi_module_lint_list = generate_mobile_module_lints(bi_module)
self.assertEqual(len(bi_module_lint_list), 0)
@skipIfNoXNNPACK
def test_preserve_bundled_inputs_methods(self):
class MyBundledInputModule(torch.nn.Module):
def forward(self, inputs):
return inputs
class MyIncompleteBundledInputModule(torch.nn.Module):
def forward(self, inputs):
return inputs
@torch.jit.export
def get_all_bundled_inputs(self):
pass
bi_module = torch.jit.script(MyBundledInputModule())
module_optim_bi_not_preserved = optimize_for_mobile(bi_module)
# Expected to be False since no bundled inputs methods were added
self.assertFalse(
hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs')
)
# Add bundled inputs methods to the module
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
bi_module, [(torch.tensor([1]),)], [])
# Now they should be preserved
module_optim_bi_preserved = optimize_for_mobile(bi_module)
# All of the bundled inputs methods were preserved
self.assertTrue(
hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs')
)
bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
module_optim_bi_preserved(*bundled_input)
# If not all 3 bundled inputs methods are present in the module,
# we will not try to preserve them unless specified by the user.
incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
# Specifically preserve get_all_bundled_inputs even if it's the only one
# bundled inputs method available.
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
@skipIfNoXNNPACK
def test_hoist_conv_packed_params(self):
if 'qnnpack' not in torch.backends.quantized.supported_engines:
return
class Standalone(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.relu = nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.relu(x)
x = self.dequant(x)
return x
def fuse_model(self):
torch.ao.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
pass
class Child(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
return x
class Parent(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv1 = nn.Conv2d(1, 1, 1)
self.child = Child()
# TODO: test nn.Sequential after #42039 is fixed
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.child(x)
x = self.dequant(x)
return x
def fuse_model(self):
pass
with override_quantized_engine('qnnpack'):
def _quant_script_and_optimize(model):
model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
model.fuse_model()
torch.ao.quantization.prepare(model, inplace=True)
model(torch.randn(4, 1, 4, 4))
torch.ao.quantization.convert(model, inplace=True)
model = torch.jit.script(model)
model_optim = optimize_for_mobile(model)
return model, model_optim
# basic case
m, m_optim = _quant_script_and_optimize(Standalone())
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
.run(m_optim.graph)
self.assertFalse(hasattr(m_optim, "conv1"))
self.assertFalse(hasattr(m_optim, "conv2"))
data = torch.randn(4, 1, 4, 4)
m_res = m(data)
m_optim_res = m_optim(data)
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
# generic case
m, m_optim = _quant_script_and_optimize(Parent())
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
.run(m_optim.graph)
self.assertFalse(hasattr(m_optim, "conv1"))
self.assertFalse(hasattr(m_optim, "child"))
data = torch.randn(4, 1, 4, 4)
m_res = m(data)
m_optim_res = m_optim(data)
torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
@skipIfNoXNNPACK
@unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
def test_mobilenet_optimize_for_mobile(self):
m = torchvision.models.mobilenet_v3_small()
m = torch.jit.script(m)
m = optimize_for_mobile(m)
# run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463
x = torch.zeros(1, 3, 56, 56)
self.assertEqual(m(x).numel(), 1000)
self.assertEqual(m(x).numel(), 1000)
self.assertEqual(m(x).numel(), 1000)
def test_clone_module_with_class(self):
class MyInnerTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pqr = torch.Tensor([10., 20., 30.])
def forward(self, inputs):
return inputs
@torch.jit.export
def dummy_method_not_cloned(self):
return 20
class MyTestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.abc = 23
self.pqr = torch.Tensor([1., 2., 3.])
self.inner = MyInnerTestModule()
def forward(self, inputs):
x = self.dummy_method_cloned()
# The call to self.inner.dummy_method_not_cloned should not raise an error
y = self.inner.dummy_method_not_cloned()
# The call to self.inner.pqr should not raise an error
z = self.inner.pqr
return (inputs, x, y, z)
@torch.jit.export
def dummy_method_not_cloned2(self):
# The call to self.inner.dummy_method_not_cloned should not raise an error
y = self.inner.dummy_method_not_cloned()
# The call to self.inner.pqr should not raise an error
z = self.inner.pqr
return self.pqr, self.dummy_method_not_cloned(), y, z
@torch.jit.export
def dummy_method_not_cloned(self):
return None
@torch.jit.export
def dummy_method_cloned(self):
return None
@torch.jit.export
def dummy_method_ref_attr_pqr(self):
return self.pqr, self.inner.pqr
m = torch.jit.script(MyTestModule())
# Check that the methods exist on the original model.
self.assertEqual(hasattr(m, "dummy_method_not_cloned"), True)
self.assertEqual(hasattr(m, "dummy_method_cloned"), True)
self.assertEqual(hasattr(m, "dummy_method_not_cloned2"), True)
self.assertEqual(hasattr(m, "pqr"), True)
# Case-1: Successfully clone, ignoring 2 methods, keeping all attributes.
cloned = torch._C._hack_do_not_use_clone_module_with_class(
m._c,
["dummy_method_not_cloned", "dummy_method_not_cloned2"], # ignored_methods
[], # ignored_attributes
)
# Check that the ignored methods don't exist on the cloned model.
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
self.assertEqual(hasattr(cloned, "pqr"), True)
# Check that the cloned class has a classname that starts with __torch__.
self.assertTrue(
cloned.qualified_name.startswith('__torch__.'),
("Expected the cloned module's name to start with the string "
f"'__torch__.', but got: {cloned.qualified_name}"),
)
# Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it.
cloned = torch._C._hack_do_not_use_clone_module_with_class(
m._c,
["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"],
["pqr"],
)
# Check that the ignored methods don't exist on the cloned model.
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False)
self.assertEqual(hasattr(cloned, "pqr"), False)
# Case-3: The statement below will throw since dummy_method_cloned2 is preserved,
# and references dummy_method_not_cloned, which is not cloned.
with self.assertRaises(RuntimeError):
cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], [])
# Case-4: The statement below will throw since dummy_method_ref_attr_pqr
# is preserved, and references "pqr", which is not cloned.
with self.assertRaises(RuntimeError):
cloned = torch._C._hack_do_not_use_clone_module_with_class(
m._c,
["dummy_method_not_cloned", "dummy_method_not_cloned2"],
["pqr"],
)
if __name__ == '__main__':
run_tests()