TorchDynamo: Add convolution unary fusion for cpu in inference mode (#87063)

cc @jansel @lezcano @fdrocha @mlazos @soumith @voznesenskym @yanboliang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87063
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
XiaobingSuper 2022-10-26 01:46:46 -04:00 committed by PyTorch MergeBot
parent b16b5fb802
commit c36db82e12
5 changed files with 389 additions and 0 deletions

View File

@ -3,6 +3,7 @@ import contextlib
import dataclasses
import functools
import importlib
import itertools
import os
import random
import sys
@ -1292,6 +1293,65 @@ class CommonTemplate:
check_lowp=False,
)
# For gpu path, there has a accurcy issue,
# see https://github.com/pytorch/pytorch/issues/87745.
@unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test")
def test_conv2d_unary(self):
def _unary_list():
unary_list = [
torch.nn.ReLU(),
torch.nn.Sigmoid(),
torch.nn.Tanh(),
torch.nn.Hardswish(),
torch.nn.LeakyReLU(0.1, inplace=False),
torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False),
torch.nn.GELU(approximate="none"),
torch.nn.GELU(approximate="tanh"),
]
return unary_list
test_memory_format = [torch.contiguous_format, torch.channels_last]
options = itertools.product(
_unary_list(),
[True, False],
[1, 3],
[1, 2],
[1, 4],
test_memory_format,
)
for (
unary_fn,
bias,
kernel_size,
dilation,
groups,
memory_format,
) in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC, 112, 112)
mod = torch.nn.Sequential(
torch.nn.Conv2d(
iC,
oC,
kernel_size=kernel_size,
dilation=dilation,
groups=groups,
bias=bias,
),
unary_fn,
).eval()
# TODO: add bf16 test for cpu path?
v = torch.randn(x_shape, dtype=torch.float32).to(
memory_format=memory_format
)
self.common(
mod,
(v,),
)
def test_gather1(self):
def fn(a, b):
return (

View File

@ -340,6 +340,7 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]
with overrides.patch_functions():
model_ = normalize_ir(model_, example_inputs_)
model_ = overrides.replace_fx(model_)
model_ = overrides.fuse_fx(model_, example_inputs_)
num_example_inputs = len(example_inputs_)
cudagraphs = BoxedBool(config.triton.cudagraphs and not config.dynamic_shapes)

View File

