mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[RESUBMIT] Standardize on error types for distributed errors. (#108191)
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc. This results in messy code during error handling somewhat like this: ``` if "NCCL" in exception_str: ... if "Timed out initializing process group in store based barrier on rank" in exception_str: ... if "The client socket has timed out after" in exception_str: ... if "Broken pipe" in exception_str: ... if "Connection reset by peer" in exception_str: ... ``` To address this issue, in this PR I've ensured added these error types: 1. **DistError** - the base type of all distributed errors 2. **DistBackendError** - this already existed and referred to PG backend errors 3. **DistStoreError** - for errors originating from the store 4. **DistNetworkError** - for general network errors coming from the socket library Pull Request resolved: https://github.com/pytorch/pytorch/pull/108191 Approved by: https://github.com/H-Huang
This commit is contained in:
parent
6dacf52f88
commit
704b0b3c67
|
|
@ -534,7 +534,6 @@ libtorch_distributed_base_sources = [
|
|||
"torch/csrc/distributed/c10d/comm.cpp",
|
||||
"torch/csrc/distributed/c10d/debug.cpp",
|
||||
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
|
||||
"torch/csrc/distributed/c10d/exception.cpp",
|
||||
"torch/csrc/distributed/c10d/logger.cpp",
|
||||
"torch/csrc/distributed/c10d/logging.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization.cpp",
|
||||
|
|
|
|||
|
|
@ -272,10 +272,28 @@ class C10_API OutOfMemoryError : public Error {
|
|||
using Error::Error;
|
||||
};
|
||||
|
||||
// Base error type for all distributed errors.
|
||||
// These turn into DistError when they cross into Python.
|
||||
class C10_API DistError : public Error {
|
||||
using Error::Error;
|
||||
};
|
||||
|
||||
// Used for collective communication library errors from the distributed module.
|
||||
// These turn into DistBackendError when they cross into Python.
|
||||
class C10_API DistBackendError : public Error {
|
||||
using Error::Error;
|
||||
class C10_API DistBackendError : public DistError {
|
||||
using DistError::DistError;
|
||||
};
|
||||
|
||||
// Used for errors originating from the store.
|
||||
// These turn into DistStoreError when they cross into Python.
|
||||
class C10_API DistStoreError : public DistError {
|
||||
using DistError::DistError;
|
||||
};
|
||||
|
||||
// Used for errors originating from the TCP/IP stack and not from collective
|
||||
// libraries. These turn into DistNetworkError when they cross into Python.
|
||||
class C10_API DistNetworkError : public DistError {
|
||||
using DistError::DistError;
|
||||
};
|
||||
|
||||
// A utility function to return an exception std::string by prepending its
|
||||
|
|
|
|||
|
|
@ -842,10 +842,20 @@ following matrix shows how the log level can be adjusted via the combination of
|
|||
| ``INFO`` | ``DETAIL`` | Trace (a.k.a. All) |
|
||||
+-------------------------+-----------------------------+------------------------+
|
||||
|
||||
Distributed has a custom Exception type derived from `RuntimeError` called `torch.distributed.DistBackendError`. This exception is thrown when a backend-specific error occurs. For example, if
|
||||
the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library.
|
||||
Distributed has customs Exception types derived from `RuntimeError`:
|
||||
|
||||
- `torch.distributed.DistError`: This is the base type of all distributed exceptions.
|
||||
- `torch.distributed.DistBackendError`: This exception is thrown when a backend-specific error occurs. For example, if
|
||||
the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library.
|
||||
- `torch.distributed.DistNetworkError`: This exception is thrown when networking
|
||||
libraries encounter errors (ex: Connection reset by peer)
|
||||
- `torch.distributed.DistStoreError`: This exception is thrown when the Store encounters
|
||||
an error (ex: TCPStore timeout)
|
||||
|
||||
.. autoclass:: torch.distributed.DistError
|
||||
.. autoclass:: torch.distributed.DistBackendError
|
||||
.. autoclass:: torch.distributed.DistNetworkError
|
||||
.. autoclass:: torch.distributed.DistStoreError
|
||||
|
||||
.. warning::
|
||||
The DistBackendError exception type is an experimental feature is subject to change.
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ void testGetSet(std::string prefix = "") {
|
|||
EXPECT_FALSE(delFailure);
|
||||
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
|
||||
store.setTimeout(timeout);
|
||||
EXPECT_THROW(store.get("key0"), std::runtime_error);
|
||||
EXPECT_THROW(store.get("key0"), c10::DistStoreError);
|
||||
}
|
||||
|
||||
// get() waits up to timeout_.
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ TEST(TCPStoreTest, testCleanShutdown) {
|
|||
clientTCPStore->get("key");
|
||||
|
||||
auto clientThread = std::thread([&clientTCPStore] {
|
||||
EXPECT_THROW(clientTCPStore->get("invalid_key"), std::system_error);
|
||||
EXPECT_THROW(clientTCPStore->get("invalid_key"), c10::DistNetworkError);
|
||||
});
|
||||
|
||||
// start server shutdown during a client request
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import sys
|
|||
import unittest
|
||||
from contextlib import closing
|
||||
|
||||
from torch.distributed import DistNetworkError
|
||||
from torch.distributed.elastic.utils.distributed import (
|
||||
create_c10d_store,
|
||||
get_socket_with_port,
|
||||
|
|
@ -108,7 +109,7 @@ class DistributedUtilTest(TestCase):
|
|||
)
|
||||
|
||||
def test_create_store_timeout_on_worker(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
with self.assertRaises(DistNetworkError):
|
||||
# use any available port (port 0) since timeout is expected
|
||||
create_c10d_store(
|
||||
is_server=False,
|
||||
|
|
@ -142,7 +143,7 @@ class DistributedUtilTest(TestCase):
|
|||
port = sock.getsockname()[1]
|
||||
# on the worker port conflict shouldn't matter, it should just timeout
|
||||
# since we never created a server
|
||||
with self.assertRaises(TimeoutError):
|
||||
with self.assertRaises(DistNetworkError):
|
||||
create_c10d_store(
|
||||
is_server=False,
|
||||
server_addr=socket.gethostname(),
|
||||
|
|
|
|||
|
|
@ -155,7 +155,8 @@ class TimeoutTest(TestCase):
|
|||
timeout=timeout,
|
||||
logging_interval=timeout / 2
|
||||
)
|
||||
except RuntimeError as e:
|
||||
except torch.distributed.DistStoreError as e:
|
||||
self.assertTrue(isinstance(e, torch.distributed.DistError))
|
||||
error_list.append(e)
|
||||
|
||||
world_size = 4
|
||||
|
|
@ -1248,15 +1249,15 @@ class AbstractCommTest:
|
|||
|
||||
group = dist.new_group(ranks=[1])
|
||||
self.assertEqual(dist.get_group_rank(group, 1), 0)
|
||||
with self.assertRaisesRegex(RuntimeError, "not part of group"):
|
||||
with self.assertRaisesRegex(ValueError, "not part of group"):
|
||||
dist.get_group_rank(group, 0)
|
||||
with self.assertRaisesRegex(RuntimeError, "not registered"):
|
||||
with self.assertRaisesRegex(ValueError, "not registered"):
|
||||
dist.get_group_rank(DummyProcessGroup(self.rank, self.world_size), 0)
|
||||
|
||||
self.assertEqual(dist.get_global_rank(group, 0), 1)
|
||||
with self.assertRaisesRegex(RuntimeError, "not part of group"):
|
||||
with self.assertRaisesRegex(ValueError, "not part of group"):
|
||||
dist.get_global_rank(group, 1)
|
||||
with self.assertRaisesRegex(RuntimeError, "not registered"):
|
||||
with self.assertRaisesRegex(ValueError, "not registered"):
|
||||
dist.get_global_rank(DummyProcessGroup(self.rank, self.world_size), 0)
|
||||
|
||||
self.assertEqual(dist.get_process_group_ranks(group), [1])
|
||||
|
|
@ -1277,44 +1278,44 @@ class AbstractCommTest:
|
|||
tensor_list_h[1] = tensor_list_h[1].half()
|
||||
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.all_gather(tensor_list_h, tensor)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.all_gather(tensor_list, tensor_h)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.all_gather_coalesced([tensor_list_h], tensor_list)
|
||||
dist.all_gather_coalesced([tensor_list], tensor_list_h)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.all_reduce_coalesced(tensor_list_h)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.reduce_scatter(tensor, tensor_list_h)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.reduce_scatter(tensor_h, tensor_list)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.all_to_all_single(tensor_h, tensor)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.all_to_all(tensor_list_h, tensor_list)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.all_to_all(tensor_list, tensor_list_h)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.scatter(tensor, tensor_list_h)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.gather(tensor_h, tensor_list)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.gather(tensor, tensor_list_h)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
|
||||
with self.assertRaisesRegex(ValueError, "tensors with different dtypes"):
|
||||
dist.scatter(tensor_h, tensor_list)
|
||||
|
||||
def _test_tensor_dtype_complex(self, backend):
|
||||
|
|
@ -1506,7 +1507,7 @@ class CommTest(AbstractCommTest, MultiProcessTestCase):
|
|||
|
||||
for mode in invalid_debug_modes:
|
||||
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
|
||||
with self.assertRaisesRegex(RuntimeError, "The value of TORCH_DISTRIBUTED_DEBUG must"):
|
||||
with self.assertRaisesRegex(ValueError, "The value of TORCH_DISTRIBUTED_DEBUG must"):
|
||||
dist.set_debug_level_from_env()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1203,7 +1203,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
# Output is not a list of lists.
|
||||
dummy_output_lists = [torch.zeros([0], dtype=torch.float32)]
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Invalid function argument.*output_tensor_lists"
|
||||
TypeError, "Invalid function argument.*output_tensor_lists"
|
||||
):
|
||||
c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
|
||||
|
||||
|
|
|
|||
|
|
@ -1100,6 +1100,7 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
|||
# Both rank 0 and 1 will use the same CUDA device resulting in ncclInvalidUsage
|
||||
with self.assertRaises(dist.DistBackendError) as cm:
|
||||
dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0)
|
||||
self.assertTrue(isinstance(cm.exception, dist.DistError))
|
||||
|
||||
self.assertIsInstance(cm.exception, RuntimeError)
|
||||
|
||||
|
|
@ -3064,7 +3065,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Invalid function argument"):
|
||||
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
|
||||
c10d.barrier(device_ids=self.rank)
|
||||
|
||||
@requires_nccl()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from sys import platform
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.distributed import DistNetworkError, DistError
|
||||
from torch.testing._internal.common_distributed import MultiThreadedTestCase
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
|
||||
|
||||
|
|
@ -558,12 +559,13 @@ class RendezvousTCPTest(TestCase):
|
|||
next(gen)
|
||||
|
||||
def test_dns_timeout(self):
|
||||
with self.assertRaisesRegex(TimeoutError, "client socket has timed out after.*dnsnotexist"):
|
||||
with self.assertRaisesRegex(DistNetworkError, "client socket has timed out after.*dnsnotexist") as manager:
|
||||
gen = dist.rendezvous(
|
||||
"tcp://dnsnotexist:23456?world_size=2&rank=0",
|
||||
timeout=timedelta(seconds=1),
|
||||
)
|
||||
next(gen)
|
||||
self.assertTrue(isinstance(manager.exception, DistError))
|
||||
|
||||
@retry_on_connect_failures
|
||||
def test_nominal(self):
|
||||
|
|
|
|||
|
|
@ -1926,7 +1926,10 @@ def _current_autograd_node() -> _Node: ...
|
|||
|
||||
# Defined in torch/csrc/Exceptions.cpp
|
||||
class _OutOfMemoryError(RuntimeError): ...
|
||||
class _DistError(RuntimeError): ...
|
||||
class _DistBackendError(RuntimeError): ...
|
||||
class _DistStoreError(RuntimeError): ...
|
||||
class _DistNetworkError(RuntimeError): ...
|
||||
|
||||
# Defined in torch/csrc/profiler/init.cpp
|
||||
class CapturedTraceback:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@
|
|||
#include <c10/util/StringUtil.h>
|
||||
|
||||
PyObject *THPException_FatalError, *THPException_LinAlgError,
|
||||
*THPException_OutOfMemoryError, *THPException_DistBackendError;
|
||||
*THPException_OutOfMemoryError, *THPException_DistError,
|
||||
*THPException_DistBackendError, *THPException_DistNetworkError,
|
||||
*THPException_DistStoreError;
|
||||
|
||||
#define ASSERT_TRUE(cond) \
|
||||
if (!(cond)) \
|
||||
|
|
@ -62,16 +64,45 @@ could not be completed because the input matrix is singular.",
|
|||
PyModule_AddObject(
|
||||
module, "_OutOfMemoryError", THPException_OutOfMemoryError) == 0);
|
||||
|
||||
ASSERT_TRUE(
|
||||
THPException_DistError = PyErr_NewExceptionWithDoc(
|
||||
"torch.distributed.DistError",
|
||||
"Exception raised when an error occurs in the distributed library",
|
||||
PyExc_RuntimeError,
|
||||
nullptr));
|
||||
ASSERT_TRUE(
|
||||
PyModule_AddObject(module, "_DistError", THPException_DistError) == 0);
|
||||
|
||||
ASSERT_TRUE(
|
||||
THPException_DistBackendError = PyErr_NewExceptionWithDoc(
|
||||
"torch.distributed.DistBackendError",
|
||||
"Exception raised when a backend error occurs in distributed",
|
||||
PyExc_RuntimeError,
|
||||
THPException_DistError,
|
||||
nullptr));
|
||||
ASSERT_TRUE(
|
||||
PyModule_AddObject(
|
||||
module, "_DistBackendError", THPException_DistBackendError) == 0);
|
||||
|
||||
ASSERT_TRUE(
|
||||
THPException_DistNetworkError = PyErr_NewExceptionWithDoc(
|
||||
"torch.distributed.DistNetworkError",
|
||||
"Exception raised when a network error occurs in distributed",
|
||||
THPException_DistError,
|
||||
nullptr));
|
||||
ASSERT_TRUE(
|
||||
PyModule_AddObject(
|
||||
module, "_DistNetworkError", THPException_DistNetworkError) == 0);
|
||||
|
||||
ASSERT_TRUE(
|
||||
THPException_DistStoreError = PyErr_NewExceptionWithDoc(
|
||||
"torch.distributed.DistStoreError",
|
||||
"Exception raised when an error occurs in the distributed store",
|
||||
THPException_DistError,
|
||||
nullptr));
|
||||
ASSERT_TRUE(
|
||||
PyModule_AddObject(
|
||||
module, "_DistStoreError", THPException_DistStoreError) == 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -63,51 +63,37 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
|||
}
|
||||
|
||||
// Only catch torch-specific exceptions
|
||||
#define CATCH_CORE_ERRORS(retstmnt) \
|
||||
catch (python_error & e) { \
|
||||
e.restore(); \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (py::error_already_set & e) { \
|
||||
e.restore(); \
|
||||
retstmnt; \
|
||||
} \
|
||||
_CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
NotImplementedError, PyExc_NotImplementedError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
DistBackendError, THPException_DistBackendError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
|
||||
catch (torch::PyTorchError & e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(e.python_type(), msg); \
|
||||
retstmnt; \
|
||||
#define CATCH_CORE_ERRORS(retstmnt) \
|
||||
catch (python_error & e) { \
|
||||
e.restore(); \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (py::error_already_set & e) { \
|
||||
e.restore(); \
|
||||
retstmnt; \
|
||||
} \
|
||||
_CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
NotImplementedError, PyExc_NotImplementedError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
DistBackendError, THPException_DistBackendError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
DistNetworkError, THPException_DistNetworkError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
|
||||
catch (torch::PyTorchError & e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(e.python_type(), msg); \
|
||||
retstmnt; \
|
||||
}
|
||||
|
||||
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
|
||||
#define CATCH_C10D_ERRORS(retstmnt) \
|
||||
catch (const c10d::TimeoutError& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(PyExc_TimeoutError, msg); \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (const c10d::C10dError& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(PyExc_RuntimeError, msg); \
|
||||
retstmnt; \
|
||||
}
|
||||
#else
|
||||
#define CATCH_C10D_ERRORS(retstmnt)
|
||||
#endif
|
||||
|
||||
#define CATCH_TH_ERRORS(retstmnt) \
|
||||
CATCH_CORE_ERRORS(retstmnt) \
|
||||
CATCH_C10D_ERRORS(retstmnt)
|
||||
#define CATCH_TH_ERRORS(retstmnt) CATCH_CORE_ERRORS(retstmnt)
|
||||
|
||||
#define CATCH_ALL_ERRORS(retstmnt) \
|
||||
CATCH_TH_ERRORS(retstmnt) \
|
||||
|
|
@ -153,7 +139,9 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
|||
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
|
||||
|
||||
extern PyObject *THPException_FatalError, *THPException_LinAlgError,
|
||||
*THPException_OutOfMemoryError, *THPException_DistBackendError;
|
||||
*THPException_OutOfMemoryError, *THPException_DistError,
|
||||
*THPException_DistBackendError, *THPException_DistNetworkError,
|
||||
*THPException_DistStoreError;
|
||||
|
||||
// Throwing this exception means that the python error flags have been already
|
||||
// set and control should be immediately returned to the interpreter.
|
||||
|
|
|
|||
|
|
@ -27,9 +27,9 @@
|
|||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#define SYSASSERT(rv, ...) \
|
||||
if ((rv) < 0) { \
|
||||
throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
|
||||
#define SYSASSERT(rv, ...) \
|
||||
if ((rv) < 0) { \
|
||||
C10_THROW_ERROR(DistStoreError, std::strerror(errno)); \
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
|
|
|
|||
|
|
@ -51,8 +51,7 @@ std::vector<uint8_t> HashStore::get(const std::string& key) {
|
|||
cv_.wait(lock, pred);
|
||||
} else {
|
||||
if (!cv_.wait_for(lock, timeout_, pred)) {
|
||||
throw std::system_error(
|
||||
ETIMEDOUT, std::system_category(), "Wait timeout");
|
||||
C10_THROW_ERROR(DistStoreError, "Wait timeout");
|
||||
}
|
||||
}
|
||||
return map_[key];
|
||||
|
|
@ -78,8 +77,7 @@ void HashStore::wait(
|
|||
cv_.wait(lock, pred);
|
||||
} else {
|
||||
if (!cv_.wait_until(lock, end, pred)) {
|
||||
throw std::system_error(
|
||||
ETIMEDOUT, std::system_category(), "Wait timeout");
|
||||
C10_THROW_ERROR(DistStoreError, "Wait timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -151,8 +149,7 @@ std::vector<std::vector<uint8_t>> HashStore::multiGet(
|
|||
cv_.wait(lock, pred);
|
||||
} else {
|
||||
if (!cv_.wait_until(lock, deadline, pred)) {
|
||||
throw std::system_error(
|
||||
ETIMEDOUT, std::system_category(), "Wait timeout");
|
||||
C10_THROW_ERROR(DistStoreError, "Wait timeout");
|
||||
}
|
||||
}
|
||||
res.emplace_back(map_[key]);
|
||||
|
|
|
|||
|
|
@ -683,7 +683,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
|
|||
std::array<char, HOST_NAME_MAX> hostname{};
|
||||
auto rv = gethostname(hostname.data(), HOST_NAME_MAX);
|
||||
if (rv != 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
C10_THROW_ERROR(DistBackendError, std::strerror(errno));
|
||||
}
|
||||
|
||||
// Use this machine's hostname if it resolves to an address.
|
||||
|
|
@ -710,7 +710,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
|
|||
auto hostname = std::unique_ptr<char[]>(new char[hostNameMax]);
|
||||
auto rv = gethostname(hostname.get(), hostNameMax);
|
||||
if (rv != 0) {
|
||||
throw std::system_error(errno, std::system_category());
|
||||
C10_THROW_ERROR(DistBackendError, std::strerror(errno));
|
||||
}
|
||||
|
||||
// Use this machine's hostname if it resolves to an address.
|
||||
|
|
|
|||
|
|
@ -186,22 +186,22 @@ void read_config() {
|
|||
|
||||
void check_device(c10::Device dev1, c10::Device dev2) {
|
||||
if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
|
||||
throw std::runtime_error("ProcessGroupUCC multidevice is not supported");
|
||||
throw std::invalid_argument("ProcessGroupUCC multidevice is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
void check_tensor(const std::vector<at::Tensor>& tensors) {
|
||||
if (tensors.size() != 1) {
|
||||
throw std::runtime_error(
|
||||
throw std::invalid_argument(
|
||||
"ProcessGroupUCC takes 1 tensor. Got " +
|
||||
std::to_string(tensors.size()) + ". ");
|
||||
}
|
||||
if (!tensors[0].is_contiguous()) {
|
||||
throw std::runtime_error(
|
||||
throw std::invalid_argument(
|
||||
"ProcessGroupUCC input tensor has to be contiguous");
|
||||
}
|
||||
if (tensors[0].is_sparse()) {
|
||||
throw std::runtime_error("ProcessGroupUCC input tensor has to be dense");
|
||||
throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense");
|
||||
}
|
||||
// TODO: check cuda case
|
||||
}
|
||||
|
|
@ -401,7 +401,7 @@ std::shared_ptr<Comm> Comm::get_comm(
|
|||
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
|
||||
"ucc communicator was initialized with different cuda device,"
|
||||
"multi device is not supported");
|
||||
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
||||
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
||||
}
|
||||
shared_comm->cuda_device_index = dev.index();
|
||||
}
|
||||
|
|
@ -825,7 +825,7 @@ c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
|
|||
default: {
|
||||
TORCH_UCC_LOG_ERROR(
|
||||
TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
|
||||
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
||||
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1009,7 +1009,7 @@ c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
|
|||
c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
|
||||
std::vector<at::Tensor>& /* unused */,
|
||||
const AllreduceCoalescedOptions& /* unused */) {
|
||||
throw std::runtime_error(
|
||||
throw std::invalid_argument(
|
||||
"ProcessGroupUCC does not support allreduce_coalesced");
|
||||
}
|
||||
|
||||
|
|
@ -1610,7 +1610,7 @@ void ProcessGroupUCC::initComm(c10::Device dev) {
|
|||
TORCH_UCC_INIT,
|
||||
"ucc communicator was initialized with different cuda device,"
|
||||
"multi device is not supported");
|
||||
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
||||
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
|
||||
}
|
||||
comm->cuda_device_index = dev.index();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -442,7 +442,7 @@ void TCPStore::doWait(
|
|||
TORCH_CHECK(false, "wait_canceled response is expected");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(false, "Socket Timeout");
|
||||
C10_THROW_ERROR(DistStoreError, "Socket Timeout");
|
||||
}
|
||||
|
||||
void TCPStore::append(
|
||||
|
|
|
|||
|
|
@ -473,9 +473,8 @@ void TCPStoreMasterDaemon::run() {
|
|||
// accept new connections.
|
||||
if (fds[0].revents != 0) {
|
||||
if (!(fds[0].revents & POLLIN)) {
|
||||
throw std::system_error(
|
||||
ECONNABORTED,
|
||||
std::system_category(),
|
||||
C10_THROW_ERROR(
|
||||
DistStoreError,
|
||||
"Unexpected poll revent on the master's listening socket: " +
|
||||
std::to_string(fds[0].revents));
|
||||
}
|
||||
|
|
@ -515,9 +514,8 @@ void TCPStoreMasterDaemon::run() {
|
|||
// accept new connections.
|
||||
if (fds[0].revents != 0) {
|
||||
if (fds[0].revents ^ POLLIN) {
|
||||
throw std::system_error(
|
||||
ECONNABORTED,
|
||||
std::system_category(),
|
||||
C10_THROW_ERROR(
|
||||
DistStoreError,
|
||||
"Unexpected poll revent on the master's listening socket: " +
|
||||
std::to_string(fds[0].revents));
|
||||
}
|
||||
|
|
@ -532,9 +530,8 @@ void TCPStoreMasterDaemon::run() {
|
|||
// The main thread will write a byte to the pipe then close it before
|
||||
// joining the background thread
|
||||
if (fds[1].revents & ~(POLLIN | POLLHUP)) {
|
||||
throw std::system_error(
|
||||
ECONNABORTED,
|
||||
std::system_category(),
|
||||
C10_THROW_ERROR(
|
||||
DistStoreError,
|
||||
"Unexpected poll revent on the control pipe's reading fd: " +
|
||||
std::to_string(fds[1].revents));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/distributed/c10d/Types.hpp>
|
||||
|
|
@ -521,30 +522,30 @@ using SizeType = uint64_t;
|
|||
continue; \
|
||||
} else if ( \
|
||||
errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \
|
||||
TORCH_CHECK(false, "Socket Timeout"); \
|
||||
C10_THROW_ERROR(DistNetworkError, "Socket Timeout"); \
|
||||
} else { \
|
||||
throw std::system_error(errno_local, std::system_category()); \
|
||||
C10_THROW_ERROR(DistNetworkError, std::strerror(errno_local)); \
|
||||
} \
|
||||
} else { \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define SYSCHECK(expr, success_cond) \
|
||||
while (true) { \
|
||||
auto __output = (expr); \
|
||||
(void)__output; \
|
||||
if (!(success_cond)) { \
|
||||
if (errno == EINTR) { \
|
||||
continue; \
|
||||
} else if (errno == EAGAIN || errno == EWOULDBLOCK) { \
|
||||
TORCH_CHECK(false, "Socket Timeout"); \
|
||||
} else { \
|
||||
throw std::system_error(errno, std::system_category()); \
|
||||
} \
|
||||
} else { \
|
||||
break; \
|
||||
} \
|
||||
#define SYSCHECK(expr, success_cond) \
|
||||
while (true) { \
|
||||
auto __output = (expr); \
|
||||
(void)__output; \
|
||||
if (!(success_cond)) { \
|
||||
if (errno == EINTR) { \
|
||||
continue; \
|
||||
} else if (errno == EAGAIN || errno == EWOULDBLOCK) { \
|
||||
C10_THROW_ERROR(DistNetworkError, "Socket Timeout"); \
|
||||
} else { \
|
||||
C10_THROW_ERROR(DistNetworkError, std::strerror(errno)); \
|
||||
} \
|
||||
} else { \
|
||||
break; \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -589,7 +590,7 @@ void sendBytes(
|
|||
bytesSent =
|
||||
::send(socket, (const char*)currentBytes, bytesToSend, flags))
|
||||
if (bytesSent == 0) {
|
||||
throw std::system_error(ECONNRESET, std::system_category());
|
||||
C10_THROW_ERROR(DistNetworkError, std::strerror(ECONNRESET));
|
||||
}
|
||||
|
||||
bytesToSend -= bytesSent;
|
||||
|
|
@ -612,7 +613,7 @@ void recvBytes(int socket, T* buffer, size_t length) {
|
|||
SYSCHECK_ERR_RETURN_NEG1(
|
||||
bytesReceived = recv(socket, (char*)currentBytes, bytesToReceive, 0))
|
||||
if (bytesReceived == 0) {
|
||||
throw std::system_error(ECONNRESET, std::system_category());
|
||||
C10_THROW_ERROR(DistNetworkError, std::strerror(ECONNRESET));
|
||||
}
|
||||
|
||||
bytesToReceive -= bytesReceived;
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ DebugLevel loadDebugLevelFromEnvironment() {
|
|||
} else if (level_str == "DETAIL") {
|
||||
level = DebugLevel::Detail;
|
||||
} else {
|
||||
throw C10dError{
|
||||
"The value of TORCH_DISTRIBUTED_DEBUG must be OFF, INFO, or DETAIL."};
|
||||
throw std::invalid_argument(
|
||||
"The value of TORCH_DISTRIBUTED_DEBUG must be OFF, INFO, or DETAIL.");
|
||||
}
|
||||
|
||||
C10D_INFO("The debug level is set to {}.", level_str);
|
||||
|
|
|
|||
|
|
@ -1,9 +0,0 @@
|
|||
#include <torch/csrc/distributed/c10d/exception.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
C10dError::~C10dError() = default;
|
||||
|
||||
TimeoutError::~TimeoutError() = default;
|
||||
|
||||
} // namespace c10d
|
||||
|
|
@ -9,37 +9,25 @@
|
|||
#include <stdexcept>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
// Utility macro similar to C10_THROW_ERROR, the major difference is that this
|
||||
// macro handles exception types defined in the c10d namespace, whereas
|
||||
// C10_THROW_ERROR requires an exception to be defined in the c10 namespace.
|
||||
#define C10D_THROW_ERROR(err_type, msg) \
|
||||
throw ::c10d::err_type( \
|
||||
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, msg)
|
||||
|
||||
namespace c10d {
|
||||
|
||||
class TORCH_API C10dError : public std::runtime_error {
|
||||
public:
|
||||
using std::runtime_error::runtime_error;
|
||||
using c10::DistNetworkError;
|
||||
|
||||
C10dError(const C10dError&) = default;
|
||||
|
||||
C10dError& operator=(const C10dError&) = default;
|
||||
|
||||
C10dError(C10dError&&) = default;
|
||||
|
||||
C10dError& operator=(C10dError&&) = default;
|
||||
|
||||
~C10dError() override;
|
||||
class TORCH_API SocketError : public DistNetworkError {
|
||||
using DistNetworkError::DistNetworkError;
|
||||
};
|
||||
|
||||
class TORCH_API TimeoutError : public C10dError {
|
||||
public:
|
||||
using C10dError::C10dError;
|
||||
|
||||
TimeoutError(const TimeoutError&) = default;
|
||||
|
||||
TimeoutError& operator=(const TimeoutError&) = default;
|
||||
|
||||
TimeoutError(TimeoutError&&) = default;
|
||||
|
||||
TimeoutError& operator=(TimeoutError&&) = default;
|
||||
|
||||
~TimeoutError() override;
|
||||
class TORCH_API TimeoutError : public DistNetworkError {
|
||||
using DistNetworkError::DistNetworkError;
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ void delay(std::chrono::seconds d) {
|
|||
// We don't care about error conditions other than EINTR since a failure
|
||||
// here is not critical.
|
||||
if (err == std::errc::interrupted) {
|
||||
throw std::system_error{err};
|
||||
C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
@ -271,7 +271,7 @@ std::unique_ptr<SocketImpl> SocketImpl::accept() const {
|
|||
if (hnd == invalid_socket) {
|
||||
std::error_code err = getSocketError();
|
||||
if (err == std::errc::interrupted) {
|
||||
throw std::system_error{err};
|
||||
C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
|
||||
}
|
||||
|
||||
std::string msg{};
|
||||
|
|
@ -287,7 +287,7 @@ std::unique_ptr<SocketImpl> SocketImpl::accept() const {
|
|||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw SocketError{msg};
|
||||
C10D_THROW_ERROR(SocketError, msg);
|
||||
}
|
||||
|
||||
::addrinfo addr{};
|
||||
|
|
@ -333,7 +333,8 @@ void SocketImpl::enableNonBlocking() {
|
|||
}
|
||||
}
|
||||
#endif
|
||||
throw SocketError{"The socket cannot be switched to non-blocking mode."};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError, "The socket cannot be switched to non-blocking mode.");
|
||||
}
|
||||
|
||||
// TODO: Remove once we migrate everything to non-blocking mode.
|
||||
|
|
@ -351,7 +352,8 @@ void SocketImpl::disableNonBlocking() {
|
|||
}
|
||||
}
|
||||
#endif
|
||||
throw SocketError{"The socket cannot be switched to blocking mode."};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError, "The socket cannot be switched to blocking mode.");
|
||||
}
|
||||
|
||||
bool SocketImpl::enableNoDelay() noexcept {
|
||||
|
|
@ -381,7 +383,8 @@ std::uint16_t SocketImpl::getPort() const {
|
|||
|
||||
if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) !=
|
||||
0) {
|
||||
throw SocketError{"The port number of the socket cannot be retrieved."};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError, "The port number of the socket cannot be retrieved.");
|
||||
}
|
||||
|
||||
if (addr_s.ss_family == AF_INET) {
|
||||
|
|
@ -471,7 +474,8 @@ std::unique_ptr<SocketImpl> SocketListenOp::run() {
|
|||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
|
||||
}
|
||||
|
||||
bool SocketListenOp::tryListen(int family) {
|
||||
|
|
@ -599,26 +603,31 @@ std::unique_ptr<SocketImpl> SocketListenFromFdOp::run() const {
|
|||
::socklen_t addr_len = sizeof(addr_storage);
|
||||
if (::getsockname(
|
||||
fd_, reinterpret_cast<::sockaddr*>(&addr_storage), &addr_len) < 0) {
|
||||
throw SocketError{
|
||||
fmt::format("getsockname failed for fd {}: {}", fd_, getSocketError())};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError,
|
||||
fmt::format("getsockname failed for fd {}: {}", fd_, getSocketError()));
|
||||
}
|
||||
|
||||
auto socket = std::make_unique<SocketImpl>(fd_);
|
||||
const auto port = socket->getPort();
|
||||
|
||||
if (port != expected_port_) {
|
||||
throw SocketError{fmt::format(
|
||||
"listen fd {} is bound to port {}, expected to be bound to port {}",
|
||||
fd_,
|
||||
port,
|
||||
expected_port_)};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError,
|
||||
fmt::format(
|
||||
"listen fd {} is bound to port {}, expected to be bound to port {}",
|
||||
fd_,
|
||||
port,
|
||||
expected_port_));
|
||||
}
|
||||
|
||||
if (::listen(socket->handle(), -1 /* backlog */) != 0) {
|
||||
throw SocketError{fmt::format(
|
||||
"Failed to listen on socket initialized from fd {}: {}.",
|
||||
socket->handle(),
|
||||
getSocketError())};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError,
|
||||
fmt::format(
|
||||
"Failed to listen on socket initialized from fd {}: {}.",
|
||||
socket->handle(),
|
||||
getSocketError()));
|
||||
}
|
||||
|
||||
socket->closeOnExec();
|
||||
|
|
@ -718,7 +727,8 @@ std::unique_ptr<SocketImpl> SocketConnectOp::run() {
|
|||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
|
||||
}
|
||||
|
||||
bool SocketConnectOp::tryConnect(int family) {
|
||||
|
|
@ -821,7 +831,7 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(
|
|||
if (cr == ConnectResult::Error) {
|
||||
std::error_code err = getSocketError();
|
||||
if (err == std::errc::interrupted) {
|
||||
throw std::system_error{err};
|
||||
C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
|
||||
}
|
||||
|
||||
// Retry if the server is not yet listening or if its backlog is exhausted.
|
||||
|
|
@ -921,7 +931,7 @@ void SocketConnectOp::throwTimeoutError() const {
|
|||
|
||||
C10D_ERROR(msg);
|
||||
|
||||
throw TimeoutError{msg};
|
||||
C10D_THROW_ERROR(TimeoutError, msg);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
@ -935,7 +945,8 @@ void Socket::initialize() {
|
|||
c10::call_once(init_flag, []() {
|
||||
WSADATA data{};
|
||||
if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) {
|
||||
throw SocketError{"The initialization of Winsock has failed."};
|
||||
C10D_THROW_ERROR(
|
||||
SocketError, "The initialization of Winsock has failed.");
|
||||
}
|
||||
});
|
||||
#endif
|
||||
|
|
@ -973,7 +984,7 @@ Socket Socket::accept() const {
|
|||
return Socket{impl_->accept()};
|
||||
}
|
||||
|
||||
throw SocketError{"The socket is not initialized."};
|
||||
C10D_THROW_ERROR(SocketError, "The socket is not initialized.");
|
||||
}
|
||||
|
||||
int Socket::handle() const noexcept {
|
||||
|
|
@ -999,6 +1010,4 @@ bool Socket::waitForInput(std::chrono::milliseconds timeout) {
|
|||
|
||||
} // namespace detail
|
||||
|
||||
SocketError::~SocketError() = default;
|
||||
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <string>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/distributed/c10d/exception.h>
|
||||
|
||||
namespace c10d {
|
||||
|
|
@ -89,19 +90,4 @@ class Socket {
|
|||
|
||||
} // namespace detail
|
||||
|
||||
class TORCH_API SocketError : public C10dError {
|
||||
public:
|
||||
using C10dError::C10dError;
|
||||
|
||||
SocketError(const SocketError&) = default;
|
||||
|
||||
SocketError& operator=(const SocketError&) = default;
|
||||
|
||||
SocketError(SocketError&&) = default;
|
||||
|
||||
SocketError& operator=(SocketError&&) = default;
|
||||
|
||||
~SocketError() override;
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
|
|
|||
|
|
@ -20,7 +20,10 @@ if is_available() and not torch._C._c10d_init():
|
|||
raise RuntimeError("Failed to initialize torch.distributed")
|
||||
|
||||
# Custom Runtime Errors thrown from the distributed package
|
||||
DistError = torch._C._DistError
|
||||
DistBackendError = torch._C._DistBackendError
|
||||
DistNetworkError = torch._C._DistNetworkError
|
||||
DistStoreError = torch._C._DistStoreError
|
||||
|
||||
if is_available():
|
||||
from torch._C._distributed_c10d import (
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from torch._C._distributed_c10d import (
|
|||
from .constants import default_pg_timeout
|
||||
from .c10d_logger import _exception_logger, _time_logger
|
||||
from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
|
||||
DistStoreError = torch._C._DistStoreError
|
||||
|
||||
__all__ = [
|
||||
'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced',
|
||||
|
|
@ -675,7 +676,7 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log
|
|||
)
|
||||
|
||||
if timedelta(seconds=(time.time() - start)) > timeout:
|
||||
raise RuntimeError(
|
||||
raise DistStoreError(
|
||||
"Timed out initializing process group in store based barrier on "
|
||||
"rank {}, for key: {} (world_size={}, num_workers_joined={}, timeout={})".format(
|
||||
rank, store_key, world_size, worker_count, timeout
|
||||
|
|
@ -722,10 +723,10 @@ def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
|
|||
if group is GroupMember.WORLD:
|
||||
return global_rank
|
||||
if group not in _world.pg_group_ranks:
|
||||
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
|
||||
raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
|
||||
group_ranks = _world.pg_group_ranks[group]
|
||||
if global_rank not in group_ranks:
|
||||
raise RuntimeError(f"Global rank {global_rank} is not part of group {group}")
|
||||
raise ValueError(f"Global rank {global_rank} is not part of group {group}")
|
||||
|
||||
return group_ranks[global_rank]
|
||||
|
||||
|
|
@ -747,11 +748,11 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
|
|||
if group is GroupMember.WORLD:
|
||||
return group_rank
|
||||
if group not in _world.pg_group_ranks:
|
||||
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
|
||||
raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
|
||||
for rank, grp_rank in _world.pg_group_ranks[group].items():
|
||||
if grp_rank == group_rank:
|
||||
return rank
|
||||
raise RuntimeError(f"Group rank {group_rank} is not part of group {group}")
|
||||
raise ValueError(f"Group rank {group_rank} is not part of group {group}")
|
||||
|
||||
# TODO: remove this once the ecosystem moves away from it.
|
||||
def _get_global_rank(group, rank):
|
||||
|
|
@ -792,7 +793,7 @@ def _check_single_tensor(param, param_name):
|
|||
Helper to check that the parameter ``param_name`` is a single tensor.
|
||||
"""
|
||||
if not isinstance(param, torch.Tensor):
|
||||
raise RuntimeError(
|
||||
raise TypeError(
|
||||
f"Invalid function argument. Expected parameter `{param_name}` to be of type torch.Tensor."
|
||||
)
|
||||
|
||||
|
|
@ -804,7 +805,7 @@ def _check_tensor_list(param, param_name):
|
|||
if not isinstance(param, list) or not all(
|
||||
isinstance(p, torch.Tensor) for p in param
|
||||
):
|
||||
raise RuntimeError(
|
||||
raise TypeError(
|
||||
f"Invalid function argument. Expected parameter `{param_name}` to be of type List[torch.Tensor]."
|
||||
)
|
||||
|
||||
|
|
@ -823,7 +824,7 @@ def _ensure_all_tensors_same_dtype(*tensors) -> None:
|
|||
last_dtype = tensor_dtype
|
||||
else:
|
||||
if last_dtype != tensor_dtype:
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"Invalid usage of tensors with different dtypes"
|
||||
f"Found {last_dtype} and {tensor.dtype}"
|
||||
)
|
||||
|
|
@ -834,7 +835,7 @@ def _check_op(op):
|
|||
Helper to check that the ``op`` is either isend or irecv.
|
||||
"""
|
||||
if op not in [isend, irecv]:
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"Invalid ``op``. Expected ``op`` "
|
||||
"to be of type ``torch.distributed.isend`` or "
|
||||
"``torch.distributed.irecv``."
|
||||
|
|
@ -849,14 +850,14 @@ def _check_p2p_op_list(p2p_op_list):
|
|||
if not isinstance(p2p_op_list, list) or not all(
|
||||
isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
|
||||
):
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"Invalid ``p2p_op_list``. Each op is expected to "
|
||||
"to be of type ``torch.distributed.P2POp``."
|
||||
)
|
||||
|
||||
group = p2p_op_list[0].group
|
||||
if not all(group == p2p_op.group for p2p_op in p2p_op_list):
|
||||
raise RuntimeError("All ops need to use the same group.")
|
||||
raise ValueError("All ops need to use the same group.")
|
||||
|
||||
|
||||
def is_mpi_available() -> bool:
|
||||
|
|
@ -937,7 +938,7 @@ def _get_default_group():
|
|||
Getting the default process group created by init_process_group
|
||||
"""
|
||||
if not is_initialized():
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"Default process group has not been initialized, "
|
||||
"please make sure to call init_process_group."
|
||||
)
|
||||
|
|
@ -949,7 +950,7 @@ def _get_default_store():
|
|||
Getting the default store created by init_process_group
|
||||
"""
|
||||
if not is_initialized():
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"Default process group has not been initialized, "
|
||||
"please make sure to call init_process_group."
|
||||
)
|
||||
|
|
@ -967,7 +968,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
|
|||
else:
|
||||
pg = group
|
||||
if _rank_not_in_group(pg):
|
||||
raise RuntimeError("Invalid process group specified")
|
||||
raise ValueError("Invalid process group specified")
|
||||
backend_config = _world.pg_backend_config.get(pg)
|
||||
assert backend_config is not None
|
||||
return str(backend_config)
|
||||
|
|
@ -990,7 +991,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> str:
|
|||
else:
|
||||
pg = group
|
||||
if _rank_not_in_group(pg):
|
||||
raise RuntimeError("Invalid process group specified")
|
||||
raise ValueError("Invalid process group specified")
|
||||
pg_store = _world.pg_map[pg] if pg in _world.pg_map else None
|
||||
assert pg_store is not None
|
||||
return pg_store[0]
|
||||
|
|
@ -1092,12 +1093,12 @@ def init_process_group(
|
|||
global _default_pg_init_method
|
||||
|
||||
if not isinstance(timeout, timedelta):
|
||||
raise RuntimeError(
|
||||
raise TypeError(
|
||||
"Expected timeout argument to be of type datetime.timedelta"
|
||||
)
|
||||
|
||||
if GroupMember.WORLD is not None:
|
||||
raise RuntimeError("trying to initialize the default process group twice!")
|
||||
raise ValueError("trying to initialize the default process group twice!")
|
||||
|
||||
assert (store is None) or (
|
||||
init_method is None
|
||||
|
|
@ -1206,13 +1207,13 @@ def _new_process_group_helper(
|
|||
global _world
|
||||
|
||||
if group_name in _world.pg_names.values():
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"The specified group name has already been "
|
||||
"created, please use a different group name"
|
||||
)
|
||||
|
||||
if not isinstance(timeout, timedelta):
|
||||
raise RuntimeError(
|
||||
raise TypeError(
|
||||
"Expected timeout argument to be of type datetime.timedelta"
|
||||
)
|
||||
|
||||
|
|
@ -1387,7 +1388,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
|||
|
||||
assert pg is not None
|
||||
if _world.pg_map.get(pg, None) is None:
|
||||
raise RuntimeError("Invalid process group specified")
|
||||
raise ValueError("Invalid process group specified")
|
||||
|
||||
# When users register Python onCompletion hooks, those hooks will run on a
|
||||
# different thread than the main thread. Today, the ProcessGroup dtor does
|
||||
|
|
@ -1647,7 +1648,7 @@ def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proces
|
|||
class _IllegalWork(Work):
|
||||
def __getattribute__(self, name):
|
||||
if name in ["is_success", "exception", "wait", "source_rank", "_source_rank", "result", "synchronize"]:
|
||||
raise RuntimeError(f"Illegal to call {name} on IllegalWork object")
|
||||
raise ValueError(f"Illegal to call {name} on IllegalWork object")
|
||||
|
||||
|
||||
class _CoalescingManager:
|
||||
|
|
@ -1699,7 +1700,7 @@ def _coalescing_manager(
|
|||
group = group or _get_default_group()
|
||||
op_list = _world.pg_coalesce_state.setdefault(group, [])
|
||||
if op_list:
|
||||
raise RuntimeError("ProcessGroup has non-empty op list at the start of coalescing")
|
||||
raise ValueError("ProcessGroup has non-empty op list at the start of coalescing")
|
||||
if device:
|
||||
group._start_coalescing(device)
|
||||
cm = _CoalescingManager()
|
||||
|
|
@ -2030,7 +2031,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
|||
|
||||
if tensor.is_complex():
|
||||
if not supports_complex(op):
|
||||
raise RuntimeError(f"all_reduce does not support {op} on complex tensors")
|
||||
raise ValueError(f"all_reduce does not support {op} on complex tensors")
|
||||
tensor = torch.view_as_real(tensor)
|
||||
|
||||
opts = AllreduceOptions()
|
||||
|
|
@ -2100,7 +2101,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
|
|||
return
|
||||
|
||||
if any(t.is_complex() for t in tensors) and not supports_complex(op):
|
||||
raise RuntimeError(f"all_reduce does not support {op} on complex tensors")
|
||||
raise ValueError(f"all_reduce does not support {op} on complex tensors")
|
||||
|
||||
tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]
|
||||
|
||||
|
|
@ -2695,7 +2696,7 @@ def scatter_object_list(
|
|||
not isinstance(scatter_object_output_list, list)
|
||||
or len(scatter_object_output_list) < 1
|
||||
):
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"Expected argument scatter_object_output_list to be a list of size at least 1."
|
||||
)
|
||||
|
||||
|
|
@ -2992,7 +2993,7 @@ def all_gather_coalesced(
|
|||
_check_tensor_list(input_tensor_list, "input_tensor_list")
|
||||
_ensure_all_tensors_same_dtype(input_tensor_list)
|
||||
if not isinstance(output_tensor_lists, list):
|
||||
raise RuntimeError(
|
||||
raise TypeError(
|
||||
"Invalid function argument: output_tensor_lists should be a list"
|
||||
)
|
||||
for output_tensor_list in output_tensor_lists:
|
||||
|
|
@ -3687,7 +3688,7 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
|
|||
if isinstance(device_ids, list):
|
||||
opts.device_ids = device_ids
|
||||
else:
|
||||
raise RuntimeError(
|
||||
raise TypeError(
|
||||
"Invalid function argument: device_ids type should be List[int]"
|
||||
)
|
||||
|
||||
|
|
@ -3760,7 +3761,7 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals
|
|||
return
|
||||
|
||||
if get_backend(group) != Backend.GLOO:
|
||||
raise RuntimeError("monitored_barrier is only implemented for GLOO backend.")
|
||||
raise ValueError("monitored_barrier is only implemented for GLOO backend.")
|
||||
|
||||
if timeout is None:
|
||||
timeout = default_pg_timeout
|
||||
|
|
@ -3909,7 +3910,7 @@ def _new_group_with_tag(
|
|||
if use_local_synchronization:
|
||||
# MPI backend doesn't have have a way for us to perform a partial sync
|
||||
if backend == Backend.MPI:
|
||||
raise RuntimeError("MPI backend doesn't support use_local_synchronization=True")
|
||||
raise ValueError("MPI backend doesn't support use_local_synchronization=True")
|
||||
if ranks is not None and get_rank() not in ranks:
|
||||
return None
|
||||
|
||||
|
|
@ -3918,7 +3919,7 @@ def _new_group_with_tag(
|
|||
ranks = sorted(ranks)
|
||||
group_world_size = len(ranks)
|
||||
if group_world_size > global_world_size:
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"the new group's world size should be less or "
|
||||
"equal to the world size set by "
|
||||
"init_process_group"
|
||||
|
|
@ -3926,7 +3927,7 @@ def _new_group_with_tag(
|
|||
# check ranks' sanity
|
||||
for rank in ranks:
|
||||
if rank < 0 or rank >= global_world_size:
|
||||
raise RuntimeError(
|
||||
raise ValueError(
|
||||
"The new group's rank should be within the "
|
||||
"the world_size set by init_process_group"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -713,7 +713,7 @@ class DistributedTest:
|
|||
self.assertEqual(dist.get_backend(group_id), backend_str)
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Invalid process group specified"
|
||||
ValueError, "Invalid process group specified"
|
||||
):
|
||||
dist.get_backend(group_id)
|
||||
|
||||
|
|
@ -970,7 +970,7 @@ class DistributedTest:
|
|||
group, group_id, rank = self._init_global_test()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
"The new group's rank should be within the the world_size set by init_process_group",
|
||||
):
|
||||
dist.new_subgroups_by_enumeration(
|
||||
|
|
@ -1492,7 +1492,7 @@ class DistributedTest:
|
|||
if rank == 0:
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
device_id = rank_to_GPU[rank][0]
|
||||
with self.assertRaisesRegex(RuntimeError, "^Invalid ``op``"):
|
||||
with self.assertRaisesRegex(ValueError, "^Invalid ``op``"):
|
||||
send_tensor = _build_tensor(rank + 1, device_id=device_id)
|
||||
send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
|
||||
dist.batch_isend_irecv([send_op])
|
||||
|
|
@ -1504,7 +1504,7 @@ class DistributedTest:
|
|||
self._barrier()
|
||||
rank = dist.get_rank()
|
||||
if rank == 0:
|
||||
with self.assertRaisesRegex(RuntimeError, "^Invalid ``p2p_op_list``"):
|
||||
with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"):
|
||||
dist.batch_isend_irecv([1, 2])
|
||||
|
||||
# NCCL Batch SEND RECV Mixed Backend Error
|
||||
|
|
@ -1519,7 +1519,7 @@ class DistributedTest:
|
|||
group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
|
||||
if rank == 0:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "All ops need to use the same group"
|
||||
ValueError, "All ops need to use the same group"
|
||||
):
|
||||
send_tensor = _build_tensor(rank + 1)
|
||||
send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
|
||||
|
|
@ -2772,7 +2772,7 @@ class DistributedTest:
|
|||
group, group_id, rank = self._init_global_test()
|
||||
for unsupported_op in unsupported_ops:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "all_reduce does not support"
|
||||
ValueError, "all_reduce does not support"
|
||||
):
|
||||
dist.all_reduce(
|
||||
_build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
|
||||
|
|
@ -3004,7 +3004,7 @@ class DistributedTest:
|
|||
)
|
||||
def test_all_reduce_coalesced_max_complex_unsupported(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"):
|
||||
with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
|
||||
dist.all_reduce_coalesced(
|
||||
[_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
|
||||
)
|
||||
|
|
@ -8306,7 +8306,7 @@ class DistributedTest:
|
|||
)
|
||||
# Ensure errors are raised upon incorrect arguments.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
"Expected argument scatter_object_output_list to be a list of size at least 1.",
|
||||
):
|
||||
dist.scatter_object_list([], scatter_list, src=src_rank)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user