[ROCm] enable cudagraph inductor UTs on ROCm (#105662)

These tests can now be enabled after a hipGraph fix landed in 5.6.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105662
Approved by: https://github.com/jithunnair-amd, https://github.com/malfet
This commit is contained in:
Jack Taylor 2023-08-01 20:55:27 +00:00 committed by PyTorch MergeBot
parent 506b55fc29
commit 60e65a70e5
5 changed files with 11 additions and 23 deletions

View File

@ -10,11 +10,7 @@ import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
from torch.testing._internal.common_utils import (
skipIfRocm,
TEST_CUDA_GRAPH,
TEST_WITH_ROCM,
)
from torch.testing._internal.common_utils import skipIfRocm, TEST_CUDA_GRAPH
def composed(*decs):
@ -49,7 +45,6 @@ def assert_aot_autograd_counter(ok=True):
def patch_all(ok=True):
return composed(
unittest.skipIf(TEST_WITH_ROCM, "ROCm not supported"),
torch._dynamo.config.patch(
verify_correctness=True, automatic_dynamic_shapes=True
),

View File

@ -15,7 +15,6 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
DeterministicGuard,
IS_FBCODE,
skipIfRocm,
TEST_WITH_ASAN,
)
@ -144,7 +143,6 @@ class CudaReproTests(TestCase):
compiled = compile_fx_inner(mod, ())
assert compiled([])[0].device.type == "cuda"
@skipIfRocm
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_no_device_idx_repro_cudagraphs(self):
@ -173,7 +171,6 @@ class CudaReproTests(TestCase):
self.common(Repro(), ())
@skipIfRocm
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_expanded_inputs_cudagraphs(self):
@ -187,7 +184,6 @@ class CudaReproTests(TestCase):
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
@skipIfRocm
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(
automatic_dynamic_shapes=True,
@ -236,7 +232,6 @@ class CudaReproTests(TestCase):
self.assertEqual(real_out, compiled_out)
torch._dynamo.reset()
@skipIfRocm
@config.patch({"triton.cudagraphs": True, "size_asserts": False})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_expanded_inputs_cudagraphs_no_size_asserts(self):
@ -250,8 +245,6 @@ class CudaReproTests(TestCase):
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
# TODO: enable
@skipIfRocm
@config.patch({"triton.cudagraph_trees": False})
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)

View File

@ -21,9 +21,9 @@ from torch.testing._internal.common_utils import (
IS_CI,
IS_LINUX,
IS_WINDOWS,
skipIfRocm,
TEST_CUDA_GRAPH,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase as TorchTestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
@ -601,6 +601,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertFalse(first_node.unaliased_in_all_paths[0])
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
@skipIfRocm
def test_checkpointing_resets_persistent_refs(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
@ -840,6 +841,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
del x
self.assertEqual(all_live_block_count(), 0)
@skipIfRocm
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
@torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True)
def test_workspace_allocation_error(self):
@ -1124,6 +1126,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
out2 + out2
@skipIfRocm
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
def test_conv_benchmark(self):
with torch.backends.cudnn.flags(
@ -1282,5 +1285,5 @@ if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("cuda graph test is skipped")
if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
if HAS_CPU or HAS_CUDA:
run_tests(needs="filelock")

View File

@ -5,7 +5,7 @@ import unittest
import torch
import torch._logging
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, TestCase
from torch.testing._internal.common_utils import IS_LINUX, TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
@ -39,7 +39,6 @@ class SmokeTest(TestCase):
# set back to defaults
torch._logging.set_logs()
@skipIfRocm
@unittest.skipIf(not HAS_CUDA, "Triton is not available")
def test_compile_decorator(self):
@torch.compile

View File

@ -4110,7 +4110,7 @@ class CommonTemplate:
inputs = (rand_strided((8,), (1,), device=self.device),)
self.assertTrue(same(fn(*inputs), 2 * inputs[0]))
@config.patch({"triton.cudagraphs": True if not torch.version.hip else False})
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_strided_inputs(self):
@torch._dynamo.optimize("inductor")
@ -4123,7 +4123,7 @@ class CommonTemplate:
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
@config.patch({"triton.cudagraphs": True if not torch.version.hip else False})
@config.patch({"triton.cudagraphs": True})
@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_input_mutation1(self):
def fn(a):
@ -4962,7 +4962,7 @@ class CommonTemplate:
def fn(a):
return torch.nn.functional.dropout(a, 0.55, True)
for cg in [False, True] if not torch.version.hip else [False]:
for cg in [False, True]:
with patch.object(config.triton, "cudagraphs", cg):
torch._dynamo.reset()
@ -5963,9 +5963,7 @@ class CommonTemplate:
else:
contexts = [
contextlib.nullcontext,
lambda: config.patch(
{"triton.cudagraphs": True if not torch.version.hip else False}
),
lambda: config.patch({"triton.cudagraphs": True}),
]
for context in contexts: