65.9K
CodeProject 正在变化。 阅读更多。
Home

从头开始构建风格迁移 CycleGAN

starIconstarIconstarIconstarIcon
emptyStarIcon
starIcon

4.80/5 (2投票s)

2021年6月18日

CPOL

4分钟阅读

viewsIcon

8654

downloadIcon

96

在本文中,我们实现了一个基于残差的生成器的 CycleGAN。

引言

在本系列文章中,我们将介绍一个基于循环一致对抗网络(CycleGAN)的移动图像到图像翻译系统。我们将构建一个可以执行非成对图像到图像翻译的 CycleGAN,并向您展示一些有趣但学术上深刻的例子。我们还将讨论如何将使用 TensorFlow 和 Keras 构建的这样训练好的网络转换为 TensorFlow Lite,并在移动设备上作为应用程序使用。

我们假设您熟悉深度学习的概念,以及Jupyter Notebook和 TensorFlow。欢迎您下载项目代码。

本系列的上一篇文章中,我们训练和评估了一个使用 U-Net 作为生成器的 CycleGAN。在本文中,我们将实现一个使用残差作为生成器的 CycleGAN。

从头开始构建 CycleGAN

最初的 CycleGAN 是使用基于残差的生成器构建的。让我们从头开始实现这种类型的 CycleGAN。我们将构建网络并对其进行训练,以使用包含和不包含伪影的眼底图像数据集来减少眼底图像中的伪影。

网络将如上所示,将带有伪影的眼底图像转换为不带伪影的图像,反之亦然。

CycleGAN 的设计将包括以下步骤:

  • 构建判别器
  • 构建残差块
  • 构建生成器
  • 构建完整模型

在开始加载数据之前,让我们导入一些必要的库和包。

#the necessary imports

from random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from matplotlib import pyplot

加载数据集

与上一篇文章中的做法不同,这次我们将使用本地机器(而不是 Google Colab)来训练 CycleGAN。因此,眼底数据集应首先下载和处理。我们将使用 Jupyter Notebook 和 TensorFlow 来构建和训练这个网络。

from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed

# load all images in a directory into memory

def load_images(path, size=(256,256)):
	data_list = list()
	# enumerate filenames in directory, assume all are images
	for filename in listdir(path):
		# load and resize the image
		pixels = load_img(path + filename, target_size=size)
		# convert to numpy array
		pixels = img_to_array(pixels)
		# store
		data_list.append(pixels)
	return asarray(data_list)

# dataset path
path = r'C:/Users/abdul/Desktop/ContentLab/P3/Fundus/'
# load dataset A
dataA1 = load_images(path + 'trainA/')
dataAB = load_images(path + 'testA/')
dataA = vstack((dataA1, dataAB))
print('Loaded dataA: ', dataA.shape)
# load dataset B
dataB1 = load_images(path + 'trainB/')
dataB2 = load_images(path + 'testB/')
dataB = vstack((dataB1, dataB2))
print('Loaded dataB: ', dataB.shape)
# save as compressed numpy array
filename = 'Artifcats.npz'
savez_compressed(filename, dataA, dataB)
print('Saved dataset: ', filename)

数据加载完成后,就可以创建一个显示一些训练图像的函数了。

# load and plot the prepared dataset

from numpy import load
from matplotlib import pyplot

# load the dataset

data = load('Artifacts.npz')
dataA, dataB = data['arr_0'], data['arr_1']
print('Loaded: ', dataA.shape, dataB.shape)

# plot source images

n_samples = 3
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + i)
	pyplot.axis('off')
	pyplot.imshow(dataA[i].astype('uint8'))

# plot target image

for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + n_samples + i)
	pyplot.axis('off')
	pyplot.imshow(dataB[i].astype('uint8'))
pyplot.show()

构建判别器

正如我们之前讨论过的,判别器是一个卷积神经网络(CNN),它包含许多卷积层,以及LeakyReLU实例归一化层。

