Remove test since it ooms on CI (#161644)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161644
Approved by: https://github.com/BoyuanFeng
This commit is contained in:
drisspg 2025-08-27 18:56:28 +00:00 committed by PyTorch MergeBot
parent 47ecd2042f
commit 443452ca2f

View File

@ -48,7 +48,6 @@ from torch.testing._internal.common_device_type import (
skipCPUIf,
skipCUDAIf,
)
from torch.testing._internal.common_utils import IS_FBCODE
from torch.utils._triton import has_triton, has_triton_tma_device
@ -4340,41 +4339,6 @@ class GraphModule(torch.nn.Module):
fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag
fa._WARNINGS_SHOWN = original_warnings_shown
@largeTensorTest("38GB", "cuda") # emperically
@skip_on_cpu
@unittest.skipIf(IS_FBCODE, "Skip large tensor test in fbcode")
def test_int64_indexing_large_stride(self, device):
B = 1
H = 64
S = 2**20
D = 64
dtype = torch.float16
def _simple_causal(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
BLOCK_M = 1024
BLOCK_N = 1024
block_mask = torch.compile(create_block_mask)(
_simple_causal, B, H, S, S, device=device, BLOCK_SIZE=(BLOCK_M, BLOCK_N)
)
q = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
# Test forward and backward pass
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
loss = out.sum()
loss.backward()
# Basic correctness checks, doing full comapre consumes too much memory :/
self.assertEqual(out.shape, (B, H, S, D))
self.assertTrue(q.grad is not None)
self.assertTrue(k.grad is not None)
self.assertTrue(v.grad is not None)
class TestBlockMask(InductorTestCase):
def setUp(self):