diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 94fe34c64e5..8932fcfc4af 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -396,6 +396,39 @@ class TestPatternMatcher(TestPatternMatcherBase): matcher_nodes = 1 self._test_common(mod, (v,), matcher_count, matcher_nodes) + def test_linear_add_bias(self): + class M(torch.nn.Module): + def __init__(self, dtype, unary_fn): + super().__init__() + self.linear = torch.nn.Linear(10, 64, bias=False) + self.bias = torch.randn(64).to(dtype=dtype) + self.unary_fn = unary_fn + + def forward(self, x): + x = self.linear(x) + self.bias + return self.unary_fn(x) + + dtypes = [] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) + options = itertools.product(unary_list, dtypes) + for unary_fn, dtype in options: + metrics.reset() + mod = M(dtype, unary_fn).eval() + v = torch.randn(2, 10) + matcher_count = 3 + # Add 1 for weight packing pass, add 2 for bias folding pass. + matcher_nodes = unary_list[unary_fn] + 3 + if self._check_unary_is_decomposed(unary_fn): + # Has extra dtype conversion nodes for autocast. + matcher_nodes += 2 + self._test_common( + mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype + ) + self.assertEqual(metrics.generated_kernel_count, 1) + @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 5d1a723fa58..be73a09ca64 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -788,14 +788,22 @@ if torch._C._has_mkldnn: def is_linear_add_bias(match): add_node = match.output_node() linear_node = add_node.args[0] - weight_meta = linear_node.args[1].meta.get("val") + packed_weight_node = linear_node.args[1] + assert packed_weight_node.name == "_reorder_linear_weight" + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.name == "permute_default" + weight_meta = transpose_weight_node.args[0].meta.get("val") + bias_node = add_node.args[1] + if isinstance(bias_node, int): + # we only folding bias if it is a constant + return False bias_meta = add_node.args[1].meta.get("val") if weight_meta is None or bias_meta is None: return False return ( linear_node.args[2] is None and bias_meta.dim() == 1 - and bias_meta.size(0) == weight_meta.size(0) + and bias_meta.size(0) == weight_meta.size(1) ) # convert linear+bias to a single linear for applying fusion path.