Revert "Dont decompose aten.baddmm in inductor (#137904)"

This reverts commit 7a117f3b3e.

Reverted https://github.com/pytorch/pytorch/pull/137904 on behalf of https://github.com/clee2000 due to unfortunately the failures on the previous import are still present on the current one D64568703 ([comment](https://github.com/pytorch/pytorch/pull/137904#issuecomment-2422789143))
This commit is contained in:
PyTorch MergeBot 2024-10-18 16:01:01 +00:00
parent 5a81475884
commit af306a392c
3 changed files with 2 additions and 28 deletions

View File

@ -581,32 +581,6 @@ class TestMaxAutotune(TestCase):
def test_empty_conv_input_with_1x1_kernel(self):
self.test_empty_conv_input(kernel_size=1)
@config.patch(max_autotune_gemm_backends="TRITON")
def test_baddmm(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(64, 64, 192, dtype=torch.float16)
)
self.bias = torch.nn.Parameter(
torch.randn(64, 1, 192, dtype=torch.float16)
)
def forward(self, x):
return torch.ops.aten.baddbmm.default(self.bias, x, self.weight)
x = torch.randn(
64, 2048, 64, dtype=torch.float16, requires_grad=False, device="cuda"
)
mod = M().cuda()
m_c = torch.compile(mode="max-autotune")(mod)
out, code = run_and_get_code(m_c, x)
self.assertEqual(out, mod(x))
FileCheck().check("triton_tem_fused_baddbmm").run(code[0])
@config.patch(max_autotune=True)
def test_conv1x1_with_free_symbols(self):
"""

View File

@ -106,7 +106,6 @@ decomps_to_exclude = [
aten.squeeze, # inductor lowers this directly
aten.sum, # inductor lowers this directly
aten.unbind, # inductor lowers this directly
aten.baddbmm, # upcasts to fp32, perf issue
]
remove_decompositions(decompositions, decomps_to_exclude)

View File

@ -185,7 +185,8 @@ def tuned_bmm(mat1, mat2, *, layout=None):
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
@L.register_lowering(aten.baddbmm)
# Don't register this since it is slower than decomposing it
# @L.register_lowering(aten.baddbmm)
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)