mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ROCm] enable cudagraph inductor UTs on ROCm (#105662)
These tests can now be enabled after a hipGraph fix landed in 5.6. Pull Request resolved: https://github.com/pytorch/pytorch/pull/105662 Approved by: https://github.com/jithunnair-amd, https://github.com/malfet
This commit is contained in:
parent
506b55fc29
commit
60e65a70e5
|
|
@ -10,11 +10,7 @@ import torch._dynamo.config
|
|||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
from torch.testing._internal.common_utils import (
|
||||
skipIfRocm,
|
||||
TEST_CUDA_GRAPH,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfRocm, TEST_CUDA_GRAPH
|
||||
|
||||
|
||||
def composed(*decs):
|
||||
|
|
@ -49,7 +45,6 @@ def assert_aot_autograd_counter(ok=True):
|
|||
|
||||
def patch_all(ok=True):
|
||||
return composed(
|
||||
unittest.skipIf(TEST_WITH_ROCM, "ROCm not supported"),
|
||||
torch._dynamo.config.patch(
|
||||
verify_correctness=True, automatic_dynamic_shapes=True
|
||||
),
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
|||
from torch.testing._internal.common_utils import (
|
||||
DeterministicGuard,
|
||||
IS_FBCODE,
|
||||
skipIfRocm,
|
||||
TEST_WITH_ASAN,
|
||||
)
|
||||
|
||||
|
|
@ -144,7 +143,6 @@ class CudaReproTests(TestCase):
|
|||
compiled = compile_fx_inner(mod, ())
|
||||
assert compiled([])[0].device.type == "cuda"
|
||||
|
||||
@skipIfRocm
|
||||
@config.patch({"triton.cudagraphs": True})
|
||||
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
||||
def test_no_device_idx_repro_cudagraphs(self):
|
||||
|
|
@ -173,7 +171,6 @@ class CudaReproTests(TestCase):
|
|||
|
||||
self.common(Repro(), ())
|
||||
|
||||
@skipIfRocm
|
||||
@config.patch({"triton.cudagraphs": True})
|
||||
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
||||
def test_expanded_inputs_cudagraphs(self):
|
||||
|
|
@ -187,7 +184,6 @@ class CudaReproTests(TestCase):
|
|||
)
|
||||
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
|
||||
|
||||
@skipIfRocm
|
||||
@config.patch({"triton.cudagraphs": True})
|
||||
@dynamo_config.patch(
|
||||
automatic_dynamic_shapes=True,
|
||||
|
|
@ -236,7 +232,6 @@ class CudaReproTests(TestCase):
|
|||
self.assertEqual(real_out, compiled_out)
|
||||
torch._dynamo.reset()
|
||||
|
||||
@skipIfRocm
|
||||
@config.patch({"triton.cudagraphs": True, "size_asserts": False})
|
||||
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
||||
def test_expanded_inputs_cudagraphs_no_size_asserts(self):
|
||||
|
|
@ -250,8 +245,6 @@ class CudaReproTests(TestCase):
|
|||
)
|
||||
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
|
||||
|
||||
# TODO: enable
|
||||
@skipIfRocm
|
||||
@config.patch({"triton.cudagraph_trees": False})
|
||||
@config.patch({"triton.cudagraphs": True})
|
||||
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ from torch.testing._internal.common_utils import (
|
|||
IS_CI,
|
||||
IS_LINUX,
|
||||
IS_WINDOWS,
|
||||
skipIfRocm,
|
||||
TEST_CUDA_GRAPH,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
|
@ -601,6 +601,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
self.assertFalse(first_node.unaliased_in_all_paths[0])
|
||||
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
|
||||
|
||||
@skipIfRocm
|
||||
def test_checkpointing_resets_persistent_refs(self):
|
||||
@torch.compile(mode="reduce-overhead")
|
||||
def foo(x):
|
||||
|
|
@ -840,6 +841,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
del x
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
|
||||
@torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True)
|
||||
def test_workspace_allocation_error(self):
|
||||
|
|
@ -1124,6 +1126,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."):
|
||||
out2 + out2
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
|
||||
def test_conv_benchmark(self):
|
||||
with torch.backends.cudnn.flags(
|
||||
|
|
@ -1282,5 +1285,5 @@ if __name__ == "__main__":
|
|||
sys.exit(0)
|
||||
raise unittest.SkipTest("cuda graph test is skipped")
|
||||
|
||||
if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
|
||||
if HAS_CPU or HAS_CUDA:
|
||||
run_tests(needs="filelock")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
import torch
|
||||
import torch._logging
|
||||
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, TestCase
|
||||
from torch.testing._internal.common_utils import IS_LINUX, TestCase
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
||||
|
|
@ -39,7 +39,6 @@ class SmokeTest(TestCase):
|
|||
# set back to defaults
|
||||
torch._logging.set_logs()
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_CUDA, "Triton is not available")
|
||||
def test_compile_decorator(self):
|
||||
@torch.compile
|
||||
|
|
|
|||
|
|
@ -4110,7 +4110,7 @@ class CommonTemplate:
|
|||
inputs = (rand_strided((8,), (1,), device=self.device),)
|
||||
self.assertTrue(same(fn(*inputs), 2 * inputs[0]))
|
||||
|
||||
@config.patch({"triton.cudagraphs": True if not torch.version.hip else False})
|
||||
@config.patch({"triton.cudagraphs": True})
|
||||
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
||||
def test_strided_inputs(self):
|
||||
@torch._dynamo.optimize("inductor")
|
||||
|
|
@ -4123,7 +4123,7 @@ class CommonTemplate:
|
|||
)
|
||||
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
|
||||
|
||||
@config.patch({"triton.cudagraphs": True if not torch.version.hip else False})
|
||||
@config.patch({"triton.cudagraphs": True})
|
||||
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
||||
def test_input_mutation1(self):
|
||||
def fn(a):
|
||||
|
|
@ -4962,7 +4962,7 @@ class CommonTemplate:
|
|||
def fn(a):
|
||||
return torch.nn.functional.dropout(a, 0.55, True)
|
||||
|
||||
for cg in [False, True] if not torch.version.hip else [False]:
|
||||
for cg in [False, True]:
|
||||
with patch.object(config.triton, "cudagraphs", cg):
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
|
@ -5963,9 +5963,7 @@ class CommonTemplate:
|
|||
else:
|
||||
contexts = [
|
||||
contextlib.nullcontext,
|
||||
lambda: config.patch(
|
||||
{"triton.cudagraphs": True if not torch.version.hip else False}
|
||||
),
|
||||
lambda: config.patch({"triton.cudagraphs": True}),
|
||||
]
|
||||
|
||||
for context in contexts:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user