def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_image = Input(shape=image_shape)
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	patch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	# define model
	model = Model(in_image, patch_out)
	# compile model
	model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
	return model

构建好判别器后,我们可以创建它的副本,这样我们就有两个相同的判别器:`DiscA` 和 `DiscB`。

image_shape=(256,256,3)
DiscA=define_discriminator(image_shape)
DiscB=define_discriminator(image_shape)
DiscA.summary()

构建残差块

下一步是为我们的生成器创建残差块。该块是一组二维卷积层,其中每两层后面都跟着一个实例归一化层。

# generator a resnet block

def resnet_block(n_filters, input_layer):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# first layer convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# second convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	# concatenate merge channel-wise with input layer
	g = Concatenate()([g, input_layer])
	return g

构建生成器

残差块的输出将通过生成器的最后一部分(解码器),在这里图像将被上采样并调整到其原始大小。由于编码器尚未定义,我们将构建一个定义解码器和编码器部分并将它们连接到残差块的函数。

# define the generator model

def define_generator(image_shape, n_resnet=9):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)

	g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d128
	g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d256
	g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# R256
	for _ in range(n_resnet):
		g = resnet_block(256, g)
	# u128
	g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# u64
	g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)

	g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model

现在,我们定义生成器 `genA` 和 `genB`。

genA=define_generator(image_shape, 9)
genB=define_generator(image_shape, 9)

构建 CycleGAN

定义了生成器和判别器之后,我们现在可以构建整个 CycleGAN 模型并设置其优化器和其他学习参数。

#define a composite model

def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
	# ensure the model we're updating is trainable
	g_model_1.trainable = True
	# mark discriminator as not trainable
	d_model.trainable = False
	# mark other generator model as not trainable
	g_model_2.trainable = False
	# discriminator element
	input_gen = Input(shape=image_shape)
	gen1_out = g_model_1(input_gen)
	output_d = d_model(gen1_out)
	# identity element
	input_id = Input(shape=image_shape)
	output_id = g_model_1(input_id)
	# forward cycle
	output_f = g_model_2(gen1_out)
	# backward cycle
	gen2_out = g_model_2(input_id)
	output_b = g_model_1(gen2_out)
	# define model graph
	model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])
	# define optimization algorithm configuration
	opt = Adam(lr=0.0002, beta_1=0.5)
	# compile model with weighting of least squares loss and L1 loss
	model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)
	return model

现在让我们定义两个模型(A 和 B),其中一个将眼底图像伪影转换为无伪影的眼底图像(`AtoB`),另一个将无伪影转换为带伪影的眼底图像(`BtoA`)。

comb_modelA=define_composite_model(genA,DiscA,genB,image_shape)
comb_modelB=define_composite_model(genB,DiscB,genA,image_shape)

训练 CycleGAN

现在我们的模型已经完成,我们将创建一个训练函数,该函数定义训练参数,并计算生成器和判别器的损失,以及在训练期间更新权重。该函数将按如下方式运行:

  1. 将图像输入生成器。
  2. 通过生成器获得生成的图像。
  3. 将生成的图像传回生成器,以验证我们是否可以从生成的图像中预测出原始图像。
  4. 使用生成器对真实图像执行身份映射。
  5. 将步骤 1 中生成的图像传递给相应的判别器。
  6. 找到生成器的总损失(对抗损失 + 循环损失 + 身份损失)。
  7. 找到判别器的损失。
  8. 更新生成器权重。
  9. 更新判别器权重。
  10. 将损失以字典的形式返回。
# train the cycleGAN model

