[CUDA][CUDA Graphs] Move cuda graphs test to subprocess to avoid polluting mempool tests (#159305)

Otherwise mempool test will fail as the previous graph capture failed but doesn't have its state in the caching allocator fully cleaned up. See also #159301

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159305
Approved by: https://github.com/eellison, https://github.com/BoyuanFeng, https://github.com/naromero77amd
This commit is contained in:
Eddie Yan 2025-07-30 23:31:38 +00:00 committed by PyTorch MergeBot
parent de7376537f
commit 25c3a7e317

View File

@ -3702,22 +3702,34 @@ exit(2)
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_cuda_graph_tensor_item_not_allowed(self):
# Tesnor.item() calls a synchronize which is not allowed in a cudagraph
# Valid for CUDA and ROCm
def my_func(a: torch.Tensor, b: torch.Tensor, perm: torch.Tensor):
idx = perm[0]
a[0] *= b[idx] # should raise an error during capture
return a
test_script = """\
import torch
import sys
# Tensor.item() calls a synchronize which is not allowed in a cudagraph
# Valid for CUDA and ROCm
def my_func(a: torch.Tensor, b: torch.Tensor, perm: torch.Tensor):
idx = perm[0]
a[0] *= b[idx] # should raise an error during capture
return a
a = torch.rand(500, 500, device="cuda")
b = torch.rand(500, 500, device="cuda")
perm = torch.randint(0, 500, (500,), device="cuda")
a = torch.rand(500, 500, device="cuda")
b = torch.rand(500, 500, device="cuda")
perm = torch.randint(0, 500, (500,), device="cuda")
g = torch.cuda.CUDAGraph()
g = torch.cuda.CUDAGraph()
with self.assertRaises(RuntimeError):
with torch.cuda.graph(g):
output = my_func(a, b, perm)
with torch.cuda.graph(g):
output = my_func(a, b, perm)
"""
with self.assertRaisesRegex(
subprocess.CalledProcessError,
"calls a synchronize which is not allowed in a cudagraph",
):
r = (
subprocess.check_output([sys.executable, "-c", test_script])
.decode("ascii")
.strip()
)
def test_batch_norm_gather_stats(self):
input = torch.randn(1, 3, 3, 3, device="cuda")