mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add api to enable/disable NaN detector per-PG (#151723)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151723 Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
parent
414ce713fb
commit
2673ea4131
|
|
@ -499,6 +499,8 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
backend = pg._get_backend(torch.device("cuda"))
|
||||
|
||||
device = self.rank_to_GPU[self.rank][0]
|
||||
# Cover different buffer sizes
|
||||
if type == torch.float64:
|
||||
|
|
@ -526,6 +528,12 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
nan_tensor = nan_tensor.to(type)
|
||||
|
||||
output = torch.empty(self.world_size, *size, dtype=type, device=device)
|
||||
|
||||
# confirm enable/disable flag works
|
||||
backend._set_enable_nan_check(False)
|
||||
pg.allreduce(nan_tensor)
|
||||
|
||||
backend._set_enable_nan_check(True)
|
||||
with self.assertRaises(RuntimeError):
|
||||
# Note: using all-gather here bc FP8 types do not support reduce ops
|
||||
# at the moment
|
||||
|
|
|
|||
|
|
@ -1669,6 +1669,10 @@ std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg(
|
|||
".");
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::setEnableNanCheck(bool enableNanCheck) {
|
||||
enableNanCheck_ = enableNanCheck;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
c10::setThreadName("pt_nccl_heartbt");
|
||||
|
||||
|
|
|
|||
|
|
@ -857,6 +857,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||
const c10::intrusive_ptr<Work>& work,
|
||||
const std::chrono::milliseconds& timeout);
|
||||
|
||||
void setEnableNanCheck(bool enableNanCheck);
|
||||
|
||||
protected:
|
||||
// Helper that broadcasts nccl unique ID to all ranks through the store
|
||||
void broadcastUniqueNCCLID(
|
||||
|
|
|
|||
|
|
@ -3148,6 +3148,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
|||
.def(
|
||||
"get_error",
|
||||
&::c10d::ProcessGroupNCCL::getError,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_set_enable_nan_check",
|
||||
[](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
|
||||
bool enable_nan_check) {
|
||||
self->setEnableNanCheck(enable_nan_check);
|
||||
},
|
||||
py::arg("enable_nan_check"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
module.def(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user