diff --git a/test/run_test.py b/test/run_test.py index c011e386242..19c5220d63d 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1855,7 +1855,9 @@ def run_tests( ): raise RuntimeError(failure.message + keep_going_message) - os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) + # This is used later to constrain memory per proc on the GPU. On ROCm + # the number of procs is the number of GPUs, so we don't need to do this + os.environ["NUM_PARALLEL_PROCS"] = str(1 if torch.version.hip else NUM_PROCS) # See Note [ROCm parallel CI testing] pool = get_context("spawn").Pool(