bugfix: train - fix occasional mem leak in preview

This commit is contained in:
torzdf 2023-08-06 13:16:32 +01:00
parent 81e3bf5cbe
commit 0bba6ffd8b

View File

@ -693,8 +693,11 @@ class _Samples(): # pylint:disable=too-few-public-methods
"""
logger.debug("Getting Predictions")
preds: dict[str, np.ndarray] = {}
standard = self._model.model.predict([feed_a, feed_b], verbose=0)
swapped = self._model.model.predict([feed_b, feed_a], verbose=0)
# Calling model.predict() can lead to both VRAM and system memory leaks, so call model
# directly
standard = [t.numpy() for t in self._model.model([feed_a, feed_b])]
swapped = [t.numpy() for t in self._model.model([feed_b, feed_a])]
if self._model.config["learn_mask"]: # Add mask to 4th channel of final output
standard = [np.concatenate(side[-2:], axis=-1) for side in standard]