diff --git a/lib/align/aligned_face.py b/lib/align/aligned_face.py index 4043730..6f41cab 100644 --- a/lib/align/aligned_face.py +++ b/lib/align/aligned_face.py @@ -9,7 +9,7 @@ from threading import Lock import cv2 import numpy as np -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) # pylint:disable=invalid-name CenteringType = T.Literal["face", "head", "legacy"] _MEAN_FACE = np.array([[0.010086, 0.106454], [0.085135, 0.038915], [0.191003, 0.018748], @@ -810,6 +810,25 @@ class AlignedFace(): self._cache.cropped_roi[centering] = roi return self._cache.cropped_roi[centering] + def split_mask(self) -> np.ndarray: + """ Remove the mask from the alpha channel of :attr:`face` and return the mask + + Returns + ------- + :class:`numpy.ndarray` + The mask that was stored in the :attr:`face`'s alpha channel + + Raises + ------ + AssertionError + If :attr:`face` does not contain a mask in the alpha channel + """ + assert self._face is not None + assert self._face.shape[-1] == 4, "No mask stored in the alpha channel" + mask = self._face[..., 3] + self._face = self._face[..., :3] + return mask + def _umeyama(source: np.ndarray, destination: np.ndarray, estimate_scale: bool) -> np.ndarray: """Estimate N-D similarity transformation with or without scaling. diff --git a/plugins/extract/mask/_base.py b/plugins/extract/mask/_base.py index 837b681..82a6e07 100644 --- a/plugins/extract/mask/_base.py +++ b/plugins/extract/mask/_base.py @@ -166,8 +166,7 @@ class Masker(Extractor): # pylint:disable=abstract-method assert feed_face.face is not None if not item.is_aligned: # Split roi mask from feed face alpha channel - roi_mask = feed_face.face[..., 3] - feed_face._face = feed_face.face[..., :3] # pylint:disable=protected-access + roi_mask = feed_face.split_mask() else: # We have to do the warp here as AlignedFace did not perform it roi_mask = transform_image(roi,