使用 Keras 构建移动风格迁移 CycleGAN





5.00/5 (1投票)
在本文中,我们将从头开始实现一个 CycleGAN。
引言
在本系列文章中,我们将介绍一个基于 循环一致对抗网络 (CycleGAN) 的移动端图像到图像转换系统。我们将构建一个可以执行非配对图像到图像转换的 CycleGAN,并向您展示一些有趣但学术上很深入的例子。我们还将讨论如何将这样一个用 TensorFlow 和 Keras 构建的训练好的网络转换为 TensorFlow Lite,并在移动设备上用作应用程序。
我们假设您熟悉深度学习的概念,以及 Jupyter Notebooks 和 TensorFlow。欢迎您下载项目代码。
在上一篇文章中,我们讨论了 CycleGAN 架构。现在我们已经完成了理论。在本文中,我们将从头开始实现 CycleGAN。
我们的 CycleGAN 将使用马到斑马数据集执行非配对图像到图像转换,您可以下载该数据集。我们将使用 TensorFlow 和 Keras 实现我们的网络,使用来自 Pix.Pix 库的生成器和判别器。我们将通过 tensorflow_examples 包导入生成器和判别器,以简化实现。但是,在后续的文章中,我们还将向您展示如何从头开始构建新的生成器和判别器。
重要的是要提到 CycleGAN 是一个非常耗电和耗内存的网络。您的系统必须至少有 8 GB 的 RAM 和一个好于或等于 GTX 1660 Ti 的好 GPU,才能训练和运行 CycleGAN,而不会出现内存不足错误或超时。
我们将使用 GoogleColab 训练我们的网络,这是一种托管的 Jupyter Notebook 服务,可免费访问包括 GPU 在内的计算资源。最重要的是,它与某些其他云计算服务不同,它是免费的。
处理数据集
让我们加载数据集并应用一些预处理技术,例如裁剪、抖动和镜像,这将有助于我们避免网络过度拟合
- 图像抖动将图像调整为 286 x 286 像素,然后从随机选择的起始点将其裁剪为 256 x 256 像素
- 图像镜像水平翻转图像,从左到右。
以上技术在原始 CycleGAN 论文中有描述。
我们将把我们的数据上传到 Google Drive,以便 Google Colab 可以访问它。数据上传后,我们就可以开始读取数据了。或者,您可以简单地在代码中使用tfds.load
,直接从 TensorFlow 数据集包中加载数据集,就像我们将在下面做的那样。
首先,让我们导入一些必需的依赖项
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
AUTOTUNE = tf.data.AUTOTUNE
现在我们将下载数据集并对其应用上述的增强技术
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
加载数据后,让我们添加一些预处理函数
def random_crop(image):
cropped_image = tf.image.random_crop(
image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image
# normalizing images to [-1, 1]
def normalize(image):
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def random_jitter(image):
# resizing to 286 x 286 x 3
image = tf.image.resize(image, [286, 286],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# randomly cropping to 256 x 256 x 3
image = random_crop(image)
# randomly mirroring
image = tf.image.random_flip_left_right(image)
return image
def preprocess_image_train(image, label):
image = random_jitter(image)
image = normalize(image)
return image
def preprocess_image_test(image, label):
image = normalize(image)
return image
现在,我们将读取图像
train_horses = train_horses.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(
preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
test_horses = test_horses.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(
preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
############################Mirroring and jittering
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Horse with random mirroring')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
这是一个抖动图像的例子。
构建生成器和判别器
现在,我们从 pix2pix 模型中导入生成器和判别器。我们将使用 基于 U-Net 的生成器,而不是 CycleGAN 论文中使用的残差块生成器。我们将使用 U-Net,因为它具有不太复杂的结构,并且比残差块需要更少的计算。但是,我们将在另一篇文章中探索基于残差块的生成器。
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
有了生成器和判别器,我们就可以开始设置损失了。由于 CycleGAN 是非配对的图像到图像转换,因此不需要配对数据来训练网络。因此,没有人可以保证输入图像和目标图像在训练期间构成有意义的对。这就是计算循环一致性损失以使网络正确映射非常重要的原因
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
real_loss = loss_obj(tf.ones_like(real), real)
generated_loss = loss_obj(tf.zeros_like(generated), generated)
total_disc_loss = real_loss + generated_loss
return total_disc_loss * 0.5
def generator_loss(generated):
return loss_obj(tf.ones_like(generated), generated)
现在,我们计算循环一致性损失,以确保转换结果接近原始图像
def calc_cycle_loss(real_image, cycled_image):
loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
return LAMBDA * loss1
def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return LAMBDA * 0.5 * loss
最后,我们为生成器和判别器设置优化器
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
后续步骤
在下一篇文章中,我们将向您展示如何训练我们的 CycleGAN 将马翻译成斑马,并将斑马翻译成马。敬请关注!