import copy import itertools import logging import operator import random import weakref from typing import Optional import numpy import torch import torch.nn as nn from torch import _prims from torch._dynamo.utils import fake_mode_from_tensors from torch.fx.experimental.optimization import ( matches_module_pattern, replace_node_module, ) from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.nn.modules.utils import _pair from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights from torch.overrides import TorchFunctionMode from . import config log = logging.getLogger(__name__) class AutogradMonkeypatch(TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): if not kwargs: kwargs = {} if func is replacements: return replacements[func](*args, **kwargs) return func(*args, **kwargs) patch_functions = AutogradMonkeypatch def replace_fx(gm: torch.fx.GraphModule): # Sometimes patch_functions() misses things already in the graph for node in reversed(list(gm.graph.nodes)): if node.op == "call_function" and node.target in replacements: if ( config.fallback_random and replacements[node.target] in replacements_using_triton_random ): continue with gm.graph.inserting_before(node): node.replace_all_uses_with( gm.graph.call_function( replacements[node.target], node.args, node.kwargs ) ) gm.graph.erase_node(node) gm.recompile() return gm class UnaryAttr(object): def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): self.op_name = op_name self.scalars_attr = scalars_attr if scalars_attr else [] self.algorithm_attr = algorithm_attr if algorithm_attr else "" super(UnaryAttr, self).__init__() def __call__(self, unary_module: nn.Module): if type(unary_module) is nn.ReLU6: unary_module = nn.Hardtanh(min_val=0, max_val=6) assert all(hasattr(unary_module, item) for item in self.scalars_attr) scalars = [getattr(unary_module, item) for item in self.scalars_attr] algorithm = "" if self.algorithm_attr: assert hasattr(unary_module, self.algorithm_attr) algorithm = getattr(unary_module, self.algorithm_attr) return self.op_name, scalars, algorithm class ConvUnary2d(nn.Conv2d): def __init__( self, conv: nn.Module, unary: Optional[nn.Module], input_size: list, ): super(ConvUnary2d, self).__init__( conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, conv.weight.device, conv.weight.dtype, ) self._update_module_params(conv, unary, input_size) def _update_module_params(self, conv, unary, input_size): self.__dict__ = copy.deepcopy(conv.__dict__) self.attr = "none" self.scalars = [] self.algorithm = "" if unary is not None: self.attr, self.scalars, self.algorithm = unary_modules_map[ unary.__class__ ](unary) self.weight = torch.nn.Parameter( torch._C._nn.mkldnn_reorder_conv2d_weight( self.weight.to_mkldnn(), self.padding, self.stride, self.dilation, self.groups, input_size, ), requires_grad=self.weight.requires_grad, ) def _conv_forward(self, input, weight, bias): if self.padding_mode != "zeros": return torch.ops.mkldnn._convolution_pointwise( F.pad( input, self._reversed_padding_repeated_twice, mode=self.padding_mode ), weight, bias, _pair(0), self.stride, self.dilation, self.groups, self.attr, self.scalars, self.algorithm, ) return torch.ops.mkldnn._convolution_pointwise( input, weight, bias, self.padding, self.stride, self.dilation, self.groups, self.attr, self.scalars, self.algorithm, ) def forward(self, input): return self._conv_forward(input, self.weight, self.bias) class ConvBinary2d(nn.Conv2d): def __init__( self, conv: nn.Module, binary_op_name: str, input_size: list, ): super(ConvBinary2d, self).__init__( conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, conv.weight.device, conv.weight.dtype, ) self._update_module_params(conv, binary_op_name, input_size) def _update_module_params(self, conv, binary_op_name, input_size): self.__dict__ = copy.deepcopy(conv.__dict__) self.binary_attr = binary_op_name self.binary_alpha = None self.unary_attr = None self.unary_scalars = [] self.unary_algorithm = None self.weight = torch.nn.Parameter( torch._C._nn.mkldnn_reorder_conv2d_weight( self.weight.to_mkldnn(), self.padding, self.stride, self.dilation, self.groups, input_size, ), requires_grad=self.weight.requires_grad, ) def _update_unary_params(self, unary): self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ unary.__class__ ](unary) def _conv_forward(self, input, other, weight, bias): if self.padding_mode != "zeros": return torch.ops.mkldnn._convolution_pointwise( F.pad( input, self._reversed_padding_repeated_twice, mode=self.padding_mode ), other, weight, bias, _pair(0), self.stride, self.dilation, self.groups, self.binary_attr, self.binary_alpha, self.unary_attr, self.unary_scalars, self.unary_algorithm, ) return torch.ops.mkldnn._convolution_pointwise( input, other, weight, bias, self.padding, self.stride, self.dilation, self.groups, self.binary_attr, self.binary_alpha, self.unary_attr, self.unary_scalars, self.unary_algorithm, ) def forward(self, input, other): return self._conv_forward(input, other, self.weight, self.bias) class ConvBinaryInplace2d(nn.Conv2d): def __init__( self, conv: nn.Module, binary_op_name: str, input_size: list, ): super(ConvBinaryInplace2d, self).__init__( conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, conv.weight.device, conv.weight.dtype, ) self._update_module_params(conv, binary_op_name, input_size) def _update_module_params(self, conv, binary_op_name, input_size): self.__dict__ = copy.deepcopy(conv.__dict__) self.binary_attr = binary_op_name self.binary_alpha = None self.unary_attr = None self.unary_scalars = [] self.unary_algorithm = None self.weight = torch.nn.Parameter( torch._C._nn.mkldnn_reorder_conv2d_weight( self.weight.to_mkldnn(), self.padding, self.stride, self.dilation, self.groups, input_size, ), requires_grad=self.weight.requires_grad, ) def _update_unary_params(self, unary): self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ unary.__class__ ](unary) def _conv_forward(self, input, other, weight, bias): if self.padding_mode != "zeros": return torch.ops.mkldnn._convolution_pointwise_( F.pad( input, self._reversed_padding_repeated_twice, mode=self.padding_mode ), other, weight, bias, _pair(0), self.stride, self.dilation, self.groups, self.binary_attr, self.binary_alpha, self.unary_attr, self.unary_scalars, self.unary_algorithm, ) return torch.ops.mkldnn._convolution_pointwise_( input, other, weight, bias, self.padding, self.stride, self.dilation, self.groups, self.binary_attr, self.binary_alpha, self.unary_attr, self.unary_scalars, self.unary_algorithm, ) def forward(self, input, other): return self._conv_forward(input, other, self.weight, self.bias) class PackedLinear(nn.Linear): def __init__(self, linear: nn.Module, input_size: list): super(PackedLinear, self).__init__( linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device, linear.weight.dtype, ) self._update_module_params(linear, input_size) def _update_module_params(self, linear, input_size): self.__dict__ = copy.deepcopy(linear.__dict__) self.batch_size = int(numpy.prod(input_size) / input_size[-1]) self.packed_weight = torch.nn.Parameter( torch.ops.mkl._mkl_reorder_linear_weight( self.weight.to_mkldnn(), self.batch_size ), requires_grad=self.weight.requires_grad, ) def forward(self, input): y = torch.ops.mkl._mkl_linear( input, self.packed_weight, self.weight, self.bias, self.batch_size ) return y class LinearUnary(nn.Linear): def __init__( self, linear: nn.Module, unary: nn.Module, ): super(LinearUnary, self).__init__( linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device, linear.weight.dtype, ) self._update_module_params(linear, unary) def _update_module_params(self, linear, unary): self.__dict__ = copy.deepcopy(linear.__dict__) self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__]( unary ) def forward(self, input): y = torch.ops.mkldnn._linear_pointwise( input, self.weight, self.bias, self.attr, self.scalars, self.algorithm ) return y class LinearBinary(nn.Linear): def __init__(self, linear: nn.Module, binary_op_name: str): super(LinearBinary, self).__init__( linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device, linear.weight.dtype, ) self._update_module_params(linear, binary_op_name) def _update_module_params(self, linear, binary_op_name): self.__dict__ = copy.deepcopy(linear.__dict__) self.attr = binary_op_name def forward(self, input, other): y = torch.ops.mkldnn._linear_pointwise( input, other, self.weight, self.bias, self.attr ) return y def packed_conv_eval(conv: nn.Module, input_size: list): assert not (conv.training), "Fusion only for eval!" return ConvUnary2d( conv, None, input_size, ) def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list): assert not (conv.training), "Fusion only for eval!" return ConvUnary2d( conv, unary, input_size, ) def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: list): assert not (conv.training), "Fusion only for eval!" return ConvBinary2d( conv, binary_op_name, input_size, ) def fused_conv_binary_inplace_eval( conv: nn.Module, binary_op_name: str, input_size: list ): assert not (conv.training), "Fusion only for eval!" return ConvBinaryInplace2d( conv, binary_op_name, input_size, ) def fused_conv_binary_unary_eval( conv_binary: nn.Module, unary: nn.Module, input_size: list ): assert not (conv_binary.training), "Fusion only for eval!" # reuse origin conv module, and just update its' unary attr. conv_binary._update_unary_params(unary) return conv_binary def is_bfloat16_module(m): weight_is_bf16 = m.weight.dtype == torch.bfloat16 bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16 return weight_is_bf16 and bias_is_bf16 def packed_linear_eval(linear: nn.Module, input_size: list): assert not (linear.training), "Fusion only for eval!" return PackedLinear(linear, input_size) def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module, input_size: list): assert not (linear.training), "Fusion only for eval!" return LinearUnary( linear, unary, ) def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list): assert not (linear.training), "Fusion only for eval!" linear_binary = LinearBinary( linear, attr, ) return linear_binary def check_node_kind(current_node, modules, node_kind): if not isinstance(current_node, torch.fx.Node): return False if current_node.op != "call_module": return False if not isinstance(current_node.target, str): return False if current_node.target not in modules: return False if type(modules[current_node.target]) is not node_kind: return False return True def check_node_is_binary(node): return ( (node.op == "call_function" and node.target in [torch.add, torch.sub]) or ( node.op == "call_function" and node.target in [operator.add, operator.iadd, operator.sub, operator.isub] ) or (node.op == "call_method" and node.target in ["add", "add_", "sub", "sub_"]) ) def check_binary_op_kwargs_is_default(node): # For binary op, we hope the kwargs values are the default value: # torch.sub(add)(input, other, *, alpha=1, out=None). if len(node.args) > 2: return False if len(node.kwargs) > 0: if "out" in node.kwargs and node.kwargs["out"] is not None: return False if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0: return False return True def check_node_is_add_inplace(node): return (node.op == "call_function" and node.target in [operator.iadd]) or ( node.op == "call_method" and node.target in ["add_"] ) def fuse_fx(gm: torch.fx.GraphModule, example_inputs): is_cpu = all( example_input.device == torch.device("cpu") for example_input in example_inputs ) fake_mode = fake_mode_from_tensors(example_inputs) if config.permute_fusion and not is_cpu: # For linear permute fusion, we need to check input info to identify # and perform proper permutation/transpose ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) gm = linear_permute_fusion(gm) gm = permute_linear_fusion(gm) gm = permute_matmul_fusion(gm) # make sure the autograd is disabled. if torch.is_grad_enabled(): return gm if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): return gm if not is_cpu: return gm gm = remove_identity(gm) gm = fuse_conv_bn(gm) # For binary fusion, we need to check inputs info to make sure # the binary inputs have same tensor info(device, dtype, and layout). ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) gm = fuse_unary(gm) gm = fuse_binary_inplace(gm) gm = fuse_binary(gm) # why re-run fuse_unary? we want to enable conv+binary+unary fusion, # such as conv+add+relu for vision model. gm = fuse_unary(gm) gm = pack_module(gm) return gm # check the pattern: (nn.module, F.function) matched. def matches_module_function_pattern(pattern, node, modules): if len(node.args) == 0: return False if not isinstance(node.args[0], torch.fx.Node) or not isinstance( node, torch.fx.Node ): return False # the first node is call_module if node.args[0].op != "call_module": return False if not isinstance(node.args[0].target, str): return False if node.args[0].target not in modules: return False if type(modules[node.args[0].target]) is not pattern[0]: return False # the second node is call_function if node.op != "call_function": return False if node.target != pattern[1]: return False # make sure node.args[0] output is only used by current node. if len(node.args[0].users) > 1: return False return True def fetch_attr(target: str, mod): target_atoms = target.split(".") attr_itr = mod for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr def remove_identity(gm: torch.fx.GraphModule): """ Removes all identity layers from the module. """ class IdentityRemover(torch.fx.Transformer): def call_module(self, target, args, kwargs): if isinstance(self.submodules[target], nn.Identity): assert len(args) == 1 return args[0] else: return super().call_module(target, args, kwargs) return IdentityRemover(gm).transform() def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False): """ Fuses Convolution/BN layers for inference purposes. """ modules_patterns = [ (torch.nn.Conv1d, torch.nn.BatchNorm1d), (torch.nn.Conv2d, torch.nn.BatchNorm2d), (torch.nn.Conv3d, torch.nn.BatchNorm3d), ] module_function_patterns = [ (torch.nn.Conv1d, F.batch_norm), (torch.nn.Conv2d, F.batch_norm), (torch.nn.Conv3d, F.batch_norm), ] modules = dict(gm.named_modules()) for pattern in modules_patterns: for node in gm.graph.nodes: if matches_module_pattern(pattern, node, modules): if len(node.args[0].users) > 1: # Output of conv is used by other nodes continue conv = modules[node.args[0].target] bn = modules[node.target] eval_mode = all(not n.training for n in [conv, bn]) if not eval_mode: continue if not bn.track_running_stats: continue fused_conv = fuse_conv_bn_eval(conv, bn) replace_node_module(node.args[0], modules, fused_conv) node.replace_all_uses_with(node.args[0]) gm.graph.erase_node(node) gm.graph.lint() for pattern in module_function_patterns: for node in gm.graph.nodes: if matches_module_function_pattern(pattern, node, modules): # TODO: support kwargs. if len(node.args) != 8: continue conv = modules[node.args[0].target] bn_training = node.args[5] bn_eps = node.args[7] if conv.training or bn_training: continue if type(bn_eps) is not float: continue bn_args_is_constant = all( n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5] ) if not bn_args_is_constant: continue bn_running_mean = fetch_attr(node.args[1].target, gm) bn_running_var = fetch_attr(node.args[2].target, gm) bn_weight = fetch_attr(node.args[3].target, gm) bn_bias = fetch_attr(node.args[4].target, gm) if bn_running_mean is None or bn_running_var is None: continue fused_conv = copy.deepcopy(conv) fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( fused_conv.weight, fused_conv.bias, bn_running_mean, bn_running_var, bn_eps, bn_weight, bn_bias, ) replace_node_module(node.args[0], modules, fused_conv) node.replace_all_uses_with(node.args[0]) gm.graph.erase_node(node) gm.graph.lint() gm.recompile() return gm def fuse_unary(gm: torch.fx.GraphModule): modules = dict(gm.named_modules()) for (unary_module, _), (computation_module, fuse_func,) in itertools.product( unary_modules_map.items(), computation_op_unary_op_fusion_map.items() ): pattern = (computation_module, unary_module) for node in gm.graph.nodes: if matches_module_pattern(pattern, node, modules): if ( len(node.args[0].users) > 1 ): # Output of computation_node is used by other nodes continue computation_node = modules[node.args[0].target] unary_node = modules[node.target] eval_mode = all(not n.training for n in [computation_node, unary_node]) if not eval_mode: continue # TODO: support padding str input("valid", "same"). if type(computation_node) in [nn.Conv2d] and isinstance( computation_node.padding, str ): continue # TODO: support more conv+binary+unary fusion. if type(computation_node) in [ ConvBinary2d, ConvBinaryInplace2d, ] and type(unary_node) not in [nn.ReLU]: continue # only fuse for linear when the dtype is bf16 if type(computation_node) in [nn.Linear] and not is_bfloat16_module( computation_node ): continue computation_node_input_size = ( node.args[0].args[0].meta.get("tensor_meta").shape ) fused_module = fuse_func( computation_node, unary_node, computation_node_input_size ) replace_node_module(node.args[0], modules, fused_module) node.replace_all_uses_with(node.args[0]) gm.graph.erase_node(node) gm.graph.lint() gm.recompile() return gm def _philox_rand_like_meta(input, seed, offset): return _prims.TensorMeta(input) def _philox_rand_like(input, seed, offset): # placeholder only used in tracing return torch.rand_like(input) class NormalizedLinearNode: def __init__(self, node: torch.fx.Node) -> None: assert node.op == "call_function" assert node.target in [torch.nn.functional.linear] self.node: torch.fx.Node = node def get_input(self) -> torch.fx.Node: if len(self.node.args) > 0: return self.node.args[0] else: return self.node.kwargs["input"] def get_weight(self) -> torch.fx.Node: if len(self.node.args) > 1: return self.node.args[1] else: return self.node.kwargs["weight"] def get_bias(self) -> torch.fx.Node: if len(self.node.args) > 2: return self.node.args[2] else: return self.node.kwargs["bias"] class NormalizedMatmulNode: def __init__(self, node: torch.fx.Node) -> None: assert node.op == "call_function" assert node.target in [torch.bmm, torch.matmul] self.node: torch.fx.Node = node def get_input(self) -> torch.fx.Node: if len(self.node.args) > 0: return self.node.args[0] else: return self.node.kwargs["input"] def get_other(self) -> torch.fx.Node: if len(self.node.args) > 1: return self.node.args[1] else: return self.node.kwargs["other"] def check_permute(node: torch.fx.Node): ranks = len(node.meta["tensor_meta"].shape) if len(node.args) > 3: permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] elif ( "permutation" in node.kwargs and node.kwargs["permutation"] is not None and len(node.kwargs["permutation"]) > 2 ): permutation = [i % ranks for i in node.kwargs["permutation"]] else: return False allowed_permutation = list(range(ranks)) allowed_permutation[-1] = ranks - 2 allowed_permutation[-2] = ranks - 1 return permutation == allowed_permutation def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in module.graph.nodes: if ( node.op == "call_method" and node.target == "permute" and check_permute(node) ): if len(node.args) > 0: input_node = node.args[0] else: input_node = node.kwargs["input"] if ( input_node.op == "call_function" and input_node.target == torch.nn.functional.linear ): normalized = NormalizedLinearNode(input_node) input = normalized.get_input() weight = normalized.get_weight() bias = normalized.get_bias() with module.graph.inserting_before(node): fused_node = module.graph.call_function( linear_transpose, args=(input, weight, bias) ) node.replace_all_uses_with(fused_node) module.graph.lint() module.graph.eliminate_dead_code() module.recompile() return module # Y1 = X * W^T + bias # Y2 = Y1.permute(0, 2, 1) # ----> # Y2 = (W * X^T + bias.unsqueeze(-1))^T def linear_transpose( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in module.graph.nodes: if node.op == "call_function" and node.target == torch.nn.functional.linear: if len(node.args) > 0: input_node = node.args[0] else: input_node = node.kwargs["input"] if ( input_node.op == "call_method" and input_node.target == "permute" and check_permute(input_node) ): normalized = NormalizedLinearNode(node) if len(input_node.args) > 0: input = input_node.args[0] else: input = input_node.kwargs["input"] weight = normalized.get_weight() bias = normalized.get_bias() with module.graph.inserting_before(node): fused_node = module.graph.call_function( transpose_linear, args=(input, weight, bias) ) node.replace_all_uses_with(fused_node) module.graph.lint() module.graph.eliminate_dead_code() module.recompile() return module def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in module.graph.nodes: if node.op == "call_function" and ( node.target == torch.bmm or node.target == torch.matmul ): normalized = NormalizedMatmulNode(node) A = normalized.get_input() B = normalized.get_other() Atrans = Btrans = False if A.op == "call_method" and A.target == "permute" and check_permute(A): Atrans = True if len(A.args) > 0: A = A.args[0] else: A = A.kwargs["input"] if B.op == "call_method" and B.target == "permute" and check_permute(B): Btrans = True if len(B.args) > 0: B = B.args[0] else: B = B.kwargs["input"] if Atrans or Btrans: with module.graph.inserting_before(node): fused_node = module.graph.call_function( transpose_matmul, args=(A, B, Atrans, Btrans), ) node.replace_all_uses_with(fused_node) module.graph.lint() module.graph.eliminate_dead_code() module.recompile() return module # X1 = X.permute(0, 2, 1) # Y1 = X1 * W1^T + bias1 # ----> # Y2 = X1.transpose(-1, -2) * W1^T + bias1 def transpose_linear( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: return torch.matmul(input.transpose(-1, -2), weight.t()) + bias def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): if Atrans: A = A.transpose(-1, -2) if Btrans: B = B.transpose(-1, -2) return torch.matmul(A, B) def replace_and_fuse_for_binary( computation_node, node, fuse_func, attr, modules, index_node, index_pointwise ): computation_node_input_size = ( node.args[index_node].args[0].meta.get("tensor_meta").shape ) fused_module = fuse_func(computation_node, attr, computation_node_input_size) replace_node_module(node.args[index_node], modules, fused_module) node.args[index_node].args = node.args[index_node].args + ( node.args[index_pointwise], ) node.replace_all_uses_with(node.args[index_node]) def binary_inputs_meta_is_same(binary_node): tensor0_meta = binary_node.args[0].meta.get("tensor_meta") tensor1_meta = binary_node.args[1].meta.get("tensor_meta") if not tensor0_meta or not tensor1_meta: return False if ( tensor0_meta.shape != tensor1_meta.shape or tensor0_meta.stride != tensor1_meta.stride or tensor0_meta.dtype != tensor1_meta.dtype ): return False return True def fuse_binary(gm: torch.fx.GraphModule): modules = dict(gm.named_modules()) for node in gm.graph.nodes: if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node): for node_kind, fuse_func in computation_op_binary_op_fusion_map.items(): if not isinstance(node.args[0], torch.fx.Node) or not isinstance( node.args[1], torch.fx.Node ): continue if not binary_inputs_meta_is_same(node): continue attr = binary_attr[node.target] index_list = supported_index_list[attr] for index_dict in index_list: index_node = index_dict["index_computation"] index_pointwise = index_dict["index_pointwise"] if check_node_kind(node.args[index_node], modules, node_kind): if len(node.args[index_node].users) > 1: continue computation_node = modules[node.args[index_node].target] # TODO: support padding str input("valid", "same"). if type(computation_node) in [nn.Conv2d] and isinstance( computation_node.padding, str ): continue # only fuse for linear when the dtype is bf16 if type(computation_node) in [ nn.Linear ] and not is_bfloat16_module(computation_node): continue replace_and_fuse_for_binary( computation_node, node, fuse_func, attr if attr != "iadd" else "add", modules, index_node, index_pointwise, ) # Make sure the fused node is post node of node's inputs nodes. node.append(node.args[index_node]) gm.graph.erase_node(node) gm.graph.lint() break gm.recompile() return gm def fuse_binary_inplace(gm: torch.fx.GraphModule): modules = dict(gm.named_modules()) for node in gm.graph.nodes: if check_node_is_add_inplace(node) and check_binary_op_kwargs_is_default(node): for ( node_kind, fuse_func, ) in computation_op_binary_op_fusion_inplace_map.items(): if not isinstance(node.args[0], torch.fx.Node) or not isinstance( node.args[1], torch.fx.Node ): continue if not binary_inputs_meta_is_same(node): continue if check_node_kind(node.args[1], modules, node_kind): if len(node.args[1].users) > 1: continue # make sure the output and input are not same tensor. if node.args[1].args[0] == node.args[0]: continue computation_node = modules[node.args[1].target] # TODO: support padding str input("valid", "same"). if type(computation_node) in [nn.Conv2d] and isinstance( computation_node.padding, str ): continue replace_and_fuse_for_binary( computation_node, node, fuse_func, "add", modules, 1, # conv module index 0, # binary op index ) # Make sure the fused node is post node of node's inputs nodes. node.append(node.args[1]) gm.graph.erase_node(node) gm.graph.lint() break gm.recompile() return gm def pack_module(gm: torch.fx.GraphModule): modules = dict(gm.named_modules()) for node in gm.graph.nodes: if node.op == "call_module": assert isinstance(node.target, str) cur_module = modules[node.target] if type(cur_module) in computation_op_packed_map: computation_node_input_meta = node.args[0].meta.get("tensor_meta") if computation_node_input_meta.dtype != torch.float32: continue if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl: continue computation_node_input_size = computation_node_input_meta.shape if type(cur_module) in [nn.Conv2d] and isinstance( cur_module.padding, str ): continue new_module = computation_op_packed_map[type(cur_module)]( cur_module, computation_node_input_size ) assert isinstance(new_module, nn.Module) replace_node_module(node, modules, new_module) gm.graph.lint() gm.recompile() return gm philox_rand_like = _prims._make_prim( schema="philox_rand_like(Tensor input, Tensor seed, int offset) -> Tensor", return_type=_prims.RETURN_TYPE.NEW, meta=_philox_rand_like_meta, impl_aten=_philox_rand_like, doc="", ) def _philox_seed_like_meta(x): return _prims.TensorMeta(_philox_seed_like(x)) def _philox_seed_like(x): # we need a tensor input here so AOT autograd properly captures this # with just a device input, this becomes a constant return torch.tensor(random.randrange(2**31), device=x.device, dtype=torch.int32) philox_seed_like = _prims._make_prim( schema="philox_seed_like(Tensor other) -> Tensor", return_type=_prims.RETURN_TYPE.NEW, meta=_philox_seed_like_meta, impl_aten=_philox_seed_like, doc="", ) def null_ref(): return None class PhiloxRandomState: next_offset = 0 seed = {} last_tracer_ref = null_ref @classmethod def reset(cls, tracer=None): cls.next_offset = 0 cls.seed = {} cls.last_tracer_ref = weakref.ref(tracer) if tracer is not None else null_ref @classmethod def get_seed_offset(cls, x): modes = torch.fx.experimental.proxy_tensor.get_torch_dispatch_modes() proxy_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)] if proxy_modes: tracer = proxy_modes[0].tracer if cls.last_tracer_ref() is not tracer: # tracer changed, need to reset state cls.reset(tracer) else: # no tracer, need to reset state cls.reset() device = x.device if device not in cls.seed: # Compute the seed just once per trace so that we pass fewer # things from forward to backward cls.seed[device] = philox_seed_like(x) seed = cls.seed[device] offset = cls.next_offset cls.next_offset += x.numel() return seed, offset class LowmemDropout(torch.autograd.Function): @staticmethod def forward(ctx, x, p): ctx.p = p scale = float(1.0 / (1.0 - p)) seed, offset = PhiloxRandomState.get_seed_offset(x) ctx.save_for_backward(seed) ctx.offset = offset bool_mask = philox_rand_like(x, seed, offset) > p return bool_mask.to(x.dtype) * x * scale @staticmethod def backward(ctx, grad_output): p = ctx.p scale = float(1.0 / (1.0 - p)) (seed,) = ctx.saved_tensors bool_mask = philox_rand_like(grad_output, seed, ctx.offset) > p return bool_mask.to(grad_output.dtype) * grad_output * scale, None @torch.fx.wrap def lowmem_dropout(input, p=0.5, training=True, inplace=False): if isinstance(input, torch.fx.Proxy): # double check we don't FX trace this return input.tracer.create_proxy( "call_function", lowmem_dropout, (input, p, training), {}, ) if not training or p == 0: return input result = LowmemDropout.apply(input, p) if inplace: input.copy_(result) return result @torch.fx.wrap def rand_like(x, **kwargs): if isinstance(x, torch.fx.Proxy): # double check we don't FX trace this return x.tracer.create_proxy("call_function", rand_like, (x), kwargs) assert kwargs.get("device", x.device) == x.device seed, offset = PhiloxRandomState.get_seed_offset(x) return philox_rand_like(x, seed, offset).to(kwargs.get("dtype", torch.float32)) replacements = {torch.nn.functional.dropout: lowmem_dropout, torch.rand_like: rand_like} # Keep track of any replacement functions that use triton random, # so they can be avoided when fallback_random is set replacements_using_triton_random = {lowmem_dropout, rand_like} computation_op_unary_op_fusion_map = { nn.Conv2d: fused_conv_unary_eval, nn.Linear: fused_linear_unary_eval, ConvBinary2d: fused_conv_binary_unary_eval, ConvBinaryInplace2d: fused_conv_binary_unary_eval, } unary_modules_map = { nn.ReLU: UnaryAttr("relu"), nn.Sigmoid: UnaryAttr("sigmoid"), nn.Tanh: UnaryAttr("tanh"), nn.Hardswish: UnaryAttr("hardswish"), nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]), nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"), nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), nn.SiLU: UnaryAttr("swish"), } binary_attr = { torch.add: "add", # node.op == "call_function" "add": "add", # node.op == "call_method" "add_": "iadd", # node.op == "call_method" operator.add: "add", # node.op == "call_function" operator.iadd: "iadd", # node.op == "call_function" torch.sub: "sub", # node.op == "call_function" "sub": "sub", # node.op == "call_method" "sub_": "sub", # node.op == "call_method" operator.sub: "sub", # node.op == "call_function" operator.isub: "sub", # node.op == "call_function" } computation_op_binary_op_fusion_map = { nn.Conv2d: fused_conv_binary_eval, nn.Linear: fused_linear_binary_eval, } computation_op_binary_op_fusion_inplace_map = { nn.Conv2d: fused_conv_binary_inplace_eval, } computation_op_packed_map = { nn.Linear: packed_linear_eval, nn.Conv2d: packed_conv_eval, } # For add: we support conv/linear + other and other + conv # For sub/add_/sub_, we only support conv/linear - other # or conv/linear +(-)= other supported_index_list = { "add": [ {"index_computation": 0, "index_pointwise": 1}, {"index_computation": 1, "index_pointwise": 0}, ], "iadd": [{"index_computation": 0, "index_pointwise": 1}], "sub": [{"index_computation": 0, "index_pointwise": 1}], }