Add api to enable/disable NaN detector per-PG (#151723)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151723
Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
Will Constable 2025-04-18 16:40:31 -07:00 committed by PyTorch MergeBot
parent 414ce713fb
commit 2673ea4131
4 changed files with 22 additions and 0 deletions

View File

@ -499,6 +499,8 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
backend = pg._get_backend(torch.device("cuda"))
device = self.rank_to_GPU[self.rank][0]
# Cover different buffer sizes
if type == torch.float64:
@ -526,6 +528,12 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
nan_tensor = nan_tensor.to(type)
output = torch.empty(self.world_size, *size, dtype=type, device=device)
# confirm enable/disable flag works
backend._set_enable_nan_check(False)
pg.allreduce(nan_tensor)
backend._set_enable_nan_check(True)
with self.assertRaises(RuntimeError):
# Note: using all-gather here bc FP8 types do not support reduce ops
# at the moment

View File

@ -1669,6 +1669,10 @@ std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg(
".");
}
void ProcessGroupNCCL::setEnableNanCheck(bool enableNanCheck) {
enableNanCheck_ = enableNanCheck;
}
void ProcessGroupNCCL::heartbeatMonitor() {
c10::setThreadName("pt_nccl_heartbt");

View File

@ -857,6 +857,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
const c10::intrusive_ptr<Work>& work,
const std::chrono::milliseconds& timeout);
void setEnableNanCheck(bool enableNanCheck);
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(

View File

@ -3148,6 +3148,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def(
"get_error",
&::c10d::ProcessGroupNCCL::getError,
py::call_guard<py::gil_scoped_release>())
.def(
"_set_enable_nan_check",
[](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
bool enable_nan_check) {
self->setEnableNanCheck(enable_nan_check);
},
py::arg("enable_nan_check"),
py::call_guard<py::gil_scoped_release>());
module.def(