[inductor] fix linear_add_bias path (#127597)

Previous the `linear_add_bias` path do not work.
This PR is to fix it and add more ut with it.

**TestPlan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_add_bias
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127597
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
haozhe.zhu 2024-06-02 19:25:42 +08:00 committed by PyTorch MergeBot
parent b42cfcabc4
commit dbf39a6e63
2 changed files with 43 additions and 2 deletions

View File

@ -396,6 +396,39 @@ class TestPatternMatcher(TestPatternMatcherBase):
matcher_nodes = 1 matcher_nodes = 1
self._test_common(mod, (v,), matcher_count, matcher_nodes) self._test_common(mod, (v,), matcher_count, matcher_nodes)
def test_linear_add_bias(self):
class M(torch.nn.Module):
def __init__(self, dtype, unary_fn):
super().__init__()
self.linear = torch.nn.Linear(10, 64, bias=False)
self.bias = torch.randn(64).to(dtype=dtype)
self.unary_fn = unary_fn
def forward(self, x):
x = self.linear(x) + self.bias
return self.unary_fn(x)
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
options = itertools.product(unary_list, dtypes)
for unary_fn, dtype in options:
metrics.reset()
mod = M(dtype, unary_fn).eval()
v = torch.randn(2, 10)
matcher_count = 3
# Add 1 for weight packing pass, add 2 for bias folding pass.
matcher_nodes = unary_list[unary_fn] + 3
if self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
matcher_nodes += 2
self._test_common(
mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype
)
self.assertEqual(metrics.generated_kernel_count, 1)
@skipIfNoDynamoSupport @skipIfNoDynamoSupport
@skipIfNoONEDNN @skipIfNoONEDNN
@skipIfRocm @skipIfRocm

View File

@ -788,14 +788,22 @@ if torch._C._has_mkldnn:
def is_linear_add_bias(match): def is_linear_add_bias(match):
add_node = match.output_node() add_node = match.output_node()
linear_node = add_node.args[0] linear_node = add_node.args[0]
weight_meta = linear_node.args[1].meta.get("val") packed_weight_node = linear_node.args[1]
assert packed_weight_node.name == "_reorder_linear_weight"
transpose_weight_node = packed_weight_node.args[0]
assert transpose_weight_node.name == "permute_default"
weight_meta = transpose_weight_node.args[0].meta.get("val")
bias_node = add_node.args[1]
if isinstance(bias_node, int):
# we only folding bias if it is a constant
return False
bias_meta = add_node.args[1].meta.get("val") bias_meta = add_node.args[1].meta.get("val")
if weight_meta is None or bias_meta is None: if weight_meta is None or bias_meta is None:
return False return False
return ( return (
linear_node.args[2] is None linear_node.args[2] is None
and bias_meta.dim() == 1 and bias_meta.dim() == 1
and bias_meta.size(0) == weight_meta.size(0) and bias_meta.size(0) == weight_meta.size(1)
) )
# convert linear+bias to a single linear for applying fusion path. # convert linear+bias to a single linear for applying fusion path.