def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):

	# define properties of the training run
	n_epochs, n_batch, = 30, 1
	# determine the output square shape of the discriminator
	n_patch = d_model_A.output_shape[1]
	# unpack dataset
	trainA, trainB = dataset
	# prepare image pool for fakes
	poolA, poolB = list(), list()
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch)
		X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch)
		X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch)
		# update fakes from pool
		X_fakeA = update_image_pool(poolA, X_fakeA)
		X_fakeB = update_image_pool(poolB, X_fakeB)
		# update generator B->A via adversarial and cycle loss
		g_loss2, _, _, _, _  = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
		# update discriminator for A -> [real/fake]
		dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
		dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
		# update generator A->B via adversarial and cycle loss
		g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
		# update discriminator for B -> [real/fake]
		dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
		dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
		# summarize performance
		print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))
		# evaluate the model performance every so often
		if (i+1) % (bat_per_epo * 1) == 0:
			# plot A->B translation
			summarize_performance(i, g_model_AtoB, trainA, 'AtoB')
			# plot B->A translation
			summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
		if (i+1) % (bat_per_epo * 5) == 0:
			# save the models
			save_models(i, g_model_AtoB, g_model_BtoA)

以下是在训练过程中使用的一些函数。

#load and prepare training images

def load_real_samples(filename):
	# load the dataset
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]

# The generate_real_samples() function below implements this

# select a batch of random samples, returns images and target

def generate_real_samples(dataset, n_samples, patch_shape):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# retrieve selected images
	X = dataset[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return X, y

# generate a batch of images, returns images and targets

def generate_fake_samples(g_model, dataset, patch_shape):
	# generate fake instance
	X = g_model.predict(dataset)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y
# update image pool for fake images

def update_image_pool(pool, images, max_size=50):
	selected = list()
	for image in images:
		if len(pool) < max_size:
			# stock the pool
			pool.append(image)
			selected.append(image)
		elif random() < 0.5:
			# use image, but don't add it to the pool
			selected.append(image)
		else:
			# replace an existing image and use replaced image
			ix = randint(0, len(pool))
			selected.append(pool[ix])
			pool[ix] = image
	return asarray(selected)

我们添加了一些函数来保存最佳模型并可视化眼底图像伪影减少的效果。

def save_models(step, g_model_AtoB, g_model_BtoA):
	# save the first generator model
	filename1 = 'g_model_AtoB_%06d.h5' % (step+1)
	g_model_AtoB.save(filename1)
	# save the second generator model
	filename2 = 'g_model_BtoA_%06d.h5' % (step+1)
	g_model_BtoA.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

# generate samples and save as a plot and save the model

def summarize_performance(step, g_model, trainX, name, n_samples=5):
	# select a sample of input images
	X_in, _ = generate_real_samples(trainX, n_samples, 0)
	# generate translated images
	X_out, _ = generate_fake_samples(g_model, X_in, 0)
	# scale all pixels from [-1,1] to [0,1]
	X_in = (X_in + 1) / 2.0
	X_out = (X_out + 1) / 2.0
	# plot real images
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_in[i])
	# plot translated image
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_out[i])
	# save plot to file
	filename1 = '%s_generated_plot_%06d.png' % (name, (step+1))
	pyplot.savefig(filename1)
	pyplot.close()

train(DiscA, DiscB, genA, genB, comb_modelA, comb_modelB, dataset)

评估性能

使用上述函数,我们将网络训练了 30 个 epoch。结果表明,我们的网络能够减少眼底图像中的伪影。

伪影到无伪影转换(**AtoB**)的结果如下所示:

还计算了无伪影到伪影(`BtoA`)的眼底图像转换;这里有一些例子。

结论

正如人工智能先驱Yann LeCun所说的关于 GAN 的话,“(它)是过去十年中最有趣的深度学习思想”。我们希望,通过本系列,我们已经帮助您理解了为什么 GAN 是非常有趣的想法。我们知道您可能会觉得本系列中提出的概念有点重和模糊,但没关系。CycleGAN 非常难以一次读懂,在您理解之前,可以多看几遍本系列。

最后,如果您喜欢本系列中的内容,请记住,您总可以改进它!为什么不利用您的新技能,创造一些伟大的东西,然后写下来并在 CodeProject 上分享呢?

© . All rights reserved.