mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Replace Conv1d with Conv2d (#42867)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42867 Test Plan: Imported from OSS Reviewed By: kimishpatel Differential Revision: D23177916 Pulled By: kimishpatel fbshipit-source-id: 68cc40cf42d03e5b8432dc08f9933a4409c76e25
This commit is contained in:
parent
e8139624f2
commit
665da61d2b
|
|
@ -5,12 +5,14 @@ import unittest
|
|||
import torch
|
||||
import torch.backends.xnnpack
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
from torch.testing import FileCheck
|
||||
import torch.testing._internal.hypothesis_utils as hu
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from hypothesis import given, assume
|
||||
from hypothesis import strategies as st
|
||||
import io
|
||||
import itertools
|
||||
|
||||
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
|
|
@ -704,6 +706,174 @@ class TestXNNPACKRewritePass(TestCase):
|
|||
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
|
||||
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
|
||||
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests."
|
||||
" Please build with USE_XNNPACK=1.")
|
||||
class TestXNNPACKConv1dTransformPass(TestCase):
|
||||
@staticmethod
|
||||
def validate_transform_conv1d_to_conv2d(
|
||||
self,
|
||||
pattern_count_transformed_map,
|
||||
pattern_count_optimized_map,
|
||||
data_shape):
|
||||
module_instance = self
|
||||
scripted_model = torch.jit.script(module_instance)
|
||||
scripted_model.eval()
|
||||
input_data = torch.normal(1, 20, size=data_shape)
|
||||
ref_result = scripted_model(input_data)
|
||||
torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
|
||||
optimized_scripted_model = optimize_for_mobile(scripted_model)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_model, buffer)
|
||||
buffer.seek(0)
|
||||
deserialized_scripted_model = torch.jit.load(buffer)
|
||||
|
||||
for pattern, v in pattern_count_transformed_map.items():
|
||||
if (v == 0):
|
||||
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
||||
elif (v == -1):
|
||||
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
||||
else:
|
||||
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
|
||||
transformed_result = deserialized_scripted_model(input_data)
|
||||
torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
optimized_buffer = io.BytesIO()
|
||||
torch.jit.save(optimized_scripted_model, optimized_buffer)
|
||||
optimized_buffer.seek(0)
|
||||
deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
|
||||
|
||||
for pattern, v in pattern_count_optimized_map.items():
|
||||
if (v == 0):
|
||||
FileCheck().check(pattern).run(deserialized_optimized_scripted_model.graph)
|
||||
elif (v == -1):
|
||||
FileCheck().check_not(pattern).run(deserialized_optimized_scripted_model.graph)
|
||||
else:
|
||||
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph)
|
||||
xnnpack_result = deserialized_optimized_scripted_model(input_data)
|
||||
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
def test_conv1d_basic(self):
|
||||
batch_size_list = range(1, 3)
|
||||
input_channels_per_group_list = range(10, 12)
|
||||
width_list = range(10, 12)
|
||||
output_channels_per_group_list = range(10, 12)
|
||||
groups_list = range(1, 3)
|
||||
kernel_list = range(1, 4)
|
||||
stride_list = range(1, 3)
|
||||
padding_list = range(0, 3)
|
||||
dilation_list = range(1, 3)
|
||||
|
||||
for hparams in itertools.product(batch_size_list,
|
||||
input_channels_per_group_list,
|
||||
width_list,
|
||||
output_channels_per_group_list,
|
||||
groups_list,
|
||||
kernel_list,
|
||||
stride_list,
|
||||
padding_list,
|
||||
dilation_list):
|
||||
batch_size, input_channels_per_group, width, output_channels_per_group, \
|
||||
groups, kernel, stride, padding, dilation = hparams
|
||||
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
|
||||
conv_bias_shape = (output_channels)
|
||||
|
||||
class Conv1D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv1D, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv1d(x, self.weight, self.bias,
|
||||
self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
data_shape = (batch_size, input_channels, width)
|
||||
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
|
||||
"Tensor = aten::conv2d": 1}
|
||||
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
|
||||
"Tensor = aten::conv2d": -1,
|
||||
"prepacked::conv2d_clamp_prepack" : -1,
|
||||
"prepacked::conv2d_clamp_run": 1}
|
||||
|
||||
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Conv1D(),
|
||||
pattern_count_transformed_map,
|
||||
pattern_count_optimized_map,
|
||||
data_shape)
|
||||
|
||||
def test_conv1d_with_relu_fc(self):
|
||||
batch_size_list = range(1, 3)
|
||||
input_channels_per_group_list = range(10, 12)
|
||||
width_list = range(10, 12)
|
||||
output_channels_per_group_list = range(10, 12)
|
||||
groups_list = range(1, 3)
|
||||
kernel_list = range(1, 4)
|
||||
stride_list = range(1, 3)
|
||||
padding_list = range(0, 3)
|
||||
dilation_list = range(1, 3)
|
||||
output_features_list = range(1, 3)
|
||||
|
||||
for hparams in itertools.product(batch_size_list,
|
||||
input_channels_per_group_list,
|
||||
width_list,
|
||||
output_channels_per_group_list,
|
||||
groups_list,
|
||||
kernel_list,
|
||||
stride_list,
|
||||
padding_list,
|
||||
dilation_list,
|
||||
output_features_list):
|
||||
batch_size, input_channels_per_group, width, output_channels_per_group, \
|
||||
groups, kernel, stride, padding, dilation, output_features = hparams
|
||||
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
|
||||
conv_bias_shape = (output_channels)
|
||||
conv_output_width = int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
|
||||
fc_weight_shape = (output_features, output_channels * conv_output_width)
|
||||
fc_bias_shape = (output_features)
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
|
||||
self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
|
||||
self.fc_weight = torch.nn.Parameter(torch.Tensor(torch.rand(fc_weight_shape)), requires_grad=False)
|
||||
self.fc_bias = torch.nn.Parameter(torch.Tensor(torch.rand(fc_bias_shape)), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.conv1d(x, self.conv_weight, self.conv_bias,
|
||||
self.stride, self.padding, self.dilation, self.groups)
|
||||
x = F.relu(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.linear(x, self.fc_weight, self.fc_bias)
|
||||
return x
|
||||
|
||||
data_shape = (batch_size, input_channels, width)
|
||||
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
|
||||
"Tensor = aten::conv2d": 1}
|
||||
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
|
||||
"Tensor = aten::conv2d": -1,
|
||||
"prepacked::conv2d_clamp_prepack" : -1,
|
||||
"prepacked::conv2d_clamp_run": 1}
|
||||
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Net(),
|
||||
pattern_count_transformed_map,
|
||||
pattern_count_optimized_map,
|
||||
data_shape)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,55 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
|
||||
std::string conv_1d_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
||||
%r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv_2d_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
||||
%zero : int = prim::Constant[value=0]()
|
||||
%one : int = prim::Constant[value=1]()
|
||||
%stride_w : int = prim::ListUnpack(%stride)
|
||||
%stride_2d : int[] = prim::ListConstruct(%one, %stride_w)
|
||||
%padding_w : int = prim::ListUnpack(%padding)
|
||||
%padding_2d : int[] = prim::ListConstruct(%zero, %padding_w)
|
||||
%dilation_w : int = prim::ListUnpack(%dilation)
|
||||
%dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w)
|
||||
%two : int = prim::Constant[value=2]()
|
||||
%input_2d : Tensor = aten::unsqueeze(%input, %two)
|
||||
%weight_2d : Tensor = aten::unsqueeze(%weight, %two)
|
||||
%output_2d = aten::conv2d(
|
||||
%input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups)
|
||||
%output : Tensor = aten::squeeze(%output_2d, %two)
|
||||
return (%output) )";
|
||||
|
||||
SubgraphRewriter rewriter;
|
||||
rewriter.RegisterRewritePattern(conv_1d_pattern, conv_2d_pattern);
|
||||
rewriter.runOnGraph(graph);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void transformConv1dToConv2d(std::shared_ptr<Graph>& graph) {
|
||||
// Replace _convolution with conv1d and conv2d
|
||||
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
|
||||
replaceConv1dWithConv2d(graph);
|
||||
}
|
||||
|
||||
void transformConv1dToConv2d(script::Module& module) {
|
||||
for (auto& method : module.get_methods()) {
|
||||
auto graph = method.graph();
|
||||
transformConv1dToConv2d(graph);
|
||||
}
|
||||
for (script::Module m : module.children()) {
|
||||
transformConv1dToConv2d(m);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
namespace {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ enum class MobileOptimizerType : int8_t {
|
|||
HOIST_CONV_PACKED_PARAMS,
|
||||
};
|
||||
|
||||
TORCH_API void transformConv1dToConv2d(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void transformConv1dToConv2d(script::Module& module);
|
||||
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void insertPrePackedOps(script::Module& module);
|
||||
TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
|
||||
|
|
|
|||
|
|
@ -562,6 +562,16 @@ void initJITBindings(PyObject* module) {
|
|||
.def(
|
||||
"_jit_pass_remove_dropout",
|
||||
[](script::Module& module) { return removeDropout(module); })
|
||||
.def(
|
||||
"_jit_pass_transform_conv1d_to_conv2d",
|
||||
[](std::shared_ptr<Graph>& graph) {
|
||||
return transformConv1dToConv2d(graph);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_transform_conv1d_to_conv2d",
|
||||
[](script::Module& module) {
|
||||
return transformConv1dToConv2d(module);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_insert_prepacked_ops",
|
||||
[](std::shared_ptr<Graph>& graph) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user