diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index ee3a374c745..117f841afa5 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -12,7 +12,7 @@ from torch.distributed.nn import RemoteModule from torch.distributed.nn.api.remote_module import _REMOTE_MODULE_PICKLED_ATTRIBUTES from torch.distributed.nn.api.remote_module import _RemoteModule from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import TemporaryFileName +from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) @@ -613,8 +613,15 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest): ) ] + if TEST_WITH_ROCM: + errorString = (r"HIP error: invalid device ordinal\n" + r"HIP kernel errors might be asynchronously reported at some other API call, " + r"so the stacktrace below might be incorrect.\n" + r"For debugging consider passing AMD_SERIALIZE_KERNEL=3") + else: + errorString = r"CUDA error: invalid device ordinal" with self.assertRaisesRegex( - RuntimeError, r"CUDA error: invalid device ordinal" + RuntimeError, errorString ): [ m.forward()