mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Dont decompose aten.baddmm in inductor (#137904)"
This reverts commit 7a117f3b3e.
Reverted https://github.com/pytorch/pytorch/pull/137904 on behalf of https://github.com/clee2000 due to unfortunately the failures on the previous import are still present on the current one D64568703 ([comment](https://github.com/pytorch/pytorch/pull/137904#issuecomment-2422789143))
This commit is contained in:
parent
5a81475884
commit
af306a392c
|
|
@ -581,32 +581,6 @@ class TestMaxAutotune(TestCase):
|
|||
def test_empty_conv_input_with_1x1_kernel(self):
|
||||
self.test_empty_conv_input(kernel_size=1)
|
||||
|
||||
@config.patch(max_autotune_gemm_backends="TRITON")
|
||||
def test_baddmm(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.randn(64, 64, 192, dtype=torch.float16)
|
||||
)
|
||||
self.bias = torch.nn.Parameter(
|
||||
torch.randn(64, 1, 192, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.baddbmm.default(self.bias, x, self.weight)
|
||||
|
||||
x = torch.randn(
|
||||
64, 2048, 64, dtype=torch.float16, requires_grad=False, device="cuda"
|
||||
)
|
||||
mod = M().cuda()
|
||||
|
||||
m_c = torch.compile(mode="max-autotune")(mod)
|
||||
out, code = run_and_get_code(m_c, x)
|
||||
self.assertEqual(out, mod(x))
|
||||
|
||||
FileCheck().check("triton_tem_fused_baddbmm").run(code[0])
|
||||
|
||||
@config.patch(max_autotune=True)
|
||||
def test_conv1x1_with_free_symbols(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -106,7 +106,6 @@ decomps_to_exclude = [
|
|||
aten.squeeze, # inductor lowers this directly
|
||||
aten.sum, # inductor lowers this directly
|
||||
aten.unbind, # inductor lowers this directly
|
||||
aten.baddbmm, # upcasts to fp32, perf issue
|
||||
]
|
||||
|
||||
remove_decompositions(decompositions, decomps_to_exclude)
|
||||
|
|
|
|||
|
|
@ -185,7 +185,8 @@ def tuned_bmm(mat1, mat2, *, layout=None):
|
|||
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
|
||||
|
||||
|
||||
@L.register_lowering(aten.baddbmm)
|
||||
# Don't register this since it is slower than decomposing it
|
||||
# @L.register_lowering(aten.baddbmm)
|
||||
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
||||
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user