diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py index b26e5568edb..42ae19aa4a6 100644 --- a/test/test_xnnpack_integration.py +++ b/test/test_xnnpack_integration.py @@ -5,12 +5,14 @@ import unittest import torch import torch.backends.xnnpack from torch.nn import functional as F +from torch.utils.mobile_optimizer import optimize_for_mobile from torch.testing import FileCheck import torch.testing._internal.hypothesis_utils as hu from torch.testing._internal.common_utils import TestCase, run_tests from hypothesis import given, assume from hypothesis import strategies as st import io +import itertools @unittest.skipUnless(torch.backends.xnnpack.enabled, @@ -704,6 +706,174 @@ class TestXNNPACKRewritePass(TestCase): TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape) TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape) +@unittest.skipUnless(torch.backends.xnnpack.enabled, + " XNNPACK must be enabled for these tests." + " Please build with USE_XNNPACK=1.") +class TestXNNPACKConv1dTransformPass(TestCase): + @staticmethod + def validate_transform_conv1d_to_conv2d( + self, + pattern_count_transformed_map, + pattern_count_optimized_map, + data_shape): + module_instance = self + scripted_model = torch.jit.script(module_instance) + scripted_model.eval() + input_data = torch.normal(1, 20, size=data_shape) + ref_result = scripted_model(input_data) + torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c) + optimized_scripted_model = optimize_for_mobile(scripted_model) + + buffer = io.BytesIO() + torch.jit.save(scripted_model, buffer) + buffer.seek(0) + deserialized_scripted_model = torch.jit.load(buffer) + + for pattern, v in pattern_count_transformed_map.items(): + if (v == 0): + FileCheck().check(pattern).run(deserialized_scripted_model.graph) + elif (v == -1): + FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) + else: + FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) + transformed_result = deserialized_scripted_model(input_data) + torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3) + + optimized_buffer = io.BytesIO() + torch.jit.save(optimized_scripted_model, optimized_buffer) + optimized_buffer.seek(0) + deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer) + + for pattern, v in pattern_count_optimized_map.items(): + if (v == 0): + FileCheck().check(pattern).run(deserialized_optimized_scripted_model.graph) + elif (v == -1): + FileCheck().check_not(pattern).run(deserialized_optimized_scripted_model.graph) + else: + FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph) + xnnpack_result = deserialized_optimized_scripted_model(input_data) + torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + + def test_conv1d_basic(self): + batch_size_list = range(1, 3) + input_channels_per_group_list = range(10, 12) + width_list = range(10, 12) + output_channels_per_group_list = range(10, 12) + groups_list = range(1, 3) + kernel_list = range(1, 4) + stride_list = range(1, 3) + padding_list = range(0, 3) + dilation_list = range(1, 3) + + for hparams in itertools.product(batch_size_list, + input_channels_per_group_list, + width_list, + output_channels_per_group_list, + groups_list, + kernel_list, + stride_list, + padding_list, + dilation_list): + batch_size, input_channels_per_group, width, output_channels_per_group, \ + groups, kernel, stride, padding, dilation = hparams + + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + conv_weight_shape = (output_channels, input_channels_per_group, kernel) + conv_bias_shape = (output_channels) + + class Conv1D(torch.nn.Module): + def __init__(self): + super(Conv1D, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward(self, x): + return F.conv1d(x, self.weight, self.bias, + self.stride, self.padding, self.dilation, self.groups) + + data_shape = (batch_size, input_channels, width) + pattern_count_transformed_map = {"Tensor = aten::conv1d": -1, + "Tensor = aten::conv2d": 1} + pattern_count_optimized_map = {"Tensor = aten::conv1d": -1, + "Tensor = aten::conv2d": -1, + "prepacked::conv2d_clamp_prepack" : -1, + "prepacked::conv2d_clamp_run": 1} + + TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Conv1D(), + pattern_count_transformed_map, + pattern_count_optimized_map, + data_shape) + + def test_conv1d_with_relu_fc(self): + batch_size_list = range(1, 3) + input_channels_per_group_list = range(10, 12) + width_list = range(10, 12) + output_channels_per_group_list = range(10, 12) + groups_list = range(1, 3) + kernel_list = range(1, 4) + stride_list = range(1, 3) + padding_list = range(0, 3) + dilation_list = range(1, 3) + output_features_list = range(1, 3) + + for hparams in itertools.product(batch_size_list, + input_channels_per_group_list, + width_list, + output_channels_per_group_list, + groups_list, + kernel_list, + stride_list, + padding_list, + dilation_list, + output_features_list): + batch_size, input_channels_per_group, width, output_channels_per_group, \ + groups, kernel, stride, padding, dilation, output_features = hparams + + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + conv_weight_shape = (output_channels, input_channels_per_group, kernel) + conv_bias_shape = (output_channels) + conv_output_width = int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1 + fc_weight_shape = (output_features, output_channels * conv_output_width) + fc_bias_shape = (output_features) + + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + self.fc_weight = torch.nn.Parameter(torch.Tensor(torch.rand(fc_weight_shape)), requires_grad=False) + self.fc_bias = torch.nn.Parameter(torch.Tensor(torch.rand(fc_bias_shape)), requires_grad=False) + + def forward(self, x): + x = F.conv1d(x, self.conv_weight, self.conv_bias, + self.stride, self.padding, self.dilation, self.groups) + x = F.relu(x) + x = x.view(x.size(0), -1) + x = F.linear(x, self.fc_weight, self.fc_bias) + return x + + data_shape = (batch_size, input_channels, width) + pattern_count_transformed_map = {"Tensor = aten::conv1d": -1, + "Tensor = aten::conv2d": 1} + pattern_count_optimized_map = {"Tensor = aten::conv1d": -1, + "Tensor = aten::conv2d": -1, + "prepacked::conv2d_clamp_prepack" : -1, + "prepacked::conv2d_clamp_run": 1} + TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Net(), + pattern_count_transformed_map, + pattern_count_optimized_map, + data_shape) if __name__ == "__main__": run_tests() diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index 31769d48894..83ba6ea5166 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -20,6 +20,55 @@ namespace torch { namespace jit { +namespace { + +void replaceConv1dWithConv2d(std::shared_ptr& graph) { + std::string conv_1d_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): + %r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) + return (%r) )"; + + std::string conv_2d_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): + %zero : int = prim::Constant[value=0]() + %one : int = prim::Constant[value=1]() + %stride_w : int = prim::ListUnpack(%stride) + %stride_2d : int[] = prim::ListConstruct(%one, %stride_w) + %padding_w : int = prim::ListUnpack(%padding) + %padding_2d : int[] = prim::ListConstruct(%zero, %padding_w) + %dilation_w : int = prim::ListUnpack(%dilation) + %dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w) + %two : int = prim::Constant[value=2]() + %input_2d : Tensor = aten::unsqueeze(%input, %two) + %weight_2d : Tensor = aten::unsqueeze(%weight, %two) + %output_2d = aten::conv2d( + %input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups) + %output : Tensor = aten::squeeze(%output_2d, %two) + return (%output) )"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(conv_1d_pattern, conv_2d_pattern); + rewriter.runOnGraph(graph); +} + +} // namespace + +void transformConv1dToConv2d(std::shared_ptr& graph) { + // Replace _convolution with conv1d and conv2d + graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); + replaceConv1dWithConv2d(graph); +} + +void transformConv1dToConv2d(script::Module& module) { + for (auto& method : module.get_methods()) { + auto graph = method.graph(); + transformConv1dToConv2d(graph); + } + for (script::Module m : module.children()) { + transformConv1dToConv2d(m); + } +} + #ifdef USE_XNNPACK namespace { diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h index c5df200cf6a..2b1b675bc06 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.h +++ b/torch/csrc/jit/passes/xnnpack_rewrite.h @@ -14,6 +14,8 @@ enum class MobileOptimizerType : int8_t { HOIST_CONV_PACKED_PARAMS, }; +TORCH_API void transformConv1dToConv2d(std::shared_ptr& graph); +TORCH_API void transformConv1dToConv2d(script::Module& module); TORCH_API void insertPrePackedOps(std::shared_ptr& graph); TORCH_API void insertPrePackedOps(script::Module& module); TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index c8674a515c7..4efb1d3ce50 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -562,6 +562,16 @@ void initJITBindings(PyObject* module) { .def( "_jit_pass_remove_dropout", [](script::Module& module) { return removeDropout(module); }) + .def( + "_jit_pass_transform_conv1d_to_conv2d", + [](std::shared_ptr& graph) { + return transformConv1dToConv2d(graph); + }) + .def( + "_jit_pass_transform_conv1d_to_conv2d", + [](script::Module& module) { + return transformConv1dToConv2d(module); + }) .def( "_jit_pass_insert_prepacked_ops", [](std::shared_ptr& graph) {