Update GAN64 to v2 (#217)

* Clearer requirements for each platform

* Refactoring of old plugins (Model_Original + Extract_Align) + Cleanups

* Adding GAN128

* Update GAN to v2

* Create instance_normalization.py

* Fix decoder output

* Revert "Fix decoder output"

This reverts commit 3a8ecb8957fe65e66282197455d775eb88455a77.

* Fix convert

* Enable all options except perceptual_loss by default

* Disable instance norm

* Update Model.py

* Update Trainer.py

* Match GAN128 to shaoanlu's latest v2

* Add first_order to GAN128

* Disable `use_perceptual_loss`

* Fix call to `self.first_order`

* Switch to average loss in output

* Constrain average to last 100 iterations

* Fix math, constrain average to intervals of 100

* Fix math averaging again

* Remove math and simplify this damn averagin

* Add gan128 conversion

* Update convert.py

* Use non-warped images in masked preview

* Add K.set_learning_phase(1) to gan64

* Add K.set_learning_phase(1) to gan128

* Add missing keras import

* Use non-warped images in masked preview for gan128

* Exclude deleted faces from conversion

* --input-aligned-dir defaults to "{input_dir}/aligned"

* Simplify map operation

* port 'face_alignment' from PyTorch to Keras. It works x2 faster, but initialization takes 20secs.

2DFAN-4.h5 and mmod_human_face_detector.dat included in lib\FaceLandmarksExtractor

fixed dlib vs tensorflow conflict: dlib must do op first, then load keras model, otherwise CUDA OOM error

if face location not found by CNN, its try to find by HOG.

removed this:
-        if face.landmarks == None:
-            print("Warning! landmarks not found. Switching to crop!")
-            return cv2.resize(face.image, (size, size))
because DetectedFace always has landmarks

* Enabled masked converter for GAN models

* Histogram matching, cli option for perceptual loss

* Fix init() positional args error

* Add backwards compatibility for aligned filenames

* Fix masked converter

* Remove GAN converters
This commit is contained in:
Othniel Cundangan 2018-03-09 19:43:24 -05:00 committed by GitHub
parent 120535eb11
commit 810bd0bce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1620 additions and 719 deletions

View File

@ -5,7 +5,7 @@ if sys.version_info[0] < 3:
if sys.version_info[0] == 3 and sys.version_info[1] < 2:
raise Exception("This program requires at least python3.2")
from lib.utils import FullHelpArgumentParser
from lib.cli import FullHelpArgumentParser
from scripts.extract import ExtractTrainingData
from scripts.train import TrainingProcessor

View File

@ -1,91 +0,0 @@
# AutoEncoder base classes
import time
import numpy
from lib.training_data import TrainingDataGenerator, stack_images
encoderH5 = 'encoder.h5'
decoder_AH5 = 'decoder_A.h5'
decoder_BH5 = 'decoder_B.h5'
class ModelAE:
def __init__(self, model_dir):
self.model_dir = model_dir
self.encoder = self.Encoder()
self.decoder_A = self.Decoder()
self.decoder_B = self.Decoder()
self.initModel()
def load(self, swapped):
(face_A,face_B) = (decoder_AH5, decoder_BH5) if not swapped else (decoder_BH5, decoder_AH5)
try:
self.encoder.load_weights(str(self.model_dir / encoderH5))
self.decoder_A.load_weights(str(self.model_dir / face_A))
self.decoder_B.load_weights(str(self.model_dir / face_B))
print('loaded model weights')
return True
except Exception as e:
print('Failed loading existing training data.')
print(e)
return False
def save_weights(self):
self.encoder.save_weights(str(self.model_dir / encoderH5))
self.decoder_A.save_weights(str(self.model_dir / decoder_AH5))
self.decoder_B.save_weights(str(self.model_dir / decoder_BH5))
print('saved model weights')
class TrainerAE():
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
def __init__(self, model, fn_A, fn_B, batch_size=64):
self.batch_size = batch_size
self.model = model
generator = TrainingDataGenerator(self.random_transform_args, 160)
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
def train_one_step(self, iter, viewer):
epoch, warped_A, target_A = next(self.images_A)
epoch, warped_B, target_B = next(self.images_B)
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
end='\r')
if viewer is not None:
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
def show_sample(self, test_A, test_B):
figure_A = numpy.stack([
test_A,
self.model.autoencoder_A.predict(test_A),
self.model.autoencoder_B.predict(test_A),
], axis=1)
figure_B = numpy.stack([
test_B,
self.model.autoencoder_B.predict(test_B),
self.model.autoencoder_A.predict(test_B),
], axis=1)
if test_A.shape[0] % 2 == 1:
figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ])
figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ])
figure = numpy.concatenate([figure_A, figure_B], axis=0)
w = 4
h = int( figure.shape[0] / w)
figure = figure.reshape((w, h) + figure.shape[1:])
figure = stack_images(figure)
return numpy.clip(figure * 255, 0, 255).astype('uint8')

View File

