mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
For Conv+binary+unary fusion, we only support conv+add+relu, this PR adds a such check to fix TIMM failed models. TODO: enable more Conv+binary+unary fusion to improve TIMM models' performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90259 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/jansel
1279 lines
42 KiB
Python
1279 lines
42 KiB
Python
import copy
|
|
import itertools
|
|
import logging
|
|
import operator
|
|
import random
|
|
import weakref
|
|
from typing import Optional
|
|
|
|
import numpy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import _prims
|
|
from torch._dynamo.utils import fake_mode_from_tensors
|
|
from torch.fx.experimental.optimization import (
|
|
matches_module_pattern,
|
|
replace_node_module,
|
|
)
|
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
|
from torch.fx.passes.shape_prop import ShapeProp
|
|
from torch.nn import functional as F
|
|
from torch.nn.modules.utils import _pair
|
|
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
|
|
from torch.overrides import TorchFunctionMode
|
|
|
|
from . import config
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class AutogradMonkeypatch(TorchFunctionMode):
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
if func is replacements:
|
|
return replacements[func](*args, **kwargs)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
patch_functions = AutogradMonkeypatch
|
|
|
|
|
|
def replace_fx(gm: torch.fx.GraphModule):
|
|
# Sometimes patch_functions() misses things already in the graph
|
|
for node in reversed(list(gm.graph.nodes)):
|
|
if node.op == "call_function" and node.target in replacements:
|
|
if (
|
|
config.fallback_random
|
|
and replacements[node.target] in replacements_using_triton_random
|
|
):
|
|
continue
|
|
with gm.graph.inserting_before(node):
|
|
node.replace_all_uses_with(
|
|
gm.graph.call_function(
|
|
replacements[node.target], node.args, node.kwargs
|
|
)
|
|
)
|
|
gm.graph.erase_node(node)
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
class UnaryAttr(object):
|
|
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
|
|
self.op_name = op_name
|
|
self.scalars_attr = scalars_attr if scalars_attr else []
|
|
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
|
|
super(UnaryAttr, self).__init__()
|
|
|
|
def __call__(self, unary_module: nn.Module):
|
|
if type(unary_module) is nn.ReLU6:
|
|
unary_module = nn.Hardtanh(min_val=0, max_val=6)
|
|
assert all(hasattr(unary_module, item) for item in self.scalars_attr)
|
|
scalars = [getattr(unary_module, item) for item in self.scalars_attr]
|
|
|
|
algorithm = ""
|
|
if self.algorithm_attr:
|
|
assert hasattr(unary_module, self.algorithm_attr)
|
|
algorithm = getattr(unary_module, self.algorithm_attr)
|
|
|
|
return self.op_name, scalars, algorithm
|
|
|
|
|
|
class ConvUnary2d(nn.Conv2d):
|
|
def __init__(
|
|
self,
|
|
conv: nn.Module,
|
|
unary: Optional[nn.Module],
|
|
input_size: list,
|
|
):
|
|
super(ConvUnary2d, self).__init__(
|
|
conv.in_channels,
|
|
conv.out_channels,
|
|
conv.kernel_size,
|
|
conv.stride,
|
|
conv.padding,
|
|
conv.dilation,
|
|
conv.groups,
|
|
conv.bias is not None,
|
|
conv.padding_mode,
|
|
conv.weight.device,
|
|
conv.weight.dtype,
|
|
)
|
|
self._update_module_params(conv, unary, input_size)
|
|
|
|
def _update_module_params(self, conv, unary, input_size):
|
|
self.__dict__ = copy.deepcopy(conv.__dict__)
|
|
self.attr = "none"
|
|
self.scalars = []
|
|
self.algorithm = ""
|
|
if unary is not None:
|
|
self.attr, self.scalars, self.algorithm = unary_modules_map[
|
|
unary.__class__
|
|
](unary)
|
|
self.weight = torch.nn.Parameter(
|
|
torch._C._nn.mkldnn_reorder_conv2d_weight(
|
|
self.weight.to_mkldnn(),
|
|
self.padding,
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
input_size,
|
|
),
|
|
requires_grad=self.weight.requires_grad,
|
|
)
|
|
|
|
def _conv_forward(self, input, weight, bias):
|
|
if self.padding_mode != "zeros":
|
|
return torch.ops.mkldnn._convolution_pointwise(
|
|
F.pad(
|
|
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
|
|
),
|
|
weight,
|
|
bias,
|
|
_pair(0),
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
self.attr,
|
|
self.scalars,
|
|
self.algorithm,
|
|
)
|
|
return torch.ops.mkldnn._convolution_pointwise(
|
|
input,
|
|
weight,
|
|
bias,
|
|
self.padding,
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
self.attr,
|
|
self.scalars,
|
|
self.algorithm,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self._conv_forward(input, self.weight, self.bias)
|
|
|
|
|
|
class ConvBinary2d(nn.Conv2d):
|
|
def __init__(
|
|
self,
|
|
conv: nn.Module,
|
|
binary_op_name: str,
|
|
input_size: list,
|
|
):
|
|
super(ConvBinary2d, self).__init__(
|
|
conv.in_channels,
|
|
conv.out_channels,
|
|
conv.kernel_size,
|
|
conv.stride,
|
|
conv.padding,
|
|
conv.dilation,
|
|
conv.groups,
|
|
conv.bias is not None,
|
|
conv.padding_mode,
|
|
conv.weight.device,
|
|
conv.weight.dtype,
|
|
)
|
|
self._update_module_params(conv, binary_op_name, input_size)
|
|
|
|
def _update_module_params(self, conv, binary_op_name, input_size):
|
|
self.__dict__ = copy.deepcopy(conv.__dict__)
|
|
self.binary_attr = binary_op_name
|
|
self.binary_alpha = None
|
|
self.unary_attr = None
|
|
self.unary_scalars = []
|
|
self.unary_algorithm = None
|
|
self.weight = torch.nn.Parameter(
|
|
torch._C._nn.mkldnn_reorder_conv2d_weight(
|
|
self.weight.to_mkldnn(),
|
|
self.padding,
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
input_size,
|
|
),
|
|
requires_grad=self.weight.requires_grad,
|
|
)
|
|
|
|
def _update_unary_params(self, unary):
|
|
self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[
|
|
unary.__class__
|
|
](unary)
|
|
|
|
def _conv_forward(self, input, other, weight, bias):
|
|
if self.padding_mode != "zeros":
|
|
return torch.ops.mkldnn._convolution_pointwise(
|
|
F.pad(
|
|
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
|
|
),
|
|
other,
|
|
weight,
|
|
bias,
|
|
_pair(0),
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
self.binary_attr,
|
|
self.binary_alpha,
|
|
self.unary_attr,
|
|
self.unary_scalars,
|
|
self.unary_algorithm,
|
|
)
|
|
return torch.ops.mkldnn._convolution_pointwise(
|
|
input,
|
|
other,
|
|
weight,
|
|
bias,
|
|
self.padding,
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
self.binary_attr,
|
|
self.binary_alpha,
|
|
self.unary_attr,
|
|
self.unary_scalars,
|
|
self.unary_algorithm,
|
|
)
|
|
|
|
def forward(self, input, other):
|
|
return self._conv_forward(input, other, self.weight, self.bias)
|
|
|
|
|
|
class ConvBinaryInplace2d(nn.Conv2d):
|
|
def __init__(
|
|
self,
|
|
conv: nn.Module,
|
|
binary_op_name: str,
|
|
input_size: list,
|
|
):
|
|
super(ConvBinaryInplace2d, self).__init__(
|
|
conv.in_channels,
|
|
conv.out_channels,
|
|
conv.kernel_size,
|
|
conv.stride,
|
|
conv.padding,
|
|
conv.dilation,
|
|
conv.groups,
|
|
conv.bias is not None,
|
|
conv.padding_mode,
|
|
conv.weight.device,
|
|
conv.weight.dtype,
|
|
)
|
|
self._update_module_params(conv, binary_op_name, input_size)
|
|
|
|
def _update_module_params(self, conv, binary_op_name, input_size):
|
|
self.__dict__ = copy.deepcopy(conv.__dict__)
|
|
self.binary_attr = binary_op_name
|
|
self.binary_alpha = None
|
|
self.unary_attr = None
|
|
self.unary_scalars = []
|
|
self.unary_algorithm = None
|
|
self.weight = torch.nn.Parameter(
|
|
torch._C._nn.mkldnn_reorder_conv2d_weight(
|
|
self.weight.to_mkldnn(),
|
|
self.padding,
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
input_size,
|
|
),
|
|
requires_grad=self.weight.requires_grad,
|
|
)
|
|
|
|
def _update_unary_params(self, unary):
|
|
self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[
|
|
unary.__class__
|
|
](unary)
|
|
|
|
def _conv_forward(self, input, other, weight, bias):
|
|
if self.padding_mode != "zeros":
|
|
return torch.ops.mkldnn._convolution_pointwise_(
|
|
F.pad(
|
|
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
|
|
),
|
|
other,
|
|
weight,
|
|
bias,
|
|
_pair(0),
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
self.binary_attr,
|
|
self.binary_alpha,
|
|
self.unary_attr,
|
|
self.unary_scalars,
|
|
self.unary_algorithm,
|
|
)
|
|
return torch.ops.mkldnn._convolution_pointwise_(
|
|
input,
|
|
other,
|
|
weight,
|
|
bias,
|
|
self.padding,
|
|
self.stride,
|
|
self.dilation,
|
|
self.groups,
|
|
self.binary_attr,
|
|
self.binary_alpha,
|
|
self.unary_attr,
|
|
self.unary_scalars,
|
|
self.unary_algorithm,
|
|
)
|
|
|
|
def forward(self, input, other):
|
|
return self._conv_forward(input, other, self.weight, self.bias)
|
|
|
|
|
|
class PackedLinear(nn.Linear):
|
|
def __init__(self, linear: nn.Module, input_size: list):
|
|
super(PackedLinear, self).__init__(
|
|
linear.in_features,
|
|
linear.out_features,
|
|
linear.bias is not None,
|
|
linear.weight.device,
|
|
linear.weight.dtype,
|
|
)
|
|
self._update_module_params(linear, input_size)
|
|
|
|
def _update_module_params(self, linear, input_size):
|
|
self.__dict__ = copy.deepcopy(linear.__dict__)
|
|
self.batch_size = int(numpy.prod(input_size) / input_size[-1])
|
|
self.packed_weight = torch.nn.Parameter(
|
|
torch.ops.mkl._mkl_reorder_linear_weight(
|
|
self.weight.to_mkldnn(), self.batch_size
|
|
),
|
|
requires_grad=self.weight.requires_grad,
|
|
)
|
|
|
|
def forward(self, input):
|
|
y = torch.ops.mkl._mkl_linear(
|
|
input, self.packed_weight, self.weight, self.bias, self.batch_size
|
|
)
|
|
return y
|
|
|
|
|
|
class LinearUnary(nn.Linear):
|
|
def __init__(
|
|
self,
|
|
linear: nn.Module,
|
|
unary: nn.Module,
|
|
):
|
|
super(LinearUnary, self).__init__(
|
|
linear.in_features,
|
|
linear.out_features,
|
|
linear.bias is not None,
|
|
linear.weight.device,
|
|
linear.weight.dtype,
|
|
)
|
|
self._update_module_params(linear, unary)
|
|
|
|
def _update_module_params(self, linear, unary):
|
|
self.__dict__ = copy.deepcopy(linear.__dict__)
|
|
self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
|
|
unary
|
|
)
|
|
|
|
def forward(self, input):
|
|
y = torch.ops.mkldnn._linear_pointwise(
|
|
input, self.weight, self.bias, self.attr, self.scalars, self.algorithm
|
|
)
|
|
return y
|
|
|
|
|
|
class LinearBinary(nn.Linear):
|
|
def __init__(self, linear: nn.Module, binary_op_name: str):
|
|
super(LinearBinary, self).__init__(
|
|
linear.in_features,
|
|
linear.out_features,
|
|
linear.bias is not None,
|
|
linear.weight.device,
|
|
linear.weight.dtype,
|
|
)
|
|
self._update_module_params(linear, binary_op_name)
|
|
|
|
def _update_module_params(self, linear, binary_op_name):
|
|
self.__dict__ = copy.deepcopy(linear.__dict__)
|
|
|
|
self.attr = binary_op_name
|
|
|
|
def forward(self, input, other):
|
|
y = torch.ops.mkldnn._linear_pointwise(
|
|
input, other, self.weight, self.bias, self.attr
|
|
)
|
|
return y
|
|
|
|
|
|
def packed_conv_eval(conv: nn.Module, input_size: list):
|
|
assert not (conv.training), "Fusion only for eval!"
|
|
return ConvUnary2d(
|
|
conv,
|
|
None,
|
|
input_size,
|
|
)
|
|
|
|
|
|
def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list):
|
|
assert not (conv.training), "Fusion only for eval!"
|
|
return ConvUnary2d(
|
|
conv,
|
|
unary,
|
|
input_size,
|
|
)
|
|
|
|
|
|
def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: list):
|
|
assert not (conv.training), "Fusion only for eval!"
|
|
return ConvBinary2d(
|
|
conv,
|
|
binary_op_name,
|
|
input_size,
|
|
)
|
|
|
|
|
|
def fused_conv_binary_inplace_eval(
|
|
conv: nn.Module, binary_op_name: str, input_size: list
|
|
):
|
|
assert not (conv.training), "Fusion only for eval!"
|
|
return ConvBinaryInplace2d(
|
|
conv,
|
|
binary_op_name,
|
|
input_size,
|
|
)
|
|
|
|
|
|
def fused_conv_binary_unary_eval(
|
|
conv_binary: nn.Module, unary: nn.Module, input_size: list
|
|
):
|
|
assert not (conv_binary.training), "Fusion only for eval!"
|
|
# reuse origin conv module, and just update its' unary attr.
|
|
conv_binary._update_unary_params(unary)
|
|
return conv_binary
|
|
|
|
|
|
def is_bfloat16_module(m):
|
|
weight_is_bf16 = m.weight.dtype == torch.bfloat16
|
|
bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16
|
|
return weight_is_bf16 and bias_is_bf16
|
|
|
|
|
|
def packed_linear_eval(linear: nn.Module, input_size: list):
|
|
assert not (linear.training), "Fusion only for eval!"
|
|
return PackedLinear(linear, input_size)
|
|
|
|
|
|
def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module, input_size: list):
|
|
assert not (linear.training), "Fusion only for eval!"
|
|
return LinearUnary(
|
|
linear,
|
|
unary,
|
|
)
|
|
|
|
|
|
def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list):
|
|
assert not (linear.training), "Fusion only for eval!"
|
|
linear_binary = LinearBinary(
|
|
linear,
|
|
attr,
|
|
)
|
|
return linear_binary
|
|
|
|
|
|
def check_node_kind(current_node, modules, node_kind):
|
|
if not isinstance(current_node, torch.fx.Node):
|
|
return False
|
|
if current_node.op != "call_module":
|
|
return False
|
|
if not isinstance(current_node.target, str):
|
|
return False
|
|
if current_node.target not in modules:
|
|
return False
|
|
if type(modules[current_node.target]) is not node_kind:
|
|
return False
|
|
return True
|
|
|
|
|
|
def check_node_is_binary(node):
|
|
return (
|
|
(node.op == "call_function" and node.target in [torch.add, torch.sub])
|
|
or (
|
|
node.op == "call_function"
|
|
and node.target
|
|
in [operator.add, operator.iadd, operator.sub, operator.isub]
|
|
)
|
|
or (node.op == "call_method" and node.target in ["add", "add_", "sub", "sub_"])
|
|
)
|
|
|
|
|
|
def check_binary_op_kwargs_is_default(node):
|
|
# For binary op, we hope the kwargs values are the default value:
|
|
# torch.sub(add)(input, other, *, alpha=1, out=None).
|
|
if len(node.args) > 2:
|
|
return False
|
|
if len(node.kwargs) > 0:
|
|
if "out" in node.kwargs and node.kwargs["out"] is not None:
|
|
return False
|
|
if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def check_node_is_add_inplace(node):
|
|
return (node.op == "call_function" and node.target in [operator.iadd]) or (
|
|
node.op == "call_method" and node.target in ["add_"]
|
|
)
|
|
|
|
|
|
def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
|
|
is_cpu = all(
|
|
example_input.device == torch.device("cpu") for example_input in example_inputs
|
|
)
|
|
|
|
fake_mode = fake_mode_from_tensors(example_inputs)
|
|
|
|
if config.permute_fusion and not is_cpu:
|
|
# For linear permute fusion, we need to check input info to identify
|
|
# and perform proper permutation/transpose
|
|
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
|
|
gm = linear_permute_fusion(gm)
|
|
gm = permute_linear_fusion(gm)
|
|
gm = permute_matmul_fusion(gm)
|
|
|
|
# make sure the autograd is disabled.
|
|
if torch.is_grad_enabled():
|
|
return gm
|
|
if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
|
|
return gm
|
|
if not is_cpu:
|
|
return gm
|
|
gm = remove_identity(gm)
|
|
gm = fuse_conv_bn(gm)
|
|
# For binary fusion, we need to check inputs info to make sure
|
|
# the binary inputs have same tensor info(device, dtype, and layout).
|
|
|
|
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
|
|
gm = fuse_unary(gm)
|
|
gm = fuse_binary_inplace(gm)
|
|
gm = fuse_binary(gm)
|
|
# why re-run fuse_unary? we want to enable conv+binary+unary fusion,
|
|
# such as conv+add+relu for vision model.
|
|
gm = fuse_unary(gm)
|
|
gm = pack_module(gm)
|
|
return gm
|
|
|
|
|
|
# check the pattern: (nn.module, F.function) matched.
|
|
def matches_module_function_pattern(pattern, node, modules):
|
|
if len(node.args) == 0:
|
|
return False
|
|
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
|
node, torch.fx.Node
|
|
):
|
|
return False
|
|
# the first node is call_module
|
|
if node.args[0].op != "call_module":
|
|
return False
|
|
if not isinstance(node.args[0].target, str):
|
|
return False
|
|
if node.args[0].target not in modules:
|
|
return False
|
|
if type(modules[node.args[0].target]) is not pattern[0]:
|
|
return False
|
|
# the second node is call_function
|
|
if node.op != "call_function":
|
|
return False
|
|
if node.target != pattern[1]:
|
|
return False
|
|
# make sure node.args[0] output is only used by current node.
|
|
if len(node.args[0].users) > 1:
|
|
return False
|
|
return True
|
|
|
|
|
|
def fetch_attr(target: str, mod):
|
|
target_atoms = target.split(".")
|
|
attr_itr = mod
|
|
for i, atom in enumerate(target_atoms):
|
|
if not hasattr(attr_itr, atom):
|
|
raise RuntimeError(
|
|
f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
|
|
)
|
|
attr_itr = getattr(attr_itr, atom)
|
|
return attr_itr
|
|
|
|
|
|
def remove_identity(gm: torch.fx.GraphModule):
|
|
"""
|
|
Removes all identity layers from the module.
|
|
"""
|
|
|
|
class IdentityRemover(torch.fx.Transformer):
|
|
def call_module(self, target, args, kwargs):
|
|
if isinstance(self.submodules[target], nn.Identity):
|
|
assert len(args) == 1
|
|
return args[0]
|
|
else:
|
|
return super().call_module(target, args, kwargs)
|
|
|
|
return IdentityRemover(gm).transform()
|
|
|
|
|
|
def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False):
|
|
"""
|
|
Fuses Convolution/BN layers for inference purposes.
|
|
"""
|
|
modules_patterns = [
|
|
(torch.nn.Conv1d, torch.nn.BatchNorm1d),
|
|
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
|
|
(torch.nn.Conv3d, torch.nn.BatchNorm3d),
|
|
]
|
|
module_function_patterns = [
|
|
(torch.nn.Conv1d, F.batch_norm),
|
|
(torch.nn.Conv2d, F.batch_norm),
|
|
(torch.nn.Conv3d, F.batch_norm),
|
|
]
|
|
modules = dict(gm.named_modules())
|
|
for pattern in modules_patterns:
|
|
for node in gm.graph.nodes:
|
|
if matches_module_pattern(pattern, node, modules):
|
|
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
|
|
continue
|
|
conv = modules[node.args[0].target]
|
|
bn = modules[node.target]
|
|
eval_mode = all(not n.training for n in [conv, bn])
|
|
if not eval_mode:
|
|
continue
|
|
if not bn.track_running_stats:
|
|
continue
|
|
fused_conv = fuse_conv_bn_eval(conv, bn)
|
|
replace_node_module(node.args[0], modules, fused_conv)
|
|
node.replace_all_uses_with(node.args[0])
|
|
gm.graph.erase_node(node)
|
|
gm.graph.lint()
|
|
for pattern in module_function_patterns:
|
|
for node in gm.graph.nodes:
|
|
if matches_module_function_pattern(pattern, node, modules):
|
|
# TODO: support kwargs.
|
|
if len(node.args) != 8:
|
|
continue
|
|
conv = modules[node.args[0].target]
|
|
bn_training = node.args[5]
|
|
bn_eps = node.args[7]
|
|
if conv.training or bn_training:
|
|
continue
|
|
if type(bn_eps) is not float:
|
|
continue
|
|
bn_args_is_constant = all(
|
|
n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5]
|
|
)
|
|
if not bn_args_is_constant:
|
|
continue
|
|
bn_running_mean = fetch_attr(node.args[1].target, gm)
|
|
bn_running_var = fetch_attr(node.args[2].target, gm)
|
|
bn_weight = fetch_attr(node.args[3].target, gm)
|
|
bn_bias = fetch_attr(node.args[4].target, gm)
|
|
if bn_running_mean is None or bn_running_var is None:
|
|
continue
|
|
fused_conv = copy.deepcopy(conv)
|
|
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
|
|
fused_conv.weight,
|
|
fused_conv.bias,
|
|
bn_running_mean,
|
|
bn_running_var,
|
|
bn_eps,
|
|
bn_weight,
|
|
bn_bias,
|
|
)
|
|
replace_node_module(node.args[0], modules, fused_conv)
|
|
node.replace_all_uses_with(node.args[0])
|
|
gm.graph.erase_node(node)
|
|
gm.graph.lint()
|
|
gm.recompile()
|
|
|
|
return gm
|
|
|
|
|
|
def fuse_unary(gm: torch.fx.GraphModule):
|
|
modules = dict(gm.named_modules())
|
|
|
|
for (unary_module, _), (computation_module, fuse_func,) in itertools.product(
|
|
unary_modules_map.items(), computation_op_unary_op_fusion_map.items()
|
|
):
|
|
pattern = (computation_module, unary_module)
|
|
for node in gm.graph.nodes:
|
|
if matches_module_pattern(pattern, node, modules):
|
|
if (
|
|
len(node.args[0].users) > 1
|
|
): # Output of computation_node is used by other nodes
|
|
continue
|
|
computation_node = modules[node.args[0].target]
|
|
unary_node = modules[node.target]
|
|
eval_mode = all(not n.training for n in [computation_node, unary_node])
|
|
if not eval_mode:
|
|
continue
|
|
# TODO: support padding str input("valid", "same").
|
|
if type(computation_node) in [nn.Conv2d] and isinstance(
|
|
computation_node.padding, str
|
|
):
|
|
continue
|
|
# TODO: support more conv+binary+unary fusion.
|
|
if type(computation_node) in [
|
|
ConvBinary2d,
|
|
ConvBinaryInplace2d,
|
|
] and type(unary_node) not in [nn.ReLU]:
|
|
continue
|
|
# only fuse for linear when the dtype is bf16
|
|
if type(computation_node) in [nn.Linear] and not is_bfloat16_module(
|
|
computation_node
|
|
):
|
|
continue
|
|
computation_node_input_size = (
|
|
node.args[0].args[0].meta.get("tensor_meta").shape
|
|
)
|
|
fused_module = fuse_func(
|
|
computation_node, unary_node, computation_node_input_size
|
|
)
|
|
replace_node_module(node.args[0], modules, fused_module)
|
|
|
|
node.replace_all_uses_with(node.args[0])
|
|
gm.graph.erase_node(node)
|
|
gm.graph.lint()
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
def _philox_rand_like_meta(input, seed, offset):
|
|
return _prims.TensorMeta(input)
|
|
|
|
|
|
def _philox_rand_like(input, seed, offset):
|
|
# placeholder only used in tracing
|
|
return torch.rand_like(input)
|
|
|
|
|
|
class NormalizedLinearNode:
|
|
def __init__(self, node: torch.fx.Node) -> None:
|
|
assert node.op == "call_function"
|
|
assert node.target in [torch.nn.functional.linear]
|
|
self.node: torch.fx.Node = node
|
|
|
|
def get_input(self) -> torch.fx.Node:
|
|
if len(self.node.args) > 0:
|
|
return self.node.args[0]
|
|
else:
|
|
return self.node.kwargs["input"]
|
|
|
|
def get_weight(self) -> torch.fx.Node:
|
|
if len(self.node.args) > 1:
|
|
return self.node.args[1]
|
|
else:
|
|
return self.node.kwargs["weight"]
|
|
|
|
def get_bias(self) -> torch.fx.Node:
|
|
if len(self.node.args) > 2:
|
|
return self.node.args[2]
|
|
else:
|
|
return self.node.kwargs["bias"]
|
|
|
|
|
|
class NormalizedMatmulNode:
|
|
def __init__(self, node: torch.fx.Node) -> None:
|
|
assert node.op == "call_function"
|
|
assert node.target in [torch.bmm, torch.matmul]
|
|
self.node: torch.fx.Node = node
|
|
|
|
def get_input(self) -> torch.fx.Node:
|
|
if len(self.node.args) > 0:
|
|
return self.node.args[0]
|
|
else:
|
|
return self.node.kwargs["input"]
|
|
|
|
def get_other(self) -> torch.fx.Node:
|
|
if len(self.node.args) > 1:
|
|
return self.node.args[1]
|
|
else:
|
|
return self.node.kwargs["other"]
|
|
|
|
|
|
def check_permute(node: torch.fx.Node):
|
|
ranks = len(node.meta["tensor_meta"].shape)
|
|
if len(node.args) > 3:
|
|
permutation = [node.args[i] % ranks for i in range(1, ranks + 1)]
|
|
elif (
|
|
"permutation" in node.kwargs
|
|
and node.kwargs["permutation"] is not None
|
|
and len(node.kwargs["permutation"]) > 2
|
|
):
|
|
permutation = [i % ranks for i in node.kwargs["permutation"]]
|
|
else:
|
|
return False
|
|
allowed_permutation = list(range(ranks))
|
|
allowed_permutation[-1] = ranks - 2
|
|
allowed_permutation[-2] = ranks - 1
|
|
return permutation == allowed_permutation
|
|
|
|
|
|
def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in module.graph.nodes:
|
|
if (
|
|
node.op == "call_method"
|
|
and node.target == "permute"
|
|
and check_permute(node)
|
|
):
|
|
if len(node.args) > 0:
|
|
input_node = node.args[0]
|
|
else:
|
|
input_node = node.kwargs["input"]
|
|
if (
|
|
input_node.op == "call_function"
|
|
and input_node.target == torch.nn.functional.linear
|
|
):
|
|
normalized = NormalizedLinearNode(input_node)
|
|
input = normalized.get_input()
|
|
weight = normalized.get_weight()
|
|
bias = normalized.get_bias()
|
|
with module.graph.inserting_before(node):
|
|
fused_node = module.graph.call_function(
|
|
linear_transpose, args=(input, weight, bias)
|
|
)
|
|
node.replace_all_uses_with(fused_node)
|
|
|
|
module.graph.lint()
|
|
module.graph.eliminate_dead_code()
|
|
module.recompile()
|
|
return module
|
|
|
|
|
|
# Y1 = X * W^T + bias
|
|
# Y2 = Y1.permute(0, 2, 1)
|
|
# ---->
|
|
# Y2 = (W * X^T + bias.unsqueeze(-1))^T
|
|
def linear_transpose(
|
|
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1)
|
|
|
|
|
|
def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in module.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.nn.functional.linear:
|
|
if len(node.args) > 0:
|
|
input_node = node.args[0]
|
|
else:
|
|
input_node = node.kwargs["input"]
|
|
if (
|
|
input_node.op == "call_method"
|
|
and input_node.target == "permute"
|
|
and check_permute(input_node)
|
|
):
|
|
normalized = NormalizedLinearNode(node)
|
|
if len(input_node.args) > 0:
|
|
input = input_node.args[0]
|
|
else:
|
|
input = input_node.kwargs["input"]
|
|
weight = normalized.get_weight()
|
|
bias = normalized.get_bias()
|
|
with module.graph.inserting_before(node):
|
|
fused_node = module.graph.call_function(
|
|
transpose_linear, args=(input, weight, bias)
|
|
)
|
|
node.replace_all_uses_with(fused_node)
|
|
|
|
module.graph.lint()
|
|
module.graph.eliminate_dead_code()
|
|
module.recompile()
|
|
return module
|
|
|
|
|
|
def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in module.graph.nodes:
|
|
if node.op == "call_function" and (
|
|
node.target == torch.bmm or node.target == torch.matmul
|
|
):
|
|
normalized = NormalizedMatmulNode(node)
|
|
A = normalized.get_input()
|
|
B = normalized.get_other()
|
|
Atrans = Btrans = False
|
|
if A.op == "call_method" and A.target == "permute" and check_permute(A):
|
|
Atrans = True
|
|
if len(A.args) > 0:
|
|
A = A.args[0]
|
|
else:
|
|
A = A.kwargs["input"]
|
|
|
|
if B.op == "call_method" and B.target == "permute" and check_permute(B):
|
|
Btrans = True
|
|
if len(B.args) > 0:
|
|
B = B.args[0]
|
|
else:
|
|
B = B.kwargs["input"]
|
|
|
|
if Atrans or Btrans:
|
|
with module.graph.inserting_before(node):
|
|
fused_node = module.graph.call_function(
|
|
transpose_matmul,
|
|
args=(A, B, Atrans, Btrans),
|
|
)
|
|
node.replace_all_uses_with(fused_node)
|
|
|
|
module.graph.lint()
|
|
module.graph.eliminate_dead_code()
|
|
module.recompile()
|
|
return module
|
|
|
|
|
|
# X1 = X.permute(0, 2, 1)
|
|
# Y1 = X1 * W1^T + bias1
|
|
# ---->
|
|
# Y2 = X1.transpose(-1, -2) * W1^T + bias1
|
|
def transpose_linear(
|
|
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
|
|
|
|
|
|
def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool):
|
|
if Atrans:
|
|
A = A.transpose(-1, -2)
|
|
if Btrans:
|
|
B = B.transpose(-1, -2)
|
|
return torch.matmul(A, B)
|
|
|
|
|
|
def replace_and_fuse_for_binary(
|
|
computation_node, node, fuse_func, attr, modules, index_node, index_pointwise
|
|
):
|
|
computation_node_input_size = (
|
|
node.args[index_node].args[0].meta.get("tensor_meta").shape
|
|
)
|
|
fused_module = fuse_func(computation_node, attr, computation_node_input_size)
|
|
replace_node_module(node.args[index_node], modules, fused_module)
|
|
node.args[index_node].args = node.args[index_node].args + (
|
|
node.args[index_pointwise],
|
|
)
|
|
node.replace_all_uses_with(node.args[index_node])
|
|
|
|
|
|
def binary_inputs_meta_is_same(binary_node):
|
|
tensor0_meta = binary_node.args[0].meta.get("tensor_meta")
|
|
tensor1_meta = binary_node.args[1].meta.get("tensor_meta")
|
|
if not tensor0_meta or not tensor1_meta:
|
|
return False
|
|
if (
|
|
tensor0_meta.shape != tensor1_meta.shape
|
|
or tensor0_meta.stride != tensor1_meta.stride
|
|
or tensor0_meta.dtype != tensor1_meta.dtype
|
|
):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def fuse_binary(gm: torch.fx.GraphModule):
|
|
modules = dict(gm.named_modules())
|
|
for node in gm.graph.nodes:
|
|
if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node):
|
|
for node_kind, fuse_func in computation_op_binary_op_fusion_map.items():
|
|
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
|
node.args[1], torch.fx.Node
|
|
):
|
|
continue
|
|
if not binary_inputs_meta_is_same(node):
|
|
continue
|
|
attr = binary_attr[node.target]
|
|
index_list = supported_index_list[attr]
|
|
for index_dict in index_list:
|
|
index_node = index_dict["index_computation"]
|
|
index_pointwise = index_dict["index_pointwise"]
|
|
if check_node_kind(node.args[index_node], modules, node_kind):
|
|
if len(node.args[index_node].users) > 1:
|
|
continue
|
|
computation_node = modules[node.args[index_node].target]
|
|
# TODO: support padding str input("valid", "same").
|
|
if type(computation_node) in [nn.Conv2d] and isinstance(
|
|
computation_node.padding, str
|
|
):
|
|
continue
|
|
# only fuse for linear when the dtype is bf16
|
|
if type(computation_node) in [
|
|
nn.Linear
|
|
] and not is_bfloat16_module(computation_node):
|
|
continue
|
|
replace_and_fuse_for_binary(
|
|
computation_node,
|
|
node,
|
|
fuse_func,
|
|
attr if attr != "iadd" else "add",
|
|
modules,
|
|
index_node,
|
|
index_pointwise,
|
|
)
|
|
# Make sure the fused node is post node of node's inputs nodes.
|
|
node.append(node.args[index_node])
|
|
gm.graph.erase_node(node)
|
|
gm.graph.lint()
|
|
break
|
|
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
def fuse_binary_inplace(gm: torch.fx.GraphModule):
|
|
modules = dict(gm.named_modules())
|
|
for node in gm.graph.nodes:
|
|
if check_node_is_add_inplace(node) and check_binary_op_kwargs_is_default(node):
|
|
for (
|
|
node_kind,
|
|
fuse_func,
|
|
) in computation_op_binary_op_fusion_inplace_map.items():
|
|
if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
|
|
node.args[1], torch.fx.Node
|
|
):
|
|
continue
|
|
if not binary_inputs_meta_is_same(node):
|
|
continue
|
|
if check_node_kind(node.args[1], modules, node_kind):
|
|
if len(node.args[1].users) > 1:
|
|
continue
|
|
# make sure the output and input are not same tensor.
|
|
if node.args[1].args[0] == node.args[0]:
|
|
continue
|
|
computation_node = modules[node.args[1].target]
|
|
# TODO: support padding str input("valid", "same").
|
|
if type(computation_node) in [nn.Conv2d] and isinstance(
|
|
computation_node.padding, str
|
|
):
|
|
continue
|
|
replace_and_fuse_for_binary(
|
|
computation_node,
|
|
node,
|
|
fuse_func,
|
|
"add",
|
|
modules,
|
|
1, # conv module index
|
|
0, # binary op index
|
|
)
|
|
# Make sure the fused node is post node of node's inputs nodes.
|
|
node.append(node.args[1])
|
|
gm.graph.erase_node(node)
|
|
gm.graph.lint()
|
|
break
|
|
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
def pack_module(gm: torch.fx.GraphModule):
|
|
modules = dict(gm.named_modules())
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_module":
|
|
assert isinstance(node.target, str)
|
|
cur_module = modules[node.target]
|
|
if type(cur_module) in computation_op_packed_map:
|
|
computation_node_input_meta = node.args[0].meta.get("tensor_meta")
|
|
if computation_node_input_meta.dtype != torch.float32:
|
|
continue
|
|
if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl:
|
|
continue
|
|
computation_node_input_size = computation_node_input_meta.shape
|
|
if type(cur_module) in [nn.Conv2d] and isinstance(
|
|
cur_module.padding, str
|
|
):
|
|
continue
|
|
new_module = computation_op_packed_map[type(cur_module)](
|
|
cur_module, computation_node_input_size
|
|
)
|
|
assert isinstance(new_module, nn.Module)
|
|
replace_node_module(node, modules, new_module)
|
|
gm.graph.lint()
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
philox_rand_like = _prims._make_prim(
|
|
schema="philox_rand_like(Tensor input, Tensor seed, int offset) -> Tensor",
|
|
return_type=_prims.RETURN_TYPE.NEW,
|
|
meta=_philox_rand_like_meta,
|
|
impl_aten=_philox_rand_like,
|
|
doc="",
|
|
)
|
|
|
|
|
|
def _philox_seed_like_meta(x):
|
|
return _prims.TensorMeta(_philox_seed_like(x))
|
|
|
|
|
|
def _philox_seed_like(x):
|
|
# we need a tensor input here so AOT autograd properly captures this
|
|
# with just a device input, this becomes a constant
|
|
return torch.tensor(random.randrange(2**31), device=x.device, dtype=torch.int32)
|
|
|
|
|
|
philox_seed_like = _prims._make_prim(
|
|
schema="philox_seed_like(Tensor other) -> Tensor",
|
|
return_type=_prims.RETURN_TYPE.NEW,
|
|
meta=_philox_seed_like_meta,
|
|
impl_aten=_philox_seed_like,
|
|
doc="",
|
|
)
|
|
|
|
|
|
def null_ref():
|
|
return None
|
|
|
|
|
|
class PhiloxRandomState:
|
|
next_offset = 0
|
|
seed = {}
|
|
last_tracer_ref = null_ref
|
|
|
|
@classmethod
|
|
def reset(cls, tracer=None):
|
|
cls.next_offset = 0
|
|
cls.seed = {}
|
|
cls.last_tracer_ref = weakref.ref(tracer) if tracer is not None else null_ref
|
|
|
|
@classmethod
|
|
def get_seed_offset(cls, x):
|
|
modes = torch.fx.experimental.proxy_tensor.get_torch_dispatch_modes()
|
|
proxy_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
|
|
if proxy_modes:
|
|
tracer = proxy_modes[0].tracer
|
|
if cls.last_tracer_ref() is not tracer:
|
|
# tracer changed, need to reset state
|
|
cls.reset(tracer)
|
|
else:
|
|
# no tracer, need to reset state
|
|
cls.reset()
|
|
|
|
device = x.device
|
|
if device not in cls.seed:
|
|
# Compute the seed just once per trace so that we pass fewer
|
|
# things from forward to backward
|
|
cls.seed[device] = philox_seed_like(x)
|
|
|
|
seed = cls.seed[device]
|
|
offset = cls.next_offset
|
|
cls.next_offset += x.numel()
|
|
return seed, offset
|
|
|
|
|
|
class LowmemDropout(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, p):
|
|
ctx.p = p
|
|
scale = float(1.0 / (1.0 - p))
|
|
seed, offset = PhiloxRandomState.get_seed_offset(x)
|
|
ctx.save_for_backward(seed)
|
|
ctx.offset = offset
|
|
bool_mask = philox_rand_like(x, seed, offset) > p
|
|
return bool_mask.to(x.dtype) * x * scale
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
p = ctx.p
|
|
scale = float(1.0 / (1.0 - p))
|
|
(seed,) = ctx.saved_tensors
|
|
bool_mask = philox_rand_like(grad_output, seed, ctx.offset) > p
|
|
return bool_mask.to(grad_output.dtype) * grad_output * scale, None
|
|
|
|
|
|
@torch.fx.wrap
|
|
def lowmem_dropout(input, p=0.5, training=True, inplace=False):
|
|
if isinstance(input, torch.fx.Proxy):
|
|
# double check we don't FX trace this
|
|
return input.tracer.create_proxy(
|
|
"call_function",
|
|
lowmem_dropout,
|
|
(input, p, training),
|
|
{},
|
|
)
|
|
if not training or p == 0:
|
|
return input
|
|
result = LowmemDropout.apply(input, p)
|
|
if inplace:
|
|
input.copy_(result)
|
|
return result
|
|
|
|
|
|
@torch.fx.wrap
|
|
def rand_like(x, **kwargs):
|
|
if isinstance(x, torch.fx.Proxy):
|
|
# double check we don't FX trace this
|
|
return x.tracer.create_proxy("call_function", rand_like, (x), kwargs)
|
|
assert kwargs.get("device", x.device) == x.device
|
|
seed, offset = PhiloxRandomState.get_seed_offset(x)
|
|
return philox_rand_like(x, seed, offset).to(kwargs.get("dtype", torch.float32))
|
|
|
|
|
|
replacements = {torch.nn.functional.dropout: lowmem_dropout, torch.rand_like: rand_like}
|
|
# Keep track of any replacement functions that use triton random,
|
|
# so they can be avoided when fallback_random is set
|
|
replacements_using_triton_random = {lowmem_dropout, rand_like}
|
|
|
|
computation_op_unary_op_fusion_map = {
|
|
nn.Conv2d: fused_conv_unary_eval,
|
|
nn.Linear: fused_linear_unary_eval,
|
|
ConvBinary2d: fused_conv_binary_unary_eval,
|
|
ConvBinaryInplace2d: fused_conv_binary_unary_eval,
|
|
}
|
|
|
|
|
|
unary_modules_map = {
|
|
nn.ReLU: UnaryAttr("relu"),
|
|
nn.Sigmoid: UnaryAttr("sigmoid"),
|
|
nn.Tanh: UnaryAttr("tanh"),
|
|
nn.Hardswish: UnaryAttr("hardswish"),
|
|
nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]),
|
|
nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
|
|
nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"),
|
|
nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]),
|
|
nn.SiLU: UnaryAttr("swish"),
|
|
}
|
|
|
|
|
|
binary_attr = {
|
|
torch.add: "add", # node.op == "call_function"
|
|
"add": "add", # node.op == "call_method"
|
|
"add_": "iadd", # node.op == "call_method"
|
|
operator.add: "add", # node.op == "call_function"
|
|
operator.iadd: "iadd", # node.op == "call_function"
|
|
torch.sub: "sub", # node.op == "call_function"
|
|
"sub": "sub", # node.op == "call_method"
|
|
"sub_": "sub", # node.op == "call_method"
|
|
operator.sub: "sub", # node.op == "call_function"
|
|
operator.isub: "sub", # node.op == "call_function"
|
|
}
|
|
|
|
|
|
computation_op_binary_op_fusion_map = {
|
|
nn.Conv2d: fused_conv_binary_eval,
|
|
nn.Linear: fused_linear_binary_eval,
|
|
}
|
|
|
|
|
|
computation_op_binary_op_fusion_inplace_map = {
|
|
nn.Conv2d: fused_conv_binary_inplace_eval,
|
|
}
|
|
|
|
|
|
computation_op_packed_map = {
|
|
nn.Linear: packed_linear_eval,
|
|
nn.Conv2d: packed_conv_eval,
|
|
}
|
|
|
|
|
|
# For add: we support conv/linear + other and other + conv
|
|
# For sub/add_/sub_, we only support conv/linear - other
|
|
# or conv/linear +(-)= other
|
|
supported_index_list = {
|
|
"add": [
|
|
{"index_computation": 0, "index_pointwise": 1},
|
|
{"index_computation": 1, "index_pointwise": 0},
|
|
],
|
|
"iadd": [{"index_computation": 0, "index_pointwise": 1}],
|
|
"sub": [{"index_computation": 0, "index_pointwise": 1}],
|
|
}
|