mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
bugfix: Generator for AMD
This commit is contained in:
parent
f3b88d5626
commit
fe8e34f99e
|
|
@ -71,7 +71,7 @@ class DataGenerator():
|
|||
self._batch_size = batch_size
|
||||
|
||||
self._process_size = max(img[1] for img in model.input_shapes + model.output_shapes)
|
||||
self._output_sizes = [shape[1] for shape in model.output_shapes if shape[-1] != 1]
|
||||
self._output_sizes = self._get_output_sizes(model)
|
||||
|
||||
self._coverage_ratio = model.coverage_ratio
|
||||
self._color_order = model.color_order.lower()
|
||||
|
|
@ -103,6 +103,27 @@ class DataGenerator():
|
|||
channels += len(mults)
|
||||
return channels
|
||||
|
||||
def _get_output_sizes(self, model: "ModelBase") -> List[int]:
|
||||
""" Obtain the size of each output tensor for the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: :class:`~plugins.train.model.ModelBase`
|
||||
The model that this data generator is feeding
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of integers for the model output size for the current side
|
||||
"""
|
||||
out_shapes = model.output_shapes
|
||||
split = len(out_shapes) // 2
|
||||
side_out = out_shapes[:split] if self._side == "a" else out_shapes[split:]
|
||||
retval = [shape[1] for shape in side_out if shape[-1] != 1]
|
||||
logger.debug("side: %s, model output shapes: %s, output sizes: %s",
|
||||
self._side, model.output_shapes, retval)
|
||||
return retval
|
||||
|
||||
def minibatch_ab(self, do_shuffle: bool = True) -> Generator[BatchType, None, None]:
|
||||
""" A Background iterator to return augmented images, samples and targets.
|
||||
|
||||
|
|
@ -313,7 +334,6 @@ class DataGenerator():
|
|||
batch = self._buffer()
|
||||
self._crop_to_coverage(filenames, raw_faces, detected_faces, batch)
|
||||
self._apply_mask(detected_faces, batch)
|
||||
|
||||
return self.process_batch(filenames, raw_faces, detected_faces, batch)
|
||||
|
||||
def process_batch(self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user