@ -3295,6 +3295,152 @@ class Convolution(ExternKernelAlloc):
)
def _prepare_convolution_fusion_create(
cls,
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups: int,
):
"""
This function is a helper function to prepare inputs, layout and constant args
for convolution post-op fusion's create function, including deciding the output
layout (channels first or channels last), realizing inputs and make them etc. The
function only supports the CPU device since conv post-op fusion kernel is only
supported on CPU right now.
"""
x = cls.require_stride1(cls.realize_input(x))
weight = cls.require_stride1(cls.realize_input(weight))
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
inputs = [x, weight]
stride = tuple(stride_)
padding = tuple(padding_)
dilation = tuple(dilation_)
assert isinstance(groups, int)
weight_shape = [
sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size()
]
out_channels, in_channels1, *kernel_size = weight_shape
in_channels1 = in_channels1 * groups
assert len(x.get_size()) == 2 + len(kernel_size)
batch, in_channels2, *input_size = x.get_size()
output_size = [batch]
V.graph.sizevars.guard_equals(in_channels1, in_channels2)
output_size.append(out_channels)
assert (
len(stride)
== len(padding)
== len(dilation)
== len(kernel_size)
== len(input_size)
)
for i in range(len(stride)):
output_size.append(
IndexingDiv(
input_size[i]
+ 2 * padding[i]
- dilation[i] * (kernel_size[i] - 1)
- 1
+ stride[i],
stride[i],
)
)
output_size[-1] = sympy.Integer(
V.graph.sizevars.guard_static_shape(output_size[-1])
)
output_layout_str = "torch.contiguous_format"
# If x or weight have one channels_last(2d or 3d) format, it will call channels_last path,
# which align with aten.convolutuion path(cpu only support 2d case now).
# TODO: after cpu 3d convolution support channels_last path, the size check can be removed.
if len(x.get_size()) == 4 and (
x.get_layout().is_channels_last_stride_ordered()
or weight.get_layout().is_channels_last_stride_ordered()
):
output_layout_str = "torch.channels_last"
if output_layout_str == "torch.channels_last":
stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))
if len(stride_order) < len(output_size):
# add batch dim if it exists
stride_order = [len(stride_order)] + stride_order
else:
stride_order = list(reversed(range(len(output_size))))
kernel_layout = FlexibleLayout(
device=inputs[0].get_device(),
dtype=inputs[0].get_dtype(),
size=output_size,
stride_order=stride_order,
)
constant_args = [padding, stride, dilation, groups]
if bias is not None:
inputs.append(bias)
else:
constant_args.insert(0, bias)
return inputs, constant_args, kernel_layout
class ConvolutionUnary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise"
def __init__(
self,
layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_pointwise",
):
super().__init__(layout, inputs, constant_args)
self.kernel = kernel
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
)
@classmethod
def create(
cls,
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups: int,
attr,
scalars,
algorithm,
):
kernel = "torch.ops.mkldnn._convolution_pointwise"
(inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
constant_args = constant_args + [attr, scalars, algorithm]
return ConvolutionUnary(
layout=kernel_layout,
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)
def apply_constraint(self):
x = self.inputs[0]
# FixedLayout of input
x = self.require_stride_order(x, self.layout.preferred_stride_order)
self.inputs[0] = x
self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)
@dataclasses.dataclass
class MutableBox(IRNode):
"""

View File

@ -886,6 +886,44 @@ def bmm(a: TensorBox, b: TensorBox):
return TensorBox.create(ir.BatchMatrixMultiply.create(a, b))
def register_onednn_fusion_ops():
if torch._C.has_mkldnn:
@register_lowering(torch.ops.mkldnn._convolution_pointwise)
def convolution_unary(
x: TensorBox,
weight: TensorBox,
bias: TensorBox,
padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
):
return TensorBox.create(
ir.ConvolutionUnary.create(
x,
weight,
bias,
padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
)
)
else:
pass
register_onednn_fusion_ops()
def fallback_handler(kernel):
fallbacks.add(kernel)

View File

@ -1,10 +1,19 @@
import copy
import itertools
import logging
import random
import weakref
import torch
import torch.nn as nn
from torch import _prims
from torch.fx.experimental.optimization import (
matches_module_pattern,
replace_node_module,
)
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from torch.overrides import TorchFunctionMode
log = logging.getLogger(__name__)
@ -37,6 +46,127 @@ def replace_fx(gm: torch.fx.GraphModule):
return gm
class UnaryAttr(object):
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
self.op_name = op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
super(UnaryAttr, self).__init__()
def __call__(self, unary_module: nn.Module):
assert all(hasattr(unary_module, item) for item in self.scalars_attr)
scalars = [getattr(unary_module, item) for item in self.scalars_attr]
algorithm = ""
if self.algorithm_attr:
assert hasattr(unary_module, self.algorithm_attr)
algorithm = getattr(unary_module, self.algorithm_attr)
return self.op_name, scalars, algorithm
class ConvUnary2d(nn.Conv2d):
def __init__(
self,
conv: nn.Module,
unary: nn.Module,
):
super(ConvUnary2d, self).__init__(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
conv.bias is not None,
conv.padding_mode,
conv.weight.device,
conv.weight.dtype,
)
self._update_module_params(conv, unary)
def _update_module_params(self, conv, unary):
self.__dict__ = copy.deepcopy(conv.__dict__)
self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
unary
)
def _conv_forward(self, input, weight, bias):
if self.padding_mode != "zeros":
return torch.ops.mkldnn._convolution_pointwise(
F.pad(
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
_pair(0),
self.stride,
self.dilation,
self.groups,
self.attr,
self.scalars,
self.algorithm,
)
return torch.ops.mkldnn._convolution_pointwise(
input,
weight,
bias,
self.padding,
self.stride,
self.dilation,
self.groups,
self.attr,
self.scalars,
self.algorithm,
)
def forward(self, input):
return self._conv_forward(input, self.weight, self.bias)
def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module):
assert not (conv.training), "Fusion only for eval!"
return ConvUnary2d(
conv,
unary,
)
def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
return gm
is_cpu = all(
example_input.device == torch.device("cpu") for example_input in example_inputs
)
if not is_cpu:
return gm
modules = dict(gm.named_modules())
for (unary_module, _), (computation_module, fuse_func,) in itertools.product(
unary_modules_map.items(), computation_op_unary_op_fusion_map.items()
):
pattern = (computation_module, unary_module)
for node in gm.graph.nodes:
if matches_module_pattern(pattern, node, modules):
if (
len(node.args[0].users) > 1
): # Output of computation_node is used by other nodes
continue
conv = modules[node.args[0].target]
unary_node = modules[node.target]
eval_mode = all(not n.training for n in [conv, unary_node])
if not eval_mode:
continue
fused_conv = fuse_func(conv, unary_node)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
gm.graph.lint()
gm.recompile()
return gm
def _philox_rand_like_meta(input, seed, offset):
return _prims.TensorMeta(input)
@ -163,3 +293,17 @@ def rand_like(x, **kwargs):
replacements = {torch.nn.functional.dropout: lowmem_dropout, torch.rand_like: rand_like}
computation_op_unary_op_fusion_map = {nn.Conv2d: fused_conv_unary_eval}
unary_modules_map = {
nn.ReLU: UnaryAttr("relu"),
nn.Sigmoid: UnaryAttr("sigmoid"),
nn.Tanh: UnaryAttr("tanh"),
nn.Hardswish: UnaryAttr("hardswish"),
nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]),
nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"),
}