Replace Conv1d with Conv2d (#42867)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42867

Test Plan: Imported from OSS

Reviewed By: kimishpatel

Differential Revision: D23177916

Pulled By: kimishpatel

fbshipit-source-id: 68cc40cf42d03e5b8432dc08f9933a4409c76e25
This commit is contained in:
taivu 2020-08-20 21:34:48 -07:00 committed by Facebook GitHub Bot
parent e8139624f2
commit 665da61d2b
4 changed files with 231 additions and 0 deletions

View File

@ -5,12 +5,14 @@ import unittest
import torch
import torch.backends.xnnpack
from torch.nn import functional as F
from torch.utils.mobile_optimizer import optimize_for_mobile
from torch.testing import FileCheck
import torch.testing._internal.hypothesis_utils as hu
from torch.testing._internal.common_utils import TestCase, run_tests
from hypothesis import given, assume
from hypothesis import strategies as st
import io
import itertools
@unittest.skipUnless(torch.backends.xnnpack.enabled,
@ -704,6 +706,174 @@ class TestXNNPACKRewritePass(TestCase):
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
class TestXNNPACKConv1dTransformPass(TestCase):
@staticmethod
def validate_transform_conv1d_to_conv2d(
self,
pattern_count_transformed_map,
pattern_count_optimized_map,
data_shape):
module_instance = self
scripted_model = torch.jit.script(module_instance)
scripted_model.eval()
input_data = torch.normal(1, 20, size=data_shape)
ref_result = scripted_model(input_data)
torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
optimized_scripted_model = optimize_for_mobile(scripted_model)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
for pattern, v in pattern_count_transformed_map.items():
if (v == 0):
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif (v == -1):
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
else:
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
transformed_result = deserialized_scripted_model(input_data)
torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3)
optimized_buffer = io.BytesIO()
torch.jit.save(optimized_scripted_model, optimized_buffer)
optimized_buffer.seek(0)
deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
for pattern, v in pattern_count_optimized_map.items():
if (v == 0):
FileCheck().check(pattern).run(deserialized_optimized_scripted_model.graph)
elif (v == -1):
FileCheck().check_not(pattern).run(deserialized_optimized_scripted_model.graph)
else:
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph)
xnnpack_result = deserialized_optimized_scripted_model(input_data)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
def test_conv1d_basic(self):
batch_size_list = range(1, 3)
input_channels_per_group_list = range(10, 12)
width_list = range(10, 12)
output_channels_per_group_list = range(10, 12)
groups_list = range(1, 3)
kernel_list = range(1, 4)
stride_list = range(1, 3)
padding_list = range(0, 3)
dilation_list = range(1, 3)
for hparams in itertools.product(batch_size_list,
input_channels_per_group_list,
width_list,
output_channels_per_group_list,
groups_list,
kernel_list,
stride_list,
padding_list,
dilation_list):
batch_size, input_channels_per_group, width, output_channels_per_group, \
groups, kernel, stride, padding, dilation = hparams
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
conv_bias_shape = (output_channels)
class Conv1D(torch.nn.Module):
def __init__(self):
super(Conv1D, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
def forward(self, x):
return F.conv1d(x, self.weight, self.bias,
self.stride, self.padding, self.dilation, self.groups)
data_shape = (batch_size, input_channels, width)
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": 1}
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack" : -1,
"prepacked::conv2d_clamp_run": 1}
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Conv1D(),
pattern_count_transformed_map,
pattern_count_optimized_map,
data_shape)
def test_conv1d_with_relu_fc(self):
batch_size_list = range(1, 3)
input_channels_per_group_list = range(10, 12)
width_list = range(10, 12)
output_channels_per_group_list = range(10, 12)
groups_list = range(1, 3)
kernel_list = range(1, 4)
stride_list = range(1, 3)
padding_list = range(0, 3)
dilation_list = range(1, 3)
output_features_list = range(1, 3)
for hparams in itertools.product(batch_size_list,
input_channels_per_group_list,
width_list,
output_channels_per_group_list,
groups_list,
kernel_list,
stride_list,
padding_list,
dilation_list,
output_features_list):
batch_size, input_channels_per_group, width, output_channels_per_group, \
groups, kernel, stride, padding, dilation, output_features = hparams
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
conv_bias_shape = (output_channels)
conv_output_width = int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
fc_weight_shape = (output_features, output_channels * conv_output_width)
fc_bias_shape = (output_features)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.fc_weight = torch.nn.Parameter(torch.Tensor(torch.rand(fc_weight_shape)), requires_grad=False)
self.fc_bias = torch.nn.Parameter(torch.Tensor(torch.rand(fc_bias_shape)), requires_grad=False)
def forward(self, x):
x = F.conv1d(x, self.conv_weight, self.conv_bias,
self.stride, self.padding, self.dilation, self.groups)
x = F.relu(x)
x = x.view(x.size(0), -1)
x = F.linear(x, self.fc_weight, self.fc_bias)
return x
data_shape = (batch_size, input_channels, width)
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": 1}
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack" : -1,
"prepacked::conv2d_clamp_run": 1}
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Net(),
pattern_count_transformed_map,
pattern_count_optimized_map,
data_shape)
if __name__ == "__main__":
run_tests()

View File

@ -20,6 +20,55 @@
namespace torch {
namespace jit {
namespace {
void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
std::string conv_1d_pattern = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv_2d_pattern = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%zero : int = prim::Constant[value=0]()
%one : int = prim::Constant[value=1]()
%stride_w : int = prim::ListUnpack(%stride)
%stride_2d : int[] = prim::ListConstruct(%one, %stride_w)
%padding_w : int = prim::ListUnpack(%padding)
%padding_2d : int[] = prim::ListConstruct(%zero, %padding_w)
%dilation_w : int = prim::ListUnpack(%dilation)
%dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w)
%two : int = prim::Constant[value=2]()
%input_2d : Tensor = aten::unsqueeze(%input, %two)
%weight_2d : Tensor = aten::unsqueeze(%weight, %two)
%output_2d = aten::conv2d(
%input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups)
%output : Tensor = aten::squeeze(%output_2d, %two)
return (%output) )";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(conv_1d_pattern, conv_2d_pattern);
rewriter.runOnGraph(graph);
}
} // namespace
void transformConv1dToConv2d(std::shared_ptr<Graph>& graph) {
// Replace _convolution with conv1d and conv2d
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
replaceConv1dWithConv2d(graph);
}
void transformConv1dToConv2d(script::Module& module) {
for (auto& method : module.get_methods()) {
auto graph = method.graph();
transformConv1dToConv2d(graph);
}
for (script::Module m : module.children()) {
transformConv1dToConv2d(m);
}
}
#ifdef USE_XNNPACK
namespace {

View File

@ -14,6 +14,8 @@ enum class MobileOptimizerType : int8_t {
HOIST_CONV_PACKED_PARAMS,
};
TORCH_API void transformConv1dToConv2d(std::shared_ptr<Graph>& graph);
TORCH_API void transformConv1dToConv2d(script::Module& module);
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
TORCH_API void insertPrePackedOps(script::Module& module);
TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);

View File

@ -562,6 +562,16 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_remove_dropout",
[](script::Module& module) { return removeDropout(module); })
.def(
"_jit_pass_transform_conv1d_to_conv2d",
[](std::shared_ptr<Graph>& graph) {
return transformConv1dToConv2d(graph);
})
.def(
"_jit_pass_transform_conv1d_to_conv2d",
[](script::Module& module) {
return transformConv1dToConv2d(module);
})
.def(
"_jit_pass_insert_prepacked_ops",
[](std::shared_ptr<Graph>& graph) {