mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Replace Conv1d with Conv2d (#42867)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42867 Test Plan: Imported from OSS Reviewed By: kimishpatel Differential Revision: D23177916 Pulled By: kimishpatel fbshipit-source-id: 68cc40cf42d03e5b8432dc08f9933a4409c76e25
This commit is contained in:
parent
e8139624f2
commit
665da61d2b
|
|
@ -5,12 +5,14 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.xnnpack
|
import torch.backends.xnnpack
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
import torch.testing._internal.hypothesis_utils as hu
|
import torch.testing._internal.hypothesis_utils as hu
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
from hypothesis import given, assume
|
from hypothesis import given, assume
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
import io
|
import io
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||||
|
|
@ -704,6 +706,174 @@ class TestXNNPACKRewritePass(TestCase):
|
||||||
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
|
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmulAdd(), pattern_count_map, data_shape)
|
||||||
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
|
TestXNNPACKRewritePass.validate_transformed_module(DecomposedLinearMatmul(), pattern_count_map, data_shape)
|
||||||
|
|
||||||
|
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||||
|
" XNNPACK must be enabled for these tests."
|
||||||
|
" Please build with USE_XNNPACK=1.")
|
||||||
|
class TestXNNPACKConv1dTransformPass(TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def validate_transform_conv1d_to_conv2d(
|
||||||
|
self,
|
||||||
|
pattern_count_transformed_map,
|
||||||
|
pattern_count_optimized_map,
|
||||||
|
data_shape):
|
||||||
|
module_instance = self
|
||||||
|
scripted_model = torch.jit.script(module_instance)
|
||||||
|
scripted_model.eval()
|
||||||
|
input_data = torch.normal(1, 20, size=data_shape)
|
||||||
|
ref_result = scripted_model(input_data)
|
||||||
|
torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
|
||||||
|
optimized_scripted_model = optimize_for_mobile(scripted_model)
|
||||||
|
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.jit.save(scripted_model, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
deserialized_scripted_model = torch.jit.load(buffer)
|
||||||
|
|
||||||
|
for pattern, v in pattern_count_transformed_map.items():
|
||||||
|
if (v == 0):
|
||||||
|
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
||||||
|
elif (v == -1):
|
||||||
|
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
||||||
|
else:
|
||||||
|
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
|
||||||
|
transformed_result = deserialized_scripted_model(input_data)
|
||||||
|
torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3)
|
||||||
|
|
||||||
|
optimized_buffer = io.BytesIO()
|
||||||
|
torch.jit.save(optimized_scripted_model, optimized_buffer)
|
||||||
|
optimized_buffer.seek(0)
|
||||||
|
deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
|
||||||
|
|
||||||
|
for pattern, v in pattern_count_optimized_map.items():
|
||||||
|
if (v == 0):
|
||||||
|
FileCheck().check(pattern).run(deserialized_optimized_scripted_model.graph)
|
||||||
|
elif (v == -1):
|
||||||
|
FileCheck().check_not(pattern).run(deserialized_optimized_scripted_model.graph)
|
||||||
|
else:
|
||||||
|
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph)
|
||||||
|
xnnpack_result = deserialized_optimized_scripted_model(input_data)
|
||||||
|
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||||
|
|
||||||
|
def test_conv1d_basic(self):
|
||||||
|
batch_size_list = range(1, 3)
|
||||||
|
input_channels_per_group_list = range(10, 12)
|
||||||
|
width_list = range(10, 12)
|
||||||
|
output_channels_per_group_list = range(10, 12)
|
||||||
|
groups_list = range(1, 3)
|
||||||
|
kernel_list = range(1, 4)
|
||||||
|
stride_list = range(1, 3)
|
||||||
|
padding_list = range(0, 3)
|
||||||
|
dilation_list = range(1, 3)
|
||||||
|
|
||||||
|
for hparams in itertools.product(batch_size_list,
|
||||||
|
input_channels_per_group_list,
|
||||||
|
width_list,
|
||||||
|
output_channels_per_group_list,
|
||||||
|
groups_list,
|
||||||
|
kernel_list,
|
||||||
|
stride_list,
|
||||||
|
padding_list,
|
||||||
|
dilation_list):
|
||||||
|
batch_size, input_channels_per_group, width, output_channels_per_group, \
|
||||||
|
groups, kernel, stride, padding, dilation = hparams
|
||||||
|
|
||||||
|
input_channels = input_channels_per_group * groups
|
||||||
|
output_channels = output_channels_per_group * groups
|
||||||
|
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
|
||||||
|
conv_bias_shape = (output_channels)
|
||||||
|
|
||||||
|
class Conv1D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Conv1D, self).__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
|
||||||
|
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.conv1d(x, self.weight, self.bias,
|
||||||
|
self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
data_shape = (batch_size, input_channels, width)
|
||||||
|
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
|
||||||
|
"Tensor = aten::conv2d": 1}
|
||||||
|
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
|
||||||
|
"Tensor = aten::conv2d": -1,
|
||||||
|
"prepacked::conv2d_clamp_prepack" : -1,
|
||||||
|
"prepacked::conv2d_clamp_run": 1}
|
||||||
|
|
||||||
|
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Conv1D(),
|
||||||
|
pattern_count_transformed_map,
|
||||||
|
pattern_count_optimized_map,
|
||||||
|
data_shape)
|
||||||
|
|
||||||
|
def test_conv1d_with_relu_fc(self):
|
||||||
|
batch_size_list = range(1, 3)
|
||||||
|
input_channels_per_group_list = range(10, 12)
|
||||||
|
width_list = range(10, 12)
|
||||||
|
output_channels_per_group_list = range(10, 12)
|
||||||
|
groups_list = range(1, 3)
|
||||||
|
kernel_list = range(1, 4)
|
||||||
|
stride_list = range(1, 3)
|
||||||
|
padding_list = range(0, 3)
|
||||||
|
dilation_list = range(1, 3)
|
||||||
|
output_features_list = range(1, 3)
|
||||||
|
|
||||||
|
for hparams in itertools.product(batch_size_list,
|
||||||
|
input_channels_per_group_list,
|
||||||
|
width_list,
|
||||||
|
output_channels_per_group_list,
|
||||||
|
groups_list,
|
||||||
|
kernel_list,
|
||||||
|
stride_list,
|
||||||
|
padding_list,
|
||||||
|
dilation_list,
|
||||||
|
output_features_list):
|
||||||
|
batch_size, input_channels_per_group, width, output_channels_per_group, \
|
||||||
|
groups, kernel, stride, padding, dilation, output_features = hparams
|
||||||
|
|
||||||
|
input_channels = input_channels_per_group * groups
|
||||||
|
output_channels = output_channels_per_group * groups
|
||||||
|
conv_weight_shape = (output_channels, input_channels_per_group, kernel)
|
||||||
|
conv_bias_shape = (output_channels)
|
||||||
|
conv_output_width = int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
|
||||||
|
fc_weight_shape = (output_features, output_channels * conv_output_width)
|
||||||
|
fc_bias_shape = (output_features)
|
||||||
|
|
||||||
|
class Net(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
|
||||||
|
self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
self.fc_weight = torch.nn.Parameter(torch.Tensor(torch.rand(fc_weight_shape)), requires_grad=False)
|
||||||
|
self.fc_bias = torch.nn.Parameter(torch.Tensor(torch.rand(fc_bias_shape)), requires_grad=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.conv1d(x, self.conv_weight, self.conv_bias,
|
||||||
|
self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = F.linear(x, self.fc_weight, self.fc_bias)
|
||||||
|
return x
|
||||||
|
|
||||||
|
data_shape = (batch_size, input_channels, width)
|
||||||
|
pattern_count_transformed_map = {"Tensor = aten::conv1d": -1,
|
||||||
|
"Tensor = aten::conv2d": 1}
|
||||||
|
pattern_count_optimized_map = {"Tensor = aten::conv1d": -1,
|
||||||
|
"Tensor = aten::conv2d": -1,
|
||||||
|
"prepacked::conv2d_clamp_prepack" : -1,
|
||||||
|
"prepacked::conv2d_clamp_run": 1}
|
||||||
|
TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(Net(),
|
||||||
|
pattern_count_transformed_map,
|
||||||
|
pattern_count_optimized_map,
|
||||||
|
data_shape)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,55 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
|
||||||
|
std::string conv_1d_pattern = R"(
|
||||||
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
||||||
|
%r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
|
||||||
|
return (%r) )";
|
||||||
|
|
||||||
|
std::string conv_2d_pattern = R"(
|
||||||
|
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
||||||
|
%zero : int = prim::Constant[value=0]()
|
||||||
|
%one : int = prim::Constant[value=1]()
|
||||||
|
%stride_w : int = prim::ListUnpack(%stride)
|
||||||
|
%stride_2d : int[] = prim::ListConstruct(%one, %stride_w)
|
||||||
|
%padding_w : int = prim::ListUnpack(%padding)
|
||||||
|
%padding_2d : int[] = prim::ListConstruct(%zero, %padding_w)
|
||||||
|
%dilation_w : int = prim::ListUnpack(%dilation)
|
||||||
|
%dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w)
|
||||||
|
%two : int = prim::Constant[value=2]()
|
||||||
|
%input_2d : Tensor = aten::unsqueeze(%input, %two)
|
||||||
|
%weight_2d : Tensor = aten::unsqueeze(%weight, %two)
|
||||||
|
%output_2d = aten::conv2d(
|
||||||
|
%input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups)
|
||||||
|
%output : Tensor = aten::squeeze(%output_2d, %two)
|
||||||
|
return (%output) )";
|
||||||
|
|
||||||
|
SubgraphRewriter rewriter;
|
||||||
|
rewriter.RegisterRewritePattern(conv_1d_pattern, conv_2d_pattern);
|
||||||
|
rewriter.runOnGraph(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void transformConv1dToConv2d(std::shared_ptr<Graph>& graph) {
|
||||||
|
// Replace _convolution with conv1d and conv2d
|
||||||
|
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
|
||||||
|
replaceConv1dWithConv2d(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
void transformConv1dToConv2d(script::Module& module) {
|
||||||
|
for (auto& method : module.get_methods()) {
|
||||||
|
auto graph = method.graph();
|
||||||
|
transformConv1dToConv2d(graph);
|
||||||
|
}
|
||||||
|
for (script::Module m : module.children()) {
|
||||||
|
transformConv1dToConv2d(m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef USE_XNNPACK
|
#ifdef USE_XNNPACK
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ enum class MobileOptimizerType : int8_t {
|
||||||
HOIST_CONV_PACKED_PARAMS,
|
HOIST_CONV_PACKED_PARAMS,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TORCH_API void transformConv1dToConv2d(std::shared_ptr<Graph>& graph);
|
||||||
|
TORCH_API void transformConv1dToConv2d(script::Module& module);
|
||||||
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
|
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
|
||||||
TORCH_API void insertPrePackedOps(script::Module& module);
|
TORCH_API void insertPrePackedOps(script::Module& module);
|
||||||
TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
|
TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
|
||||||
|
|
|
||||||
|
|
@ -562,6 +562,16 @@ void initJITBindings(PyObject* module) {
|
||||||
.def(
|
.def(
|
||||||
"_jit_pass_remove_dropout",
|
"_jit_pass_remove_dropout",
|
||||||
[](script::Module& module) { return removeDropout(module); })
|
[](script::Module& module) { return removeDropout(module); })
|
||||||
|
.def(
|
||||||
|
"_jit_pass_transform_conv1d_to_conv2d",
|
||||||
|
[](std::shared_ptr<Graph>& graph) {
|
||||||
|
return transformConv1dToConv2d(graph);
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"_jit_pass_transform_conv1d_to_conv2d",
|
||||||
|
[](script::Module& module) {
|
||||||
|
return transformConv1dToConv2d(module);
|
||||||
|
})
|
||||||
.def(
|
.def(
|
||||||
"_jit_pass_insert_prepacked_ops",
|
"_jit_pass_insert_prepacked_ops",
|
||||||
[](std::shared_ptr<Graph>& graph) {
|
[](std::shared_ptr<Graph>& graph) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user