mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove test since it ooms on CI (#161644)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161644 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
parent
47ecd2042f
commit
443452ca2f
|
|
@ -48,7 +48,6 @@ from torch.testing._internal.common_device_type import (
|
|||
skipCPUIf,
|
||||
skipCUDAIf,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_FBCODE
|
||||
from torch.utils._triton import has_triton, has_triton_tma_device
|
||||
|
||||
|
||||
|
|
@ -4340,41 +4339,6 @@ class GraphModule(torch.nn.Module):
|
|||
fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag
|
||||
fa._WARNINGS_SHOWN = original_warnings_shown
|
||||
|
||||
@largeTensorTest("38GB", "cuda") # emperically
|
||||
@skip_on_cpu
|
||||
@unittest.skipIf(IS_FBCODE, "Skip large tensor test in fbcode")
|
||||
def test_int64_indexing_large_stride(self, device):
|
||||
B = 1
|
||||
H = 64
|
||||
S = 2**20
|
||||
D = 64
|
||||
dtype = torch.float16
|
||||
|
||||
def _simple_causal(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
BLOCK_M = 1024
|
||||
BLOCK_N = 1024
|
||||
|
||||
block_mask = torch.compile(create_block_mask)(
|
||||
_simple_causal, B, H, S, S, device=device, BLOCK_SIZE=(BLOCK_M, BLOCK_N)
|
||||
)
|
||||
|
||||
q = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
# Test forward and backward pass
|
||||
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
# Basic correctness checks, doing full comapre consumes too much memory :/
|
||||
self.assertEqual(out.shape, (B, H, S, D))
|
||||
self.assertTrue(q.grad is not None)
|
||||
self.assertTrue(k.grad is not None)
|
||||
self.assertTrue(v.grad is not None)
|
||||
|
||||
|
||||
class TestBlockMask(InductorTestCase):
|
||||
def setUp(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user