mirror of
https://github.com/zebrajr/faceswap.git
synced 2025-12-06 00:20:09 +01:00
Update GAN64 to v2 (#217)
* Clearer requirements for each platform
* Refactoring of old plugins (Model_Original + Extract_Align) + Cleanups
* Adding GAN128
* Update GAN to v2
* Create instance_normalization.py
* Fix decoder output
* Revert "Fix decoder output"
This reverts commit 3a8ecb8957fe65e66282197455d775eb88455a77.
* Fix convert
* Enable all options except perceptual_loss by default
* Disable instance norm
* Update Model.py
* Update Trainer.py
* Match GAN128 to shaoanlu's latest v2
* Add first_order to GAN128
* Disable `use_perceptual_loss`
* Fix call to `self.first_order`
* Switch to average loss in output
* Constrain average to last 100 iterations
* Fix math, constrain average to intervals of 100
* Fix math averaging again
* Remove math and simplify this damn averagin
* Add gan128 conversion
* Update convert.py
* Use non-warped images in masked preview
* Add K.set_learning_phase(1) to gan64
* Add K.set_learning_phase(1) to gan128
* Add missing keras import
* Use non-warped images in masked preview for gan128
* Exclude deleted faces from conversion
* --input-aligned-dir defaults to "{input_dir}/aligned"
* Simplify map operation
* port 'face_alignment' from PyTorch to Keras. It works x2 faster, but initialization takes 20secs.
2DFAN-4.h5 and mmod_human_face_detector.dat included in lib\FaceLandmarksExtractor
fixed dlib vs tensorflow conflict: dlib must do op first, then load keras model, otherwise CUDA OOM error
if face location not found by CNN, its try to find by HOG.
removed this:
- if face.landmarks == None:
- print("Warning! landmarks not found. Switching to crop!")
- return cv2.resize(face.image, (size, size))
because DetectedFace always has landmarks
* Enabled masked converter for GAN models
* Histogram matching, cli option for perceptual loss
* Fix init() positional args error
* Add backwards compatibility for aligned filenames
* Fix masked converter
* Remove GAN converters
This commit is contained in:
parent
120535eb11
commit
810bd0bce7
|
|
@ -5,7 +5,7 @@ if sys.version_info[0] < 3:
|
|||
if sys.version_info[0] == 3 and sys.version_info[1] < 2:
|
||||
raise Exception("This program requires at least python3.2")
|
||||
|
||||
from lib.utils import FullHelpArgumentParser
|
||||
from lib.cli import FullHelpArgumentParser
|
||||
|
||||
from scripts.extract import ExtractTrainingData
|
||||
from scripts.train import TrainingProcessor
|
||||
|
|
|
|||
|
|
@ -1,91 +0,0 @@
|
|||
# AutoEncoder base classes
|
||||
|
||||
import time
|
||||
import numpy
|
||||
from lib.training_data import TrainingDataGenerator, stack_images
|
||||
|
||||
encoderH5 = 'encoder.h5'
|
||||
decoder_AH5 = 'decoder_A.h5'
|
||||
decoder_BH5 = 'decoder_B.h5'
|
||||
|
||||
class ModelAE:
|
||||
def __init__(self, model_dir):
|
||||
self.model_dir = model_dir
|
||||
|
||||
self.encoder = self.Encoder()
|
||||
self.decoder_A = self.Decoder()
|
||||
self.decoder_B = self.Decoder()
|
||||
|
||||
self.initModel()
|
||||
|
||||
def load(self, swapped):
|
||||
(face_A,face_B) = (decoder_AH5, decoder_BH5) if not swapped else (decoder_BH5, decoder_AH5)
|
||||
|
||||
try:
|
||||
self.encoder.load_weights(str(self.model_dir / encoderH5))
|
||||
self.decoder_A.load_weights(str(self.model_dir / face_A))
|
||||
self.decoder_B.load_weights(str(self.model_dir / face_B))
|
||||
print('loaded model weights')
|
||||
return True
|
||||
except Exception as e:
|
||||
print('Failed loading existing training data.')
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def save_weights(self):
|
||||
self.encoder.save_weights(str(self.model_dir / encoderH5))
|
||||
self.decoder_A.save_weights(str(self.model_dir / decoder_AH5))
|
||||
self.decoder_B.save_weights(str(self.model_dir / decoder_BH5))
|
||||
print('saved model weights')
|
||||
|
||||
class TrainerAE():
|
||||
random_transform_args = {
|
||||
'rotation_range': 10,
|
||||
'zoom_range': 0.05,
|
||||
'shift_range': 0.05,
|
||||
'random_flip': 0.4,
|
||||
}
|
||||
|
||||
def __init__(self, model, fn_A, fn_B, batch_size=64):
|
||||
self.batch_size = batch_size
|
||||
self.model = model
|
||||
|
||||
generator = TrainingDataGenerator(self.random_transform_args, 160)
|
||||
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
|
||||
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
|
||||
|
||||
def train_one_step(self, iter, viewer):
|
||||
epoch, warped_A, target_A = next(self.images_A)
|
||||
epoch, warped_B, target_B = next(self.images_B)
|
||||
|
||||
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
|
||||
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
|
||||
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
|
||||
end='\r')
|
||||
|
||||
if viewer is not None:
|
||||
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
|
||||
|
||||
def show_sample(self, test_A, test_B):
|
||||
figure_A = numpy.stack([
|
||||
test_A,
|
||||
self.model.autoencoder_A.predict(test_A),
|
||||
self.model.autoencoder_B.predict(test_A),
|
||||
], axis=1)
|
||||
figure_B = numpy.stack([
|
||||
test_B,
|
||||
self.model.autoencoder_B.predict(test_B),
|
||||
self.model.autoencoder_A.predict(test_B),
|
||||
], axis=1)
|
||||
|
||||
if test_A.shape[0] % 2 == 1:
|
||||
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
|
||||
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
|
||||
|
||||
figure = numpy.concatenate([figure_A, figure_B], axis=0)
|
||||
w = 4
|
||||
h = int( figure.shape[0] / w)
|
||||
figure = figure.reshape((w, h) + figure.shape[1:])
|
||||
figure = stack_images(figure)
|
||||
|
||||
return numpy.clip(figure * 255, 0, 255).astype('uint8')
|
||||
13
lib/cli.py
13
lib/cli.py
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
|
|
@ -15,6 +16,16 @@ class FullPaths(argparse.Action):
|
|||
setattr(namespace, self.dest, os.path.abspath(
|
||||
os.path.expanduser(values)))
|
||||
|
||||
class FullHelpArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
Identical to the built-in argument parser, but on error
|
||||
it prints full help message instead of just usage information
|
||||
"""
|
||||
def error(self, message):
|
||||
self.print_help(sys.stderr)
|
||||
args = {'prog': self.prog, 'message': message}
|
||||
self.exit(2, '%(prog)s: error: %(message)s\n' % args)
|
||||
|
||||
class DirectoryProcessor(object):
|
||||
'''
|
||||
Abstract class that processes a directory of images
|
||||
|
|
@ -35,7 +46,6 @@ class DirectoryProcessor(object):
|
|||
self.create_parser(subparser, command, description)
|
||||
self.parse_arguments(description, subparser, command)
|
||||
|
||||
|
||||
def process_arguments(self, arguments):
|
||||
self.arguments = arguments
|
||||
print("Input Directory: {}".format(self.arguments.input_dir))
|
||||
|
|
@ -58,6 +68,7 @@ class DirectoryProcessor(object):
|
|||
pass
|
||||
|
||||
self.output_dir = get_folder(self.arguments.output_dir)
|
||||
|
||||
try:
|
||||
try:
|
||||
if self.arguments.skip_existing:
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ from .utils import BackgroundGenerator
|
|||
from .umeyama import umeyama
|
||||
|
||||
class TrainingDataGenerator():
|
||||
def __init__(self, random_transform_args, coverage):
|
||||
def __init__(self, random_transform_args, coverage, scale=5, zoom=1): #TODO thos default should stay in the warp function
|
||||
self.random_transform_args = random_transform_args
|
||||
self.coverage = coverage
|
||||
self.scale = scale
|
||||
self.zoom = zoom
|
||||
|
||||
def minibatchAB(self, images, batchsize):
|
||||
batch = BackgroundGenerator(self.minibatch(images, batchsize), 1)
|
||||
|
|
@ -42,7 +44,7 @@ class TrainingDataGenerator():
|
|||
|
||||
image = cv2.resize(image, (256,256))
|
||||
image = self.random_transform( image, **self.random_transform_args )
|
||||
warped_img, target_img = self.random_warp( image, self.coverage )
|
||||
warped_img, target_img = self.random_warp( image, self.coverage, self.scale, self.zoom )
|
||||
|
||||
return warped_img, target_img
|
||||
|
||||
|
|
@ -61,25 +63,25 @@ class TrainingDataGenerator():
|
|||
return result
|
||||
|
||||
# get pair of random warped images from aligned face image
|
||||
def random_warp(self, image, coverage):
|
||||
def random_warp(self, image, coverage, scale = 5, zoom = 1):
|
||||
assert image.shape == (256, 256, 3)
|
||||
range_ = numpy.linspace(128 - coverage//2, 128 + coverage//2, 5)
|
||||
mapx = numpy.broadcast_to(range_, (5, 5))
|
||||
mapy = mapx.T
|
||||
|
||||
mapx = mapx + numpy.random.normal(size=(5, 5), scale=5)
|
||||
mapy = mapy + numpy.random.normal(size=(5, 5), scale=5)
|
||||
mapx = mapx + numpy.random.normal(size=(5,5), scale=scale)
|
||||
mapy = mapy + numpy.random.normal(size=(5,5), scale=scale)
|
||||
|
||||
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
|
||||
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
|
||||
interp_mapx = cv2.resize(mapx, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32')
|
||||
interp_mapy = cv2.resize(mapy, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32')
|
||||
|
||||
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
|
||||
|
||||
src_points = numpy.stack([mapx.ravel(), mapy.ravel()], axis=-1)
|
||||
dst_points = numpy.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2)
|
||||
src_points = numpy.stack([mapx.ravel(), mapy.ravel() ], axis=-1)
|
||||
dst_points = numpy.mgrid[0:65*zoom:16*zoom,0:65*zoom:16*zoom].T.reshape(-1,2)
|
||||
mat = umeyama(src_points, dst_points, True)[0:2]
|
||||
|
||||
target_image = cv2.warpAffine(image, mat, (64, 64))
|
||||
target_image = cv2.warpAffine(image, mat, (64*zoom,64*zoom))
|
||||
|
||||
return warped_image, target_image
|
||||
|
||||
|
|
|
|||
11
lib/utils.py
11
lib/utils.py
|
|
@ -1,4 +1,3 @@
|
|||
import argparse
|
||||
import sys
|
||||
from os.path import basename, exists
|
||||
|
||||
|
|
@ -31,16 +30,6 @@ def get_image_paths(directory, exclude=[], debug=False):
|
|||
|
||||
return dir_contents
|
||||
|
||||
class FullHelpArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
Identical to the built-in argument parser, but on error
|
||||
it prints full help message instead of just usage information
|
||||
"""
|
||||
def error(self, message):
|
||||
self.print_help(sys.stderr)
|
||||
args = {'prog': self.prog, 'message': message}
|
||||
self.exit(2, '%(prog)s: error: %(message)s\n' % args)
|
||||
|
||||
# From: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
||||
import threading
|
||||
import queue as Queue
|
||||
|
|
|
|||
|
|
@ -8,18 +8,18 @@ import os
|
|||
class Convert(object):
|
||||
def __init__(self, encoder, smooth_mask=True, avg_color_adjust=True, **kwargs):
|
||||
self.encoder = encoder
|
||||
|
||||
|
||||
self.use_smooth_mask = smooth_mask
|
||||
self.use_avg_color_adjust = avg_color_adjust
|
||||
|
||||
def patch_image( self, original, face_detected ):
|
||||
def patch_image( self, original, face_detected, size ):
|
||||
#assert image.shape == (256, 256, 3)
|
||||
image = cv2.resize(face_detected.image, (256, 256))
|
||||
crop = slice(48, 208)
|
||||
face = image[crop, crop]
|
||||
old_face = face.copy()
|
||||
|
||||
face = cv2.resize(face, (64, 64))
|
||||
face = cv2.resize(face, (size, size))
|
||||
face = numpy.expand_dims(face, 0)
|
||||
new_face = self.encoder(face / 255.0)[0]
|
||||
new_face = numpy.clip(new_face * 255, 0, 255).astype(image.dtype)
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
# Based on the https://github.com/shaoanlu/faceswap-GAN repo (master/FaceSwap_GAN_v2_train.ipynb)
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
|
||||
class Convert(object):
|
||||
def __init__(self, encoder, **kwargs):
|
||||
self.encoder = encoder
|
||||
|
||||
def patch_image( self, original, face_detected ):
|
||||
face = cv2.resize(face_detected.image, (64, 64))
|
||||
face = numpy.expand_dims(face, 0) / 255.0 * 2 - 1
|
||||
mask, new_face = self.encoder(face)
|
||||
new_face = mask * new_face + (1 - mask) * face
|
||||
new_face = numpy.clip((new_face[0] + 1) * 255 / 2, 0, 255).astype('uint8')
|
||||
|
||||
original[face_detected.y: face_detected.y + face_detected.h, face_detected.x: face_detected.x + face_detected.w] = cv2.resize(new_face, (face_detected.w, face_detected.h))
|
||||
return original
|
||||
|
|
@ -6,10 +6,11 @@ import numpy
|
|||
from lib.aligner import get_align_mat
|
||||
|
||||
class Convert():
|
||||
def __init__(self, encoder, blur_size=2, seamless_clone=False, mask_type="facehullandrect", erosion_kernel_size=None, **kwargs):
|
||||
def __init__(self, encoder, trainer, blur_size=2, seamless_clone=False, mask_type="facehullandrect", erosion_kernel_size=None, match_histogram=False, **kwargs):
|
||||
self.encoder = encoder
|
||||
self.trainer = trainer
|
||||
self.erosion_kernel = None
|
||||
self.erosion_kernel_size = erosion_kernel_size
|
||||
self.erosion_kernel_size = erosion_kernel_size
|
||||
if erosion_kernel_size is not None:
|
||||
if erosion_kernel_size > 0:
|
||||
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(erosion_kernel_size,erosion_kernel_size))
|
||||
|
|
@ -17,10 +18,11 @@ class Convert():
|
|||
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(abs(erosion_kernel_size),abs(erosion_kernel_size)))
|
||||
self.blur_size = blur_size
|
||||
self.seamless_clone = seamless_clone
|
||||
self.match_histogram = match_histogram
|
||||
self.mask_type = mask_type.lower() # Choose in 'FaceHullAndRect','FaceHull','Rect'
|
||||
|
||||
def patch_image( self, image, face_detected ):
|
||||
size = 64
|
||||
def patch_image( self, image, face_detected, size ):
|
||||
|
||||
image_size = image.shape[1], image.shape[0]
|
||||
|
||||
mat = numpy.array(get_align_mat(face_detected)).reshape(2,3) * size
|
||||
|
|
@ -40,9 +42,9 @@ class Convert():
|
|||
outImage = None
|
||||
if self.seamless_clone:
|
||||
unitMask = numpy.clip( image_mask * 365, 0, 255 ).astype(numpy.uint8)
|
||||
|
||||
|
||||
maxregion = numpy.argwhere(unitMask==255)
|
||||
|
||||
|
||||
if maxregion.size > 0:
|
||||
miny,minx = maxregion.min(axis=0)[:2]
|
||||
maxy,maxx = maxregion.max(axis=0)[:2]
|
||||
|
|
@ -51,21 +53,77 @@ class Convert():
|
|||
masky = int(minx+(lenx//2))
|
||||
maskx = int(miny+(leny//2))
|
||||
outimage = cv2.seamlessClone(new_image.astype(numpy.uint8),base_image.astype(numpy.uint8),unitMask,(masky,maskx) , cv2.NORMAL_CLONE )
|
||||
|
||||
|
||||
return outimage
|
||||
|
||||
|
||||
foreground = cv2.multiply(image_mask, new_image.astype(float))
|
||||
background = cv2.multiply(1.0 - image_mask, base_image.astype(float))
|
||||
outimage = cv2.add(foreground, background)
|
||||
|
||||
return outimage
|
||||
|
||||
def hist_match(self, source, template, mask=None):
|
||||
# Code borrowed from:
|
||||
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
|
||||
masked_source = source
|
||||
masked_template = template
|
||||
|
||||
if mask is not None:
|
||||
masked_source = source * mask
|
||||
masked_template = template * mask
|
||||
|
||||
oldshape = source.shape
|
||||
source = source.ravel()
|
||||
template = template.ravel()
|
||||
masked_source = masked_source.ravel()
|
||||
masked_template = masked_template.ravel()
|
||||
s_values, bin_idx, s_counts = numpy.unique(source, return_inverse=True,
|
||||
return_counts=True)
|
||||
t_values, t_counts = numpy.unique(template, return_counts=True)
|
||||
ms_values, mbin_idx, ms_counts = numpy.unique(source, return_inverse=True,
|
||||
return_counts=True)
|
||||
mt_values, mt_counts = numpy.unique(template, return_counts=True)
|
||||
|
||||
s_quantiles = numpy.cumsum(s_counts).astype(numpy.float64)
|
||||
s_quantiles /= s_quantiles[-1]
|
||||
t_quantiles = numpy.cumsum(t_counts).astype(numpy.float64)
|
||||
t_quantiles /= t_quantiles[-1]
|
||||
interp_t_values = numpy.interp(s_quantiles, t_quantiles, t_values)
|
||||
|
||||
return interp_t_values[bin_idx].reshape(oldshape)
|
||||
|
||||
def color_hist_match(self, src_im, tar_im, mask):
|
||||
matched_R = self.hist_match(src_im[:,:,0], tar_im[:,:,0], mask)
|
||||
matched_G = self.hist_match(src_im[:,:,1], tar_im[:,:,1], mask)
|
||||
matched_B = self.hist_match(src_im[:,:,2], tar_im[:,:,2], mask)
|
||||
matched = numpy.stack((matched_R, matched_G, matched_B), axis=2).astype(src_im.dtype)
|
||||
return matched
|
||||
|
||||
def get_new_face(self, image, mat, size):
|
||||
face = cv2.warpAffine( image, mat, (size,size) )
|
||||
face = numpy.expand_dims( face, 0 )
|
||||
new_face = self.encoder( face / 255.0 )[0]
|
||||
face_clipped = numpy.clip(face[0], 0, 255).astype( image.dtype )
|
||||
new_face = None
|
||||
mask = None
|
||||
|
||||
return numpy.clip( new_face * 255, 0, 255 ).astype( image.dtype )
|
||||
if "GAN" not in self.trainer:
|
||||
normalized_face = face / 255.0
|
||||
new_face = self.encoder(normalized_face)[0]
|
||||
new_face = numpy.clip( new_face * 255, 0, 255 ).astype( image.dtype )
|
||||
else:
|
||||
normalized_face = face / 255.0 * 2 - 1
|
||||
fake_output = self.encoder(normalized_face)
|
||||
if "128" in self.trainer: # TODO: Another hack to switch between 64 and 128
|
||||
fake_output = fake_output[0]
|
||||
mask = fake_output[:,:,:, :1]
|
||||
new_face = fake_output[:,:,:, 1:]
|
||||
new_face = mask * new_face + (1 - mask) * normalized_face
|
||||
new_face = numpy.clip((new_face[0] + 1) * 255 / 2, 0, 255).astype( image.dtype )
|
||||
|
||||
if self.match_histogram:
|
||||
new_face = self.color_hist_match(new_face, face_clipped, mask)
|
||||
|
||||
return new_face
|
||||
|
||||
def get_image_mask(self, image, new_face, face_detected, mat, image_size):
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import cv2
|
|||
|
||||
from lib.aligner import get_align_mat
|
||||
|
||||
class Extract(object):
|
||||
class Extract:
|
||||
def extract(self, image, face, size):
|
||||
alignment = get_align_mat( face )
|
||||
return self.transform( image, alignment, size, 48 )
|
||||
|
||||
|
||||
def transform( self, image, mat, size, padding=0 ):
|
||||
mat = mat * (size - 2 * padding)
|
||||
mat[:,2] += padding
|
||||
6
plugins/Extract_Align/__init__.py
Normal file
6
plugins/Extract_Align/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
__author__ = """Based on https://reddit.com/u/deepfakes/"""
|
||||
__version__ = '0.1.0'
|
||||
|
||||
from .Extract import Extract
|
||||
|
|
@ -9,18 +9,36 @@ from keras.applications import *
|
|||
from keras.optimizers import Adam
|
||||
|
||||
from lib.PixelShuffler import PixelShuffler
|
||||
from .instance_normalization import InstanceNormalization
|
||||
|
||||
netGAH5 = 'netGA_GAN.h5'
|
||||
netGBH5 = 'netGB_GAN.h5'
|
||||
netDAH5 = 'netDA_GAN.h5'
|
||||
netDBH5 = 'netDB_GAN.h5'
|
||||
|
||||
def __conv_init(a):
|
||||
print("conv_init", a)
|
||||
k = RandomNormal(0, 0.02)(a) # for convolution kernel
|
||||
k.conv_weight = True
|
||||
return k
|
||||
|
||||
#def batchnorm():
|
||||
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
|
||||
|
||||
def inst_norm():
|
||||
return InstanceNormalization()
|
||||
|
||||
conv_init = RandomNormal(0, 0.02)
|
||||
gamma_init = RandomNormal(1., 0.02) # for batch normalization
|
||||
|
||||
class GANModel():
|
||||
img_size = 64
|
||||
channels = 3
|
||||
img_shape = (img_size, img_size, channels)
|
||||
encoded_dim = 1024
|
||||
|
||||
nc_in = 3 # number of input channels of generators
|
||||
nc_D_inp = 6 # number of input channels of discriminators
|
||||
|
||||
def __init__(self, model_dir):
|
||||
self.model_dir = model_dir
|
||||
|
||||
|
|
@ -29,76 +47,41 @@ class GANModel():
|
|||
# Build and compile the discriminator
|
||||
self.netDA, self.netDB = self.build_discriminator()
|
||||
|
||||
# For the adversarial_autoencoder model we will only train the generator
|
||||
self.netDA.trainable = False
|
||||
self.netDB.trainable = False
|
||||
|
||||
self.netDA.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
|
||||
self.netDB.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
|
||||
|
||||
# Build and compile the generator
|
||||
self.netGA, self.netGB = self.build_generator()
|
||||
self.netGA.compile(loss=['mae', 'mse'], optimizer=optimizer)
|
||||
self.netGB.compile(loss=['mae', 'mse'], optimizer=optimizer)
|
||||
|
||||
img = Input(shape=self.img_shape)
|
||||
alphaA, reconstructed_imgA = self.netGA(img)
|
||||
alphaB, reconstructed_imgB = self.netGB(img)
|
||||
|
||||
def one_minus(x): return 1 - x
|
||||
# masked_img = alpha * reconstructed_img + (1 - alpha) * img
|
||||
masked_imgA = add([multiply([alphaA, reconstructed_imgA]), multiply([Lambda(one_minus)(alphaA), img])])
|
||||
masked_imgB = add([multiply([alphaB, reconstructed_imgB]), multiply([Lambda(one_minus)(alphaB), img])])
|
||||
out_discriminatorA = self.netDA(concatenate([masked_imgA, img], axis=-1))
|
||||
out_discriminatorB = self.netDB(concatenate([masked_imgB, img], axis=-1))
|
||||
|
||||
# The adversarial_autoencoder model (stacked generator and discriminator) takes
|
||||
# img as input => generates encoded represenation and reconstructed image => determines validity
|
||||
self.adversarial_autoencoderA = Model(img, [reconstructed_imgA, out_discriminatorA])
|
||||
self.adversarial_autoencoderB = Model(img, [reconstructed_imgB, out_discriminatorB])
|
||||
self.adversarial_autoencoderA.compile(loss=['mae', 'mse'],
|
||||
loss_weights=[1, 0.5],
|
||||
optimizer=optimizer)
|
||||
self.adversarial_autoencoderB.compile(loss=['mae', 'mse'],
|
||||
loss_weights=[1, 0.5],
|
||||
optimizer=optimizer)
|
||||
|
||||
def converter(self, swap):
|
||||
predictor = self.netGB if not swap else self.netGA
|
||||
return lambda img: predictor.predict(img)
|
||||
|
||||
def build_generator(self):
|
||||
|
||||
def conv_block(input_tensor, f):
|
||||
x = input_tensor
|
||||
x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=RandomNormal(0, 0.02),
|
||||
use_bias=False, padding="same")(x)
|
||||
x = LeakyReLU(alpha=0.2)(x)
|
||||
x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
||||
x = Activation("relu")(x)
|
||||
return x
|
||||
|
||||
def res_block(input_tensor, f):
|
||||
x = input_tensor
|
||||
x = Conv2D(f, kernel_size=3, kernel_initializer=RandomNormal(0, 0.02),
|
||||
use_bias=False, padding="same")(x)
|
||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
||||
x = LeakyReLU(alpha=0.2)(x)
|
||||
x = Conv2D(f, kernel_size=3, kernel_initializer=RandomNormal(0, 0.02),
|
||||
use_bias=False, padding="same")(x)
|
||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
||||
x = add([x, input_tensor])
|
||||
x = LeakyReLU(alpha=0.2)(x)
|
||||
return x
|
||||
|
||||
def upscale_ps(filters, use_norm=True):
|
||||
def upscale_ps(filters, use_instance_norm=True):
|
||||
def block(x):
|
||||
x = Conv2D(filters*4, kernel_size=3, use_bias=False,
|
||||
kernel_initializer=RandomNormal(0, 0.02), padding='same' )(x)
|
||||
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
|
||||
x = LeakyReLU(0.1)(x)
|
||||
x = PixelShuffler()(x)
|
||||
return x
|
||||
return block
|
||||
|
||||
def Encoder(img_shape):
|
||||
inp = Input(shape=img_shape)
|
||||
x = Conv2D(64, kernel_size=5, kernel_initializer=RandomNormal(0, 0.02),
|
||||
use_bias=False, padding="same")(inp)
|
||||
def Encoder(nc_in=3, input_size=64):
|
||||
inp = Input(shape=(input_size, input_size, nc_in))
|
||||
x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
|
||||
x = conv_block(x,128)
|
||||
x = conv_block(x,256)
|
||||
x = conv_block(x,512)
|
||||
|
|
@ -109,23 +92,23 @@ class GANModel():
|
|||
out = upscale_ps(512)(x)
|
||||
return Model(inputs=inp, outputs=out)
|
||||
|
||||
def Decoder_ps(img_shape):
|
||||
nc_in = 512
|
||||
input_size = img_shape[0]//8
|
||||
inp = Input(shape=(input_size, input_size, nc_in))
|
||||
x = inp
|
||||
def Decoder_ps(nc_in=512, input_size=8):
|
||||
input_ = Input(shape=(input_size, input_size, nc_in))
|
||||
x = input_
|
||||
x = upscale_ps(256)(x)
|
||||
x = upscale_ps(128)(x)
|
||||
x = upscale_ps(64)(x)
|
||||
x = res_block(x, 64)
|
||||
x = res_block(x, 64)
|
||||
#x = Conv2D(4, kernel_size=5, padding='same')(x)
|
||||
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
|
||||
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
|
||||
return Model(inp, [alpha, rgb])
|
||||
out = concatenate([alpha, rgb])
|
||||
return Model(input_, out )
|
||||
|
||||
encoder = Encoder(self.img_shape)
|
||||
decoder_A = Decoder_ps(self.img_shape)
|
||||
decoder_B = Decoder_ps(self.img_shape)
|
||||
encoder = Encoder()
|
||||
decoder_A = Decoder_ps()
|
||||
decoder_B = Decoder_ps()
|
||||
x = Input(shape=self.img_shape)
|
||||
netGA = Model(x, decoder_A(encoder(x)))
|
||||
netGB = Model(x, decoder_B(encoder(x)))
|
||||
|
|
@ -136,26 +119,26 @@ class GANModel():
|
|||
except:
|
||||
print ("Generator weights files not found.")
|
||||
pass
|
||||
return netGA, netGB,
|
||||
return netGA, netGB
|
||||
|
||||
def build_discriminator(self):
|
||||
def build_discriminator(self):
|
||||
def conv_block_d(input_tensor, f, use_instance_norm=True):
|
||||
x = input_tensor
|
||||
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=RandomNormal(0, 0.02),
|
||||
use_bias=False, padding="same")(x)
|
||||
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
||||
x = LeakyReLU(alpha=0.2)(x)
|
||||
return x
|
||||
def Discriminator(img_shape):
|
||||
inp = Input(shape=(img_shape[0], img_shape[1], img_shape[2]*2))
|
||||
return x
|
||||
|
||||
def Discriminator(nc_in, input_size=64):
|
||||
inp = Input(shape=(input_size, input_size, nc_in))
|
||||
#x = GaussianNoise(0.05)(inp)
|
||||
x = conv_block_d(inp, 64, False)
|
||||
x = conv_block_d(x, 128, False)
|
||||
x = conv_block_d(x, 256, False)
|
||||
out = Conv2D(1, kernel_size=4, kernel_initializer=RandomNormal(0, 0.02),
|
||||
use_bias=False, padding="same", activation="sigmoid")(x)
|
||||
return Model(inputs=[inp], outputs=out)
|
||||
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
|
||||
return Model(inputs=[inp], outputs=out)
|
||||
|
||||
netDA = Discriminator(self.img_shape)
|
||||
netDB = Discriminator(self.img_shape)
|
||||
netDA = Discriminator(self.nc_D_inp)
|
||||
netDB = Discriminator(self.nc_D_inp)
|
||||
try:
|
||||
netDA.load_weights(str(self.model_dir / netDAH5))
|
||||
netDB.load_weights(str(self.model_dir / netDBH5))
|
||||
|
|
@ -173,7 +156,7 @@ class GANModel():
|
|||
|
||||
def save_weights(self):
|
||||
self.netGA.save_weights(str(self.model_dir / netGAH5))
|
||||
self.netGB.save_weights(str(self.model_dir / netGBH5))
|
||||
self.netDA.save_weights(str(self.model_dir / netDAH5))
|
||||
self.netDB.save_weights(str(self.model_dir / netDBH5))
|
||||
print ("Models saved.")
|
||||
self.netGB.save_weights(str(self.model_dir / netGBH5))
|
||||
self.netDA.save_weights(str(self.model_dir / netDAH5))
|
||||
self.netDB.save_weights(str(self.model_dir / netDBH5))
|
||||
print ("Models saved.")
|
||||
|
|
|
|||
|
|
@ -2,11 +2,17 @@ import time
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from keras.layers import *
|
||||
from tensorflow.contrib.distributions import Beta
|
||||
import tensorflow as tf
|
||||
from keras.optimizers import Adam
|
||||
from keras import backend as K
|
||||
|
||||
from lib.training_data import TrainingDataGenerator, stack_images
|
||||
|
||||
class GANTrainingDataGenerator(TrainingDataGenerator):
|
||||
def __init__(self, random_transform_args, coverage):
|
||||
super().__init__(random_transform_args, coverage)
|
||||
def __init__(self, random_transform_args, coverage, scale, zoom):
|
||||
super().__init__(random_transform_args, coverage, scale, zoom)
|
||||
|
||||
def color_adjust(self, img):
|
||||
return img / 255.0 * 2 - 1
|
||||
|
|
@ -14,140 +20,241 @@ class GANTrainingDataGenerator(TrainingDataGenerator):
|
|||
class Trainer():
|
||||
random_transform_args = {
|
||||
'rotation_range': 20,
|
||||
'zoom_range': 0.05,
|
||||
'zoom_range': 0.1,
|
||||
'shift_range': 0.05,
|
||||
'random_flip': 0.5,
|
||||
}
|
||||
|
||||
def __init__(self, model, fn_A, fn_B, batch_size):
|
||||
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
|
||||
K.set_learning_phase(1)
|
||||
|
||||
assert batch_size % 2 == 0, "batch_size must be an even number"
|
||||
self.batch_size = batch_size
|
||||
self.model = model
|
||||
|
||||
|
||||
self.use_lsgan = True
|
||||
self.use_mixup = True
|
||||
self.mixup_alpha = 0.2
|
||||
self.use_perceptual_loss = perceptual_loss
|
||||
self.use_instancenorm = False
|
||||
|
||||
generator = GANTrainingDataGenerator(self.random_transform_args, 220)
|
||||
self.lrD = 1e-4 # Discriminator learning rate
|
||||
self.lrG = 1e-4 # Generator learning rate
|
||||
|
||||
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 1)
|
||||
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
|
||||
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
|
||||
|
||||
|
||||
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
||||
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
distorted_A, fake_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
|
||||
distorted_B, fake_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
|
||||
real_A = Input(shape=self.model.img_shape)
|
||||
real_B = Input(shape=self.model.img_shape)
|
||||
|
||||
if self.use_lsgan:
|
||||
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
|
||||
else:
|
||||
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
|
||||
|
||||
# ========== Define Perceptual Loss Model==========
|
||||
if self.use_perceptual_loss:
|
||||
from keras.models import Model
|
||||
from keras_vggface.vggface import VGGFace
|
||||
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
|
||||
vggface.trainable = False
|
||||
out_size55 = vggface.layers[36].output
|
||||
out_size28 = vggface.layers[78].output
|
||||
out_size7 = vggface.layers[-2].output
|
||||
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
|
||||
vggface_feat.trainable = False
|
||||
else:
|
||||
vggface_feat = None
|
||||
|
||||
#TODO check "Tips for mask refinement (optional after >15k iters)" => https://render.githubusercontent.com/view/ipynb?commit=87d6e7a28ce754acd38d885367b6ceb0be92ec54&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f7368616f616e6c752f66616365737761702d47414e2f383764366537613238636537353461636433386438383533363762366365623062653932656335342f46616365537761705f47414e5f76325f737a3132385f747261696e2e6970796e62&nwo=shaoanlu%2Ffaceswap-GAN&path=FaceSwap_GAN_v2_sz128_train.ipynb&repository_id=115182783&repository_type=Repository#Tips-for-mask-refinement-(optional-after-%3E15k-iters)
|
||||
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, distorted_A, vggface_feat)
|
||||
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, distorted_B, vggface_feat)
|
||||
|
||||
loss_GA += 1e-3 * K.mean(K.abs(mask_A))
|
||||
loss_GB += 1e-3 * K.mean(K.abs(mask_B))
|
||||
|
||||
w_fo = 0.01
|
||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
|
||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
|
||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
|
||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
|
||||
|
||||
weightsDA = self.model.netDA.trainable_weights
|
||||
weightsGA = self.model.netGA.trainable_weights
|
||||
weightsDB = self.model.netDB.trainable_weights
|
||||
weightsGB = self.model.netGB.trainable_weights
|
||||
|
||||
# Adam(..).get_updates(...)
|
||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
|
||||
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
|
||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
|
||||
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
|
||||
|
||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
|
||||
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
|
||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
|
||||
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
|
||||
|
||||
def first_order(self, x, axis=1):
|
||||
img_nrows = x.shape[1]
|
||||
img_ncols = x.shape[2]
|
||||
if axis == 1:
|
||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
|
||||
elif axis == 2:
|
||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
|
||||
else:
|
||||
return None
|
||||
|
||||
def train_one_step(self, iter, viewer):
|
||||
# ---------------------
|
||||
# Train Discriminators
|
||||
# ---------------------
|
||||
|
||||
# Select a random half batch of images
|
||||
epoch, warped_A, target_A = next(self.train_batchA)
|
||||
epoch, warped_B, target_B = next(self.train_batchB)
|
||||
epoch, warped_A, target_A = next(self.train_batchA)
|
||||
epoch, warped_B, target_B = next(self.train_batchB)
|
||||
|
||||
# Generate a half batch of new images
|
||||
gen_alphasA, gen_imgsA = self.model.netGA.predict(warped_A)
|
||||
gen_alphasB, gen_imgsB = self.model.netGB.predict(warped_B)
|
||||
#gen_masked_imgsA = gen_alphasA * gen_imgsA + (1 - gen_alphasA) * warped_A
|
||||
#gen_masked_imgsB = gen_alphasB * gen_imgsB + (1 - gen_alphasB) * warped_B
|
||||
gen_masked_imgsA = np.array([gen_alphasA[i] * gen_imgsA[i] + (1 - gen_alphasA[i]) * warped_A[i]
|
||||
for i in range(self.batch_size)])
|
||||
gen_masked_imgsB = np.array([gen_alphasB[i] * gen_imgsB[i] + (1 - gen_alphasB[i]) * warped_B[i]
|
||||
for i in range (self.batch_size)])
|
||||
# Train dicriminators for one batch
|
||||
errDA = self.netDA_train([warped_A, target_A])
|
||||
errDB = self.netDB_train([warped_B, target_B])
|
||||
|
||||
valid = np.ones((self.batch_size, ) + self.model.netDA.output_shape[1:])
|
||||
fake = np.zeros((self.batch_size, ) + self.model.netDA.output_shape[1:])
|
||||
# Train generators for one batch
|
||||
errGA = self.netGA_train([warped_A, target_A])
|
||||
errGB = self.netGB_train([warped_B, target_B])
|
||||
|
||||
concat_real_inputA = np.array([np.concatenate([target_A[i], warped_A[i]], axis=-1)
|
||||
for i in range(self.batch_size)])
|
||||
concat_real_inputB = np.array([np.concatenate([target_B[i], warped_B[i]], axis=-1)
|
||||
for i in range(self.batch_size)])
|
||||
concat_fake_inputA = np.array([np.concatenate([gen_masked_imgsA[i], warped_A[i]], axis=-1)
|
||||
for i in range(self.batch_size)])
|
||||
concat_fake_inputB = np.array([np.concatenate([gen_masked_imgsB[i], warped_B[i]], axis=-1)
|
||||
for i in range(self.batch_size)])
|
||||
if self.use_mixup:
|
||||
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||
mixup_A = lam * concat_real_inputA + (1 - lam) * concat_fake_inputA
|
||||
mixup_B = lam * concat_real_inputB + (1 - lam) * concat_fake_inputB
|
||||
# For calculating average losses
|
||||
self.errDA_sum += errDA[0]
|
||||
self.errDB_sum += errDB[0]
|
||||
self.errGA_sum += errGA[0]
|
||||
self.errGB_sum += errGB[0]
|
||||
self.avg_counter += 1
|
||||
|
||||
# Train the discriminators
|
||||
#print ("Train the discriminators.")
|
||||
if self.use_mixup:
|
||||
d_lossA = self.model.netDA.train_on_batch(mixup_A, lam * valid)
|
||||
d_lossB = self.model.netDB.train_on_batch(mixup_B, lam * valid)
|
||||
else:
|
||||
d_lossA = self.model.netDA.train_on_batch(np.concatenate([concat_real_inputA, concat_fake_inputA], axis=0),
|
||||
np.concatenate([valid, fake], axis=0))
|
||||
d_lossB = self.model.netDB.train_on_batch(np.concatenate([concat_real_inputB, concat_fake_inputB], axis=0),
|
||||
np.concatenate([valid, fake], axis=0))
|
||||
|
||||
# ---------------------
|
||||
# Train Generators
|
||||
# ---------------------
|
||||
|
||||
# Train the generators
|
||||
#print ("Train the generators.")
|
||||
g_lossA = self.model.adversarial_autoencoderA.train_on_batch(warped_A, [target_A, valid])
|
||||
g_lossB = self.model.adversarial_autoencoderB.train_on_batch(warped_B, [target_B, valid])
|
||||
|
||||
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
|
||||
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, d_lossA[0], d_lossB[0], g_lossA[0], g_lossB[0]),
|
||||
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
|
||||
end='\r')
|
||||
|
||||
|
||||
if viewer is not None:
|
||||
self.show_sample(viewer)
|
||||
|
||||
|
||||
def cycle_variables(self, netG):
|
||||
distorted_input = netG.inputs[0]
|
||||
fake_output = netG.outputs[0]
|
||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
|
||||
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
|
||||
|
||||
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
|
||||
|
||||
fn_generate = K.function([distorted_input], [masked_fake_output])
|
||||
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
|
||||
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
|
||||
fn_bgr = K.function([distorted_input], [rgb])
|
||||
return distorted_input, fake_output, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
|
||||
|
||||
def define_loss(self, netD, real, fake_argb, distorted, vggface_feat=None):
|
||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
|
||||
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
|
||||
fake = alpha * fake_rgb + (1-alpha) * distorted
|
||||
|
||||
if self.use_mixup:
|
||||
dist = Beta(self.mixup_alpha, self.mixup_alpha)
|
||||
lam = dist.sample()
|
||||
# ==========
|
||||
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
|
||||
# ==========
|
||||
output_mixup = netD(mixup)
|
||||
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
|
||||
output_fake = netD(concatenate([fake, distorted])) # dummy
|
||||
loss_G = .5 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
|
||||
else:
|
||||
output_real = netD(concatenate([real, distorted])) # positive sample
|
||||
output_fake = netD(concatenate([fake, distorted])) # negative sample
|
||||
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
|
||||
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
|
||||
loss_D = loss_D_real + loss_D_fake
|
||||
loss_G = .5 * self.loss_fn(output_fake, K.ones_like(output_fake))
|
||||
# ==========
|
||||
loss_G += K.mean(K.abs(fake_rgb - real))
|
||||
# ==========
|
||||
|
||||
# Edge loss (similar with total variation loss)
|
||||
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=1) - self.first_order(real, axis=1)))
|
||||
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=2) - self.first_order(real, axis=2)))
|
||||
|
||||
|
||||
# Perceptual Loss
|
||||
if not vggface_feat is None:
|
||||
def preprocess_vggface(x):
|
||||
x = (x + 1)/2 * 255 # channel order: BGR
|
||||
#x[..., 0] -= 93.5940
|
||||
#x[..., 1] -= 104.7624
|
||||
#x[..., 2] -= 129.
|
||||
x -= [91.4953, 103.8827, 131.0912]
|
||||
return x
|
||||
pl_params = (0.011, 0.11, 0.1919)
|
||||
real_sz224 = tf.image.resize_images(real, [224, 224])
|
||||
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
|
||||
# ==========
|
||||
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
|
||||
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
|
||||
# ==========
|
||||
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
|
||||
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
|
||||
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
|
||||
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
|
||||
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
|
||||
|
||||
return loss_D, loss_G
|
||||
|
||||
def show_sample(self, display_fn):
|
||||
_, wA, tA = next(self.train_batchA)
|
||||
_, wB, tB = next(self.train_batchB)
|
||||
self.showG(tA, tB, display_fn)
|
||||
|
||||
def showG(self, test_A, test_B, display_fn):
|
||||
def display_fig(name, figure_A, figure_B):
|
||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
||||
columns = 4
|
||||
elements = figure.shape[0]
|
||||
figure = figure.reshape((columns,(elements//columns)) + figure.shape[1:])
|
||||
figure = stack_images(figure)
|
||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
||||
display_fn(figure, name)
|
||||
|
||||
out_test_A_netGA = self.model.netGA.predict(test_A)
|
||||
out_test_A_netGB = self.model.netGB.predict(test_A)
|
||||
out_test_B_netGA = self.model.netGA.predict(test_B)
|
||||
out_test_B_netGB = self.model.netGB.predict(test_B)
|
||||
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "raw")
|
||||
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "masked")
|
||||
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
|
||||
# Reset the averages
|
||||
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
||||
self.avg_counter = 0
|
||||
|
||||
def showG(self, test_A, test_B, path_A, path_B):
|
||||
figure_A = np.stack([
|
||||
test_A,
|
||||
out_test_A_netGA[1],
|
||||
out_test_A_netGB[1],
|
||||
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
||||
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
||||
], axis=1 )
|
||||
figure_B = np.stack([
|
||||
test_B,
|
||||
out_test_B_netGB[1],
|
||||
out_test_B_netGA[1],
|
||||
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
||||
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
||||
], axis=1 )
|
||||
|
||||
display_fig("raw", figure_A, figure_B)
|
||||
|
||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
||||
figure = stack_images(figure)
|
||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
||||
return figure
|
||||
|
||||
def showG_mask(self, test_A, test_B, path_A, path_B):
|
||||
figure_A = np.stack([
|
||||
test_A,
|
||||
np.tile(out_test_A_netGA[0],3) * 2 - 1,
|
||||
np.tile(out_test_A_netGB[0],3) * 2 - 1,
|
||||
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
||||
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
||||
], axis=1 )
|
||||
figure_B = np.stack([
|
||||
test_B,
|
||||
np.tile(out_test_B_netGB[0],3) * 2 - 1,
|
||||
np.tile(out_test_B_netGA[0],3) * 2 - 1,
|
||||
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
||||
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
||||
], axis=1 )
|
||||
|
||||
display_fig("alpha_masks", figure_A, figure_B)
|
||||
|
||||
figure_A = np.stack([
|
||||
test_A,
|
||||
out_test_A_netGA[0] * out_test_A_netGA[1] + (1 - out_test_A_netGA[0]) * test_A,
|
||||
out_test_A_netGB[0] * out_test_A_netGB[1] + (1 - out_test_A_netGB[0]) * test_A,
|
||||
], axis=1 )
|
||||
figure_B = np.stack([
|
||||
test_B,
|
||||
out_test_B_netGB[0] * out_test_B_netGB[1] + (1 - out_test_B_netGB[0]) * test_B,
|
||||
out_test_B_netGA[0] * out_test_B_netGA[1] + (1 - out_test_B_netGA[0]) * test_B,
|
||||
], axis=1 )
|
||||
|
||||
display_fig("masked", figure_A, figure_B)
|
||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
||||
figure = stack_images(figure)
|
||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
||||
return figure
|
||||
|
|
|
|||
145
plugins/Model_GAN/instance_normalization.py
Normal file
145
plugins/Model_GAN/instance_normalization.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
from keras.engine import Layer, InputSpec
|
||||
from keras import initializers, regularizers, constraints
|
||||
from keras import backend as K
|
||||
from keras.utils.generic_utils import get_custom_objects
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class InstanceNormalization(Layer):
|
||||
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
||||
Normalize the activations of the previous layer at each step,
|
||||
i.e. applies a transformation that maintains the mean activation
|
||||
close to 0 and the activation standard deviation close to 1.
|
||||
# Arguments
|
||||
axis: Integer, the axis that should be normalized
|
||||
(typically the features axis).
|
||||
For instance, after a `Conv2D` layer with
|
||||
`data_format="channels_first"`,
|
||||
set `axis=1` in `InstanceNormalization`.
|
||||
Setting `axis=None` will normalize all values in each instance of the batch.
|
||||
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
center: If True, add offset of `beta` to normalized tensor.
|
||||
If False, `beta` is ignored.
|
||||
scale: If True, multiply by `gamma`.
|
||||
If False, `gamma` is not used.
|
||||
When the next layer is linear (also e.g. `nn.relu`),
|
||||
this can be disabled since the scaling
|
||||
will be done by the next layer.
|
||||
beta_initializer: Initializer for the beta weight.
|
||||
gamma_initializer: Initializer for the gamma weight.
|
||||
beta_regularizer: Optional regularizer for the beta weight.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||
beta_constraint: Optional constraint for the beta weight.
|
||||
gamma_constraint: Optional constraint for the gamma weight.
|
||||
# Input shape
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
# Output shape
|
||||
Same shape as input.
|
||||
# References
|
||||
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
||||
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
|
||||
"""
|
||||
def __init__(self,
|
||||
axis=None,
|
||||
epsilon=1e-3,
|
||||
center=True,
|
||||
scale=True,
|
||||
beta_initializer='zeros',
|
||||
gamma_initializer='ones',
|
||||
beta_regularizer=None,
|
||||
gamma_regularizer=None,
|
||||
beta_constraint=None,
|
||||
gamma_constraint=None,
|
||||
**kwargs):
|
||||
super(InstanceNormalization, self).__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self.axis = axis
|
||||
self.epsilon = epsilon
|
||||
self.center = center
|
||||
self.scale = scale
|
||||
self.beta_initializer = initializers.get(beta_initializer)
|
||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||
self.beta_constraint = constraints.get(beta_constraint)
|
||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
ndim = len(input_shape)
|
||||
if self.axis == 0:
|
||||
raise ValueError('Axis cannot be zero')
|
||||
|
||||
if (self.axis is not None) and (ndim == 2):
|
||||
raise ValueError('Cannot specify axis for rank 1 tensor')
|
||||
|
||||
self.input_spec = InputSpec(ndim=ndim)
|
||||
|
||||
if self.axis is None:
|
||||
shape = (1,)
|
||||
else:
|
||||
shape = (input_shape[self.axis],)
|
||||
|
||||
if self.scale:
|
||||
self.gamma = self.add_weight(shape=shape,
|
||||
name='gamma',
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint)
|
||||
else:
|
||||
self.gamma = None
|
||||
if self.center:
|
||||
self.beta = self.add_weight(shape=shape,
|
||||
name='beta',
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint)
|
||||
else:
|
||||
self.beta = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
input_shape = K.int_shape(inputs)
|
||||
reduction_axes = list(range(0, len(input_shape)))
|
||||
|
||||
if (self.axis is not None):
|
||||
del reduction_axes[self.axis]
|
||||
|
||||
del reduction_axes[0]
|
||||
|
||||
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
||||
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
||||
normed = (inputs - mean) / stddev
|
||||
|
||||
broadcast_shape = [1] * len(input_shape)
|
||||
if self.axis is not None:
|
||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||
|
||||
if self.scale:
|
||||
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
||||
normed = normed * broadcast_gamma
|
||||
if self.center:
|
||||
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
||||
normed = normed + broadcast_beta
|
||||
return normed
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'axis': self.axis,
|
||||
'epsilon': self.epsilon,
|
||||
'center': self.center,
|
||||
'scale': self.scale,
|
||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
||||
}
|
||||
base_config = super(InstanceNormalization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|
||||
177
plugins/Model_GAN128/Model.py
Normal file
177
plugins/Model_GAN128/Model.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
# Based on the https://github.com/shaoanlu/faceswap-GAN repo
|
||||
# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb
|
||||
|
||||
from keras.models import Model
|
||||
from keras.layers import *
|
||||
from keras.layers.advanced_activations import LeakyReLU
|
||||
from keras.activations import relu
|
||||
from keras.initializers import RandomNormal
|
||||
from keras.applications import *
|
||||
from keras.optimizers import Adam
|
||||
|
||||
from lib.PixelShuffler import PixelShuffler
|
||||
from .instance_normalization import InstanceNormalization
|
||||
|
||||
netGAH5 = 'netGA_GAN128.h5'
|
||||
netGBH5 = 'netGB_GAN128.h5'
|
||||
netDAH5 = 'netDA_GAN128.h5'
|
||||
netDBH5 = 'netDB_GAN128.h5'
|
||||
|
||||
def __conv_init(a):
|
||||
print("conv_init", a)
|
||||
k = RandomNormal(0, 0.02)(a) # for convolution kernel
|
||||
k.conv_weight = True
|
||||
return k
|
||||
|
||||
#def batchnorm():
|
||||
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
|
||||
|
||||
def inst_norm():
|
||||
return InstanceNormalization()
|
||||
|
||||
conv_init = RandomNormal(0, 0.02)
|
||||
gamma_init = RandomNormal(1., 0.02) # for batch normalization
|
||||
|
||||
class GANModel():
|
||||
img_size = 128
|
||||
channels = 3
|
||||
img_shape = (img_size, img_size, channels)
|
||||
encoded_dim = 1024
|
||||
nc_in = 3 # number of input channels of generators
|
||||
nc_D_inp = 6 # number of input channels of discriminators
|
||||
|
||||
def __init__(self, model_dir):
|
||||
self.model_dir = model_dir
|
||||
|
||||
optimizer = Adam(1e-4, 0.5)
|
||||
|
||||
# Build and compile the discriminator
|
||||
self.netDA, self.netDB = self.build_discriminator()
|
||||
|
||||
# Build and compile the generator
|
||||
self.netGA, self.netGB = self.build_generator()
|
||||
|
||||
def converter(self, swap):
|
||||
predictor = self.netGB if not swap else self.netGA
|
||||
return lambda img: predictor.predict(img)
|
||||
|
||||
def build_generator(self):
|
||||
|
||||
def conv_block(input_tensor, f, use_instance_norm=True):
|
||||
x = input_tensor
|
||||
x = SeparableConv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
||||
if use_instance_norm:
|
||||
x = inst_norm()(x)
|
||||
x = Activation("relu")(x)
|
||||
return x
|
||||
|
||||
def res_block(input_tensor, f, dilation=1):
|
||||
x = input_tensor
|
||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
|
||||
x = LeakyReLU(alpha=0.2)(x)
|
||||
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
|
||||
x = add([x, input_tensor])
|
||||
#x = LeakyReLU(alpha=0.2)(x)
|
||||
return x
|
||||
|
||||
def upscale_ps(filters, use_instance_norm=True):
|
||||
def block(x, use_instance_norm=use_instance_norm):
|
||||
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
|
||||
if use_instance_norm:
|
||||
x = inst_norm()(x)
|
||||
x = LeakyReLU(0.1)(x)
|
||||
x = PixelShuffler()(x)
|
||||
return x
|
||||
return block
|
||||
|
||||
def Encoder(nc_in=3, input_size=128):
|
||||
inp = Input(shape=(input_size, input_size, nc_in))
|
||||
x = Conv2D(32, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
|
||||
x = conv_block(x,64, use_instance_norm=False)
|
||||
x = conv_block(x,128)
|
||||
x = conv_block(x,256)
|
||||
x = conv_block(x,512)
|
||||
x = conv_block(x,1024)
|
||||
x = Dense(1024)(Flatten()(x))
|
||||
x = Dense(4*4*1024)(x)
|
||||
x = Reshape((4, 4, 1024))(x)
|
||||
out = upscale_ps(512)(x)
|
||||
return Model(inputs=inp, outputs=out)
|
||||
|
||||
def Decoder_ps(nc_in=512, input_size=8):
|
||||
input_ = Input(shape=(input_size, input_size, nc_in))
|
||||
x = input_
|
||||
x = upscale_ps(256)(x)
|
||||
x = upscale_ps(128)(x)
|
||||
x = upscale_ps(64)(x)
|
||||
x = res_block(x, 64, dilation=2)
|
||||
|
||||
out64 = Conv2D(64, kernel_size=3, padding='same')(x)
|
||||
out64 = LeakyReLU(alpha=0.1)(out64)
|
||||
out64 = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(out64)
|
||||
|
||||
x = upscale_ps(32)(x)
|
||||
x = res_block(x, 32)
|
||||
x = res_block(x, 32)
|
||||
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
|
||||
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
|
||||
out = concatenate([alpha, rgb])
|
||||
return Model(input_, [out, out64] )
|
||||
|
||||
encoder = Encoder()
|
||||
decoder_A = Decoder_ps()
|
||||
decoder_B = Decoder_ps()
|
||||
x = Input(shape=self.img_shape)
|
||||
netGA = Model(x, decoder_A(encoder(x)))
|
||||
netGB = Model(x, decoder_B(encoder(x)))
|
||||
try:
|
||||
netGA.load_weights(str(self.model_dir / netGAH5))
|
||||
netGB.load_weights(str(self.model_dir / netGBH5))
|
||||
print ("Generator models loaded.")
|
||||
except:
|
||||
print ("Generator weights files not found.")
|
||||
pass
|
||||
return netGA, netGB
|
||||
|
||||
def build_discriminator(self):
|
||||
def conv_block_d(input_tensor, f, use_instance_norm=True):
|
||||
x = input_tensor
|
||||
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
|
||||
if use_instance_norm:
|
||||
x = inst_norm()(x)
|
||||
x = LeakyReLU(alpha=0.2)(x)
|
||||
return x
|
||||
|
||||
def Discriminator(nc_in, input_size=128):
|
||||
inp = Input(shape=(input_size, input_size, nc_in))
|
||||
#x = GaussianNoise(0.05)(inp)
|
||||
x = conv_block_d(inp, 64, False)
|
||||
x = conv_block_d(x, 128, True)
|
||||
x = conv_block_d(x, 256, True)
|
||||
x = conv_block_d(x, 512, True)
|
||||
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
|
||||
return Model(inputs=[inp], outputs=out)
|
||||
|
||||
netDA = Discriminator(self.nc_D_inp)
|
||||
netDB = Discriminator(self.nc_D_inp)
|
||||
try:
|
||||
netDA.load_weights(str(self.model_dir / netDAH5))
|
||||
netDB.load_weights(str(self.model_dir / netDBH5))
|
||||
print ("Discriminator models loaded.")
|
||||
except:
|
||||
print ("Discriminator weights files not found.")
|
||||
pass
|
||||
return netDA, netDB
|
||||
|
||||
def load(self, swapped):
|
||||
if swapped:
|
||||
print("swapping not supported on GAN")
|
||||
# TODO load is done in __init__ => look how to swap if possible
|
||||
return True
|
||||
|
||||
def save_weights(self):
|
||||
self.netGA.save_weights(str(self.model_dir / netGAH5))
|
||||
self.netGB.save_weights(str(self.model_dir / netGBH5))
|
||||
self.netDA.save_weights(str(self.model_dir / netDAH5))
|
||||
self.netDB.save_weights(str(self.model_dir / netDBH5))
|
||||
print ("Models saved.")
|
||||
253
plugins/Model_GAN128/Trainer.py
Normal file
253
plugins/Model_GAN128/Trainer.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from keras.layers import *
|
||||
from tensorflow.contrib.distributions import Beta
|
||||
import tensorflow as tf
|
||||
from keras.optimizers import Adam
|
||||
from keras import backend as K
|
||||
|
||||
from lib.training_data import TrainingDataGenerator, stack_images
|
||||
|
||||
class GANTrainingDataGenerator(TrainingDataGenerator):
|
||||
def __init__(self, random_transform_args, coverage, scale, zoom):
|
||||
super().__init__(random_transform_args, coverage, scale, zoom)
|
||||
|
||||
def color_adjust(self, img):
|
||||
return img / 255.0 * 2 - 1
|
||||
|
||||
class Trainer():
|
||||
random_transform_args = {
|
||||
'rotation_range': 20,
|
||||
'zoom_range': 0.1,
|
||||
'shift_range': 0.05,
|
||||
'random_flip': 0.5,
|
||||
}
|
||||
|
||||
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
|
||||
K.set_learning_phase(1)
|
||||
|
||||
assert batch_size % 2 == 0, "batch_size must be an even number"
|
||||
self.batch_size = batch_size
|
||||
self.model = model
|
||||
|
||||
self.use_lsgan = True
|
||||
self.use_mixup = True
|
||||
self.mixup_alpha = 0.2
|
||||
self.use_perceptual_loss = perceptual_loss
|
||||
|
||||
self.lrD = 1e-4 # Discriminator learning rate
|
||||
self.lrG = 1e-4 # Generator learning rate
|
||||
|
||||
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 2)
|
||||
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
|
||||
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
|
||||
|
||||
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
||||
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
distorted_A, fake_A, fake_sz64_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
|
||||
distorted_B, fake_B, fake_sz64_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
|
||||
real_A = Input(shape=self.model.img_shape)
|
||||
real_B = Input(shape=self.model.img_shape)
|
||||
|
||||
if self.use_lsgan:
|
||||
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
|
||||
else:
|
||||
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
|
||||
|
||||
# ========== Define Perceptual Loss Model==========
|
||||
if self.use_perceptual_loss:
|
||||
from keras.models import Model
|
||||
from keras_vggface.vggface import VGGFace
|
||||
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
|
||||
vggface.trainable = False
|
||||
out_size55 = vggface.layers[36].output
|
||||
out_size28 = vggface.layers[78].output
|
||||
out_size7 = vggface.layers[-2].output
|
||||
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
|
||||
vggface_feat.trainable = False
|
||||
else:
|
||||
vggface_feat = None
|
||||
|
||||
#TODO check "Tips for mask refinement (optional after >15k iters)" => https://render.githubusercontent.com/view/ipynb?commit=87d6e7a28ce754acd38d885367b6ceb0be92ec54&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f7368616f616e6c752f66616365737761702d47414e2f383764366537613238636537353461636433386438383533363762366365623062653932656335342f46616365537761705f47414e5f76325f737a3132385f747261696e2e6970796e62&nwo=shaoanlu%2Ffaceswap-GAN&path=FaceSwap_GAN_v2_sz128_train.ipynb&repository_id=115182783&repository_type=Repository#Tips-for-mask-refinement-(optional-after-%3E15k-iters)
|
||||
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, fake_sz64_A, distorted_A, vggface_feat)
|
||||
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, fake_sz64_B, distorted_B, vggface_feat)
|
||||
|
||||
loss_GA += 3e-3 * K.mean(K.abs(mask_A))
|
||||
loss_GB += 3e-3 * K.mean(K.abs(mask_B))
|
||||
|
||||
w_fo = 0.01
|
||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
|
||||
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
|
||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
|
||||
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
|
||||
|
||||
weightsDA = self.model.netDA.trainable_weights
|
||||
weightsGA = self.model.netGA.trainable_weights
|
||||
weightsDB = self.model.netDB.trainable_weights
|
||||
weightsGB = self.model.netGB.trainable_weights
|
||||
|
||||
# Adam(..).get_updates(...)
|
||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
|
||||
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
|
||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
|
||||
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
|
||||
|
||||
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
|
||||
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
|
||||
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
|
||||
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
|
||||
|
||||
def first_order(self, x, axis=1):
|
||||
img_nrows = x.shape[1]
|
||||
img_ncols = x.shape[2]
|
||||
if axis == 1:
|
||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
|
||||
elif axis == 2:
|
||||
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
|
||||
else:
|
||||
return None
|
||||
|
||||
def train_one_step(self, iter, viewer):
|
||||
# ---------------------
|
||||
# Train Discriminators
|
||||
# ---------------------
|
||||
|
||||
# Select a random half batch of images
|
||||
epoch, warped_A, target_A = next(self.train_batchA)
|
||||
epoch, warped_B, target_B = next(self.train_batchB)
|
||||
|
||||
# Train dicriminators for one batch
|
||||
errDA = self.netDA_train([warped_A, target_A])
|
||||
errDB = self.netDB_train([warped_B, target_B])
|
||||
|
||||
# Train generators for one batch
|
||||
errGA = self.netGA_train([warped_A, target_A])
|
||||
errGB = self.netGB_train([warped_B, target_B])
|
||||
|
||||
# For calculating average losses
|
||||
self.errDA_sum += errDA[0]
|
||||
self.errDB_sum += errDB[0]
|
||||
self.errGA_sum += errGA[0]
|
||||
self.errGB_sum += errGB[0]
|
||||
self.avg_counter += 1
|
||||
|
||||
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
|
||||
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
|
||||
end='\r')
|
||||
|
||||
if viewer is not None:
|
||||
self.show_sample(viewer)
|
||||
|
||||
def cycle_variables(self, netG):
|
||||
distorted_input = netG.inputs[0]
|
||||
fake_output = netG.outputs[0]
|
||||
fake_output64 = netG.outputs[1]
|
||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
|
||||
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
|
||||
|
||||
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
|
||||
|
||||
fn_generate = K.function([distorted_input], [masked_fake_output])
|
||||
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
|
||||
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
|
||||
fn_bgr = K.function([distorted_input], [rgb])
|
||||
return distorted_input, fake_output, fake_output64, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
|
||||
|
||||
def define_loss(self, netD, real, fake_argb, fake_sz64, distorted, vggface_feat=None):
|
||||
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
|
||||
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
|
||||
fake = alpha * fake_rgb + (1-alpha) * distorted
|
||||
|
||||
if self.use_mixup:
|
||||
dist = Beta(self.mixup_alpha, self.mixup_alpha)
|
||||
lam = dist.sample()
|
||||
# ==========
|
||||
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
|
||||
# ==========
|
||||
output_mixup = netD(mixup)
|
||||
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
|
||||
#output_fake = netD(concatenate([fake, distorted])) # dummy
|
||||
loss_G = 1 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
|
||||
else:
|
||||
output_real = netD(concatenate([real, distorted])) # positive sample
|
||||
output_fake = netD(concatenate([fake, distorted])) # negative sample
|
||||
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
|
||||
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
|
||||
loss_D = loss_D_real + loss_D_fake
|
||||
loss_G = 1 * self.loss_fn(output_fake, K.ones_like(output_fake))
|
||||
# ==========
|
||||
loss_G += K.mean(K.abs(fake_rgb - real))
|
||||
loss_G += K.mean(K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64])))
|
||||
# ==========
|
||||
|
||||
# Perceptual Loss
|
||||
if not vggface_feat is None:
|
||||
def preprocess_vggface(x):
|
||||
x = (x + 1)/2 * 255 # channel order: BGR
|
||||
x -= [93.5940, 104.7624, 129.]
|
||||
return x
|
||||
pl_params = (0.02, 0.3, 0.5)
|
||||
real_sz224 = tf.image.resize_images(real, [224, 224])
|
||||
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
|
||||
# ==========
|
||||
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
|
||||
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
|
||||
# ==========
|
||||
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
|
||||
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
|
||||
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
|
||||
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
|
||||
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
|
||||
|
||||
return loss_D, loss_G
|
||||
|
||||
def show_sample(self, display_fn):
|
||||
_, wA, tA = next(self.train_batchA)
|
||||
_, wB, tB = next(self.train_batchB)
|
||||
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "raw")
|
||||
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "masked")
|
||||
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
|
||||
# Reset the averages
|
||||
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
|
||||
self.avg_counter = 0
|
||||
|
||||
def showG(self, test_A, test_B, path_A, path_B):
|
||||
figure_A = np.stack([
|
||||
test_A,
|
||||
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
||||
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
|
||||
], axis=1 )
|
||||
figure_B = np.stack([
|
||||
test_B,
|
||||
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
||||
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
|
||||
], axis=1 )
|
||||
|
||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
||||
figure = stack_images(figure)
|
||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
||||
return figure
|
||||
|
||||
def showG_mask(self, test_A, test_B, path_A, path_B):
|
||||
figure_A = np.stack([
|
||||
test_A,
|
||||
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
||||
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
|
||||
], axis=1 )
|
||||
figure_B = np.stack([
|
||||
test_B,
|
||||
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
||||
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
|
||||
], axis=1 )
|
||||
|
||||
figure = np.concatenate([figure_A, figure_B], axis=0 )
|
||||
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
|
||||
figure = stack_images(figure)
|
||||
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
|
||||
return figure
|
||||
7
plugins/Model_GAN128/__init__.py
Normal file
7
plugins/Model_GAN128/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
__author__ = """Based on https://github.com/shaoanlu/"""
|
||||
__version__ = '0.1.0'
|
||||
|
||||
from .Model import GANModel as Model
|
||||
from .Trainer import Trainer
|
||||
145
plugins/Model_GAN128/instance_normalization.py
Normal file
145
plugins/Model_GAN128/instance_normalization.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
from keras.engine import Layer, InputSpec
|
||||
from keras import initializers, regularizers, constraints
|
||||
from keras import backend as K
|
||||
from keras.utils.generic_utils import get_custom_objects
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class InstanceNormalization(Layer):
|
||||
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
|
||||
Normalize the activations of the previous layer at each step,
|
||||
i.e. applies a transformation that maintains the mean activation
|
||||
close to 0 and the activation standard deviation close to 1.
|
||||
# Arguments
|
||||
axis: Integer, the axis that should be normalized
|
||||
(typically the features axis).
|
||||
For instance, after a `Conv2D` layer with
|
||||
`data_format="channels_first"`,
|
||||
set `axis=1` in `InstanceNormalization`.
|
||||
Setting `axis=None` will normalize all values in each instance of the batch.
|
||||
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
|
||||
epsilon: Small float added to variance to avoid dividing by zero.
|
||||
center: If True, add offset of `beta` to normalized tensor.
|
||||
If False, `beta` is ignored.
|
||||
scale: If True, multiply by `gamma`.
|
||||
If False, `gamma` is not used.
|
||||
When the next layer is linear (also e.g. `nn.relu`),
|
||||
this can be disabled since the scaling
|
||||
will be done by the next layer.
|
||||
beta_initializer: Initializer for the beta weight.
|
||||
gamma_initializer: Initializer for the gamma weight.
|
||||
beta_regularizer: Optional regularizer for the beta weight.
|
||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||
beta_constraint: Optional constraint for the beta weight.
|
||||
gamma_constraint: Optional constraint for the gamma weight.
|
||||
# Input shape
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
when using this layer as the first layer in a model.
|
||||
# Output shape
|
||||
Same shape as input.
|
||||
# References
|
||||
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
|
||||
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
|
||||
"""
|
||||
def __init__(self,
|
||||
axis=None,
|
||||
epsilon=1e-3,
|
||||
center=True,
|
||||
scale=True,
|
||||
beta_initializer='zeros',
|
||||
gamma_initializer='ones',
|
||||
beta_regularizer=None,
|
||||
gamma_regularizer=None,
|
||||
beta_constraint=None,
|
||||
gamma_constraint=None,
|
||||
**kwargs):
|
||||
super(InstanceNormalization, self).__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self.axis = axis
|
||||
self.epsilon = epsilon
|
||||
self.center = center
|
||||
self.scale = scale
|
||||
self.beta_initializer = initializers.get(beta_initializer)
|
||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||
self.beta_constraint = constraints.get(beta_constraint)
|
||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
ndim = len(input_shape)
|
||||
if self.axis == 0:
|
||||
raise ValueError('Axis cannot be zero')
|
||||
|
||||
if (self.axis is not None) and (ndim == 2):
|
||||
raise ValueError('Cannot specify axis for rank 1 tensor')
|
||||
|
||||
self.input_spec = InputSpec(ndim=ndim)
|
||||
|
||||
if self.axis is None:
|
||||
shape = (1,)
|
||||
else:
|
||||
shape = (input_shape[self.axis],)
|
||||
|
||||
if self.scale:
|
||||
self.gamma = self.add_weight(shape=shape,
|
||||
name='gamma',
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint)
|
||||
else:
|
||||
self.gamma = None
|
||||
if self.center:
|
||||
self.beta = self.add_weight(shape=shape,
|
||||
name='beta',
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint)
|
||||
else:
|
||||
self.beta = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
input_shape = K.int_shape(inputs)
|
||||
reduction_axes = list(range(0, len(input_shape)))
|
||||
|
||||
if (self.axis is not None):
|
||||
del reduction_axes[self.axis]
|
||||
|
||||
del reduction_axes[0]
|
||||
|
||||
mean = K.mean(inputs, reduction_axes, keepdims=True)
|
||||
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
|
||||
normed = (inputs - mean) / stddev
|
||||
|
||||
broadcast_shape = [1] * len(input_shape)
|
||||
if self.axis is not None:
|
||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||
|
||||
if self.scale:
|
||||
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
||||
normed = normed * broadcast_gamma
|
||||
if self.center:
|
||||
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
||||
normed = normed + broadcast_beta
|
||||
return normed
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'axis': self.axis,
|
||||
'epsilon': self.epsilon,
|
||||
'center': self.center,
|
||||
'scale': self.scale,
|
||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
||||
}
|
||||
base_config = super(InstanceNormalization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})
|
||||
|
|
@ -6,13 +6,13 @@ from keras.layers.advanced_activations import LeakyReLU
|
|||
from keras.layers.convolutional import Conv2D
|
||||
from keras.optimizers import Adam
|
||||
|
||||
from lib.ModelAE import ModelAE, TrainerAE
|
||||
from .Model_Original import AutoEncoder, Trainer
|
||||
from lib.PixelShuffler import PixelShuffler
|
||||
|
||||
IMAGE_SHAPE = (64, 64, 3)
|
||||
ENCODER_DIM = 512
|
||||
|
||||
class Model(ModelAE):
|
||||
class Model(AutoEncoder):
|
||||
def initModel(self):
|
||||
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
x = Input(shape=IMAGE_SHAPE)
|
||||
|
|
@ -63,6 +63,3 @@ class Model(ModelAE):
|
|||
x = self.upscale(64)(x)
|
||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
return KerasModel(input_, x)
|
||||
|
||||
class Trainer(TrainerAE):
|
||||
"""Empty inheritance"""
|
||||
35
plugins/Model_Original/AutoEncoder.py
Normal file
35
plugins/Model_Original/AutoEncoder.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# AutoEncoder base classes
|
||||
|
||||
encoderH5 = 'encoder.h5'
|
||||
decoder_AH5 = 'decoder_A.h5'
|
||||
decoder_BH5 = 'decoder_B.h5'
|
||||
|
||||
class AutoEncoder:
|
||||
def __init__(self, model_dir):
|
||||
self.model_dir = model_dir
|
||||
|
||||
self.encoder = self.Encoder()
|
||||
self.decoder_A = self.Decoder()
|
||||
self.decoder_B = self.Decoder()
|
||||
|
||||
self.initModel()
|
||||
|
||||
def load(self, swapped):
|
||||
(face_A,face_B) = (decoder_AH5, decoder_BH5) if not swapped else (decoder_BH5, decoder_AH5)
|
||||
|
||||
try:
|
||||
self.encoder.load_weights(str(self.model_dir / encoderH5))
|
||||
self.decoder_A.load_weights(str(self.model_dir / face_A))
|
||||
self.decoder_B.load_weights(str(self.model_dir / face_B))
|
||||
print('loaded model weights')
|
||||
return True
|
||||
except Exception as e:
|
||||
print('Failed loading existing training data.')
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def save_weights(self):
|
||||
self.encoder.save_weights(str(self.model_dir / encoderH5))
|
||||
self.decoder_A.save_weights(str(self.model_dir / decoder_AH5))
|
||||
self.decoder_B.save_weights(str(self.model_dir / decoder_BH5))
|
||||
print('saved model weights')
|
||||
|
|
@ -6,13 +6,13 @@ from keras.layers.advanced_activations import LeakyReLU
|
|||
from keras.layers.convolutional import Conv2D
|
||||
from keras.optimizers import Adam
|
||||
|
||||
from lib.ModelAE import ModelAE, TrainerAE
|
||||
from .AutoEncoder import AutoEncoder
|
||||
from lib.PixelShuffler import PixelShuffler
|
||||
|
||||
IMAGE_SHAPE = (64, 64, 3)
|
||||
ENCODER_DIM = 1024
|
||||
|
||||
class Model(ModelAE):
|
||||
class Model(AutoEncoder):
|
||||
def initModel(self):
|
||||
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
x = Input(shape=IMAGE_SHAPE)
|
||||
|
|
@ -62,7 +62,4 @@ class Model(ModelAE):
|
|||
x = self.upscale(128)(x)
|
||||
x = self.upscale(64)(x)
|
||||
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
return KerasModel(input_, x)
|
||||
|
||||
class Trainer(TrainerAE):
|
||||
"""Empty inheritance"""
|
||||
return KerasModel(input_, x)
|
||||
50
plugins/Model_Original/Trainer.py
Normal file
50
plugins/Model_Original/Trainer.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
|
||||
import time
|
||||
import numpy
|
||||
from lib.training_data import TrainingDataGenerator, stack_images
|
||||
|
||||
class Trainer():
|
||||
random_transform_args = {
|
||||
'rotation_range': 10,
|
||||
'zoom_range': 0.05,
|
||||
'shift_range': 0.05,
|
||||
'random_flip': 0.4,
|
||||
}
|
||||
|
||||
def __init__(self, model, fn_A, fn_B, batch_size, *args):
|
||||
self.batch_size = batch_size
|
||||
self.model = model
|
||||
|
||||
generator = TrainingDataGenerator(self.random_transform_args, 160)
|
||||
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
|
||||
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
|
||||
|
||||
def train_one_step(self, iter, viewer):
|
||||
epoch, warped_A, target_A = next(self.images_A)
|
||||
epoch, warped_B, target_B = next(self.images_B)
|
||||
|
||||
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
|
||||
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
|
||||
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
|
||||
end='\r')
|
||||
|
||||
if viewer is not None:
|
||||
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
|
||||
|
||||
def show_sample(self, test_A, test_B):
|
||||
figure_A = numpy.stack([
|
||||
test_A,
|
||||
self.model.autoencoder_A.predict(test_A),
|
||||
self.model.autoencoder_B.predict(test_A),
|
||||
], axis=1)
|
||||
figure_B = numpy.stack([
|
||||
test_B,
|
||||
self.model.autoencoder_B.predict(test_B),
|
||||
self.model.autoencoder_A.predict(test_B),
|
||||
], axis=1)
|
||||
|
||||
figure = numpy.concatenate([figure_A, figure_B], axis=0)
|
||||
figure = figure.reshape((4, 7) + figure.shape[1:])
|
||||
figure = stack_images(figure)
|
||||
|
||||
return numpy.clip(figure * 255, 0, 255).astype('uint8')
|
||||
8
plugins/Model_Original/__init__.py
Normal file
8
plugins/Model_Original/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
__author__ = """Based on https://reddit.com/u/deepfakes/"""
|
||||
__version__ = '0.1.0'
|
||||
|
||||
from .Model import Model
|
||||
from .Trainer import Trainer
|
||||
from .AutoEncoder import AutoEncoder
|
||||
|
|
@ -1,200 +1,235 @@
|
|||
import cv2
|
||||
import re
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from lib.cli import DirectoryProcessor, FullPaths
|
||||
from lib.utils import BackgroundGenerator, get_folder
|
||||
|
||||
from plugins.PluginLoader import PluginLoader
|
||||
|
||||
class ConvertImage(DirectoryProcessor):
|
||||
filename = ''
|
||||
def create_parser(self, subparser, command, description):
|
||||
self.parser = subparser.add_parser(
|
||||
command,
|
||||
help="Convert a source image to a new one with the face swapped.",
|
||||
description=description,
|
||||
epilog="Questions and feedback: \
|
||||
https://github.com/deepfakes/faceswap-playground"
|
||||
)
|
||||
|
||||
def add_optional_arguments(self, parser):
|
||||
parser.add_argument('-m', '--model-dir',
|
||||
action=FullPaths,
|
||||
dest="model_dir",
|
||||
default="models",
|
||||
help="Model directory. A directory containing the trained model \
|
||||
you wish to process. Defaults to 'models'")
|
||||
|
||||
parser.add_argument('-t', '--trainer',
|
||||
type=str,
|
||||
choices=("Original", "LowMem", "GAN"), # case sensitive because this is used to load a plug-in.
|
||||
default="Original",
|
||||
help="Select the trainer that was used to create the model.")
|
||||
|
||||
parser.add_argument('-s', '--swap-model',
|
||||
action="store_true",
|
||||
dest="swap_model",
|
||||
default=False,
|
||||
help="Swap the model. Instead of A -> B, swap B -> A.")
|
||||
|
||||
parser.add_argument('-c', '--converter',
|
||||
type=str,
|
||||
choices=("Masked", "Adjust", "GAN"), # case sensitive because this is used to load a plugin.
|
||||
default="Masked",
|
||||
help="Converter to use.")
|
||||
|
||||
parser.add_argument('-D', '--detector',
|
||||
type=str,
|
||||
choices=("hog", "cnn"), # case sensitive because this is used to load a plugin.
|
||||
default="hog",
|
||||
help="Detector to use. 'cnn' detects much more angles but will be much more resource intensive and may fail on large files.")
|
||||
|
||||
parser.add_argument('-fr', '--frame-ranges',
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use --frame-ranges 10-50 90-100. \
|
||||
Files must have the frame-number as the last number in the name!"
|
||||
)
|
||||
|
||||
parser.add_argument('-d', '--discard-frames',
|
||||
action="store_true",
|
||||
dest="discard_frames",
|
||||
default=False,
|
||||
help="When used with --frame-ranges discards frames that are not processed instead of writing them out unchanged."
|
||||
)
|
||||
|
||||
parser.add_argument('-f', '--filter',
|
||||
type=str,
|
||||
dest="filter",
|
||||
default="filter.jpg",
|
||||
help="Reference image for the person you want to process. Should be a front portrait"
|
||||
)
|
||||
|
||||
parser.add_argument('-b', '--blur-size',
|
||||
type=int,
|
||||
default=2,
|
||||
help="Blur size. (Masked converter only)")
|
||||
|
||||
|
||||
parser.add_argument('-S', '--seamless',
|
||||
action="store_true",
|
||||
dest="seamless_clone",
|
||||
default=False,
|
||||
help="Seamless mode. (Masked converter only)")
|
||||
|
||||
parser.add_argument('-M', '--mask-type',
|
||||
type=str.lower, #lowercase this, because its just a string later on.
|
||||
dest="mask_type",
|
||||
choices=["rect", "facehull", "facehullandrect"],
|
||||
default="facehullandrect",
|
||||
help="Mask to use to replace faces. (Masked converter only)")
|
||||
|
||||
parser.add_argument('-e', '--erosion-kernel-size',
|
||||
dest="erosion_kernel_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Erosion kernel size. (Masked converter only). Positive values apply erosion which reduces the edge \
|
||||
of the swapped face. Negative values apply dilation which allows the swapped face to cover more space.")
|
||||
|
||||
parser.add_argument('-sm', '--smooth-mask',
|
||||
action="store_true",
|
||||
dest="smooth_mask",
|
||||
default=True,
|
||||
help="Smooth mask (Adjust converter only)")
|
||||
|
||||
parser.add_argument('-aca', '--avg-color-adjust',
|
||||
action="store_true",
|
||||
dest="avg_color_adjust",
|
||||
default=True,
|
||||
help="Average color adjust. (Adjust converter only)")
|
||||
return parser
|
||||
|
||||
def process(self):
|
||||
# Original & LowMem models go with Adjust or Masked converter
|
||||
# GAN converter & model must go together
|
||||
# Note: GAN prediction outputs a mask + an image, while other predicts only an image
|
||||
model_name = self.arguments.trainer
|
||||
conv_name = self.arguments.converter
|
||||
|
||||
if conv_name.startswith("GAN"):
|
||||
assert model_name.startswith("GAN") is True, "GAN converter can only be used with GAN model!"
|
||||
else:
|
||||
assert model_name.startswith("GAN") is False, "GAN model can only be used with GAN converter!"
|
||||
|
||||
model = PluginLoader.get_model(model_name)(get_folder(self.arguments.model_dir))
|
||||
if not model.load(self.arguments.swap_model):
|
||||
print('Model Not Found! A valid model must be provided to continue!')
|
||||
exit(1)
|
||||
|
||||
converter = PluginLoader.get_converter(conv_name)(model.converter(False),
|
||||
blur_size=self.arguments.blur_size,
|
||||
seamless_clone=self.arguments.seamless_clone,
|
||||
mask_type=self.arguments.mask_type,
|
||||
erosion_kernel_size=self.arguments.erosion_kernel_size,
|
||||
smooth_mask=self.arguments.smooth_mask,
|
||||
avg_color_adjust=self.arguments.avg_color_adjust
|
||||
)
|
||||
|
||||
batch = BackgroundGenerator(self.prepare_images(), 1)
|
||||
|
||||
# frame ranges stuff...
|
||||
self.frame_ranges = None
|
||||
|
||||
# split out the frame ranges and parse out "min" and "max" values
|
||||
minmax = {
|
||||
"min": 0, # never any frames less than 0
|
||||
"max": float("inf")
|
||||
}
|
||||
|
||||
if self.arguments.frame_ranges:
|
||||
self.frame_ranges = [tuple(map(lambda q: minmax[q] if q in minmax.keys() else int(q), v.split("-"))) for v in self.arguments.frame_ranges]
|
||||
|
||||
# last number regex. I know regex is hacky, but its reliablyhacky(tm).
|
||||
self.imageidxre = re.compile(r'(\d+)(?!.*\d)')
|
||||
|
||||
for item in batch.iterator():
|
||||
self.convert(converter, item)
|
||||
|
||||
def check_skipframe(self, filename):
|
||||
try:
|
||||
idx = int(self.imageidxre.findall(filename)[0])
|
||||
return not any(map(lambda b: b[0]<=idx<=b[1], self.frame_ranges))
|
||||
except:
|
||||
return False
|
||||
|
||||
def convert(self, converter, item):
|
||||
try:
|
||||
(filename, image, faces) = item
|
||||
|
||||
skip = self.check_skipframe(filename)
|
||||
if self.arguments.discard_frames and skip:
|
||||
return
|
||||
|
||||
if not skip: # process as normal
|
||||
for idx, face in faces:
|
||||
image = converter.patch_image(image, face)
|
||||
|
||||
output_file = get_folder(self.output_dir) / Path(filename).name
|
||||
cv2.imwrite(str(output_file), image)
|
||||
except Exception as e:
|
||||
print('Failed to convert image: {}. Reason: {}'.format(filename, e))
|
||||
|
||||
def prepare_images(self):
|
||||
self.read_alignments()
|
||||
is_have_alignments = self.have_alignments()
|
||||
for filename in tqdm(self.read_directory()):
|
||||
image = cv2.imread(filename)
|
||||
|
||||
if is_have_alignments:
|
||||
if self.have_face(filename):
|
||||
faces = self.get_faces_alignments(filename, image)
|
||||
else:
|
||||
print ('no alignment found for {}, skipping'.format(os.path.basename(filename)))
|
||||
continue
|
||||
else:
|
||||
faces = self.get_faces(image)
|
||||
yield filename, image, faces
|
||||
import cv2
|
||||
import re
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from lib.cli import DirectoryProcessor, FullPaths
|
||||
from lib.utils import BackgroundGenerator, get_folder, get_image_paths
|
||||
|
||||
from plugins.PluginLoader import PluginLoader
|
||||
|
||||
class ConvertImage(DirectoryProcessor):
|
||||
filename = ''
|
||||
def create_parser(self, subparser, command, description):
|
||||
self.parser = subparser.add_parser(
|
||||
command,
|
||||
help="Convert a source image to a new one with the face swapped.",
|
||||
description=description,
|
||||
epilog="Questions and feedback: \
|
||||
https://github.com/deepfakes/faceswap-playground"
|
||||
)
|
||||
|
||||
def add_optional_arguments(self, parser):
|
||||
parser.add_argument('-m', '--model-dir',
|
||||
action=FullPaths,
|
||||
dest="model_dir",
|
||||
default="models",
|
||||
help="Model directory. A directory containing the trained model \
|
||||
you wish to process. Defaults to 'models'")
|
||||
|
||||
parser.add_argument('-a', '--input-aligned-dir',
|
||||
action=FullPaths,
|
||||
dest="input_aligned_dir",
|
||||
default=None,
|
||||
help="Input \"aligned directory\". A directory that should contain the \
|
||||
aligned faces extracted from the input files. If you delete faces from \
|
||||
this folder, they'll be skipped during conversion. If no aligned dir is \
|
||||
specified, all faces will be converted.")
|
||||
|
||||
parser.add_argument('-t', '--trainer',
|
||||
type=str,
|
||||
choices=("Original", "LowMem", "GAN", "GAN128"), # case sensitive because this is used to load a plug-in.
|
||||
default="Original",
|
||||
help="Select the trainer that was used to create the model.")
|
||||
|
||||
parser.add_argument('-s', '--swap-model',
|
||||
action="store_true",
|
||||
dest="swap_model",
|
||||
default=False,
|
||||
help="Swap the model. Instead of A -> B, swap B -> A.")
|
||||
|
||||
parser.add_argument('-c', '--converter',
|
||||
type=str,
|
||||
choices=("Masked", "Adjust"), # case sensitive because this is used to load a plugin.
|
||||
default="Masked",
|
||||
help="Converter to use.")
|
||||
|
||||
parser.add_argument('-D', '--detector',
|
||||
type=str,
|
||||
choices=("hog", "cnn"), # case sensitive because this is used to load a plugin.
|
||||
default="hog",
|
||||
help="Detector to use. 'cnn' detects much more angles but will be much more resource intensive and may fail on large files.")
|
||||
|
||||
parser.add_argument('-fr', '--frame-ranges',
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use --frame-ranges 10-50 90-100. \
|
||||
Files must have the frame-number as the last number in the name!"
|
||||
)
|
||||
|
||||
parser.add_argument('-d', '--discard-frames',
|
||||
action="store_true",
|
||||
dest="discard_frames",
|
||||
default=False,
|
||||
help="When used with --frame-ranges discards frames that are not processed instead of writing them out unchanged."
|
||||
)
|
||||
|
||||
parser.add_argument('-f', '--filter',
|
||||
type=str,
|
||||
dest="filter",
|
||||
default="filter.jpg",
|
||||
help="Reference image for the person you want to process. Should be a front portrait"
|
||||
)
|
||||
|
||||
parser.add_argument('-b', '--blur-size',
|
||||
type=int,
|
||||
default=2,
|
||||
help="Blur size. (Masked converter only)")
|
||||
|
||||
|
||||
parser.add_argument('-S', '--seamless',
|
||||
action="store_true",
|
||||
dest="seamless_clone",
|
||||
default=False,
|
||||
help="Use cv2's seamless clone. (Masked converter only)")
|
||||
|
||||
parser.add_argument('-M', '--mask-type',
|
||||
type=str.lower, #lowercase this, because its just a string later on.
|
||||
dest="mask_type",
|
||||
choices=["rect", "facehull", "facehullandrect"],
|
||||
default="facehullandrect",
|
||||
help="Mask to use to replace faces. (Masked converter only)")
|
||||
|
||||
parser.add_argument('-e', '--erosion-kernel-size',
|
||||
dest="erosion_kernel_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Erosion kernel size. (Masked converter only). Positive values apply erosion which reduces the edge of the swapped face. Negative values apply dilation which allows the swapped face to cover more space.")
|
||||
|
||||
parser.add_argument('-mh', '--match-histgoram',
|
||||
action="store_true",
|
||||
dest="match_histogram",
|
||||
default=False,
|
||||
help="Use histogram matching. (Masked converter only)")
|
||||
|
||||
parser.add_argument('-sm', '--smooth-mask',
|
||||
action="store_true",
|
||||
dest="smooth_mask",
|
||||
default=True,
|
||||
help="Smooth mask (Adjust converter only)")
|
||||
|
||||
parser.add_argument('-aca', '--avg-color-adjust',
|
||||
action="store_true",
|
||||
dest="avg_color_adjust",
|
||||
default=True,
|
||||
help="Average color adjust. (Adjust converter only)")
|
||||
return parser
|
||||
|
||||
def process(self):
|
||||
# Original & LowMem models go with Adjust or Masked converter
|
||||
# Note: GAN prediction outputs a mask + an image, while other predicts only an image
|
||||
model_name = self.arguments.trainer
|
||||
conv_name = self.arguments.converter
|
||||
self.input_aligned_dir = None
|
||||
|
||||
model = PluginLoader.get_model(model_name)(get_folder(self.arguments.model_dir))
|
||||
if not model.load(self.arguments.swap_model):
|
||||
print('Model Not Found! A valid model must be provided to continue!')
|
||||
exit(1)
|
||||
|
||||
input_aligned_dir = Path(self.arguments.input_dir)/Path('aligned')
|
||||
if self.arguments.input_aligned_dir is not None:
|
||||
input_aligned_dir = self.arguments.input_aligned_dir
|
||||
try:
|
||||
self.input_aligned_dir = [Path(path) for path in get_image_paths(input_aligned_dir)]
|
||||
if len(self.input_aligned_dir) == 0:
|
||||
print('Aligned directory is empty, no faces will be converted!')
|
||||
elif len(self.input_aligned_dir) <= len(self.input_dir)/3:
|
||||
print('Aligned directory contains an amount of images much less than the input, are you sure this is the right directory?')
|
||||
except:
|
||||
print('Aligned directory not found. All faces listed in the alignments file will be converted.')
|
||||
|
||||
converter = PluginLoader.get_converter(conv_name)(model.converter(False),
|
||||
trainer=self.arguments.trainer,
|
||||
blur_size=self.arguments.blur_size,
|
||||
seamless_clone=self.arguments.seamless_clone,
|
||||
mask_type=self.arguments.mask_type,
|
||||
erosion_kernel_size=self.arguments.erosion_kernel_size,
|
||||
match_histogram=self.arguments.match_histogram,
|
||||
smooth_mask=self.arguments.smooth_mask,
|
||||
avg_color_adjust=self.arguments.avg_color_adjust
|
||||
)
|
||||
|
||||
batch = BackgroundGenerator(self.prepare_images(), 1)
|
||||
|
||||
# frame ranges stuff...
|
||||
self.frame_ranges = None
|
||||
|
||||
# split out the frame ranges and parse out "min" and "max" values
|
||||
minmax = {
|
||||
"min": 0, # never any frames less than 0
|
||||
"max": float("inf")
|
||||
}
|
||||
|
||||
if self.arguments.frame_ranges:
|
||||
self.frame_ranges = [tuple(map(lambda q: minmax[q] if q in minmax.keys() else int(q), v.split("-"))) for v in self.arguments.frame_ranges]
|
||||
|
||||
# last number regex. I know regex is hacky, but its reliablyhacky(tm).
|
||||
self.imageidxre = re.compile(r'(\d+)(?!.*\d)')
|
||||
|
||||
for item in batch.iterator():
|
||||
self.convert(converter, item)
|
||||
|
||||
def check_skipframe(self, filename):
|
||||
try:
|
||||
idx = int(self.imageidxre.findall(filename)[0])
|
||||
return not any(map(lambda b: b[0]<=idx<=b[1], self.frame_ranges))
|
||||
except:
|
||||
return False
|
||||
|
||||
def check_skipface(self, filename, face_idx):
|
||||
aligned_face_name = '{}_{}{}'.format(Path(filename).stem, face_idx, Path(filename).suffix)
|
||||
aligned_face_file = Path(self.arguments.input_aligned_dir) / Path(aligned_face_name)
|
||||
# TODO: Remove this temporary fix for backwards compatibility of filenames
|
||||
bk_compat_aligned_face_name = '{}{}{}'.format(Path(filename).stem, face_idx, Path(filename).suffix)
|
||||
bk_compat_aligned_face_file = Path(self.arguments.input_aligned_dir) / Path(bk_compat_aligned_face_name)
|
||||
return aligned_face_file not in self.input_aligned_dir and bk_compat_aligned_face_file not in self.input_aligned_dir
|
||||
|
||||
def convert(self, converter, item):
|
||||
try:
|
||||
(filename, image, faces) = item
|
||||
|
||||
skip = self.check_skipframe(filename)
|
||||
if self.arguments.discard_frames and skip:
|
||||
return
|
||||
|
||||
if not skip: # process frame as normal
|
||||
for idx, face in faces:
|
||||
if self.input_aligned_dir is not None and self.check_skipface(filename, idx):
|
||||
print ('face {} for frame {} was deleted, skipping'.format(idx, os.path.basename(filename)))
|
||||
continue
|
||||
image = converter.patch_image(image, face, 64 if "128" not in self.arguments.trainer else 128)
|
||||
# TODO: This switch between 64 and 128 is a hack for now. We should have a separate cli option for size
|
||||
|
||||
output_file = get_folder(self.output_dir) / Path(filename).name
|
||||
cv2.imwrite(str(output_file), image)
|
||||
except Exception as e:
|
||||
print('Failed to convert image: {}. Reason: {}'.format(filename, e))
|
||||
|
||||
def prepare_images(self):
|
||||
self.read_alignments()
|
||||
is_have_alignments = self.have_alignments()
|
||||
for filename in tqdm(self.read_directory()):
|
||||
image = cv2.imread(filename)
|
||||
|
||||
if is_have_alignments:
|
||||
if self.have_face(filename):
|
||||
faces = self.get_faces_alignments(filename, image)
|
||||
else:
|
||||
print ('no alignment found for {}, skipping'.format(os.path.basename(filename)))
|
||||
continue
|
||||
else:
|
||||
faces = self.get_faces(image)
|
||||
yield filename, image, faces
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class ExtractTrainingData(DirectoryProcessor):
|
|||
|
||||
resized_image = self.extractor.extract(image, face, 256)
|
||||
output_file = get_folder(self.output_dir) / Path(filename).stem
|
||||
cv2.imwrite(str(output_file) + str(idx) + Path(filename).suffix, resized_image)
|
||||
cv2.imwrite('{}_{}{}'.format(str(output_file), str(idx), Path(filename).suffix), resized_image)
|
||||
f = {
|
||||
"x": face.x,
|
||||
"w": face.w,
|
||||
|
|
|
|||
393
scripts/train.py
393
scripts/train.py
|
|
@ -1,194 +1,199 @@
|
|||
import cv2
|
||||
import numpy
|
||||
import time
|
||||
|
||||
from threading import Lock
|
||||
from lib.utils import get_image_paths, get_folder
|
||||
from lib.cli import FullPaths
|
||||
from plugins.PluginLoader import PluginLoader
|
||||
|
||||
class TrainingProcessor(object):
|
||||
arguments = None
|
||||
|
||||
def __init__(self, subparser, command, description='default'):
|
||||
self.parse_arguments(description, subparser, command)
|
||||
self.lock = Lock()
|
||||
|
||||
def process_arguments(self, arguments):
|
||||
self.arguments = arguments
|
||||
print("Model A Directory: {}".format(self.arguments.input_A))
|
||||
print("Model B Directory: {}".format(self.arguments.input_B))
|
||||
print("Training data directory: {}".format(self.arguments.model_dir))
|
||||
|
||||
self.process()
|
||||
|
||||
def parse_arguments(self, description, subparser, command):
|
||||
parser = subparser.add_parser(
|
||||
command,
|
||||
help="This command trains the model for the two faces A and B.",
|
||||
description=description,
|
||||
epilog="Questions and feedback: \
|
||||
https://github.com/deepfakes/faceswap-playground"
|
||||
)
|
||||
|
||||
parser.add_argument('-A', '--input-A',
|
||||
action=FullPaths,
|
||||
dest="input_A",
|
||||
default="input_A",
|
||||
help="Input directory. A directory containing training images for face A.\
|
||||
Defaults to 'input'")
|
||||
parser.add_argument('-B', '--input-B',
|
||||
action=FullPaths,
|
||||
dest="input_B",
|
||||
default="input_B",
|
||||
help="Input directory. A directory containing training images for face B.\
|
||||
Defaults to 'input'")
|
||||
parser.add_argument('-m', '--model-dir',
|
||||
action=FullPaths,
|
||||
dest="model_dir",
|
||||
default="models",
|
||||
help="Model directory. This is where the training data will \
|
||||
be stored. Defaults to 'model'")
|
||||
parser.add_argument('-p', '--preview',
|
||||
action="store_true",
|
||||
dest="preview",
|
||||
default=False,
|
||||
help="Show preview output. If not specified, write progress \
|
||||
to file.")
|
||||
parser.add_argument('-v', '--verbose',
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
default=False,
|
||||
help="Show verbose output")
|
||||
parser.add_argument('-s', '--save-interval',
|
||||
type=int,
|
||||
dest="save_interval",
|
||||
default=100,
|
||||
help="Sets the number of iterations before saving the model.")
|
||||
parser.add_argument('-w', '--write-image',
|
||||
action="store_true",
|
||||
dest="write_image",
|
||||
default=False,
|
||||
help="Writes the training result to a file even on preview mode.")
|
||||
parser.add_argument('-t', '--trainer',
|
||||
type=str,
|
||||
choices=("Original", "LowMem", "GAN"),
|
||||
default="Original",
|
||||
help="Select which trainer to use, LowMem for cards < 2gb.")
|
||||
parser.add_argument('-bs', '--batch-size',
|
||||
type=int,
|
||||
default=64,
|
||||
help="Batch size, as a power of 2 (64, 128, 256, etc)")
|
||||
parser.add_argument('-ag', '--allow-growth',
|
||||
action="store_true",
|
||||
dest="allow_growth",
|
||||
default=False,
|
||||
help="Sets allow_growth option of Tensorflow to spare memory on some configs")
|
||||
parser.add_argument('-ep', '--epochs',
|
||||
type=int,
|
||||
default=1000000,
|
||||
help="Length of training in epochs.")
|
||||
parser = self.add_optional_arguments(parser)
|
||||
parser.set_defaults(func=self.process_arguments)
|
||||
|
||||
def add_optional_arguments(self, parser):
|
||||
# Override this for custom arguments
|
||||
return parser
|
||||
|
||||
def process(self):
|
||||
import threading
|
||||
self.stop = False
|
||||
self.save_now = False
|
||||
|
||||
thr = threading.Thread(target=self.processThread, args=(), kwargs={})
|
||||
thr.start()
|
||||
|
||||
if self.arguments.preview:
|
||||
print('Using live preview')
|
||||
while True:
|
||||
try:
|
||||
with self.lock:
|
||||
for name, image in self.preview_buffer.items():
|
||||
cv2.imshow(name, image)
|
||||
|
||||
key = cv2.waitKey(1000)
|
||||
if key == ord('\n') or key == ord('\r'):
|
||||
break
|
||||
if key == ord('s'):
|
||||
self.save_now = True
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
else:
|
||||
input() # TODO how to catch a specific key instead of Enter?
|
||||
# there isnt a good multiplatform solution: https://stackoverflow.com/questions/3523174/raw-input-in-python-without-pressing-enter
|
||||
|
||||
print("Exit requested! The trainer will complete its current cycle, save the models and quit (it can take up a couple of seconds depending on your training speed). If you want to kill it now, press Ctrl + c")
|
||||
self.stop = True
|
||||
thr.join() # waits until thread finishes
|
||||
|
||||
def processThread(self):
|
||||
if self.arguments.allow_growth:
|
||||
self.set_tf_allow_growth()
|
||||
|
||||
print('Loading data, this may take a while...')
|
||||
# this is so that you can enter case insensitive values for trainer
|
||||
trainer = self.arguments.trainer
|
||||
trainer = "LowMem" if trainer.lower() == "lowmem" else trainer
|
||||
model = PluginLoader.get_model(trainer)(get_folder(self.arguments.model_dir))
|
||||
model.load(swapped=False)
|
||||
|
||||
images_A = get_image_paths(self.arguments.input_A)
|
||||
images_B = get_image_paths(self.arguments.input_B)
|
||||
trainer = PluginLoader.get_trainer(trainer)
|
||||
trainer = trainer(model, images_A, images_B, batch_size=self.arguments.batch_size)
|
||||
|
||||
try:
|
||||
print('Starting. Press "Enter" to stop training and save model')
|
||||
|
||||
for epoch in range(0, self.arguments.epochs):
|
||||
|
||||
save_iteration = epoch % self.arguments.save_interval == 0
|
||||
|
||||
trainer.train_one_step(epoch, self.show if (save_iteration or self.save_now) else None)
|
||||
|
||||
if save_iteration:
|
||||
model.save_weights()
|
||||
|
||||
if self.stop:
|
||||
model.save_weights()
|
||||
exit()
|
||||
|
||||
if self.save_now:
|
||||
model.save_weights()
|
||||
self.save_now = False
|
||||
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
model.save_weights()
|
||||
except KeyboardInterrupt:
|
||||
print('Saving model weights has been cancelled!')
|
||||
exit(0)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
exit(1)
|
||||
|
||||
def set_tf_allow_growth(self):
|
||||
import tensorflow as tf
|
||||
from keras.backend.tensorflow_backend import set_session
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
config.gpu_options.visible_device_list="0"
|
||||
set_session(tf.Session(config=config))
|
||||
|
||||
preview_buffer = {}
|
||||
|
||||
def show(self, image, name=''):
|
||||
try:
|
||||
if self.arguments.preview:
|
||||
with self.lock:
|
||||
self.preview_buffer[name] = image
|
||||
elif self.arguments.write_image:
|
||||
cv2.imwrite('_sample_{}.jpg'.format(name), image)
|
||||
except Exception as e:
|
||||
print("could not preview sample")
|
||||
print(e)
|
||||
import cv2
|
||||
import numpy
|
||||
import time
|
||||
|
||||
from threading import Lock
|
||||
from lib.utils import get_image_paths, get_folder
|
||||
from lib.cli import FullPaths
|
||||
from plugins.PluginLoader import PluginLoader
|
||||
|
||||
class TrainingProcessor(object):
|
||||
arguments = None
|
||||
|
||||
def __init__(self, subparser, command, description='default'):
|
||||
self.parse_arguments(description, subparser, command)
|
||||
self.lock = Lock()
|
||||
|
||||
def process_arguments(self, arguments):
|
||||
self.arguments = arguments
|
||||
print("Model A Directory: {}".format(self.arguments.input_A))
|
||||
print("Model B Directory: {}".format(self.arguments.input_B))
|
||||
print("Training data directory: {}".format(self.arguments.model_dir))
|
||||
|
||||
self.process()
|
||||
|
||||
def parse_arguments(self, description, subparser, command):
|
||||
parser = subparser.add_parser(
|
||||
command,
|
||||
help="This command trains the model for the two faces A and B.",
|
||||
description=description,
|
||||
epilog="Questions and feedback: \
|
||||
https://github.com/deepfakes/faceswap-playground"
|
||||
)
|
||||
|
||||
parser.add_argument('-A', '--input-A',
|
||||
action=FullPaths,
|
||||
dest="input_A",
|
||||
default="input_A",
|
||||
help="Input directory. A directory containing training images for face A.\
|
||||
Defaults to 'input'")
|
||||
parser.add_argument('-B', '--input-B',
|
||||
action=FullPaths,
|
||||
dest="input_B",
|
||||
default="input_B",
|
||||
help="Input directory. A directory containing training images for face B.\
|
||||
Defaults to 'input'")
|
||||
parser.add_argument('-m', '--model-dir',
|
||||
action=FullPaths,
|
||||
dest="model_dir",
|
||||
default="models",
|
||||
help="Model directory. This is where the training data will \
|
||||
be stored. Defaults to 'model'")
|
||||
parser.add_argument('-p', '--preview',
|
||||
action="store_true",
|
||||
dest="preview",
|
||||
default=False,
|
||||
help="Show preview output. If not specified, write progress \
|
||||
to file.")
|
||||
parser.add_argument('-v', '--verbose',
|
||||
action="store_true",
|
||||
dest="verbose",
|
||||
default=False,
|
||||
help="Show verbose output")
|
||||
parser.add_argument('-s', '--save-interval',
|
||||
type=int,
|
||||
dest="save_interval",
|
||||
default=100,
|
||||
help="Sets the number of iterations before saving the model.")
|
||||
parser.add_argument('-w', '--write-image',
|
||||
action="store_true",
|
||||
dest="write_image",
|
||||
default=False,
|
||||
help="Writes the training result to a file even on preview mode.")
|
||||
parser.add_argument('-t', '--trainer',
|
||||
type=str,
|
||||
choices=("Original", "LowMem", "GAN", "GAN128"),
|
||||
default="Original",
|
||||
help="Select which trainer to use, LowMem for cards < 2gb.")
|
||||
parser.add_argument('-pl', '--use-perceptual-loss',
|
||||
action="store_true",
|
||||
dest="perceptual_loss",
|
||||
default=False,
|
||||
help="Use perceptual loss while training")
|
||||
parser.add_argument('-bs', '--batch-size',
|
||||
type=int,
|
||||
default=64,
|
||||
help="Batch size, as a power of 2 (64, 128, 256, etc)")
|
||||
parser.add_argument('-ag', '--allow-growth',
|
||||
action="store_true",
|
||||
dest="allow_growth",
|
||||
default=False,
|
||||
help="Sets allow_growth option of Tensorflow to spare memory on some configs")
|
||||
parser.add_argument('-ep', '--epochs',
|
||||
type=int,
|
||||
default=1000000,
|
||||
help="Length of training in epochs.")
|
||||
parser = self.add_optional_arguments(parser)
|
||||
parser.set_defaults(func=self.process_arguments)
|
||||
|
||||
def add_optional_arguments(self, parser):
|
||||
# Override this for custom arguments
|
||||
return parser
|
||||
|
||||
def process(self):
|
||||
import threading
|
||||
self.stop = False
|
||||
self.save_now = False
|
||||
|
||||
thr = threading.Thread(target=self.processThread, args=(), kwargs={})
|
||||
thr.start()
|
||||
|
||||
if self.arguments.preview:
|
||||
print('Using live preview')
|
||||
while True:
|
||||
try:
|
||||
with self.lock:
|
||||
for name, image in self.preview_buffer.items():
|
||||
cv2.imshow(name, image)
|
||||
|
||||
key = cv2.waitKey(1000)
|
||||
if key == ord('\n') or key == ord('\r'):
|
||||
break
|
||||
if key == ord('s'):
|
||||
self.save_now = True
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
else:
|
||||
input() # TODO how to catch a specific key instead of Enter?
|
||||
# there isnt a good multiplatform solution: https://stackoverflow.com/questions/3523174/raw-input-in-python-without-pressing-enter
|
||||
|
||||
print("Exit requested! The trainer will complete its current cycle, save the models and quit (it can take up a couple of seconds depending on your training speed). If you want to kill it now, press Ctrl + c")
|
||||
self.stop = True
|
||||
thr.join() # waits until thread finishes
|
||||
|
||||
def processThread(self):
|
||||
if self.arguments.allow_growth:
|
||||
self.set_tf_allow_growth()
|
||||
|
||||
print('Loading data, this may take a while...')
|
||||
# this is so that you can enter case insensitive values for trainer
|
||||
trainer = self.arguments.trainer
|
||||
trainer = "LowMem" if trainer.lower() == "lowmem" else trainer
|
||||
model = PluginLoader.get_model(trainer)(get_folder(self.arguments.model_dir))
|
||||
model.load(swapped=False)
|
||||
|
||||
images_A = get_image_paths(self.arguments.input_A)
|
||||
images_B = get_image_paths(self.arguments.input_B)
|
||||
trainer = PluginLoader.get_trainer(trainer)
|
||||
trainer = trainer(model, images_A, images_B, self.arguments.batch_size, self.arguments.perceptual_loss)
|
||||
|
||||
try:
|
||||
print('Starting. Press "Enter" to stop training and save model')
|
||||
|
||||
for epoch in range(0, self.arguments.epochs):
|
||||
|
||||
save_iteration = epoch % self.arguments.save_interval == 0
|
||||
|
||||
trainer.train_one_step(epoch, self.show if (save_iteration or self.save_now) else None)
|
||||
|
||||
if save_iteration:
|
||||
model.save_weights()
|
||||
|
||||
if self.stop:
|
||||
model.save_weights()
|
||||
exit()
|
||||
|
||||
if self.save_now:
|
||||
model.save_weights()
|
||||
self.save_now = False
|
||||
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
model.save_weights()
|
||||
except KeyboardInterrupt:
|
||||
print('Saving model weights has been cancelled!')
|
||||
exit(0)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
exit(1)
|
||||
|
||||
def set_tf_allow_growth(self):
|
||||
import tensorflow as tf
|
||||
from keras.backend.tensorflow_backend import set_session
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
config.gpu_options.visible_device_list="0"
|
||||
set_session(tf.Session(config=config))
|
||||
|
||||
preview_buffer = {}
|
||||
|
||||
def show(self, image, name=''):
|
||||
try:
|
||||
if self.arguments.preview:
|
||||
with self.lock:
|
||||
self.preview_buffer[name] = image
|
||||
elif self.arguments.write_image:
|
||||
cv2.imwrite('_sample_{}.jpg'.format(name), image)
|
||||
except Exception as e:
|
||||
print("could not preview sample")
|
||||
print(e)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user