mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[cutlass backend] fix assertion that prevent self multiplication (#148233)"
This reverts commit 4aeca28137.
Reverted https://github.com/pytorch/pytorch/pull/148233 on behalf of https://github.com/henrylhtsang due to mistake in PR ([comment](https://github.com/pytorch/pytorch/pull/148233#issuecomment-2704534995))
This commit is contained in:
parent
3cde4c3069
commit
28b68b46bc
|
|
@ -1139,26 +1139,6 @@ class TestCutlassBackend(TestCase):
|
|||
num_ops = int(match.group(1))
|
||||
self.assertTrue(num_ops > 0, "The number of ops should be greater than 0")
|
||||
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_cutlass_backend_matmul_same_tensor(self):
|
||||
max_autotune_gemm_backends = "CUTLASS"
|
||||
|
||||
M = 128
|
||||
A = torch.randn(M, M).cuda().half()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": max_autotune_gemm_backends,
|
||||
"cuda.cutlass_max_profiling_configs": 2,
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(torch.mm)
|
||||
|
||||
torch.testing.assert_close(A @ A.t(), compiled(A, A.t()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
|
|
|||
|
|
@ -86,20 +86,8 @@ class CUDAKernel(Kernel):
|
|||
matches = [
|
||||
arg for arg in self.layout_args.values() if arg.matches(node, attr, dim)
|
||||
]
|
||||
if len(matches) > 1:
|
||||
# Verify all matches have the same node, attribute, and dimension
|
||||
# And if they come from the same node, whichever symbol we use is fine.
|
||||
# if in runtime the logic changes, this would trigger guard
|
||||
first_match = matches[0]
|
||||
if not all(
|
||||
match.node == first_match.node
|
||||
and match.attr == first_match.attr
|
||||
and match.dim == first_match.dim
|
||||
for match in matches
|
||||
):
|
||||
raise AssertionError("All matching layout args should be identical")
|
||||
return first_match
|
||||
return None
|
||||
assert len(matches) <= 1, matches
|
||||
return None if len(matches) == 0 else matches[0]
|
||||
|
||||
def add_layout_arg(
|
||||
self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user