[RESUBMIT] Standardize on error types for distributed errors. (#108191)

We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.

This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
  ...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
  ...
if "The client socket has timed out after" in exception_str:
  ...
if "Broken pipe" in exception_str:
  ...
if "Connection reset by peer" in exception_str:
  ...
```

To address this issue, in this PR I've ensured added these error types:

1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108191
Approved by: https://github.com/H-Huang
This commit is contained in:
Pritam Damania 2023-08-30 21:47:35 +00:00 committed by PyTorch MergeBot
parent 6dacf52f88
commit 704b0b3c67
28 changed files with 267 additions and 240 deletions

View File

@ -534,7 +534,6 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/comm.cpp",
"torch/csrc/distributed/c10d/debug.cpp",
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
"torch/csrc/distributed/c10d/exception.cpp",
"torch/csrc/distributed/c10d/logger.cpp",
"torch/csrc/distributed/c10d/logging.cpp",
"torch/csrc/distributed/c10d/quantization/quantization.cpp",

View File

@ -272,10 +272,28 @@ class C10_API OutOfMemoryError : public Error {
using Error::Error;
};
// Base error type for all distributed errors.
// These turn into DistError when they cross into Python.
class C10_API DistError : public Error {
using Error::Error;
};
// Used for collective communication library errors from the distributed module.
// These turn into DistBackendError when they cross into Python.
class C10_API DistBackendError : public Error {
using Error::Error;
class C10_API DistBackendError : public DistError {
using DistError::DistError;
};
// Used for errors originating from the store.
// These turn into DistStoreError when they cross into Python.
class C10_API DistStoreError : public DistError {
using DistError::DistError;
};
// Used for errors originating from the TCP/IP stack and not from collective
// libraries. These turn into DistNetworkError when they cross into Python.
class C10_API DistNetworkError : public DistError {
using DistError::DistError;
};
// A utility function to return an exception std::string by prepending its

View File

@ -842,10 +842,20 @@ following matrix shows how the log level can be adjusted via the combination of
| ``INFO`` | ``DETAIL`` | Trace (a.k.a. All) |
+-------------------------+-----------------------------+------------------------+
Distributed has a custom Exception type derived from `RuntimeError` called `torch.distributed.DistBackendError`. This exception is thrown when a backend-specific error occurs. For example, if
the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library.
Distributed has customs Exception types derived from `RuntimeError`:
- `torch.distributed.DistError`: This is the base type of all distributed exceptions.
- `torch.distributed.DistBackendError`: This exception is thrown when a backend-specific error occurs. For example, if
the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library.
- `torch.distributed.DistNetworkError`: This exception is thrown when networking
libraries encounter errors (ex: Connection reset by peer)
- `torch.distributed.DistStoreError`: This exception is thrown when the Store encounters
an error (ex: TCPStore timeout)
.. autoclass:: torch.distributed.DistError
.. autoclass:: torch.distributed.DistBackendError
.. autoclass:: torch.distributed.DistNetworkError
.. autoclass:: torch.distributed.DistStoreError
.. warning::
The DistBackendError exception type is an experimental feature is subject to change.

View File

@ -39,7 +39,7 @@ void testGetSet(std::string prefix = "") {
EXPECT_FALSE(delFailure);
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
store.setTimeout(timeout);
EXPECT_THROW(store.get("key0"), std::runtime_error);
EXPECT_THROW(store.get("key0"), c10::DistStoreError);
}
// get() waits up to timeout_.

View File

@ -195,7 +195,7 @@ TEST(TCPStoreTest, testCleanShutdown) {
clientTCPStore->get("key");
auto clientThread = std::thread([&clientTCPStore] {
EXPECT_THROW(clientTCPStore->get("invalid_key"), std::system_error);
EXPECT_THROW(clientTCPStore->get("invalid_key"), c10::DistNetworkError);
});
// start server shutdown during a client request

View File

@ -14,6 +14,7 @@ import sys
import unittest
from contextlib import closing
from torch.distributed import DistNetworkError
from torch.distributed.elastic.utils.distributed import (
create_c10d_store,
get_socket_with_port,
@ -108,7 +109,7 @@ class DistributedUtilTest(TestCase):
)
def test_create_store_timeout_on_worker(self):
with self.assertRaises(TimeoutError):
with self.assertRaises(DistNetworkError):
# use any available port (port 0) since timeout is expected
create_c10d_store(
is_server=False,
@ -142,7 +143,7 @@ class DistributedUtilTest(TestCase):
port = sock.getsockname()[1]
# on the worker port conflict shouldn't matter, it should just timeout
# since we never created a server
with self.assertRaises(TimeoutError):
with self.assertRaises(DistNetworkError):
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),

View File

@ -155,7 +155,8 @@ class TimeoutTest(TestCase):
timeout=timeout,
logging_interval=timeout / 2
)
except RuntimeError as e:
except torch.distributed.DistStoreError as e:
self.assertTrue(isinstance(e, torch.distributed.DistError))
error_list.append(e)
world_size = 4
@ -1248,15 +1249,15 @@ class AbstractCommTest:
group = dist.new_group(ranks=[1])
self.assertEqual(dist.get_group_rank(group, 1), 0)
with self.assertRaisesRegex(RuntimeError, "not part of group"):
with self.assertRaisesRegex(ValueError, "not part of group"):
dist.get_group_rank(group, 0)
with self.assertRaisesRegex(RuntimeError, "not registered"):
with self.assertRaisesRegex(ValueError, "not registered"):
dist.get_group_rank(DummyProcessGroup(self.rank, self.world_size), 0)
self.assertEqual(dist.get_global_rank(group, 0), 1)
with self.assertRaisesRegex(RuntimeError, "not part of group"):
with self.assertRaisesRegex(ValueError, "not part of group"):
dist.get_global_rank(group, 1)
with self.assertRaisesRegex(RuntimeError, "not registered"):
with self.assertRaisesRegex(ValueError, "not registered"):
dist.get_global_rank(DummyProcessGroup(self.rank, self.world_size), 0)
self.assertEqual(dist.get_process_group_ranks(group), [1])
@ -1277,44 +1278,44 @@ class AbstractCommTest:
tensor_list_h[1] = tensor_list_h[1].half()
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.all_gather(tensor_list_h, tensor)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.all_gather(tensor_list, tensor_h)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.all_gather_coalesced([tensor_list_h], tensor_list)
dist.all_gather_coalesced([tensor_list], tensor_list_h)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.all_reduce_coalesced(tensor_list_h)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.reduce_scatter(tensor, tensor_list_h)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.reduce_scatter(tensor_h, tensor_list)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.all_to_all_single(tensor_h, tensor)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.all_to_all(tensor_list_h, tensor_list)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.all_to_all(tensor_list, tensor_list_h)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.scatter(tensor, tensor_list_h)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.gather(tensor_h, tensor_list)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.gather(tensor, tensor_list_h)
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
dist.scatter(tensor_h, tensor_list)
def _test_tensor_dtype_complex(self, backend):
@ -1506,7 +1507,7 @@ class CommTest(AbstractCommTest, MultiProcessTestCase):
for mode in invalid_debug_modes:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
with self.assertRaisesRegex(RuntimeError, "The value of TORCH_DISTRIBUTED_DEBUG must"):
with self.assertRaisesRegex(ValueError, "The value of TORCH_DISTRIBUTED_DEBUG must"):
dist.set_debug_level_from_env()

View File

@ -1203,7 +1203,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
# Output is not a list of lists.
dummy_output_lists = [torch.zeros([0], dtype=torch.float32)]
with self.assertRaisesRegex(
RuntimeError, "Invalid function argument.*output_tensor_lists"
TypeError, "Invalid function argument.*output_tensor_lists"
):
c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)

View File

@ -1100,6 +1100,7 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
# Both rank 0 and 1 will use the same CUDA device resulting in ncclInvalidUsage
with self.assertRaises(dist.DistBackendError) as cm:
dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0)
self.assertTrue(isinstance(cm.exception, dist.DistError))
self.assertIsInstance(cm.exception, RuntimeError)
@ -3064,7 +3065,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
with self.assertRaisesRegex(RuntimeError, "Invalid function argument"):
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
c10d.barrier(device_ids=self.rank)
@requires_nccl()

View File

@ -11,6 +11,7 @@ from sys import platform
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.distributed import DistNetworkError, DistError
from torch.testing._internal.common_distributed import MultiThreadedTestCase
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
@ -558,12 +559,13 @@ class RendezvousTCPTest(TestCase):
next(gen)
def test_dns_timeout(self):
with self.assertRaisesRegex(TimeoutError, "client socket has timed out after.*dnsnotexist"):
with self.assertRaisesRegex(DistNetworkError, "client socket has timed out after.*dnsnotexist") as manager:
gen = dist.rendezvous(
"tcp://dnsnotexist:23456?world_size=2&rank=0",
timeout=timedelta(seconds=1),
)
next(gen)
self.assertTrue(isinstance(manager.exception, DistError))
@retry_on_connect_failures
def test_nominal(self):

View File

@ -1926,7 +1926,10 @@ def _current_autograd_node() -> _Node: ...
# Defined in torch/csrc/Exceptions.cpp
class _OutOfMemoryError(RuntimeError): ...
class _DistError(RuntimeError): ...
class _DistBackendError(RuntimeError): ...
class _DistStoreError(RuntimeError): ...
class _DistNetworkError(RuntimeError): ...
# Defined in torch/csrc/profiler/init.cpp
class CapturedTraceback:

View File

@ -12,7 +12,9 @@
#include <c10/util/StringUtil.h>
PyObject *THPException_FatalError, *THPException_LinAlgError,
*THPException_OutOfMemoryError, *THPException_DistBackendError;
*THPException_OutOfMemoryError, *THPException_DistError,
*THPException_DistBackendError, *THPException_DistNetworkError,
*THPException_DistStoreError;
#define ASSERT_TRUE(cond) \
if (!(cond)) \
@ -62,16 +64,45 @@ could not be completed because the input matrix is singular.",
PyModule_AddObject(
module, "_OutOfMemoryError", THPException_OutOfMemoryError) == 0);
ASSERT_TRUE(
THPException_DistError = PyErr_NewExceptionWithDoc(
"torch.distributed.DistError",
"Exception raised when an error occurs in the distributed library",
PyExc_RuntimeError,
nullptr));
ASSERT_TRUE(
PyModule_AddObject(module, "_DistError", THPException_DistError) == 0);
ASSERT_TRUE(
THPException_DistBackendError = PyErr_NewExceptionWithDoc(
"torch.distributed.DistBackendError",
"Exception raised when a backend error occurs in distributed",
PyExc_RuntimeError,
THPException_DistError,
nullptr));
ASSERT_TRUE(
PyModule_AddObject(
module, "_DistBackendError", THPException_DistBackendError) == 0);
ASSERT_TRUE(
THPException_DistNetworkError = PyErr_NewExceptionWithDoc(
"torch.distributed.DistNetworkError",
"Exception raised when a network error occurs in distributed",
THPException_DistError,
nullptr));
ASSERT_TRUE(
PyModule_AddObject(
module, "_DistNetworkError", THPException_DistNetworkError) == 0);
ASSERT_TRUE(
THPException_DistStoreError = PyErr_NewExceptionWithDoc(
"torch.distributed.DistStoreError",
"Exception raised when an error occurs in the distributed store",
THPException_DistError,
nullptr));
ASSERT_TRUE(
PyModule_AddObject(
module, "_DistStoreError", THPException_DistStoreError) == 0);
return true;
}

View File

@ -63,51 +63,37 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
}
// Only catch torch-specific exceptions
#define CATCH_CORE_ERRORS(retstmnt) \
catch (python_error & e) { \
e.restore(); \
retstmnt; \
} \
catch (py::error_already_set & e) { \
e.restore(); \
retstmnt; \
} \
_CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
_CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
_CATCH_GENERIC_ERROR( \
NotImplementedError, PyExc_NotImplementedError, retstmnt) \
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
_CATCH_GENERIC_ERROR( \
OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
_CATCH_GENERIC_ERROR( \
DistBackendError, THPException_DistBackendError, retstmnt) \
_CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
catch (torch::PyTorchError & e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(e.python_type(), msg); \
retstmnt; \
#define CATCH_CORE_ERRORS(retstmnt) \
catch (python_error & e) { \
e.restore(); \
retstmnt; \
} \
catch (py::error_already_set & e) { \
e.restore(); \
retstmnt; \
} \
_CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
_CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
_CATCH_GENERIC_ERROR( \
NotImplementedError, PyExc_NotImplementedError, retstmnt) \
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
_CATCH_GENERIC_ERROR( \
OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
_CATCH_GENERIC_ERROR( \
DistBackendError, THPException_DistBackendError, retstmnt) \
_CATCH_GENERIC_ERROR( \
DistNetworkError, THPException_DistNetworkError, retstmnt) \
_CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
_CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt) \
_CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
catch (torch::PyTorchError & e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(e.python_type(), msg); \
retstmnt; \
}
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
#define CATCH_C10D_ERRORS(retstmnt) \
catch (const c10d::TimeoutError& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_TimeoutError, msg); \
retstmnt; \
} \
catch (const c10d::C10dError& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_RuntimeError, msg); \
retstmnt; \
}
#else
#define CATCH_C10D_ERRORS(retstmnt)
#endif
#define CATCH_TH_ERRORS(retstmnt) \
CATCH_CORE_ERRORS(retstmnt) \
CATCH_C10D_ERRORS(retstmnt)
#define CATCH_TH_ERRORS(retstmnt) CATCH_CORE_ERRORS(retstmnt)
#define CATCH_ALL_ERRORS(retstmnt) \
CATCH_TH_ERRORS(retstmnt) \
@ -153,7 +139,9 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
extern PyObject *THPException_FatalError, *THPException_LinAlgError,
*THPException_OutOfMemoryError, *THPException_DistBackendError;
*THPException_OutOfMemoryError, *THPException_DistError,
*THPException_DistBackendError, *THPException_DistNetworkError,
*THPException_DistStoreError;
// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.

View File

@ -27,9 +27,9 @@
#include <c10/util/Exception.h>
#define SYSASSERT(rv, ...) \
if ((rv) < 0) { \
throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
#define SYSASSERT(rv, ...) \
if ((rv) < 0) { \
C10_THROW_ERROR(DistStoreError, std::strerror(errno)); \
}
#ifdef _WIN32

View File

@ -51,8 +51,7 @@ std::vector<uint8_t> HashStore::get(const std::string& key) {
cv_.wait(lock, pred);
} else {
if (!cv_.wait_for(lock, timeout_, pred)) {
throw std::system_error(
ETIMEDOUT, std::system_category(), "Wait timeout");
C10_THROW_ERROR(DistStoreError, "Wait timeout");
}
}
return map_[key];
@ -78,8 +77,7 @@ void HashStore::wait(
cv_.wait(lock, pred);
} else {
if (!cv_.wait_until(lock, end, pred)) {
throw std::system_error(
ETIMEDOUT, std::system_category(), "Wait timeout");
C10_THROW_ERROR(DistStoreError, "Wait timeout");
}
}
}
@ -151,8 +149,7 @@ std::vector<std::vector<uint8_t>> HashStore::multiGet(
cv_.wait(lock, pred);
} else {
if (!cv_.wait_until(lock, deadline, pred)) {
throw std::system_error(
ETIMEDOUT, std::system_category(), "Wait timeout");
C10_THROW_ERROR(DistStoreError, "Wait timeout");
}
}
res.emplace_back(map_[key]);

View File

@ -683,7 +683,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
std::array<char, HOST_NAME_MAX> hostname{};
auto rv = gethostname(hostname.data(), HOST_NAME_MAX);
if (rv != 0) {
throw std::system_error(errno, std::system_category());
C10_THROW_ERROR(DistBackendError, std::strerror(errno));
}
// Use this machine's hostname if it resolves to an address.
@ -710,7 +710,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
auto hostname = std::unique_ptr<char[]>(new char[hostNameMax]);
auto rv = gethostname(hostname.get(), hostNameMax);
if (rv != 0) {
throw std::system_error(errno, std::system_category());
C10_THROW_ERROR(DistBackendError, std::strerror(errno));
}
// Use this machine's hostname if it resolves to an address.

View File

@ -186,22 +186,22 @@ void read_config() {
void check_device(c10::Device dev1, c10::Device dev2) {
if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
throw std::runtime_error("ProcessGroupUCC multidevice is not supported");
throw std::invalid_argument("ProcessGroupUCC multidevice is not supported");
}
}
void check_tensor(const std::vector<at::Tensor>& tensors) {
if (tensors.size() != 1) {
throw std::runtime_error(
throw std::invalid_argument(
"ProcessGroupUCC takes 1 tensor. Got " +
std::to_string(tensors.size()) + ". ");
}
if (!tensors[0].is_contiguous()) {
throw std::runtime_error(
throw std::invalid_argument(
"ProcessGroupUCC input tensor has to be contiguous");
}
if (tensors[0].is_sparse()) {
throw std::runtime_error("ProcessGroupUCC input tensor has to be dense");
throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense");
}
// TODO: check cuda case
}
@ -401,7 +401,7 @@ std::shared_ptr<Comm> Comm::get_comm(
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
"ucc communicator was initialized with different cuda device,"
"multi device is not supported");
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
shared_comm->cuda_device_index = dev.index();
}
@ -825,7 +825,7 @@ c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
default: {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
}
}
@ -1009,7 +1009,7 @@ c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
std::vector<at::Tensor>& /* unused */,
const AllreduceCoalescedOptions& /* unused */) {
throw std::runtime_error(
throw std::invalid_argument(
"ProcessGroupUCC does not support allreduce_coalesced");
}
@ -1610,7 +1610,7 @@ void ProcessGroupUCC::initComm(c10::Device dev) {
TORCH_UCC_INIT,
"ucc communicator was initialized with different cuda device,"
"multi device is not supported");
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
comm->cuda_device_index = dev.index();
}

View File

@ -442,7 +442,7 @@ void TCPStore::doWait(
TORCH_CHECK(false, "wait_canceled response is expected");
}
}
TORCH_CHECK(false, "Socket Timeout");
C10_THROW_ERROR(DistStoreError, "Socket Timeout");
}
void TCPStore::append(

View File

@ -473,9 +473,8 @@ void TCPStoreMasterDaemon::run() {
// accept new connections.
if (fds[0].revents != 0) {
if (!(fds[0].revents & POLLIN)) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
C10_THROW_ERROR(
DistStoreError,
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
@ -515,9 +514,8 @@ void TCPStoreMasterDaemon::run() {
// accept new connections.
if (fds[0].revents != 0) {
if (fds[0].revents ^ POLLIN) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
C10_THROW_ERROR(
DistStoreError,
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
@ -532,9 +530,8 @@ void TCPStoreMasterDaemon::run() {
// The main thread will write a byte to the pipe then close it before
// joining the background thread
if (fds[1].revents & ~(POLLIN | POLLHUP)) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
C10_THROW_ERROR(
DistStoreError,
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[1].revents));
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <torch/csrc/distributed/c10d/Types.hpp>
@ -521,30 +522,30 @@ using SizeType = uint64_t;
continue; \
} else if ( \
errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \
TORCH_CHECK(false, "Socket Timeout"); \
C10_THROW_ERROR(DistNetworkError, "Socket Timeout"); \
} else { \
throw std::system_error(errno_local, std::system_category()); \
C10_THROW_ERROR(DistNetworkError, std::strerror(errno_local)); \
} \
} else { \
break; \
} \
}
#else
#define SYSCHECK(expr, success_cond) \
while (true) { \
auto __output = (expr); \
(void)__output; \
if (!(success_cond)) { \
if (errno == EINTR) { \
continue; \
} else if (errno == EAGAIN || errno == EWOULDBLOCK) { \
TORCH_CHECK(false, "Socket Timeout"); \
} else { \
throw std::system_error(errno, std::system_category()); \
} \
} else { \
break; \
} \
#define SYSCHECK(expr, success_cond) \
while (true) { \
auto __output = (expr); \
(void)__output; \
if (!(success_cond)) { \
if (errno == EINTR) { \
continue; \
} else if (errno == EAGAIN || errno == EWOULDBLOCK) { \
C10_THROW_ERROR(DistNetworkError, "Socket Timeout"); \
} else { \
C10_THROW_ERROR(DistNetworkError, std::strerror(errno)); \
} \
} else { \
break; \
} \
}
#endif
@ -589,7 +590,7 @@ void sendBytes(
bytesSent =
::send(socket, (const char*)currentBytes, bytesToSend, flags))
if (bytesSent == 0) {
throw std::system_error(ECONNRESET, std::system_category());
C10_THROW_ERROR(DistNetworkError, std::strerror(ECONNRESET));
}
bytesToSend -= bytesSent;
@ -612,7 +613,7 @@ void recvBytes(int socket, T* buffer, size_t length) {
SYSCHECK_ERR_RETURN_NEG1(
bytesReceived = recv(socket, (char*)currentBytes, bytesToReceive, 0))
if (bytesReceived == 0) {
throw std::system_error(ECONNRESET, std::system_category());
C10_THROW_ERROR(DistNetworkError, std::strerror(ECONNRESET));
}
bytesToReceive -= bytesReceived;

View File

@ -42,8 +42,8 @@ DebugLevel loadDebugLevelFromEnvironment() {
} else if (level_str == "DETAIL") {
level = DebugLevel::Detail;
} else {
throw C10dError{
"The value of TORCH_DISTRIBUTED_DEBUG must be OFF, INFO, or DETAIL."};
throw std::invalid_argument(
"The value of TORCH_DISTRIBUTED_DEBUG must be OFF, INFO, or DETAIL.");
}
C10D_INFO("The debug level is set to {}.", level_str);

View File

@ -1,9 +0,0 @@
#include <torch/csrc/distributed/c10d/exception.h>
namespace c10d {
C10dError::~C10dError() = default;
TimeoutError::~TimeoutError() = default;
} // namespace c10d

View File

@ -9,37 +9,25 @@
#include <stdexcept>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
// Utility macro similar to C10_THROW_ERROR, the major difference is that this
// macro handles exception types defined in the c10d namespace, whereas
// C10_THROW_ERROR requires an exception to be defined in the c10 namespace.
#define C10D_THROW_ERROR(err_type, msg) \
throw ::c10d::err_type( \
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, msg)
namespace c10d {
class TORCH_API C10dError : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
using c10::DistNetworkError;
C10dError(const C10dError&) = default;
C10dError& operator=(const C10dError&) = default;
C10dError(C10dError&&) = default;
C10dError& operator=(C10dError&&) = default;
~C10dError() override;
class TORCH_API SocketError : public DistNetworkError {
using DistNetworkError::DistNetworkError;
};
class TORCH_API TimeoutError : public C10dError {
public:
using C10dError::C10dError;
TimeoutError(const TimeoutError&) = default;
TimeoutError& operator=(const TimeoutError&) = default;
TimeoutError(TimeoutError&&) = default;
TimeoutError& operator=(TimeoutError&&) = default;
~TimeoutError() override;
class TORCH_API TimeoutError : public DistNetworkError {
using DistNetworkError::DistNetworkError;
};
} // namespace c10d

View File

@ -109,7 +109,7 @@ void delay(std::chrono::seconds d) {
// We don't care about error conditions other than EINTR since a failure
// here is not critical.
if (err == std::errc::interrupted) {
throw std::system_error{err};
C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
}
}
#endif
@ -271,7 +271,7 @@ std::unique_ptr<SocketImpl> SocketImpl::accept() const {
if (hnd == invalid_socket) {
std::error_code err = getSocketError();
if (err == std::errc::interrupted) {
throw std::system_error{err};
C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
}
std::string msg{};
@ -287,7 +287,7 @@ std::unique_ptr<SocketImpl> SocketImpl::accept() const {
C10D_ERROR(msg);
throw SocketError{msg};
C10D_THROW_ERROR(SocketError, msg);
}
::addrinfo addr{};
@ -333,7 +333,8 @@ void SocketImpl::enableNonBlocking() {
}
}
#endif
throw SocketError{"The socket cannot be switched to non-blocking mode."};
C10D_THROW_ERROR(
SocketError, "The socket cannot be switched to non-blocking mode.");
}
// TODO: Remove once we migrate everything to non-blocking mode.
@ -351,7 +352,8 @@ void SocketImpl::disableNonBlocking() {
}
}
#endif
throw SocketError{"The socket cannot be switched to blocking mode."};
C10D_THROW_ERROR(
SocketError, "The socket cannot be switched to blocking mode.");
}
bool SocketImpl::enableNoDelay() noexcept {
@ -381,7 +383,8 @@ std::uint16_t SocketImpl::getPort() const {
if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) !=
0) {
throw SocketError{"The port number of the socket cannot be retrieved."};
C10D_THROW_ERROR(
SocketError, "The port number of the socket cannot be retrieved.");
}
if (addr_s.ss_family == AF_INET) {
@ -471,7 +474,8 @@ std::unique_ptr<SocketImpl> SocketListenOp::run() {
C10D_ERROR(msg);
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
C10D_THROW_ERROR(
SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
}
bool SocketListenOp::tryListen(int family) {
@ -599,26 +603,31 @@ std::unique_ptr<SocketImpl> SocketListenFromFdOp::run() const {
::socklen_t addr_len = sizeof(addr_storage);
if (::getsockname(
fd_, reinterpret_cast<::sockaddr*>(&addr_storage), &addr_len) < 0) {
throw SocketError{
fmt::format("getsockname failed for fd {}: {}", fd_, getSocketError())};
C10D_THROW_ERROR(
SocketError,
fmt::format("getsockname failed for fd {}: {}", fd_, getSocketError()));
}
auto socket = std::make_unique<SocketImpl>(fd_);
const auto port = socket->getPort();
if (port != expected_port_) {
throw SocketError{fmt::format(
"listen fd {} is bound to port {}, expected to be bound to port {}",
fd_,
port,
expected_port_)};
C10D_THROW_ERROR(
SocketError,
fmt::format(
"listen fd {} is bound to port {}, expected to be bound to port {}",
fd_,
port,
expected_port_));
}
if (::listen(socket->handle(), -1 /* backlog */) != 0) {
throw SocketError{fmt::format(
"Failed to listen on socket initialized from fd {}: {}.",
socket->handle(),
getSocketError())};
C10D_THROW_ERROR(
SocketError,
fmt::format(
"Failed to listen on socket initialized from fd {}: {}.",
socket->handle(),
getSocketError()));
}
socket->closeOnExec();
@ -718,7 +727,8 @@ std::unique_ptr<SocketImpl> SocketConnectOp::run() {
C10D_ERROR(msg);
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
C10D_THROW_ERROR(
SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
}
bool SocketConnectOp::tryConnect(int family) {
@ -821,7 +831,7 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(
if (cr == ConnectResult::Error) {
std::error_code err = getSocketError();
if (err == std::errc::interrupted) {
throw std::system_error{err};
C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
}
// Retry if the server is not yet listening or if its backlog is exhausted.
@ -921,7 +931,7 @@ void SocketConnectOp::throwTimeoutError() const {
C10D_ERROR(msg);
throw TimeoutError{msg};
C10D_THROW_ERROR(TimeoutError, msg);
}
} // namespace
@ -935,7 +945,8 @@ void Socket::initialize() {
c10::call_once(init_flag, []() {
WSADATA data{};
if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) {
throw SocketError{"The initialization of Winsock has failed."};
C10D_THROW_ERROR(
SocketError, "The initialization of Winsock has failed.");
}
});
#endif
@ -973,7 +984,7 @@ Socket Socket::accept() const {
return Socket{impl_->accept()};
}
throw SocketError{"The socket is not initialized."};
C10D_THROW_ERROR(SocketError, "The socket is not initialized.");
}
int Socket::handle() const noexcept {
@ -999,6 +1010,4 @@ bool Socket::waitForInput(std::chrono::milliseconds timeout) {
} // namespace detail
SocketError::~SocketError() = default;
} // namespace c10d

View File

@ -12,6 +12,7 @@
#include <string>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <torch/csrc/distributed/c10d/exception.h>
namespace c10d {
@ -89,19 +90,4 @@ class Socket {
} // namespace detail
class TORCH_API SocketError : public C10dError {
public:
using C10dError::C10dError;
SocketError(const SocketError&) = default;
SocketError& operator=(const SocketError&) = default;
SocketError(SocketError&&) = default;
SocketError& operator=(SocketError&&) = default;
~SocketError() override;
};
} // namespace c10d

View File

@ -20,7 +20,10 @@ if is_available() and not torch._C._c10d_init():
raise RuntimeError("Failed to initialize torch.distributed")
# Custom Runtime Errors thrown from the distributed package
DistError = torch._C._DistError
DistBackendError = torch._C._DistBackendError
DistNetworkError = torch._C._DistNetworkError
DistStoreError = torch._C._DistStoreError
if is_available():
from torch._C._distributed_c10d import (

View File

@ -35,6 +35,7 @@ from torch._C._distributed_c10d import (
from .constants import default_pg_timeout
from .c10d_logger import _exception_logger, _time_logger
from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
DistStoreError = torch._C._DistStoreError
__all__ = [
'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced',
@ -675,7 +676,7 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log
)
if timedelta(seconds=(time.time() - start)) > timeout:
raise RuntimeError(
raise DistStoreError(
"Timed out initializing process group in store based barrier on "
"rank {}, for key: {} (world_size={}, num_workers_joined={}, timeout={})".format(
rank, store_key, world_size, worker_count, timeout
@ -722,10 +723,10 @@ def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
if group is GroupMember.WORLD:
return global_rank
if group not in _world.pg_group_ranks:
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
group_ranks = _world.pg_group_ranks[group]
if global_rank not in group_ranks:
raise RuntimeError(f"Global rank {global_rank} is not part of group {group}")
raise ValueError(f"Global rank {global_rank} is not part of group {group}")
return group_ranks[global_rank]
@ -747,11 +748,11 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
if group is GroupMember.WORLD:
return group_rank
if group not in _world.pg_group_ranks:
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
for rank, grp_rank in _world.pg_group_ranks[group].items():
if grp_rank == group_rank:
return rank
raise RuntimeError(f"Group rank {group_rank} is not part of group {group}")
raise ValueError(f"Group rank {group_rank} is not part of group {group}")
# TODO: remove this once the ecosystem moves away from it.
def _get_global_rank(group, rank):
@ -792,7 +793,7 @@ def _check_single_tensor(param, param_name):
Helper to check that the parameter ``param_name`` is a single tensor.
"""
if not isinstance(param, torch.Tensor):
raise RuntimeError(
raise TypeError(
f"Invalid function argument. Expected parameter `{param_name}` to be of type torch.Tensor."
)
@ -804,7 +805,7 @@ def _check_tensor_list(param, param_name):
if not isinstance(param, list) or not all(
isinstance(p, torch.Tensor) for p in param
):
raise RuntimeError(
raise TypeError(
f"Invalid function argument. Expected parameter `{param_name}` to be of type List[torch.Tensor]."
)
@ -823,7 +824,7 @@ def _ensure_all_tensors_same_dtype(*tensors) -> None:
last_dtype = tensor_dtype
else:
if last_dtype != tensor_dtype:
raise RuntimeError(
raise ValueError(
"Invalid usage of tensors with different dtypes"
f"Found {last_dtype} and {tensor.dtype}"
)
@ -834,7 +835,7 @@ def _check_op(op):
Helper to check that the ``op`` is either isend or irecv.
"""
if op not in [isend, irecv]:
raise RuntimeError(
raise ValueError(
"Invalid ``op``. Expected ``op`` "
"to be of type ``torch.distributed.isend`` or "
"``torch.distributed.irecv``."
@ -849,14 +850,14 @@ def _check_p2p_op_list(p2p_op_list):
if not isinstance(p2p_op_list, list) or not all(
isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
):
raise RuntimeError(
raise ValueError(
"Invalid ``p2p_op_list``. Each op is expected to "
"to be of type ``torch.distributed.P2POp``."
)
group = p2p_op_list[0].group
if not all(group == p2p_op.group for p2p_op in p2p_op_list):
raise RuntimeError("All ops need to use the same group.")
raise ValueError("All ops need to use the same group.")
def is_mpi_available() -> bool:
@ -937,7 +938,7 @@ def _get_default_group():
Getting the default process group created by init_process_group
"""
if not is_initialized():
raise RuntimeError(
raise ValueError(
"Default process group has not been initialized, "
"please make sure to call init_process_group."
)
@ -949,7 +950,7 @@ def _get_default_store():
Getting the default store created by init_process_group
"""
if not is_initialized():
raise RuntimeError(
raise ValueError(
"Default process group has not been initialized, "
"please make sure to call init_process_group."
)
@ -967,7 +968,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
else:
pg = group
if _rank_not_in_group(pg):
raise RuntimeError("Invalid process group specified")
raise ValueError("Invalid process group specified")
backend_config = _world.pg_backend_config.get(pg)
assert backend_config is not None
return str(backend_config)
@ -990,7 +991,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> str:
else:
pg = group
if _rank_not_in_group(pg):
raise RuntimeError("Invalid process group specified")
raise ValueError("Invalid process group specified")
pg_store = _world.pg_map[pg] if pg in _world.pg_map else None
assert pg_store is not None
return pg_store[0]
@ -1092,12 +1093,12 @@ def init_process_group(
global _default_pg_init_method
if not isinstance(timeout, timedelta):
raise RuntimeError(
raise TypeError(
"Expected timeout argument to be of type datetime.timedelta"
)
if GroupMember.WORLD is not None:
raise RuntimeError("trying to initialize the default process group twice!")
raise ValueError("trying to initialize the default process group twice!")
assert (store is None) or (
init_method is None
@ -1206,13 +1207,13 @@ def _new_process_group_helper(
global _world
if group_name in _world.pg_names.values():
raise RuntimeError(
raise ValueError(
"The specified group name has already been "
"created, please use a different group name"
)
if not isinstance(timeout, timedelta):
raise RuntimeError(
raise TypeError(
"Expected timeout argument to be of type datetime.timedelta"
)
@ -1387,7 +1388,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
assert pg is not None
if _world.pg_map.get(pg, None) is None:
raise RuntimeError("Invalid process group specified")
raise ValueError("Invalid process group specified")
# When users register Python onCompletion hooks, those hooks will run on a
# different thread than the main thread. Today, the ProcessGroup dtor does
@ -1647,7 +1648,7 @@ def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proces
class _IllegalWork(Work):
def __getattribute__(self, name):
if name in ["is_success", "exception", "wait", "source_rank", "_source_rank", "result", "synchronize"]:
raise RuntimeError(f"Illegal to call {name} on IllegalWork object")
raise ValueError(f"Illegal to call {name} on IllegalWork object")
class _CoalescingManager:
@ -1699,7 +1700,7 @@ def _coalescing_manager(
group = group or _get_default_group()
op_list = _world.pg_coalesce_state.setdefault(group, [])
if op_list:
raise RuntimeError("ProcessGroup has non-empty op list at the start of coalescing")
raise ValueError("ProcessGroup has non-empty op list at the start of coalescing")
if device:
group._start_coalescing(device)
cm = _CoalescingManager()
@ -2030,7 +2031,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if tensor.is_complex():
if not supports_complex(op):
raise RuntimeError(f"all_reduce does not support {op} on complex tensors")
raise ValueError(f"all_reduce does not support {op} on complex tensors")
tensor = torch.view_as_real(tensor)
opts = AllreduceOptions()
@ -2100,7 +2101,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
return
if any(t.is_complex() for t in tensors) and not supports_complex(op):
raise RuntimeError(f"all_reduce does not support {op} on complex tensors")
raise ValueError(f"all_reduce does not support {op} on complex tensors")
tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]
@ -2695,7 +2696,7 @@ def scatter_object_list(
not isinstance(scatter_object_output_list, list)
or len(scatter_object_output_list) < 1
):
raise RuntimeError(
raise ValueError(
"Expected argument scatter_object_output_list to be a list of size at least 1."
)
@ -2992,7 +2993,7 @@ def all_gather_coalesced(
_check_tensor_list(input_tensor_list, "input_tensor_list")
_ensure_all_tensors_same_dtype(input_tensor_list)
if not isinstance(output_tensor_lists, list):
raise RuntimeError(
raise TypeError(
"Invalid function argument: output_tensor_lists should be a list"
)
for output_tensor_list in output_tensor_lists:
@ -3687,7 +3688,7 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
if isinstance(device_ids, list):
opts.device_ids = device_ids
else:
raise RuntimeError(
raise TypeError(
"Invalid function argument: device_ids type should be List[int]"
)
@ -3760,7 +3761,7 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals
return
if get_backend(group) != Backend.GLOO:
raise RuntimeError("monitored_barrier is only implemented for GLOO backend.")
raise ValueError("monitored_barrier is only implemented for GLOO backend.")
if timeout is None:
timeout = default_pg_timeout
@ -3909,7 +3910,7 @@ def _new_group_with_tag(
if use_local_synchronization:
# MPI backend doesn't have have a way for us to perform a partial sync
if backend == Backend.MPI:
raise RuntimeError("MPI backend doesn't support use_local_synchronization=True")
raise ValueError("MPI backend doesn't support use_local_synchronization=True")
if ranks is not None and get_rank() not in ranks:
return None
@ -3918,7 +3919,7 @@ def _new_group_with_tag(
ranks = sorted(ranks)
group_world_size = len(ranks)
if group_world_size > global_world_size:
raise RuntimeError(
raise ValueError(
"the new group's world size should be less or "
"equal to the world size set by "
"init_process_group"
@ -3926,7 +3927,7 @@ def _new_group_with_tag(
# check ranks' sanity
for rank in ranks:
if rank < 0 or rank >= global_world_size:
raise RuntimeError(
raise ValueError(
"The new group's rank should be within the "
"the world_size set by init_process_group"
)

View File

@ -713,7 +713,7 @@ class DistributedTest:
self.assertEqual(dist.get_backend(group_id), backend_str)
else:
with self.assertRaisesRegex(
RuntimeError, "Invalid process group specified"
ValueError, "Invalid process group specified"
):
dist.get_backend(group_id)
@ -970,7 +970,7 @@ class DistributedTest:
group, group_id, rank = self._init_global_test()
with self.assertRaisesRegex(
RuntimeError,
ValueError,
"The new group's rank should be within the the world_size set by init_process_group",
):
dist.new_subgroups_by_enumeration(
@ -1492,7 +1492,7 @@ class DistributedTest:
if rank == 0:
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
device_id = rank_to_GPU[rank][0]
with self.assertRaisesRegex(RuntimeError, "^Invalid ``op``"):
with self.assertRaisesRegex(ValueError, "^Invalid ``op``"):
send_tensor = _build_tensor(rank + 1, device_id=device_id)
send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
dist.batch_isend_irecv([send_op])
@ -1504,7 +1504,7 @@ class DistributedTest:
self._barrier()
rank = dist.get_rank()
if rank == 0:
with self.assertRaisesRegex(RuntimeError, "^Invalid ``p2p_op_list``"):
with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"):
dist.batch_isend_irecv([1, 2])
# NCCL Batch SEND RECV Mixed Backend Error
@ -1519,7 +1519,7 @@ class DistributedTest:
group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
if rank == 0:
with self.assertRaisesRegex(
RuntimeError, "All ops need to use the same group"
ValueError, "All ops need to use the same group"
):
send_tensor = _build_tensor(rank + 1)
send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
@ -2772,7 +2772,7 @@ class DistributedTest:
group, group_id, rank = self._init_global_test()
for unsupported_op in unsupported_ops:
with self.assertRaisesRegex(
RuntimeError, "all_reduce does not support"
ValueError, "all_reduce does not support"
):
dist.all_reduce(
_build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
@ -3004,7 +3004,7 @@ class DistributedTest:
)
def test_all_reduce_coalesced_max_complex_unsupported(self):
group, group_id, rank = self._init_global_test()
with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"):
with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
dist.all_reduce_coalesced(
[_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
)
@ -8306,7 +8306,7 @@ class DistributedTest:
)
# Ensure errors are raised upon incorrect arguments.
with self.assertRaisesRegex(
RuntimeError,
ValueError,
"Expected argument scatter_object_output_list to be a list of size at least 1.",
):
dist.scatter_object_list([], scatter_list, src=src_rank)