@ -1,5 +1,6 @@
import argparse
import os
import sys
import time
from pathlib import Path
@ -15,6 +16,16 @@ class FullPaths(argparse.Action):
setattr(namespace, self.dest, os.path.abspath(
os.path.expanduser(values)))
class FullHelpArgumentParser(argparse.ArgumentParser):
"""
Identical to the built-in argument parser, but on error
it prints full help message instead of just usage information
"""
def error(self, message):
self.print_help(sys.stderr)
args = {'prog': self.prog, 'message': message}
self.exit(2, '%(prog)s: error: %(message)s\n' % args)
class DirectoryProcessor(object):
'''
Abstract class that processes a directory of images
@ -35,7 +46,6 @@ class DirectoryProcessor(object):
self.create_parser(subparser, command, description)
self.parse_arguments(description, subparser, command)
def process_arguments(self, arguments):
self.arguments = arguments
print("Input Directory: {}".format(self.arguments.input_dir))
@ -58,6 +68,7 @@ class DirectoryProcessor(object):
pass
self.output_dir = get_folder(self.arguments.output_dir)
try:
try:
if self.arguments.skip_existing:

View File

@ -6,9 +6,11 @@ from .utils import BackgroundGenerator
from .umeyama import umeyama
class TrainingDataGenerator():
def __init__(self, random_transform_args, coverage):
def __init__(self, random_transform_args, coverage, scale=5, zoom=1): #TODO thos default should stay in the warp function
self.random_transform_args = random_transform_args
self.coverage = coverage
self.scale = scale
self.zoom = zoom
def minibatchAB(self, images, batchsize):
batch = BackgroundGenerator(self.minibatch(images, batchsize), 1)
@ -42,7 +44,7 @@ class TrainingDataGenerator():
image = cv2.resize(image, (256,256))
image = self.random_transform( image, **self.random_transform_args )
warped_img, target_img = self.random_warp( image, self.coverage )
warped_img, target_img = self.random_warp( image, self.coverage, self.scale, self.zoom )
return warped_img, target_img
@ -61,25 +63,25 @@ class TrainingDataGenerator():
return result
# get pair of random warped images from aligned face image
def random_warp(self, image, coverage):
def random_warp(self, image, coverage, scale = 5, zoom = 1):
assert image.shape == (256, 256, 3)
range_ = numpy.linspace(128 - coverage//2, 128 + coverage//2, 5)
mapx = numpy.broadcast_to(range_, (5, 5))
mapy = mapx.T
mapx = mapx + numpy.random.normal(size=(5, 5), scale=5)
mapy = mapy + numpy.random.normal(size=(5, 5), scale=5)
mapx = mapx + numpy.random.normal(size=(5,5), scale=scale)
mapy = mapy + numpy.random.normal(size=(5,5), scale=scale)
interp_mapx = cv2.resize(mapx, (80, 80))[8:72, 8:72].astype('float32')
interp_mapy = cv2.resize(mapy, (80, 80))[8:72, 8:72].astype('float32')
interp_mapx = cv2.resize(mapx, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32')
interp_mapy = cv2.resize(mapy, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32')
warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
src_points = numpy.stack([mapx.ravel(), mapy.ravel()], axis=-1)
dst_points = numpy.mgrid[0:65:16, 0:65:16].T.reshape(-1, 2)
src_points = numpy.stack([mapx.ravel(), mapy.ravel() ], axis=-1)
dst_points = numpy.mgrid[0:65*zoom:16*zoom,0:65*zoom:16*zoom].T.reshape(-1,2)
mat = umeyama(src_points, dst_points, True)[0:2]
target_image = cv2.warpAffine(image, mat, (64, 64))
target_image = cv2.warpAffine(image, mat, (64*zoom,64*zoom))
return warped_image, target_image

View File

@ -1,4 +1,3 @@
import argparse
import sys
from os.path import basename, exists
@ -31,16 +30,6 @@ def get_image_paths(directory, exclude=[], debug=False):
return dir_contents
class FullHelpArgumentParser(argparse.ArgumentParser):
"""
Identical to the built-in argument parser, but on error
it prints full help message instead of just usage information
"""
def error(self, message):
self.print_help(sys.stderr)
args = {'prog': self.prog, 'message': message}
self.exit(2, '%(prog)s: error: %(message)s\n' % args)
# From: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
import threading
import queue as Queue

View File

@ -8,18 +8,18 @@ import os
class Convert(object):
def __init__(self, encoder, smooth_mask=True, avg_color_adjust=True, **kwargs):
self.encoder = encoder
self.use_smooth_mask = smooth_mask
self.use_avg_color_adjust = avg_color_adjust
def patch_image( self, original, face_detected ):
def patch_image( self, original, face_detected, size ):
#assert image.shape == (256, 256, 3)
image = cv2.resize(face_detected.image, (256, 256))
crop = slice(48, 208)
face = image[crop, crop]
old_face = face.copy()
face = cv2.resize(face, (64, 64))
face = cv2.resize(face, (size, size))
face = numpy.expand_dims(face, 0)
new_face = self.encoder(face / 255.0)[0]
new_face = numpy.clip(new_face * 255, 0, 255).astype(image.dtype)

View File

@ -1,18 +0,0 @@
# Based on the https://github.com/shaoanlu/faceswap-GAN repo (master/FaceSwap_GAN_v2_train.ipynb)
import cv2
import numpy
class Convert(object):
def __init__(self, encoder, **kwargs):
self.encoder = encoder
def patch_image( self, original, face_detected ):
face = cv2.resize(face_detected.image, (64, 64))
face = numpy.expand_dims(face, 0) / 255.0 * 2 - 1
mask, new_face = self.encoder(face)
new_face = mask * new_face + (1 - mask) * face
new_face = numpy.clip((new_face[0] + 1) * 255 / 2, 0, 255).astype('uint8')
original[face_detected.y: face_detected.y + face_detected.h, face_detected.x: face_detected.x + face_detected.w] = cv2.resize(new_face, (face_detected.w, face_detected.h))
return original

View File

@ -6,10 +6,11 @@ import numpy
from lib.aligner import get_align_mat
class Convert():
def __init__(self, encoder, blur_size=2, seamless_clone=False, mask_type="facehullandrect", erosion_kernel_size=None, **kwargs):
def __init__(self, encoder, trainer, blur_size=2, seamless_clone=False, mask_type="facehullandrect", erosion_kernel_size=None, match_histogram=False, **kwargs):
self.encoder = encoder
self.trainer = trainer
self.erosion_kernel = None
self.erosion_kernel_size = erosion_kernel_size
self.erosion_kernel_size = erosion_kernel_size
if erosion_kernel_size is not None:
if erosion_kernel_size > 0:
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(erosion_kernel_size,erosion_kernel_size))
@ -17,10 +18,11 @@ class Convert():
self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(abs(erosion_kernel_size),abs(erosion_kernel_size)))
self.blur_size = blur_size
self.seamless_clone = seamless_clone
self.match_histogram = match_histogram
self.mask_type = mask_type.lower() # Choose in 'FaceHullAndRect','FaceHull','Rect'
def patch_image( self, image, face_detected ):
size = 64
def patch_image( self, image, face_detected, size ):
image_size = image.shape[1], image.shape[0]
mat = numpy.array(get_align_mat(face_detected)).reshape(2,3) * size
@ -40,9 +42,9 @@ class Convert():
outImage = None
if self.seamless_clone:
unitMask = numpy.clip( image_mask * 365, 0, 255 ).astype(numpy.uint8)
maxregion = numpy.argwhere(unitMask==255)
if maxregion.size > 0:
miny,minx = maxregion.min(axis=0)[:2]
maxy,maxx = maxregion.max(axis=0)[:2]
@ -51,21 +53,77 @@ class Convert():
masky = int(minx+(lenx//2))
maskx = int(miny+(leny//2))
outimage = cv2.seamlessClone(new_image.astype(numpy.uint8),base_image.astype(numpy.uint8),unitMask,(masky,maskx) , cv2.NORMAL_CLONE )
return outimage
foreground = cv2.multiply(image_mask, new_image.astype(float))
background = cv2.multiply(1.0 - image_mask, base_image.astype(float))
outimage = cv2.add(foreground, background)
return outimage
def hist_match(self, source, template, mask=None):
# Code borrowed from:
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
masked_source = source
masked_template = template
if mask is not None:
masked_source = source * mask
masked_template = template * mask
oldshape = source.shape
source = source.ravel()
template = template.ravel()
masked_source = masked_source.ravel()
masked_template = masked_template.ravel()
s_values, bin_idx, s_counts = numpy.unique(source, return_inverse=True,
return_counts=True)
t_values, t_counts = numpy.unique(template, return_counts=True)
ms_values, mbin_idx, ms_counts = numpy.unique(source, return_inverse=True,
return_counts=True)
mt_values, mt_counts = numpy.unique(template, return_counts=True)
s_quantiles = numpy.cumsum(s_counts).astype(numpy.float64)
s_quantiles /= s_quantiles[-1]
t_quantiles = numpy.cumsum(t_counts).astype(numpy.float64)
t_quantiles /= t_quantiles[-1]
interp_t_values = numpy.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(oldshape)
def color_hist_match(self, src_im, tar_im, mask):
matched_R = self.hist_match(src_im[:,:,0], tar_im[:,:,0], mask)
matched_G = self.hist_match(src_im[:,:,1], tar_im[:,:,1], mask)
matched_B = self.hist_match(src_im[:,:,2], tar_im[:,:,2], mask)
matched = numpy.stack((matched_R, matched_G, matched_B), axis=2).astype(src_im.dtype)
return matched
def get_new_face(self, image, mat, size):
face = cv2.warpAffine( image, mat, (size,size) )
face = numpy.expand_dims( face, 0 )
new_face = self.encoder( face / 255.0 )[0]
face_clipped = numpy.clip(face[0], 0, 255).astype( image.dtype )
new_face = None
mask = None
return numpy.clip( new_face * 255, 0, 255 ).astype( image.dtype )
if "GAN" not in self.trainer:
normalized_face = face / 255.0
new_face = self.encoder(normalized_face)[0]
new_face = numpy.clip( new_face * 255, 0, 255 ).astype( image.dtype )
else:
normalized_face = face / 255.0 * 2 - 1
fake_output = self.encoder(normalized_face)
if "128" in self.trainer: # TODO: Another hack to switch between 64 and 128
fake_output = fake_output[0]
mask = fake_output[:,:,:, :1]
new_face = fake_output[:,:,:, 1:]
new_face = mask * new_face + (1 - mask) * normalized_face
new_face = numpy.clip((new_face[0] + 1) * 255 / 2, 0, 255).astype( image.dtype )
if self.match_histogram:
new_face = self.color_hist_match(new_face, face_clipped, mask)
return new_face
def get_image_mask(self, image, new_face, face_detected, mat, image_size):

View File

@ -4,11 +4,11 @@ import cv2
from lib.aligner import get_align_mat
class Extract(object):
class Extract:
def extract(self, image, face, size):
alignment = get_align_mat( face )
return self.transform( image, alignment, size, 48 )
def transform( self, image, mat, size, padding=0 ):
mat = mat * (size - 2 * padding)
mat[:,2] += padding

View File

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://reddit.com/u/deepfakes/"""
__version__ = '0.1.0'
from .Extract import Extract

View File

@ -9,18 +9,36 @@ from keras.applications import *
from keras.optimizers import Adam
from lib.PixelShuffler import PixelShuffler
from .instance_normalization import InstanceNormalization
netGAH5 = 'netGA_GAN.h5'
netGBH5 = 'netGB_GAN.h5'
netDAH5 = 'netDA_GAN.h5'
netDBH5 = 'netDB_GAN.h5'
def __conv_init(a):
print("conv_init", a)
k = RandomNormal(0, 0.02)(a) # for convolution kernel
k.conv_weight = True
return k
#def batchnorm():
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
def inst_norm():
return InstanceNormalization()
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02) # for batch normalization
class GANModel():
img_size = 64
channels = 3
img_shape = (img_size, img_size, channels)
encoded_dim = 1024
nc_in = 3 # number of input channels of generators
nc_D_inp = 6 # number of input channels of discriminators
def __init__(self, model_dir):
self.model_dir = model_dir
@ -29,76 +47,41 @@ class GANModel():
# Build and compile the discriminator
self.netDA, self.netDB = self.build_discriminator()
# For the adversarial_autoencoder model we will only train the generator
self.netDA.trainable = False
self.netDB.trainable = False
self.netDA.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
self.netDB.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
# Build and compile the generator
self.netGA, self.netGB = self.build_generator()
self.netGA.compile(loss=['mae', 'mse'], optimizer=optimizer)
self.netGB.compile(loss=['mae', 'mse'], optimizer=optimizer)
img = Input(shape=self.img_shape)
alphaA, reconstructed_imgA = self.netGA(img)
alphaB, reconstructed_imgB = self.netGB(img)
def one_minus(x): return 1 - x
# masked_img = alpha * reconstructed_img + (1 - alpha) * img
masked_imgA = add([multiply([alphaA, reconstructed_imgA]), multiply([Lambda(one_minus)(alphaA), img])])
masked_imgB = add([multiply([alphaB, reconstructed_imgB]), multiply([Lambda(one_minus)(alphaB), img])])
out_discriminatorA = self.netDA(concatenate([masked_imgA, img], axis=-1))
out_discriminatorB = self.netDB(concatenate([masked_imgB, img], axis=-1))
# The adversarial_autoencoder model (stacked generator and discriminator) takes
# img as input => generates encoded represenation and reconstructed image => determines validity
self.adversarial_autoencoderA = Model(img, [reconstructed_imgA, out_discriminatorA])
self.adversarial_autoencoderB = Model(img, [reconstructed_imgB, out_discriminatorB])
self.adversarial_autoencoderA.compile(loss=['mae', 'mse'],
loss_weights=[1, 0.5],
optimizer=optimizer)
self.adversarial_autoencoderB.compile(loss=['mae', 'mse'],
loss_weights=[1, 0.5],
optimizer=optimizer)
def converter(self, swap):
predictor = self.netGB if not swap else self.netGA
return lambda img: predictor.predict(img)
def build_generator(self):
def conv_block(input_tensor, f):
x = input_tensor
x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=RandomNormal(0, 0.02),
use_bias=False, padding="same")(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = Activation("relu")(x)
return x
def res_block(input_tensor, f):
x = input_tensor
x = Conv2D(f, kernel_size=3, kernel_initializer=RandomNormal(0, 0.02),
use_bias=False, padding="same")(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=RandomNormal(0, 0.02),
use_bias=False, padding="same")(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = add([x, input_tensor])
x = LeakyReLU(alpha=0.2)(x)
return x
def upscale_ps(filters, use_norm=True):
def upscale_ps(filters, use_instance_norm=True):
def block(x):
x = Conv2D(filters*4, kernel_size=3, use_bias=False,
kernel_initializer=RandomNormal(0, 0.02), padding='same' )(x)
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder(img_shape):
inp = Input(shape=img_shape)
x = Conv2D(64, kernel_size=5, kernel_initializer=RandomNormal(0, 0.02),
use_bias=False, padding="same")(inp)
def Encoder(nc_in=3, input_size=64):
inp = Input(shape=(input_size, input_size, nc_in))
x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
x = conv_block(x,128)
x = conv_block(x,256)
x = conv_block(x,512)
@ -109,23 +92,23 @@ class GANModel():
out = upscale_ps(512)(x)
return Model(inputs=inp, outputs=out)
def Decoder_ps(img_shape):
nc_in = 512
input_size = img_shape[0]//8
inp = Input(shape=(input_size, input_size, nc_in))
x = inp
def Decoder_ps(nc_in=512, input_size=8):
input_ = Input(shape=(input_size, input_size, nc_in))
x = input_
x = upscale_ps(256)(x)
x = upscale_ps(128)(x)
x = upscale_ps(64)(x)
x = res_block(x, 64)
x = res_block(x, 64)
#x = Conv2D(4, kernel_size=5, padding='same')(x)
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
return Model(inp, [alpha, rgb])
out = concatenate([alpha, rgb])
return Model(input_, out )
encoder = Encoder(self.img_shape)
decoder_A = Decoder_ps(self.img_shape)
decoder_B = Decoder_ps(self.img_shape)
encoder = Encoder()
decoder_A = Decoder_ps()
decoder_B = Decoder_ps()
x = Input(shape=self.img_shape)
netGA = Model(x, decoder_A(encoder(x)))
netGB = Model(x, decoder_B(encoder(x)))
@ -136,26 +119,26 @@ class GANModel():
except:
print ("Generator weights files not found.")
pass
return netGA, netGB,
return netGA, netGB
def build_discriminator(self):
def build_discriminator(self):
def conv_block_d(input_tensor, f, use_instance_norm=True):
x = input_tensor
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=RandomNormal(0, 0.02),
use_bias=False, padding="same")(x)
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
x = LeakyReLU(alpha=0.2)(x)
return x
def Discriminator(img_shape):
inp = Input(shape=(img_shape[0], img_shape[1], img_shape[2]*2))
return x
def Discriminator(nc_in, input_size=64):
inp = Input(shape=(input_size, input_size, nc_in))
#x = GaussianNoise(0.05)(inp)
x = conv_block_d(inp, 64, False)
x = conv_block_d(x, 128, False)
x = conv_block_d(x, 256, False)
out = Conv2D(1, kernel_size=4, kernel_initializer=RandomNormal(0, 0.02),
use_bias=False, padding="same", activation="sigmoid")(x)
return Model(inputs=[inp], outputs=out)
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
return Model(inputs=[inp], outputs=out)
netDA = Discriminator(self.img_shape)
netDB = Discriminator(self.img_shape)
netDA = Discriminator(self.nc_D_inp)
netDB = Discriminator(self.nc_D_inp)
try:
netDA.load_weights(str(self.model_dir / netDAH5))
netDB.load_weights(str(self.model_dir / netDBH5))
@ -173,7 +156,7 @@ class GANModel():
def save_weights(self):
self.netGA.save_weights(str(self.model_dir / netGAH5))
self.netGB.save_weights(str(self.model_dir / netGBH5))
self.netDA.save_weights(str(self.model_dir / netDAH5))
self.netDB.save_weights(str(self.model_dir / netDBH5))
print ("Models saved.")
self.netGB.save_weights(str(self.model_dir / netGBH5))
self.netDA.save_weights(str(self.model_dir / netDAH5))
self.netDB.save_weights(str(self.model_dir / netDBH5))
print ("Models saved.")

View File

@ -2,11 +2,17 @@ import time
import cv2
import numpy as np
from keras.layers import *
from tensorflow.contrib.distributions import Beta
import tensorflow as tf
from keras.optimizers import Adam
from keras import backend as K
from lib.training_data import TrainingDataGenerator, stack_images
class GANTrainingDataGenerator(TrainingDataGenerator):
def __init__(self, random_transform_args, coverage):
super().__init__(random_transform_args, coverage)
def __init__(self, random_transform_args, coverage, scale, zoom):
super().__init__(random_transform_args, coverage, scale, zoom)
def color_adjust(self, img):
return img / 255.0 * 2 - 1
@ -14,140 +20,241 @@ class GANTrainingDataGenerator(TrainingDataGenerator):
class Trainer():
random_transform_args = {
'rotation_range': 20,
'zoom_range': 0.05,
'zoom_range': 0.1,
'shift_range': 0.05,
'random_flip': 0.5,
}
def __init__(self, model, fn_A, fn_B, batch_size):
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
K.set_learning_phase(1)
assert batch_size % 2 == 0, "batch_size must be an even number"
self.batch_size = batch_size
self.model = model
self.use_lsgan = True
self.use_mixup = True
self.mixup_alpha = 0.2
self.use_perceptual_loss = perceptual_loss
self.use_instancenorm = False
generator = GANTrainingDataGenerator(self.random_transform_args, 220)
self.lrD = 1e-4 # Discriminator learning rate
self.lrG = 1e-4 # Generator learning rate
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 1)
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.setup()
def setup(self):
distorted_A, fake_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
distorted_B, fake_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
real_A = Input(shape=self.model.img_shape)
real_B = Input(shape=self.model.img_shape)
if self.use_lsgan:
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
else:
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
# ========== Define Perceptual Loss Model==========
if self.use_perceptual_loss:
from keras.models import Model
from keras_vggface.vggface import VGGFace
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
vggface.trainable = False
out_size55 = vggface.layers[36].output
out_size28 = vggface.layers[78].output
out_size7 = vggface.layers[-2].output
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
vggface_feat.trainable = False
else:
vggface_feat = None
#TODO check "Tips for mask refinement (optional after >15k iters)" => https://render.githubusercontent.com/view/ipynb?commit=87d6e7a28ce754acd38d885367b6ceb0be92ec54&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f7368616f616e6c752f66616365737761702d47414e2f383764366537613238636537353461636433386438383533363762366365623062653932656335342f46616365537761705f47414e5f76325f737a3132385f747261696e2e6970796e62&nwo=shaoanlu%2Ffaceswap-GAN&path=FaceSwap_GAN_v2_sz128_train.ipynb&repository_id=115182783&repository_type=Repository#Tips-for-mask-refinement-(optional-after-%3E15k-iters)
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, distorted_A, vggface_feat)
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, distorted_B, vggface_feat)
loss_GA += 1e-3 * K.mean(K.abs(mask_A))
loss_GB += 1e-3 * K.mean(K.abs(mask_B))
w_fo = 0.01
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
weightsDA = self.model.netDA.trainable_weights
weightsGA = self.model.netGA.trainable_weights
weightsDB = self.model.netDB.trainable_weights
weightsGB = self.model.netGB.trainable_weights
# Adam(..).get_updates(...)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
def first_order(self, x, axis=1):
img_nrows = x.shape[1]
img_ncols = x.shape[2]
if axis == 1:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
elif axis == 2:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
else:
return None
def train_one_step(self, iter, viewer):
# ---------------------
# Train Discriminators
# ---------------------
# Select a random half batch of images
epoch, warped_A, target_A = next(self.train_batchA)
epoch, warped_B, target_B = next(self.train_batchB)
epoch, warped_A, target_A = next(self.train_batchA)
epoch, warped_B, target_B = next(self.train_batchB)
# Generate a half batch of new images
gen_alphasA, gen_imgsA = self.model.netGA.predict(warped_A)
gen_alphasB, gen_imgsB = self.model.netGB.predict(warped_B)
#gen_masked_imgsA = gen_alphasA * gen_imgsA + (1 - gen_alphasA) * warped_A
#gen_masked_imgsB = gen_alphasB * gen_imgsB + (1 - gen_alphasB) * warped_B
gen_masked_imgsA = np.array([gen_alphasA[i] * gen_imgsA[i] + (1 - gen_alphasA[i]) * warped_A[i]
for i in range(self.batch_size)])
gen_masked_imgsB = np.array([gen_alphasB[i] * gen_imgsB[i] + (1 - gen_alphasB[i]) * warped_B[i]
for i in range (self.batch_size)])
# Train dicriminators for one batch
errDA = self.netDA_train([warped_A, target_A])
errDB = self.netDB_train([warped_B, target_B])
valid = np.ones((self.batch_size, ) + self.model.netDA.output_shape[1:])
fake = np.zeros((self.batch_size, ) + self.model.netDA.output_shape[1:])
# Train generators for one batch
errGA = self.netGA_train([warped_A, target_A])
errGB = self.netGB_train([warped_B, target_B])
concat_real_inputA = np.array([np.concatenate([target_A[i], warped_A[i]], axis=-1)
for i in range(self.batch_size)])
concat_real_inputB = np.array([np.concatenate([target_B[i], warped_B[i]], axis=-1)
for i in range(self.batch_size)])
concat_fake_inputA = np.array([np.concatenate([gen_masked_imgsA[i], warped_A[i]], axis=-1)
for i in range(self.batch_size)])
concat_fake_inputB = np.array([np.concatenate([gen_masked_imgsB[i], warped_B[i]], axis=-1)
for i in range(self.batch_size)])
if self.use_mixup:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
mixup_A = lam * concat_real_inputA + (1 - lam) * concat_fake_inputA
mixup_B = lam * concat_real_inputB + (1 - lam) * concat_fake_inputB
# For calculating average losses
self.errDA_sum += errDA[0]
self.errDB_sum += errDB[0]
self.errGA_sum += errGA[0]
self.errGB_sum += errGB[0]
self.avg_counter += 1
# Train the discriminators
#print ("Train the discriminators.")
if self.use_mixup:
d_lossA = self.model.netDA.train_on_batch(mixup_A, lam * valid)
d_lossB = self.model.netDB.train_on_batch(mixup_B, lam * valid)
else:
d_lossA = self.model.netDA.train_on_batch(np.concatenate([concat_real_inputA, concat_fake_inputA], axis=0),
np.concatenate([valid, fake], axis=0))
d_lossB = self.model.netDB.train_on_batch(np.concatenate([concat_real_inputB, concat_fake_inputB], axis=0),
np.concatenate([valid, fake], axis=0))
# ---------------------
# Train Generators
# ---------------------
# Train the generators
#print ("Train the generators.")
g_lossA = self.model.adversarial_autoencoderA.train_on_batch(warped_A, [target_A, valid])
g_lossB = self.model.adversarial_autoencoderB.train_on_batch(warped_B, [target_B, valid])
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, d_lossA[0], d_lossB[0], g_lossA[0], g_lossB[0]),
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
end='\r')
if viewer is not None:
self.show_sample(viewer)
def cycle_variables(self, netG):
distorted_input = netG.inputs[0]
fake_output = netG.outputs[0]
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
fn_generate = K.function([distorted_input], [masked_fake_output])
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
fn_bgr = K.function([distorted_input], [rgb])
return distorted_input, fake_output, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
def define_loss(self, netD, real, fake_argb, distorted, vggface_feat=None):
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
fake = alpha * fake_rgb + (1-alpha) * distorted
if self.use_mixup:
dist = Beta(self.mixup_alpha, self.mixup_alpha)
lam = dist.sample()
# ==========
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
# ==========
output_mixup = netD(mixup)
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
output_fake = netD(concatenate([fake, distorted])) # dummy
loss_G = .5 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
else:
output_real = netD(concatenate([real, distorted])) # positive sample
output_fake = netD(concatenate([fake, distorted])) # negative sample
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
loss_D = loss_D_real + loss_D_fake
loss_G = .5 * self.loss_fn(output_fake, K.ones_like(output_fake))
# ==========
loss_G += K.mean(K.abs(fake_rgb - real))
# ==========
# Edge loss (similar with total variation loss)
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=1) - self.first_order(real, axis=1)))
loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=2) - self.first_order(real, axis=2)))
# Perceptual Loss
if not vggface_feat is None:
def preprocess_vggface(x):
x = (x + 1)/2 * 255 # channel order: BGR
#x[..., 0] -= 93.5940
#x[..., 1] -= 104.7624
#x[..., 2] -= 129.
x -= [91.4953, 103.8827, 131.0912]
return x
pl_params = (0.011, 0.11, 0.1919)
real_sz224 = tf.image.resize_images(real, [224, 224])
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
# ==========
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
# ==========
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
return loss_D, loss_G
def show_sample(self, display_fn):
_, wA, tA = next(self.train_batchA)
_, wB, tB = next(self.train_batchB)
self.showG(tA, tB, display_fn)
def showG(self, test_A, test_B, display_fn):
def display_fig(name, figure_A, figure_B):
figure = np.concatenate([figure_A, figure_B], axis=0 )
columns = 4
elements = figure.shape[0]
figure = figure.reshape((columns,(elements//columns)) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
display_fn(figure, name)
out_test_A_netGA = self.model.netGA.predict(test_A)
out_test_A_netGB = self.model.netGB.predict(test_A)
out_test_B_netGA = self.model.netGA.predict(test_B)
out_test_B_netGB = self.model.netGB.predict(test_B)
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "raw")
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "masked")
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
# Reset the averages
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.avg_counter = 0
def showG(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
out_test_A_netGA[1],
out_test_A_netGB[1],
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
], axis=1 )
figure_B = np.stack([
test_B,
out_test_B_netGB[1],
out_test_B_netGA[1],
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
], axis=1 )
display_fig("raw", figure_A, figure_B)
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure
def showG_mask(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
np.tile(out_test_A_netGA[0],3) * 2 - 1,
np.tile(out_test_A_netGB[0],3) * 2 - 1,
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
], axis=1 )
figure_B = np.stack([
test_B,
np.tile(out_test_B_netGB[0],3) * 2 - 1,
np.tile(out_test_B_netGA[0],3) * 2 - 1,
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
], axis=1 )
display_fig("alpha_masks", figure_A, figure_B)
figure_A = np.stack([
test_A,
out_test_A_netGA[0] * out_test_A_netGA[1] + (1 - out_test_A_netGA[0]) * test_A,
out_test_A_netGB[0] * out_test_A_netGB[1] + (1 - out_test_A_netGB[0]) * test_A,
], axis=1 )
figure_B = np.stack([
test_B,
out_test_B_netGB[0] * out_test_B_netGB[1] + (1 - out_test_B_netGB[0]) * test_B,
out_test_B_netGA[0] * out_test_B_netGA[1] + (1 - out_test_B_netGA[0]) * test_B,
], axis=1 )
display_fig("masked", figure_A, figure_B)
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure

View File

@ -0,0 +1,145 @@
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
import numpy as np
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})

View File

@ -0,0 +1,177 @@
# Based on the https://github.com/shaoanlu/faceswap-GAN repo
# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb
from keras.models import Model
from keras.layers import *
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.applications import *
from keras.optimizers import Adam
from lib.PixelShuffler import PixelShuffler
from .instance_normalization import InstanceNormalization
netGAH5 = 'netGA_GAN128.h5'
netGBH5 = 'netGB_GAN128.h5'
netDAH5 = 'netDA_GAN128.h5'
netDBH5 = 'netDB_GAN128.h5'
def __conv_init(a):
print("conv_init", a)
k = RandomNormal(0, 0.02)(a) # for convolution kernel
k.conv_weight = True
return k
#def batchnorm():
# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)
def inst_norm():
return InstanceNormalization()
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02) # for batch normalization
class GANModel():
img_size = 128
channels = 3
img_shape = (img_size, img_size, channels)
encoded_dim = 1024
nc_in = 3 # number of input channels of generators
nc_D_inp = 6 # number of input channels of discriminators
def __init__(self, model_dir):
self.model_dir = model_dir
optimizer = Adam(1e-4, 0.5)
# Build and compile the discriminator
self.netDA, self.netDB = self.build_discriminator()
# Build and compile the generator
self.netGA, self.netGB = self.build_generator()
def converter(self, swap):
predictor = self.netGB if not swap else self.netGA
return lambda img: predictor.predict(img)
def build_generator(self):
def conv_block(input_tensor, f, use_instance_norm=True):
x = input_tensor
x = SeparableConv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
if use_instance_norm:
x = inst_norm()(x)
x = Activation("relu")(x)
return x
def res_block(input_tensor, f, dilation=1):
x = input_tensor
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x)
x = add([x, input_tensor])
#x = LeakyReLU(alpha=0.2)(x)
return x
def upscale_ps(filters, use_instance_norm=True):
def block(x, use_instance_norm=use_instance_norm):
x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
if use_instance_norm:
x = inst_norm()(x)
x = LeakyReLU(0.1)(x)
x = PixelShuffler()(x)
return x
return block
def Encoder(nc_in=3, input_size=128):
inp = Input(shape=(input_size, input_size, nc_in))
x = Conv2D(32, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
x = conv_block(x,64, use_instance_norm=False)
x = conv_block(x,128)
x = conv_block(x,256)
x = conv_block(x,512)
x = conv_block(x,1024)
x = Dense(1024)(Flatten()(x))
x = Dense(4*4*1024)(x)
x = Reshape((4, 4, 1024))(x)
out = upscale_ps(512)(x)
return Model(inputs=inp, outputs=out)
def Decoder_ps(nc_in=512, input_size=8):
input_ = Input(shape=(input_size, input_size, nc_in))
x = input_
x = upscale_ps(256)(x)
x = upscale_ps(128)(x)
x = upscale_ps(64)(x)
x = res_block(x, 64, dilation=2)
out64 = Conv2D(64, kernel_size=3, padding='same')(x)
out64 = LeakyReLU(alpha=0.1)(out64)
out64 = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(out64)
x = upscale_ps(32)(x)
x = res_block(x, 32)
x = res_block(x, 32)
alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
out = concatenate([alpha, rgb])
return Model(input_, [out, out64] )
encoder = Encoder()
decoder_A = Decoder_ps()
decoder_B = Decoder_ps()
x = Input(shape=self.img_shape)
netGA = Model(x, decoder_A(encoder(x)))
netGB = Model(x, decoder_B(encoder(x)))
try:
netGA.load_weights(str(self.model_dir / netGAH5))
netGB.load_weights(str(self.model_dir / netGBH5))
print ("Generator models loaded.")
except:
print ("Generator weights files not found.")
pass
return netGA, netGB
def build_discriminator(self):
def conv_block_d(input_tensor, f, use_instance_norm=True):
x = input_tensor
x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
if use_instance_norm:
x = inst_norm()(x)
x = LeakyReLU(alpha=0.2)(x)
return x
def Discriminator(nc_in, input_size=128):
inp = Input(shape=(input_size, input_size, nc_in))
#x = GaussianNoise(0.05)(inp)
x = conv_block_d(inp, 64, False)
x = conv_block_d(x, 128, True)
x = conv_block_d(x, 256, True)
x = conv_block_d(x, 512, True)
out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x)
return Model(inputs=[inp], outputs=out)
netDA = Discriminator(self.nc_D_inp)
netDB = Discriminator(self.nc_D_inp)
try:
netDA.load_weights(str(self.model_dir / netDAH5))
netDB.load_weights(str(self.model_dir / netDBH5))
print ("Discriminator models loaded.")
except:
print ("Discriminator weights files not found.")
pass
return netDA, netDB
def load(self, swapped):
if swapped:
print("swapping not supported on GAN")
# TODO load is done in __init__ => look how to swap if possible
return True
def save_weights(self):
self.netGA.save_weights(str(self.model_dir / netGAH5))
self.netGB.save_weights(str(self.model_dir / netGBH5))
self.netDA.save_weights(str(self.model_dir / netDAH5))
self.netDB.save_weights(str(self.model_dir / netDBH5))
print ("Models saved.")

View File

@ -0,0 +1,253 @@
import time
import cv2
import numpy as np
from keras.layers import *
from tensorflow.contrib.distributions import Beta
import tensorflow as tf
from keras.optimizers import Adam
from keras import backend as K
from lib.training_data import TrainingDataGenerator, stack_images
class GANTrainingDataGenerator(TrainingDataGenerator):
def __init__(self, random_transform_args, coverage, scale, zoom):
super().__init__(random_transform_args, coverage, scale, zoom)
def color_adjust(self, img):
return img / 255.0 * 2 - 1
class Trainer():
random_transform_args = {
'rotation_range': 20,
'zoom_range': 0.1,
'shift_range': 0.05,
'random_flip': 0.5,
}
def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss):
K.set_learning_phase(1)
assert batch_size % 2 == 0, "batch_size must be an even number"
self.batch_size = batch_size
self.model = model
self.use_lsgan = True
self.use_mixup = True
self.mixup_alpha = 0.2
self.use_perceptual_loss = perceptual_loss
self.lrD = 1e-4 # Discriminator learning rate
self.lrG = 1e-4 # Generator learning rate
generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 2)
self.train_batchA = generator.minibatchAB(fn_A, batch_size)
self.train_batchB = generator.minibatchAB(fn_B, batch_size)
self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.setup()
def setup(self):
distorted_A, fake_A, fake_sz64_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA)
distorted_B, fake_B, fake_sz64_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB)
real_A = Input(shape=self.model.img_shape)
real_B = Input(shape=self.model.img_shape)
if self.use_lsgan:
self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
else:
self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
# ========== Define Perceptual Loss Model==========
if self.use_perceptual_loss:
from keras.models import Model
from keras_vggface.vggface import VGGFace
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
vggface.trainable = False
out_size55 = vggface.layers[36].output
out_size28 = vggface.layers[78].output
out_size7 = vggface.layers[-2].output
vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7])
vggface_feat.trainable = False
else:
vggface_feat = None
#TODO check "Tips for mask refinement (optional after >15k iters)" => https://render.githubusercontent.com/view/ipynb?commit=87d6e7a28ce754acd38d885367b6ceb0be92ec54&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f7368616f616e6c752f66616365737761702d47414e2f383764366537613238636537353461636433386438383533363762366365623062653932656335342f46616365537761705f47414e5f76325f737a3132385f747261696e2e6970796e62&nwo=shaoanlu%2Ffaceswap-GAN&path=FaceSwap_GAN_v2_sz128_train.ipynb&repository_id=115182783&repository_type=Repository#Tips-for-mask-refinement-(optional-after-%3E15k-iters)
loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, fake_sz64_A, distorted_A, vggface_feat)
loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, fake_sz64_B, distorted_B, vggface_feat)
loss_GA += 3e-3 * K.mean(K.abs(mask_A))
loss_GB += 3e-3 * K.mean(K.abs(mask_B))
w_fo = 0.01
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1))
loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1))
loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2))
weightsDA = self.model.netDA.trainable_weights
weightsGA = self.model.netGA.trainable_weights
weightsDB = self.model.netDB.trainable_weights
weightsGB = self.model.netGB.trainable_weights
# Adam(..).get_updates(...)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)
training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
def first_order(self, x, axis=1):
img_nrows = x.shape[1]
img_ncols = x.shape[2]
if axis == 1:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
elif axis == 2:
return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
else:
return None
def train_one_step(self, iter, viewer):
# ---------------------
# Train Discriminators
# ---------------------
# Select a random half batch of images
epoch, warped_A, target_A = next(self.train_batchA)
epoch, warped_B, target_B = next(self.train_batchB)
# Train dicriminators for one batch
errDA = self.netDA_train([warped_A, target_A])
errDB = self.netDB_train([warped_B, target_B])
# Train generators for one batch
errGA = self.netGA_train([warped_A, target_A])
errGB = self.netGB_train([warped_B, target_B])
# For calculating average losses
self.errDA_sum += errDA[0]
self.errDB_sum += errDB[0]
self.errGA_sum += errGA[0]
self.errGB_sum += errGB[0]
self.avg_counter += 1
print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f'
% (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter),
end='\r')
if viewer is not None:
self.show_sample(viewer)
def cycle_variables(self, netG):
distorted_input = netG.inputs[0]
fake_output = netG.outputs[0]
fake_output64 = netG.outputs[1]
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
masked_fake_output = alpha * rgb + (1-alpha) * distorted_input
fn_generate = K.function([distorted_input], [masked_fake_output])
fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
fn_bgr = K.function([distorted_input], [rgb])
return distorted_input, fake_output, fake_output64, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr
def define_loss(self, netD, real, fake_argb, fake_sz64, distorted, vggface_feat=None):
alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
fake = alpha * fake_rgb + (1-alpha) * distorted
if self.use_mixup:
dist = Beta(self.mixup_alpha, self.mixup_alpha)
lam = dist.sample()
# ==========
mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])
# ==========
output_mixup = netD(mixup)
loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup))
#output_fake = netD(concatenate([fake, distorted])) # dummy
loss_G = 1 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
else:
output_real = netD(concatenate([real, distorted])) # positive sample
output_fake = netD(concatenate([fake, distorted])) # negative sample
loss_D_real = self.loss_fn(output_real, K.ones_like(output_real))
loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake))
loss_D = loss_D_real + loss_D_fake
loss_G = 1 * self.loss_fn(output_fake, K.ones_like(output_fake))
# ==========
loss_G += K.mean(K.abs(fake_rgb - real))
loss_G += K.mean(K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64])))
# ==========
# Perceptual Loss
if not vggface_feat is None:
def preprocess_vggface(x):
x = (x + 1)/2 * 255 # channel order: BGR
x -= [93.5940, 104.7624, 129.]
return x
pl_params = (0.02, 0.3, 0.5)
real_sz224 = tf.image.resize_images(real, [224, 224])
real_sz224 = Lambda(preprocess_vggface)(real_sz224)
# ==========
fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224])
fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
# ==========
real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224)
loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
return loss_D, loss_G
def show_sample(self, display_fn):
_, wA, tA = next(self.train_batchA)
_, wB, tB = next(self.train_batchB)
display_fn(self.showG(tA, tB, self.path_A, self.path_B), "raw")
display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "masked")
display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask")
# Reset the averages
self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0
self.avg_counter = 0
def showG(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
], axis=1 )
figure_B = np.stack([
test_B,
np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
], axis=1 )
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure
def showG_mask(self, test_A, test_B, path_A, path_B):
figure_A = np.stack([
test_A,
(np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
(np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
], axis=1 )
figure_B = np.stack([
test_B,
(np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
(np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
], axis=1 )
figure = np.concatenate([figure_A, figure_B], axis=0 )
figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:])
figure = stack_images(figure)
figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
return figure

View File

@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://github.com/shaoanlu/"""
__version__ = '0.1.0'
from .Model import GANModel as Model
from .Trainer import Trainer

View File

@ -0,0 +1,145 @@
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers, constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
import numpy as np
class InstanceNormalization(Layer):
"""Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016).
Normalize the activations of the previous layer at each step,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
# Arguments
axis: Integer, the axis that should be normalized
(typically the features axis).
For instance, after a `Conv2D` layer with
`data_format="channels_first"`,
set `axis=1` in `InstanceNormalization`.
Setting `axis=None` will normalize all values in each instance of the batch.
Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022)
"""
def __init__(self,
axis=None,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
ndim = len(input_shape)
if self.axis == 0:
raise ValueError('Axis cannot be zero')
if (self.axis is not None) and (ndim == 2):
raise ValueError('Cannot specify axis for rank 1 tensor')
self.input_spec = InputSpec(ndim=ndim)
if self.axis is None:
shape = (1,)
else:
shape = (input_shape[self.axis],)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
reduction_axes = list(range(0, len(input_shape)))
if (self.axis is not None):
del reduction_axes[self.axis]
del reduction_axes[0]
mean = K.mean(inputs, reduction_axes, keepdims=True)
stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
normed = (inputs - mean) / stddev
broadcast_shape = [1] * len(input_shape)
if self.axis is not None:
broadcast_shape[self.axis] = input_shape[self.axis]
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
normed = normed * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
normed = normed + broadcast_beta
return normed
def get_config(self):
config = {
'axis': self.axis,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(InstanceNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'InstanceNormalization': InstanceNormalization})

View File

@ -6,13 +6,13 @@ from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from lib.ModelAE import ModelAE, TrainerAE
from .Model_Original import AutoEncoder, Trainer
from lib.PixelShuffler import PixelShuffler
IMAGE_SHAPE = (64, 64, 3)
ENCODER_DIM = 512
class Model(ModelAE):
class Model(AutoEncoder):
def initModel(self):
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
x = Input(shape=IMAGE_SHAPE)
@ -63,6 +63,3 @@ class Model(ModelAE):
x = self.upscale(64)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return KerasModel(input_, x)
class Trainer(TrainerAE):
"""Empty inheritance"""

View File

@ -0,0 +1,35 @@
# AutoEncoder base classes
encoderH5 = 'encoder.h5'
decoder_AH5 = 'decoder_A.h5'
decoder_BH5 = 'decoder_B.h5'
class AutoEncoder:
def __init__(self, model_dir):
self.model_dir = model_dir
self.encoder = self.Encoder()
self.decoder_A = self.Decoder()
self.decoder_B = self.Decoder()
self.initModel()
def load(self, swapped):
(face_A,face_B) = (decoder_AH5, decoder_BH5) if not swapped else (decoder_BH5, decoder_AH5)
try:
self.encoder.load_weights(str(self.model_dir / encoderH5))
self.decoder_A.load_weights(str(self.model_dir / face_A))
self.decoder_B.load_weights(str(self.model_dir / face_B))
print('loaded model weights')
return True
except Exception as e:
print('Failed loading existing training data.')
print(e)
return False
def save_weights(self):
self.encoder.save_weights(str(self.model_dir / encoderH5))
self.decoder_A.save_weights(str(self.model_dir / decoder_AH5))
self.decoder_B.save_weights(str(self.model_dir / decoder_BH5))
print('saved model weights')

View File

@ -6,13 +6,13 @@ from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.optimizers import Adam
from lib.ModelAE import ModelAE, TrainerAE
from .AutoEncoder import AutoEncoder
from lib.PixelShuffler import PixelShuffler
IMAGE_SHAPE = (64, 64, 3)
ENCODER_DIM = 1024
class Model(ModelAE):
class Model(AutoEncoder):
def initModel(self):
optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
x = Input(shape=IMAGE_SHAPE)
@ -62,7 +62,4 @@ class Model(ModelAE):
x = self.upscale(128)(x)
x = self.upscale(64)(x)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return KerasModel(input_, x)
class Trainer(TrainerAE):
"""Empty inheritance"""
return KerasModel(input_, x)

View File

@ -0,0 +1,50 @@
import time
import numpy
from lib.training_data import TrainingDataGenerator, stack_images
class Trainer():
random_transform_args = {
'rotation_range': 10,
'zoom_range': 0.05,
'shift_range': 0.05,
'random_flip': 0.4,
}
def __init__(self, model, fn_A, fn_B, batch_size, *args):
self.batch_size = batch_size
self.model = model
generator = TrainingDataGenerator(self.random_transform_args, 160)
self.images_A = generator.minibatchAB(fn_A, self.batch_size)
self.images_B = generator.minibatchAB(fn_B, self.batch_size)
def train_one_step(self, iter, viewer):
epoch, warped_A, target_A = next(self.images_A)
epoch, warped_B, target_B = next(self.images_B)
loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A)
loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B)
print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B),
end='\r')
if viewer is not None:
viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training")
def show_sample(self, test_A, test_B):
figure_A = numpy.stack([
test_A,
self.model.autoencoder_A.predict(test_A),
self.model.autoencoder_B.predict(test_A),
], axis=1)
figure_B = numpy.stack([
test_B,
self.model.autoencoder_B.predict(test_B),
self.model.autoencoder_A.predict(test_B),
], axis=1)
figure = numpy.concatenate([figure_A, figure_B], axis=0)
figure = figure.reshape((4, 7) + figure.shape[1:])
figure = stack_images(figure)
return numpy.clip(figure * 255, 0, 255).astype('uint8')

View File

@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
__author__ = """Based on https://reddit.com/u/deepfakes/"""
__version__ = '0.1.0'
from .Model import Model
from .Trainer import Trainer
from .AutoEncoder import AutoEncoder

View File

@ -1,200 +1,235 @@
import cv2
import re
import os
from pathlib import Path
from tqdm import tqdm
from lib.cli import DirectoryProcessor, FullPaths
from lib.utils import BackgroundGenerator, get_folder
from plugins.PluginLoader import PluginLoader
class ConvertImage(DirectoryProcessor):
filename = ''
def create_parser(self, subparser, command, description):
self.parser = subparser.add_parser(
command,
help="Convert a source image to a new one with the face swapped.",
description=description,
epilog="Questions and feedback: \
https://github.com/deepfakes/faceswap-playground"
)
def add_optional_arguments(self, parser):
parser.add_argument('-m', '--model-dir',
action=FullPaths,
dest="model_dir",
default="models",
help="Model directory. A directory containing the trained model \
you wish to process. Defaults to 'models'")
parser.add_argument('-t', '--trainer',
type=str,
choices=("Original", "LowMem", "GAN"), # case sensitive because this is used to load a plug-in.
default="Original",
help="Select the trainer that was used to create the model.")
parser.add_argument('-s', '--swap-model',
action="store_true",
dest="swap_model",
default=False,
help="Swap the model. Instead of A -> B, swap B -> A.")
parser.add_argument('-c', '--converter',
type=str,
choices=("Masked", "Adjust", "GAN"), # case sensitive because this is used to load a plugin.
default="Masked",
help="Converter to use.")
parser.add_argument('-D', '--detector',
type=str,
choices=("hog", "cnn"), # case sensitive because this is used to load a plugin.
default="hog",
help="Detector to use. 'cnn' detects much more angles but will be much more resource intensive and may fail on large files.")
parser.add_argument('-fr', '--frame-ranges',
nargs="+",
type=str,
help="frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use --frame-ranges 10-50 90-100. \
Files must have the frame-number as the last number in the name!"
)
parser.add_argument('-d', '--discard-frames',
action="store_true",
dest="discard_frames",
default=False,
help="When used with --frame-ranges discards frames that are not processed instead of writing them out unchanged."
)
parser.add_argument('-f', '--filter',
type=str,
dest="filter",
default="filter.jpg",
help="Reference image for the person you want to process. Should be a front portrait"
)
parser.add_argument('-b', '--blur-size',
type=int,
default=2,
help="Blur size. (Masked converter only)")
parser.add_argument('-S', '--seamless',
action="store_true",
dest="seamless_clone",
default=False,
help="Seamless mode. (Masked converter only)")
parser.add_argument('-M', '--mask-type',
type=str.lower, #lowercase this, because its just a string later on.
dest="mask_type",
choices=["rect", "facehull", "facehullandrect"],
default="facehullandrect",
help="Mask to use to replace faces. (Masked converter only)")
parser.add_argument('-e', '--erosion-kernel-size',
dest="erosion_kernel_size",
type=int,
default=None,
help="Erosion kernel size. (Masked converter only). Positive values apply erosion which reduces the edge \
of the swapped face. Negative values apply dilation which allows the swapped face to cover more space.")
parser.add_argument('-sm', '--smooth-mask',
action="store_true",
dest="smooth_mask",
default=True,
help="Smooth mask (Adjust converter only)")
parser.add_argument('-aca', '--avg-color-adjust',
action="store_true",
dest="avg_color_adjust",
default=True,
help="Average color adjust. (Adjust converter only)")
return parser
def process(self):
# Original & LowMem models go with Adjust or Masked converter
# GAN converter & model must go together
# Note: GAN prediction outputs a mask + an image, while other predicts only an image
model_name = self.arguments.trainer
conv_name = self.arguments.converter
if conv_name.startswith("GAN"):
assert model_name.startswith("GAN") is True, "GAN converter can only be used with GAN model!"
else:
assert model_name.startswith("GAN") is False, "GAN model can only be used with GAN converter!"
model = PluginLoader.get_model(model_name)(get_folder(self.arguments.model_dir))
if not model.load(self.arguments.swap_model):
print('Model Not Found! A valid model must be provided to continue!')
exit(1)
converter = PluginLoader.get_converter(conv_name)(model.converter(False),
blur_size=self.arguments.blur_size,
seamless_clone=self.arguments.seamless_clone,
mask_type=self.arguments.mask_type,
erosion_kernel_size=self.arguments.erosion_kernel_size,
smooth_mask=self.arguments.smooth_mask,
avg_color_adjust=self.arguments.avg_color_adjust
)
batch = BackgroundGenerator(self.prepare_images(), 1)
# frame ranges stuff...
self.frame_ranges = None
# split out the frame ranges and parse out "min" and "max" values
minmax = {
"min": 0, # never any frames less than 0
"max": float("inf")
}
if self.arguments.frame_ranges:
self.frame_ranges = [tuple(map(lambda q: minmax[q] if q in minmax.keys() else int(q), v.split("-"))) for v in self.arguments.frame_ranges]
# last number regex. I know regex is hacky, but its reliablyhacky(tm).
self.imageidxre = re.compile(r'(\d+)(?!.*\d)')
for item in batch.iterator():
self.convert(converter, item)
def check_skipframe(self, filename):
try:
idx = int(self.imageidxre.findall(filename)[0])
return not any(map(lambda b: b[0]<=idx<=b[1], self.frame_ranges))
except:
return False
def convert(self, converter, item):
try:
(filename, image, faces) = item
skip = self.check_skipframe(filename)
if self.arguments.discard_frames and skip:
return
if not skip: # process as normal
for idx, face in faces:
image = converter.patch_image(image, face)
output_file = get_folder(self.output_dir) / Path(filename).name
cv2.imwrite(str(output_file), image)
except Exception as e:
print('Failed to convert image: {}. Reason: {}'.format(filename, e))
def prepare_images(self):
self.read_alignments()
is_have_alignments = self.have_alignments()
for filename in tqdm(self.read_directory()):
image = cv2.imread(filename)
if is_have_alignments:
if self.have_face(filename):
faces = self.get_faces_alignments(filename, image)
else:
print ('no alignment found for {}, skipping'.format(os.path.basename(filename)))
continue
else:
faces = self.get_faces(image)
yield filename, image, faces
import cv2
import re
import os
from pathlib import Path
from tqdm import tqdm
from lib.cli import DirectoryProcessor, FullPaths
from lib.utils import BackgroundGenerator, get_folder, get_image_paths
from plugins.PluginLoader import PluginLoader
class ConvertImage(DirectoryProcessor):
filename = ''
def create_parser(self, subparser, command, description):
self.parser = subparser.add_parser(
command,
help="Convert a source image to a new one with the face swapped.",
description=description,
epilog="Questions and feedback: \
https://github.com/deepfakes/faceswap-playground"
)
def add_optional_arguments(self, parser):
parser.add_argument('-m', '--model-dir',
action=FullPaths,
dest="model_dir",
default="models",
help="Model directory. A directory containing the trained model \
you wish to process. Defaults to 'models'")
parser.add_argument('-a', '--input-aligned-dir',
action=FullPaths,
dest="input_aligned_dir",
default=None,
help="Input \"aligned directory\". A directory that should contain the \
aligned faces extracted from the input files. If you delete faces from \
this folder, they'll be skipped during conversion. If no aligned dir is \
specified, all faces will be converted.")
parser.add_argument('-t', '--trainer',
type=str,
choices=("Original", "LowMem", "GAN", "GAN128"), # case sensitive because this is used to load a plug-in.
default="Original",
help="Select the trainer that was used to create the model.")
parser.add_argument('-s', '--swap-model',
action="store_true",
dest="swap_model",
default=False,
help="Swap the model. Instead of A -> B, swap B -> A.")
parser.add_argument('-c', '--converter',
type=str,
choices=("Masked", "Adjust"), # case sensitive because this is used to load a plugin.
default="Masked",
help="Converter to use.")
parser.add_argument('-D', '--detector',
type=str,
choices=("hog", "cnn"), # case sensitive because this is used to load a plugin.
default="hog",
help="Detector to use. 'cnn' detects much more angles but will be much more resource intensive and may fail on large files.")
parser.add_argument('-fr', '--frame-ranges',
nargs="+",
type=str,
help="frame ranges to apply transfer to e.g. For frames 10 to 50 and 90 to 100 use --frame-ranges 10-50 90-100. \
Files must have the frame-number as the last number in the name!"
)
parser.add_argument('-d', '--discard-frames',
action="store_true",
dest="discard_frames",
default=False,
help="When used with --frame-ranges discards frames that are not processed instead of writing them out unchanged."
)
parser.add_argument('-f', '--filter',
type=str,
dest="filter",
default="filter.jpg",
help="Reference image for the person you want to process. Should be a front portrait"
)
parser.add_argument('-b', '--blur-size',
type=int,
default=2,
help="Blur size. (Masked converter only)")
parser.add_argument('-S', '--seamless',
action="store_true",
dest="seamless_clone",
default=False,
help="Use cv2's seamless clone. (Masked converter only)")
parser.add_argument('-M', '--mask-type',
type=str.lower, #lowercase this, because its just a string later on.
dest="mask_type",
choices=["rect", "facehull", "facehullandrect"],
default="facehullandrect",
help="Mask to use to replace faces. (Masked converter only)")
parser.add_argument('-e', '--erosion-kernel-size',
dest="erosion_kernel_size",
type=int,
default=None,
help="Erosion kernel size. (Masked converter only). Positive values apply erosion which reduces the edge of the swapped face. Negative values apply dilation which allows the swapped face to cover more space.")
parser.add_argument('-mh', '--match-histgoram',
action="store_true",
dest="match_histogram",
default=False,
help="Use histogram matching. (Masked converter only)")
parser.add_argument('-sm', '--smooth-mask',
action="store_true",
dest="smooth_mask",
default=True,
help="Smooth mask (Adjust converter only)")
parser.add_argument('-aca', '--avg-color-adjust',
action="store_true",
dest="avg_color_adjust",
default=True,
help="Average color adjust. (Adjust converter only)")
return parser
def process(self):
# Original & LowMem models go with Adjust or Masked converter
# Note: GAN prediction outputs a mask + an image, while other predicts only an image
model_name = self.arguments.trainer
conv_name = self.arguments.converter
self.input_aligned_dir = None
model = PluginLoader.get_model(model_name)(get_folder(self.arguments.model_dir))
if not model.load(self.arguments.swap_model):
print('Model Not Found! A valid model must be provided to continue!')
exit(1)
input_aligned_dir = Path(self.arguments.input_dir)/Path('aligned')
if self.arguments.input_aligned_dir is not None:
input_aligned_dir = self.arguments.input_aligned_dir
try:
self.input_aligned_dir = [Path(path) for path in get_image_paths(input_aligned_dir)]
if len(self.input_aligned_dir) == 0:
print('Aligned directory is empty, no faces will be converted!')
elif len(self.input_aligned_dir) <= len(self.input_dir)/3:
print('Aligned directory contains an amount of images much less than the input, are you sure this is the right directory?')
except:
print('Aligned directory not found. All faces listed in the alignments file will be converted.')
converter = PluginLoader.get_converter(conv_name)(model.converter(False),
trainer=self.arguments.trainer,
blur_size=self.arguments.blur_size,
seamless_clone=self.arguments.seamless_clone,
mask_type=self.arguments.mask_type,
erosion_kernel_size=self.arguments.erosion_kernel_size,
match_histogram=self.arguments.match_histogram,
smooth_mask=self.arguments.smooth_mask,
avg_color_adjust=self.arguments.avg_color_adjust
)
batch = BackgroundGenerator(self.prepare_images(), 1)
# frame ranges stuff...
self.frame_ranges = None
# split out the frame ranges and parse out "min" and "max" values
minmax = {
"min": 0, # never any frames less than 0
"max": float("inf")
}
if self.arguments.frame_ranges:
self.frame_ranges = [tuple(map(lambda q: minmax[q] if q in minmax.keys() else int(q), v.split("-"))) for v in self.arguments.frame_ranges]
# last number regex. I know regex is hacky, but its reliablyhacky(tm).
self.imageidxre = re.compile(r'(\d+)(?!.*\d)')
for item in batch.iterator():
self.convert(converter, item)
def check_skipframe(self, filename):
try:
idx = int(self.imageidxre.findall(filename)[0])
return not any(map(lambda b: b[0]<=idx<=b[1], self.frame_ranges))
except:
return False
def check_skipface(self, filename, face_idx):
aligned_face_name = '{}_{}{}'.format(Path(filename).stem, face_idx, Path(filename).suffix)
aligned_face_file = Path(self.arguments.input_aligned_dir) / Path(aligned_face_name)
# TODO: Remove this temporary fix for backwards compatibility of filenames
bk_compat_aligned_face_name = '{}{}{}'.format(Path(filename).stem, face_idx, Path(filename).suffix)
bk_compat_aligned_face_file = Path(self.arguments.input_aligned_dir) / Path(bk_compat_aligned_face_name)
return aligned_face_file not in self.input_aligned_dir and bk_compat_aligned_face_file not in self.input_aligned_dir
def convert(self, converter, item):
try:
(filename, image, faces) = item
skip = self.check_skipframe(filename)
if self.arguments.discard_frames and skip:
return
if not skip: # process frame as normal
for idx, face in faces:
if self.input_aligned_dir is not None and self.check_skipface(filename, idx):
print ('face {} for frame {} was deleted, skipping'.format(idx, os.path.basename(filename)))
continue
image = converter.patch_image(image, face, 64 if "128" not in self.arguments.trainer else 128)
# TODO: This switch between 64 and 128 is a hack for now. We should have a separate cli option for size
output_file = get_folder(self.output_dir) / Path(filename).name
cv2.imwrite(str(output_file), image)
except Exception as e:
print('Failed to convert image: {}. Reason: {}'.format(filename, e))
def prepare_images(self):
self.read_alignments()
is_have_alignments = self.have_alignments()
for filename in tqdm(self.read_directory()):
image = cv2.imread(filename)
if is_have_alignments:
if self.have_face(filename):
faces = self.get_faces_alignments(filename, image)
else:
print ('no alignment found for {}, skipping'.format(os.path.basename(filename)))
continue
else:
faces = self.get_faces(image)
yield filename, image, faces

View File

@ -94,7 +94,7 @@ class ExtractTrainingData(DirectoryProcessor):
resized_image = self.extractor.extract(image, face, 256)
output_file = get_folder(self.output_dir) / Path(filename).stem
cv2.imwrite(str(output_file) + str(idx) + Path(filename).suffix, resized_image)
cv2.imwrite('{}_{}{}'.format(str(output_file), str(idx), Path(filename).suffix), resized_image)
f = {
"x": face.x,
"w": face.w,

View File

@ -1,194 +1,199 @@
import cv2
import numpy
import time
from threading import Lock
from lib.utils import get_image_paths, get_folder
from lib.cli import FullPaths
from plugins.PluginLoader import PluginLoader
class TrainingProcessor(object):
arguments = None
def __init__(self, subparser, command, description='default'):
self.parse_arguments(description, subparser, command)
self.lock = Lock()
def process_arguments(self, arguments):
self.arguments = arguments
print("Model A Directory: {}".format(self.arguments.input_A))
print("Model B Directory: {}".format(self.arguments.input_B))
print("Training data directory: {}".format(self.arguments.model_dir))
self.process()
def parse_arguments(self, description, subparser, command):
parser = subparser.add_parser(
command,
help="This command trains the model for the two faces A and B.",
description=description,
epilog="Questions and feedback: \
https://github.com/deepfakes/faceswap-playground"
)
parser.add_argument('-A', '--input-A',
action=FullPaths,
dest="input_A",
default="input_A",
help="Input directory. A directory containing training images for face A.\
Defaults to 'input'")
parser.add_argument('-B', '--input-B',
action=FullPaths,
dest="input_B",
default="input_B",
help="Input directory. A directory containing training images for face B.\
Defaults to 'input'")
parser.add_argument('-m', '--model-dir',
action=FullPaths,
dest="model_dir",
default="models",
help="Model directory. This is where the training data will \
be stored. Defaults to 'model'")
parser.add_argument('-p', '--preview',
action="store_true",
dest="preview",
default=False,
help="Show preview output. If not specified, write progress \
to file.")
parser.add_argument('-v', '--verbose',
action="store_true",
dest="verbose",
default=False,
help="Show verbose output")
parser.add_argument('-s', '--save-interval',
type=int,
dest="save_interval",
default=100,
help="Sets the number of iterations before saving the model.")
parser.add_argument('-w', '--write-image',
action="store_true",
dest="write_image",
default=False,
help="Writes the training result to a file even on preview mode.")
parser.add_argument('-t', '--trainer',
type=str,
choices=("Original", "LowMem", "GAN"),
default="Original",
help="Select which trainer to use, LowMem for cards < 2gb.")
parser.add_argument('-bs', '--batch-size',
type=int,
default=64,
help="Batch size, as a power of 2 (64, 128, 256, etc)")
parser.add_argument('-ag', '--allow-growth',
action="store_true",
dest="allow_growth",
default=False,
help="Sets allow_growth option of Tensorflow to spare memory on some configs")
parser.add_argument('-ep', '--epochs',
type=int,
default=1000000,
help="Length of training in epochs.")
parser = self.add_optional_arguments(parser)
parser.set_defaults(func=self.process_arguments)
def add_optional_arguments(self, parser):
# Override this for custom arguments
return parser
def process(self):
import threading
self.stop = False
self.save_now = False
thr = threading.Thread(target=self.processThread, args=(), kwargs={})
thr.start()
if self.arguments.preview:
print('Using live preview')
while True:
try:
with self.lock:
for name, image in self.preview_buffer.items():
cv2.imshow(name, image)
key = cv2.waitKey(1000)
if key == ord('\n') or key == ord('\r'):
break
if key == ord('s'):
self.save_now = True
except KeyboardInterrupt:
break
else:
input() # TODO how to catch a specific key instead of Enter?
# there isnt a good multiplatform solution: https://stackoverflow.com/questions/3523174/raw-input-in-python-without-pressing-enter
print("Exit requested! The trainer will complete its current cycle, save the models and quit (it can take up a couple of seconds depending on your training speed). If you want to kill it now, press Ctrl + c")
self.stop = True
thr.join() # waits until thread finishes
def processThread(self):
if self.arguments.allow_growth:
self.set_tf_allow_growth()
print('Loading data, this may take a while...')
# this is so that you can enter case insensitive values for trainer
trainer = self.arguments.trainer
trainer = "LowMem" if trainer.lower() == "lowmem" else trainer
model = PluginLoader.get_model(trainer)(get_folder(self.arguments.model_dir))
model.load(swapped=False)
images_A = get_image_paths(self.arguments.input_A)
images_B = get_image_paths(self.arguments.input_B)
trainer = PluginLoader.get_trainer(trainer)
trainer = trainer(model, images_A, images_B, batch_size=self.arguments.batch_size)
try:
print('Starting. Press "Enter" to stop training and save model')
for epoch in range(0, self.arguments.epochs):
save_iteration = epoch % self.arguments.save_interval == 0
trainer.train_one_step(epoch, self.show if (save_iteration or self.save_now) else None)
if save_iteration:
model.save_weights()
if self.stop:
model.save_weights()
exit()
if self.save_now:
model.save_weights()
self.save_now = False
except KeyboardInterrupt:
try:
model.save_weights()
except KeyboardInterrupt:
print('Saving model weights has been cancelled!')
exit(0)
except Exception as e:
print(e)
exit(1)
def set_tf_allow_growth(self):
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list="0"
set_session(tf.Session(config=config))
preview_buffer = {}
def show(self, image, name=''):
try:
if self.arguments.preview:
with self.lock:
self.preview_buffer[name] = image
elif self.arguments.write_image:
cv2.imwrite('_sample_{}.jpg'.format(name), image)
except Exception as e:
print("could not preview sample")
print(e)
import cv2
import numpy
import time
from threading import Lock
from lib.utils import get_image_paths, get_folder
from lib.cli import FullPaths
from plugins.PluginLoader import PluginLoader
class TrainingProcessor(object):
arguments = None
def __init__(self, subparser, command, description='default'):
self.parse_arguments(description, subparser, command)
self.lock = Lock()
def process_arguments(self, arguments):
self.arguments = arguments
print("Model A Directory: {}".format(self.arguments.input_A))
print("Model B Directory: {}".format(self.arguments.input_B))
print("Training data directory: {}".format(self.arguments.model_dir))
self.process()
def parse_arguments(self, description, subparser, command):
parser = subparser.add_parser(
command,
help="This command trains the model for the two faces A and B.",
description=description,
epilog="Questions and feedback: \
https://github.com/deepfakes/faceswap-playground"
)
parser.add_argument('-A', '--input-A',
action=FullPaths,
dest="input_A",
default="input_A",
help="Input directory. A directory containing training images for face A.\
Defaults to 'input'")
parser.add_argument('-B', '--input-B',
action=FullPaths,
dest="input_B",
default="input_B",
help="Input directory. A directory containing training images for face B.\
Defaults to 'input'")
parser.add_argument('-m', '--model-dir',
action=FullPaths,
dest="model_dir",
default="models",
help="Model directory. This is where the training data will \
be stored. Defaults to 'model'")
parser.add_argument('-p', '--preview',
action="store_true",
dest="preview",
default=False,
help="Show preview output. If not specified, write progress \
to file.")
parser.add_argument('-v', '--verbose',
action="store_true",
dest="verbose",
default=False,
help="Show verbose output")
parser.add_argument('-s', '--save-interval',
type=int,
dest="save_interval",
default=100,
help="Sets the number of iterations before saving the model.")
parser.add_argument('-w', '--write-image',
action="store_true",
dest="write_image",
default=False,
help="Writes the training result to a file even on preview mode.")
parser.add_argument('-t', '--trainer',
type=str,
choices=("Original", "LowMem", "GAN", "GAN128"),
default="Original",
help="Select which trainer to use, LowMem for cards < 2gb.")
parser.add_argument('-pl', '--use-perceptual-loss',
action="store_true",
dest="perceptual_loss",
default=False,
help="Use perceptual loss while training")
parser.add_argument('-bs', '--batch-size',
type=int,
default=64,
help="Batch size, as a power of 2 (64, 128, 256, etc)")
parser.add_argument('-ag', '--allow-growth',
action="store_true",
dest="allow_growth",
default=False,
help="Sets allow_growth option of Tensorflow to spare memory on some configs")
parser.add_argument('-ep', '--epochs',
type=int,
default=1000000,
help="Length of training in epochs.")
parser = self.add_optional_arguments(parser)
parser.set_defaults(func=self.process_arguments)
def add_optional_arguments(self, parser):
# Override this for custom arguments
return parser
def process(self):
import threading
self.stop = False
self.save_now = False
thr = threading.Thread(target=self.processThread, args=(), kwargs={})
thr.start()
if self.arguments.preview:
print('Using live preview')
while True:
try:
with self.lock:
for name, image in self.preview_buffer.items():
cv2.imshow(name, image)
key = cv2.waitKey(1000)
if key == ord('\n') or key == ord('\r'):
break
if key == ord('s'):
self.save_now = True
except KeyboardInterrupt:
break
else:
input() # TODO how to catch a specific key instead of Enter?
# there isnt a good multiplatform solution: https://stackoverflow.com/questions/3523174/raw-input-in-python-without-pressing-enter
print("Exit requested! The trainer will complete its current cycle, save the models and quit (it can take up a couple of seconds depending on your training speed). If you want to kill it now, press Ctrl + c")
self.stop = True
thr.join() # waits until thread finishes
def processThread(self):
if self.arguments.allow_growth:
self.set_tf_allow_growth()
print('Loading data, this may take a while...')
# this is so that you can enter case insensitive values for trainer
trainer = self.arguments.trainer
trainer = "LowMem" if trainer.lower() == "lowmem" else trainer
model = PluginLoader.get_model(trainer)(get_folder(self.arguments.model_dir))
model.load(swapped=False)
images_A = get_image_paths(self.arguments.input_A)
images_B = get_image_paths(self.arguments.input_B)
trainer = PluginLoader.get_trainer(trainer)
trainer = trainer(model, images_A, images_B, self.arguments.batch_size, self.arguments.perceptual_loss)
try:
print('Starting. Press "Enter" to stop training and save model')
for epoch in range(0, self.arguments.epochs):
save_iteration = epoch % self.arguments.save_interval == 0
trainer.train_one_step(epoch, self.show if (save_iteration or self.save_now) else None)
if save_iteration:
model.save_weights()
if self.stop:
model.save_weights()
exit()
if self.save_now:
model.save_weights()
self.save_now = False
except KeyboardInterrupt:
try:
model.save_weights()
except KeyboardInterrupt:
print('Saving model weights has been cancelled!')
exit(0)
except Exception as e:
print(e)
exit(1)
def set_tf_allow_growth(self):
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list="0"
set_session(tf.Session(config=config))
preview_buffer = {}
def show(self, image, name=''):
try:
if self.arguments.preview:
with self.lock:
self.preview_buffer[name] = image
elif self.arguments.write_image:
cv2.imwrite('_sample_{}.jpg'.format(name), image)
except Exception as e:
print("could not preview sample")
print(e)