diff --git a/lib/training/lr_finder.py b/lib/training/lr_finder.py index d7c3298..2ed81a3 100644 --- a/lib/training/lr_finder.py +++ b/lib/training/lr_finder.py @@ -132,7 +132,8 @@ class LearningRateFinder: for idx in pbar: model_inputs, model_targets = self._feeder.get_batch() loss: list[float] = self._model.model.train_on_batch(model_inputs, y=model_targets) - if np.isnan(loss[0]): + if any(np.isnan(x) for x in loss): + logger.warning("NaN detected! Exiting early") break self._on_batch_end(idx, loss[0]) self._update_description(pbar)