mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] fix linear_add_bias path (#127597)
Previous the `linear_add_bias` path do not work. This PR is to fix it and add more ut with it. **TestPlan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_add_bias ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127597 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
b42cfcabc4
commit
dbf39a6e63
|
|
@ -396,6 +396,39 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
matcher_nodes = 1
|
||||
self._test_common(mod, (v,), matcher_count, matcher_nodes)
|
||||
|
||||
def test_linear_add_bias(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, dtype, unary_fn):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(10, 64, bias=False)
|
||||
self.bias = torch.randn(64).to(dtype=dtype)
|
||||
self.unary_fn = unary_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x) + self.bias
|
||||
return self.unary_fn(x)
|
||||
|
||||
dtypes = []
|
||||
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||
dtypes.append(torch.bfloat16)
|
||||
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
||||
dtypes.append(torch.float16)
|
||||
options = itertools.product(unary_list, dtypes)
|
||||
for unary_fn, dtype in options:
|
||||
metrics.reset()
|
||||
mod = M(dtype, unary_fn).eval()
|
||||
v = torch.randn(2, 10)
|
||||
matcher_count = 3
|
||||
# Add 1 for weight packing pass, add 2 for bias folding pass.
|
||||
matcher_nodes = unary_list[unary_fn] + 3
|
||||
if self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
matcher_nodes += 2
|
||||
self._test_common(
|
||||
mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype
|
||||
)
|
||||
self.assertEqual(metrics.generated_kernel_count, 1)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfRocm
|
||||
|
|
|
|||
|
|
@ -788,14 +788,22 @@ if torch._C._has_mkldnn:
|
|||
def is_linear_add_bias(match):
|
||||
add_node = match.output_node()
|
||||
linear_node = add_node.args[0]
|
||||
weight_meta = linear_node.args[1].meta.get("val")
|
||||
packed_weight_node = linear_node.args[1]
|
||||
assert packed_weight_node.name == "_reorder_linear_weight"
|
||||
transpose_weight_node = packed_weight_node.args[0]
|
||||
assert transpose_weight_node.name == "permute_default"
|
||||
weight_meta = transpose_weight_node.args[0].meta.get("val")
|
||||
bias_node = add_node.args[1]
|
||||
if isinstance(bias_node, int):
|
||||
# we only folding bias if it is a constant
|
||||
return False
|
||||
bias_meta = add_node.args[1].meta.get("val")
|
||||
if weight_meta is None or bias_meta is None:
|
||||
return False
|
||||
return (
|
||||
linear_node.args[2] is None
|
||||
and bias_meta.dim() == 1
|
||||
and bias_meta.size(0) == weight_meta.size(0)
|
||||
and bias_meta.size(0) == weight_meta.size(1)
|
||||
)
|
||||
|
||||
# convert linear+bias to a single linear for applying fusion path.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user