Revert "[cutlass backend] fix assertion that prevent self multiplication (#148233)"

This reverts commit 4aeca28137.

Reverted https://github.com/pytorch/pytorch/pull/148233 on behalf of https://github.com/henrylhtsang due to mistake in PR  ([comment](https://github.com/pytorch/pytorch/pull/148233#issuecomment-2704534995))
This commit is contained in:
PyTorch MergeBot 2025-03-06 17:45:49 +00:00
parent 3cde4c3069
commit 28b68b46bc
2 changed files with 2 additions and 34 deletions

View File

@ -1139,26 +1139,6 @@ class TestCutlassBackend(TestCase):
num_ops = int(match.group(1))
self.assertTrue(num_ops > 0, "The number of ops should be greater than 0")
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_matmul_same_tensor(self):
max_autotune_gemm_backends = "CUTLASS"
M = 128
A = torch.randn(M, M).cuda().half()
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"autotune_fallback_to_aten": False,
}
):
compiled = torch.compile(torch.mm)
torch.testing.assert_close(A @ A.t(), compiled(A, A.t()))
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

View File

@ -86,20 +86,8 @@ class CUDAKernel(Kernel):
matches = [
arg for arg in self.layout_args.values() if arg.matches(node, attr, dim)
]
if len(matches) > 1:
# Verify all matches have the same node, attribute, and dimension
# And if they come from the same node, whichever symbol we use is fine.
# if in runtime the logic changes, this would trigger guard
first_match = matches[0]
if not all(
match.node == first_match.node
and match.attr == first_match.attr
and match.dim == first_match.dim
for match in matches
):
raise AssertionError("All matching layout args should be identical")
return first_match
return None
assert len(matches) <= 1, matches
return None if len(matches) == 0 else matches[0]
def add_layout_arg(
self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int