mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDA][CUDA Graphs] Move cuda graphs test to subprocess to avoid polluting mempool tests (#159305)
Otherwise mempool test will fail as the previous graph capture failed but doesn't have its state in the caching allocator fully cleaned up. See also #159301 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159305 Approved by: https://github.com/eellison, https://github.com/BoyuanFeng, https://github.com/naromero77amd
This commit is contained in:
parent
de7376537f
commit
25c3a7e317
|
|
@ -3702,22 +3702,34 @@ exit(2)
|
|||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
def test_cuda_graph_tensor_item_not_allowed(self):
|
||||
# Tesnor.item() calls a synchronize which is not allowed in a cudagraph
|
||||
# Valid for CUDA and ROCm
|
||||
def my_func(a: torch.Tensor, b: torch.Tensor, perm: torch.Tensor):
|
||||
idx = perm[0]
|
||||
a[0] *= b[idx] # should raise an error during capture
|
||||
return a
|
||||
test_script = """\
|
||||
import torch
|
||||
import sys
|
||||
# Tensor.item() calls a synchronize which is not allowed in a cudagraph
|
||||
# Valid for CUDA and ROCm
|
||||
def my_func(a: torch.Tensor, b: torch.Tensor, perm: torch.Tensor):
|
||||
idx = perm[0]
|
||||
a[0] *= b[idx] # should raise an error during capture
|
||||
return a
|
||||
|
||||
a = torch.rand(500, 500, device="cuda")
|
||||
b = torch.rand(500, 500, device="cuda")
|
||||
perm = torch.randint(0, 500, (500,), device="cuda")
|
||||
a = torch.rand(500, 500, device="cuda")
|
||||
b = torch.rand(500, 500, device="cuda")
|
||||
perm = torch.randint(0, 500, (500,), device="cuda")
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
g = torch.cuda.CUDAGraph()
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
with torch.cuda.graph(g):
|
||||
output = my_func(a, b, perm)
|
||||
with torch.cuda.graph(g):
|
||||
output = my_func(a, b, perm)
|
||||
"""
|
||||
with self.assertRaisesRegex(
|
||||
subprocess.CalledProcessError,
|
||||
"calls a synchronize which is not allowed in a cudagraph",
|
||||
):
|
||||
r = (
|
||||
subprocess.check_output([sys.executable, "-c", test_script])
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
|
||||
def test_batch_norm_gather_stats(self):
|
||||
input = torch.randn(1, 3, 3, 3, device="cuda")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user