mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
TorchDynamo: Add convolution unary fusion for cpu in inference mode (#87063)
cc @jansel @lezcano @fdrocha @mlazos @soumith @voznesenskym @yanboliang Pull Request resolved: https://github.com/pytorch/pytorch/pull/87063 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
b16b5fb802
commit
c36db82e12
|
|
@ -3,6 +3,7 @@ import contextlib
|
|||
import dataclasses
|
||||
import functools
|
||||
import importlib
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
|
@ -1292,6 +1293,65 @@ class CommonTemplate:
|
|||
check_lowp=False,
|
||||
)
|
||||
|
||||
# For gpu path, there has a accurcy issue,
|
||||
# see https://github.com/pytorch/pytorch/issues/87745.
|
||||
@unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test")
|
||||
def test_conv2d_unary(self):
|
||||
def _unary_list():
|
||||
unary_list = [
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Tanh(),
|
||||
torch.nn.Hardswish(),
|
||||
torch.nn.LeakyReLU(0.1, inplace=False),
|
||||
torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False),
|
||||
torch.nn.GELU(approximate="none"),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
]
|
||||
return unary_list
|
||||
|
||||
test_memory_format = [torch.contiguous_format, torch.channels_last]
|
||||
options = itertools.product(
|
||||
_unary_list(),
|
||||
[True, False],
|
||||
[1, 3],
|
||||
[1, 2],
|
||||
[1, 4],
|
||||
test_memory_format,
|
||||
)
|
||||
|
||||
for (
|
||||
unary_fn,
|
||||
bias,
|
||||
kernel_size,
|
||||
dilation,
|
||||
groups,
|
||||
memory_format,
|
||||
) in options:
|
||||
oC = 32 * groups
|
||||
iC = 3 * groups
|
||||
x_shape = (1, iC, 112, 112)
|
||||
mod = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(
|
||||
iC,
|
||||
oC,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
),
|
||||
unary_fn,
|
||||
).eval()
|
||||
|
||||
# TODO: add bf16 test for cpu path?
|
||||
v = torch.randn(x_shape, dtype=torch.float32).to(
|
||||
memory_format=memory_format
|
||||
)
|
||||
self.common(
|
||||
mod,
|
||||
(v,),
|
||||
)
|
||||
|
||||
def test_gather1(self):
|
||||
def fn(a, b):
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -340,6 +340,7 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]
|
|||
with overrides.patch_functions():
|
||||
model_ = normalize_ir(model_, example_inputs_)
|
||||
model_ = overrides.replace_fx(model_)
|
||||
model_ = overrides.fuse_fx(model_, example_inputs_)
|
||||
num_example_inputs = len(example_inputs_)
|
||||
cudagraphs = BoxedBool(config.triton.cudagraphs and not config.dynamic_shapes)
|
||||
|
||||
|
|
|
|||
|
|
@ -3295,6 +3295,152 @@ class Convolution(ExternKernelAlloc):
|
|||
)
|
||||
|
||||
|
||||
def _prepare_convolution_fusion_create(
|
||||
cls,
|
||||
x: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding_: List[int],
|
||||
stride_: List[int],
|
||||
dilation_: List[int],
|
||||
groups: int,
|
||||
):
|
||||
"""
|
||||
This function is a helper function to prepare inputs, layout and constant args
|
||||
for convolution post-op fusion's create function, including deciding the output
|
||||
layout (channels first or channels last), realizing inputs and make them etc. The
|
||||
function only supports the CPU device since conv post-op fusion kernel is only
|
||||
supported on CPU right now.
|
||||
"""
|
||||
|
||||
x = cls.require_stride1(cls.realize_input(x))
|
||||
weight = cls.require_stride1(cls.realize_input(weight))
|
||||
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
|
||||
inputs = [x, weight]
|
||||
stride = tuple(stride_)
|
||||
padding = tuple(padding_)
|
||||
dilation = tuple(dilation_)
|
||||
assert isinstance(groups, int)
|
||||
|
||||
weight_shape = [
|
||||
sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size()
|
||||
]
|
||||
|
||||
out_channels, in_channels1, *kernel_size = weight_shape
|
||||
in_channels1 = in_channels1 * groups
|
||||
assert len(x.get_size()) == 2 + len(kernel_size)
|
||||
batch, in_channels2, *input_size = x.get_size()
|
||||
output_size = [batch]
|
||||
V.graph.sizevars.guard_equals(in_channels1, in_channels2)
|
||||
|
||||
output_size.append(out_channels)
|
||||
assert (
|
||||
len(stride)
|
||||
== len(padding)
|
||||
== len(dilation)
|
||||
== len(kernel_size)
|
||||
== len(input_size)
|
||||
)
|
||||
for i in range(len(stride)):
|
||||
output_size.append(
|
||||
IndexingDiv(
|
||||
input_size[i]
|
||||
+ 2 * padding[i]
|
||||
- dilation[i] * (kernel_size[i] - 1)
|
||||
- 1
|
||||
+ stride[i],
|
||||
stride[i],
|
||||
)
|
||||
)
|
||||
output_size[-1] = sympy.Integer(
|
||||
V.graph.sizevars.guard_static_shape(output_size[-1])
|
||||
)
|
||||
|
||||
output_layout_str = "torch.contiguous_format"
|
||||
# If x or weight have one channels_last(2d or 3d) format, it will call channels_last path,
|
||||
# which align with aten.convolutuion path(cpu only support 2d case now).
|
||||
# TODO: after cpu 3d convolution support channels_last path, the size check can be removed.
|
||||
if len(x.get_size()) == 4 and (
|
||||
x.get_layout().is_channels_last_stride_ordered()
|
||||
or weight.get_layout().is_channels_last_stride_ordered()
|
||||
):
|
||||
output_layout_str = "torch.channels_last"
|
||||
|
||||
if output_layout_str == "torch.channels_last":
|
||||
stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))
|
||||
if len(stride_order) < len(output_size):
|
||||
# add batch dim if it exists
|
||||
stride_order = [len(stride_order)] + stride_order
|
||||
else:
|
||||
stride_order = list(reversed(range(len(output_size))))
|
||||
|
||||
kernel_layout = FlexibleLayout(
|
||||
device=inputs[0].get_device(),
|
||||
dtype=inputs[0].get_dtype(),
|
||||
size=output_size,
|
||||
stride_order=stride_order,
|
||||
)
|
||||
constant_args = [padding, stride, dilation, groups]
|
||||
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
else:
|
||||
constant_args.insert(0, bias)
|
||||
return inputs, constant_args, kernel_layout
|
||||
|
||||
|
||||
class ConvolutionUnary(ExternKernelAlloc):
|
||||
kernel = "torch.ops.mkldnn._convolution_pointwise"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layout,
|
||||
inputs,
|
||||
constant_args=(),
|
||||
kernel="torch.ops.mkldnn._convolution_pointwise",
|
||||
):
|
||||
super().__init__(layout, inputs, constant_args)
|
||||
self.kernel = kernel
|
||||
|
||||
def codegen(self, wrapper):
|
||||
wrapper.writeline(
|
||||
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
x: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding_: List[int],
|
||||
stride_: List[int],
|
||||
dilation_: List[int],
|
||||
groups: int,
|
||||
attr,
|
||||
scalars,
|
||||
algorithm,
|
||||
):
|
||||
kernel = "torch.ops.mkldnn._convolution_pointwise"
|
||||
(inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create(
|
||||
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
||||
)
|
||||
constant_args = constant_args + [attr, scalars, algorithm]
|
||||
return ConvolutionUnary(
|
||||
layout=kernel_layout,
|
||||
inputs=inputs,
|
||||
constant_args=constant_args,
|
||||
kernel=kernel,
|
||||
)
|
||||
|
||||
def apply_constraint(self):
|
||||
x = self.inputs[0]
|
||||
# FixedLayout of input
|
||||
x = self.require_stride_order(x, self.layout.preferred_stride_order)
|
||||
self.inputs[0] = x
|
||||
self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MutableBox(IRNode):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -886,6 +886,44 @@ def bmm(a: TensorBox, b: TensorBox):
|
|||
return TensorBox.create(ir.BatchMatrixMultiply.create(a, b))
|
||||
|
||||
|
||||
def register_onednn_fusion_ops():
|
||||
if torch._C.has_mkldnn:
|
||||
|
||||
@register_lowering(torch.ops.mkldnn._convolution_pointwise)
|
||||
def convolution_unary(
|
||||
x: TensorBox,
|
||||
weight: TensorBox,
|
||||
bias: TensorBox,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
attr,
|
||||
scalars,
|
||||
algorithm,
|
||||
):
|
||||
return TensorBox.create(
|
||||
ir.ConvolutionUnary.create(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
attr,
|
||||
scalars,
|
||||
algorithm,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
register_onednn_fusion_ops()
|
||||
|
||||
|
||||
def fallback_handler(kernel):
|
||||
fallbacks.add(kernel)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,19 @@
|
|||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import _prims
|
||||
from torch.fx.experimental.optimization import (
|
||||
matches_module_pattern,
|
||||
replace_node_module,
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -37,6 +46,127 @@ def replace_fx(gm: torch.fx.GraphModule):
|
|||
return gm
|
||||
|
||||
|
||||
class UnaryAttr(object):
|
||||
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
|
||||
self.op_name = op_name
|
||||
self.scalars_attr = scalars_attr if scalars_attr else []
|
||||
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
||||
super(UnaryAttr, self).__init__()
|
||||
|
||||
def __call__(self, unary_module: nn.Module):
|
||||
assert all(hasattr(unary_module, item) for item in self.scalars_attr)
|
||||
scalars = [getattr(unary_module, item) for item in self.scalars_attr]
|
||||
|
||||
algorithm = ""
|
||||
if self.algorithm_attr:
|
||||
assert hasattr(unary_module, self.algorithm_attr)
|
||||
algorithm = getattr(unary_module, self.algorithm_attr)
|
||||
|
||||
return self.op_name, scalars, algorithm
|
||||
|
||||
|
||||
class ConvUnary2d(nn.Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
conv: nn.Module,
|
||||
unary: nn.Module,
|
||||
):
|
||||
super(ConvUnary2d, self).__init__(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
conv.kernel_size,
|
||||
conv.stride,
|
||||
conv.padding,
|
||||
conv.dilation,
|
||||
conv.groups,
|
||||
conv.bias is not None,
|
||||
conv.padding_mode,
|
||||
conv.weight.device,
|
||||
conv.weight.dtype,
|
||||
)
|
||||
self._update_module_params(conv, unary)
|
||||
|
||||
def _update_module_params(self, conv, unary):
|
||||
self.__dict__ = copy.deepcopy(conv.__dict__)
|
||||
self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
|
||||
unary
|
||||
)
|
||||
|
||||
def _conv_forward(self, input, weight, bias):
|
||||
if self.padding_mode != "zeros":
|
||||
return torch.ops.mkldnn._convolution_pointwise(
|
||||
F.pad(
|
||||
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
|
||||
),
|
||||
weight,
|
||||
bias,
|
||||
_pair(0),
|
||||
self.stride,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
self.attr,
|
||||
self.scalars,
|
||||
self.algorithm,
|
||||
)
|
||||
return torch.ops.mkldnn._convolution_pointwise(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
self.padding,
|
||||
self.stride,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
self.attr,
|
||||
self.scalars,
|
||||
self.algorithm,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self._conv_forward(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module):
|
||||
assert not (conv.training), "Fusion only for eval!"
|
||||
return ConvUnary2d(
|
||||
conv,
|
||||
unary,
|
||||
)
|
||||
|
||||
|
||||
def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
|
||||
if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
|
||||
return gm
|
||||
is_cpu = all(
|
||||
example_input.device == torch.device("cpu") for example_input in example_inputs
|
||||
)
|
||||
if not is_cpu:
|
||||
return gm
|
||||
modules = dict(gm.named_modules())
|
||||
|
||||
for (unary_module, _), (computation_module, fuse_func,) in itertools.product(
|
||||
unary_modules_map.items(), computation_op_unary_op_fusion_map.items()
|
||||
):
|
||||
pattern = (computation_module, unary_module)
|
||||
for node in gm.graph.nodes:
|
||||
if matches_module_pattern(pattern, node, modules):
|
||||
if (
|
||||
len(node.args[0].users) > 1
|
||||
): # Output of computation_node is used by other nodes
|
||||
continue
|
||||
conv = modules[node.args[0].target]
|
||||
unary_node = modules[node.target]
|
||||
eval_mode = all(not n.training for n in [conv, unary_node])
|
||||
if not eval_mode:
|
||||
continue
|
||||
fused_conv = fuse_func(conv, unary_node)
|
||||
replace_node_module(node.args[0], modules, fused_conv)
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def _philox_rand_like_meta(input, seed, offset):
|
||||
return _prims.TensorMeta(input)
|
||||
|
||||
|
|
@ -163,3 +293,17 @@ def rand_like(x, **kwargs):
|
|||
|
||||
|
||||
replacements = {torch.nn.functional.dropout: lowmem_dropout, torch.rand_like: rand_like}
|
||||
|
||||
|
||||
computation_op_unary_op_fusion_map = {nn.Conv2d: fused_conv_unary_eval}
|
||||
|
||||
|
||||
unary_modules_map = {
|
||||
nn.ReLU: UnaryAttr("relu"),
|
||||
nn.Sigmoid: UnaryAttr("sigmoid"),
|
||||
nn.Tanh: UnaryAttr("tanh"),
|
||||
nn.Hardswish: UnaryAttr("hardswish"),
|
||||
nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]),
|
||||
nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
|
||||
nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"),
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user