bugfix: Generator for AMD

This commit is contained in:
torzdf 2022-08-29 01:48:20 +01:00
parent f3b88d5626
commit fe8e34f99e

View File

@ -71,7 +71,7 @@ class DataGenerator():
self._batch_size = batch_size
self._process_size = max(img[1] for img in model.input_shapes + model.output_shapes)
self._output_sizes = [shape[1] for shape in model.output_shapes if shape[-1] != 1]
self._output_sizes = self._get_output_sizes(model)
self._coverage_ratio = model.coverage_ratio
self._color_order = model.color_order.lower()
@ -103,6 +103,27 @@ class DataGenerator():
channels += len(mults)
return channels
def _get_output_sizes(self, model: "ModelBase") -> List[int]:
""" Obtain the size of each output tensor for the model.
Parameters
----------
model: :class:`~plugins.train.model.ModelBase`
The model that this data generator is feeding
Returns
-------
list
A list of integers for the model output size for the current side
"""
out_shapes = model.output_shapes
split = len(out_shapes) // 2
side_out = out_shapes[:split] if self._side == "a" else out_shapes[split:]
retval = [shape[1] for shape in side_out if shape[-1] != 1]
logger.debug("side: %s, model output shapes: %s, output sizes: %s",
self._side, model.output_shapes, retval)
return retval
def minibatch_ab(self, do_shuffle: bool = True) -> Generator[BatchType, None, None]:
""" A Background iterator to return augmented images, samples and targets.
@ -313,7 +334,6 @@ class DataGenerator():
batch = self._buffer()
self._crop_to_coverage(filenames, raw_faces, detected_faces, batch)
self._apply_mask(detected_faces, batch)
return self.process_batch(filenames, raw_faces, detected_faces, batch)
def process_batch(self,