pytorch/test/test_mobile_optimizer.py
Yanan Cao bdcf320bed Support custom exception message (#41907)
Summary:
Raise and assert used to have a hard-coded error message "Exception". User provided error message was ignored. This PR adds support to represent user's error message in TorchScript.

This breaks backward compatibility because now we actually need to script the user's error message, which can potentially contain unscriptable expressions. Such programs can break when scripting, but saved models can still continue to work.

Increased an op count in test_mobile_optimizer.py because now we need aten::format to form the actual exception message.

This is built upon an WIP PR:  https://github.com/pytorch/pytorch/pull/34112 by driazati

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41907

Reviewed By: ngimel

Differential Revision: D22778301

Pulled By: gmagogsfm

fbshipit-source-id: 2b94f0db4ae9fe70c4cd03f4048e519ea96323ad
2020-08-01 13:03:45 -07:00

212 lines
9.8 KiB
Python

import unittest
import torch
import torch.backends.xnnpack
import torch.utils.bundled_inputs
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
from torch.utils.mobile_optimizer import *
from torch.nn import functional as F
from torch._C import MobileOptimizerType
FileCheck = torch._C.FileCheck
class TestOptimizer(unittest.TestCase):
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
def test_optimize_for_mobile(self):
batch_size = 2
input_channels_per_group = 6
height = 16
width = 16
output_channels_per_group = 6
groups = 4
kernel_h = kernel_w = 3
stride_h = stride_w = 1
pad_h = pad_w = 1
dilation = 1
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
conv_bias_shape = (output_channels)
input_data = torch.rand((batch_size, input_channels, height, width))
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
conv_bias = torch.rand((output_channels))
result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
weight_output_dim = 24
linear_input_shape = result.shape[1]
linear_weight_shape = (weight_output_dim, linear_input_shape)
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)))
self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand((conv_bias_shape))))
self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(x, self.conv_weight, self.conv_bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.relu(o)
x = o.permute([0, 2, 3, 1])
o = F.linear(x, self.linear_weight, self.linear_bias)
o = o + x
return F.relu(o)
class BNTestModule(torch.nn.Module):
def __init__(self):
super(BNTestModule, self).__init__()
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
self.bn = torch.nn.BatchNorm2d(num_features=20)
self.bn.eps = 0.0023
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
data_shape = (batch_size, input_channels, height, width)
input_data = torch.normal(1, 20, size=data_shape)
scripted_model = torch.jit.script(MyTestModule())
scripted_model.eval()
initial_result = scripted_model(input_data)
optimized_scripted_model = optimize_for_mobile(scripted_model)
optimized_result = optimized_scripted_model(input_data)
FileCheck().check_not("Tensor = aten::conv2d") \
.check_not("Tensor = prim::CallFunction") \
.check_not("prepacked::conv2d_clamp_prepack") \
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
.check_not("prepacked::linear_clamp_prepack") \
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
.check_not("aten::add(") \
.check_not("aten::relu(") \
.check_count("aten::add_relu(", 1, exactly=True) \
.run(optimized_scripted_model.graph)
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blacklist_no_prepack)
optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
.check_not("prepacked::linear_clamp_run") \
.check_not("prepacked::conv2d_clamp_run") \
.run(optimized_scripted_model_no_prepack.graph)
torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
bn_test_module = BNTestModule()
bn_scripted_module = torch.jit.script(bn_test_module)
bn_scripted_module.eval()
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(bn_scripted_module._c).graph))
optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_prepack)
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
.run(str(get_forward_graph(bn_fold_scripted_module._c)))
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
optimization_blacklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_fold_bn)
FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
.run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
class MyPreserveMethodsTest(torch.nn.Module):
def __init__(self):
super(MyPreserveMethodsTest, self).__init__()
self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
def forward(self, x):
o = F.linear(x, self.linear_weight, self.linear_bias)
return F.relu(o)
@torch.jit.export
def preserveThis(self):
pass
preserve_method_module = MyPreserveMethodsTest()
m = torch.jit.script(preserve_method_module)
m.eval()
opt_m = optimize_for_mobile(m)
no_preserveThis = getattr(opt_m, "preserveThis", None)
self.assertEqual(no_preserveThis, None)
opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
preserveThis = getattr(opt_m, "preserveThis", None)
self.assertNotEqual(preserveThis, None)
def test_generate_mobile_module_lints(self):
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
self.fc = torch.nn.Linear(4, 4)
self.dropout = torch.nn.Dropout(p=0.5)
def forward(self, inputs):
out = self.fc(inputs)
out = self.dropout(out)
return out
class MyBNModule(torch.nn.Module):
def __init__(self):
super(MyBNModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(4, affine=True)
def forward(self, inputs):
bn = self.bn(inputs)
return bn
class MyBundledInputModule(torch.nn.Module):
def __init__(self):
super(MyBundledInputModule, self).__init__()
def forward(self, inputs):
return inputs
def get_lint_count_by_type(lint_type, module_lint_List):
return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
test_module = torch.jit.script(MyTestModule())
test_module_lint_list = generate_mobile_module_lints(test_module)
self.assertEqual(len(test_module_lint_list), 4)
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
bn_module = torch.jit.script(MyBNModule())
bn_module_lint_list = generate_mobile_module_lints(bn_module)
self.assertEqual(len(bn_module_lint_list), 4)
self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
bi_module = torch.jit.script(MyBundledInputModule())
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
bi_module, [(torch.tensor([1]),)], [])
bi_module_lint_list = generate_mobile_module_lints(bi_module)
self.assertEqual(len(bi_module_lint_list), 0)
if __name__ == '__main__':
unittest.main()