65.9K
CodeProject 正在变化。 阅读更多。
Home

使用 Keras 构建移动风格迁移 CycleGAN

starIconstarIconstarIconstarIconstarIcon

5.00/5 (1投票)

2021 年 6 月 16 日

CPOL

3分钟阅读

viewsIcon

6163

downloadIcon

92

在本文中,我们将从头开始实现一个 CycleGAN。

引言

在本系列文章中,我们将介绍一个基于 循环一致对抗网络 (CycleGAN) 的移动端图像到图像转换系统。我们将构建一个可以执行非配对图像到图像转换的 CycleGAN,并向您展示一些有趣但学术上很深入的例子。我们还将讨论如何将这样一个用 TensorFlow 和 Keras 构建的训练好的网络转换为 TensorFlow Lite,并在移动设备上用作应用程序。

我们假设您熟悉深度学习的概念,以及 Jupyter Notebooks 和 TensorFlow。欢迎您下载项目代码。

上一篇文章中,我们讨论了 CycleGAN 架构。现在我们已经完成了理论。在本文中,我们将从头开始实现 CycleGAN。

我们的 CycleGAN 将使用马到斑马数据集执行非配对图像到图像转换,您可以下载该数据集。我们将使用 TensorFlow 和 Keras 实现我们的网络,使用来自 Pix.Pix 库的生成器和判别器。我们将通过 tensorflow_examples 包导入生成器和判别器,以简化实现。但是,在后续的文章中,我们还将向您展示如何从头开始构建新的生成器和判别器。

重要的是要提到 CycleGAN 是一个非常耗电和耗内存的网络。您的系统必须至少有 8 GB 的 RAM 和一个好于或等于 GTX 1660 Ti 的好 GPU,才能训练和运行 CycleGAN,而不会出现内存不足错误或超时。

我们将使用 GoogleColab 训练我们的网络,这是一种托管的 Jupyter Notebook 服务,可免费访问包括 GPU 在内的计算资源。最重要的是,它与某些其他云计算服务不同,它是免费的。

处理数据集

让我们加载数据集并应用一些预处理技术,例如裁剪、抖动和镜像,这将有助于我们避免网络过度拟合

  • 图像抖动将图像调整为 286 x 286 像素,然后从随机选择的起始点将其裁剪为 256 x 256 像素
  • 图像镜像水平翻转图像,从左到右。

以上技术在原始 CycleGAN 论文中有描述。

我们将把我们的数据上传到 Google Drive,以便 Google Colab 可以访问它。数据上传后,我们就可以开始读取数据了。或者,您可以简单地在代码中使用tfds.load,直接从 TensorFlow 数据集包中加载数据集,就像我们将在下面做的那样。

首先,让我们导入一些必需的依赖项

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

现在我们将下载数据集并对其应用上述的增强技术

dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

加载数据后,让我们添加一些预处理函数

def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

# normalizing images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # randomly mirroring
  image = tf.image.random_flip_left_right(image)

  return image

def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

def preprocess_image_test(image, label):
  image = normalize(image)
  return image

现在,我们将读取图像

train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
############################Mirroring and jittering
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random mirroring')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

这是一个抖动图像的例子。

构建生成器和判别器

现在,我们从 pix2pix 模型中导入生成器和判别器。我们将使用 基于 U-Net 的生成器,而不是 CycleGAN 论文中使用的残差块生成器。我们将使用 U-Net,因为它具有不太复杂的结构,并且比残差块需要更少的计算。但是,我们将在另一篇文章中探索基于残差块的生成器。

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

有了生成器和判别器,我们就可以开始设置损失了。由于 CycleGAN 是非配对的图像到图像转换,因此不需要配对数据来训练网络。因此,没有人可以保证输入图像和目标图像在训练期间构成有意义的对。这就是计算循环一致性损失以使网络正确映射非常重要的原因

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

现在,我们计算循环一致性损失,以确保转换结果接近原始图像

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

最后,我们为生成器和判别器设置优化器

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

后续步骤

下一篇文章中,我们将向您展示如何训练我们的 CycleGAN 将马翻译成斑马,并将斑马翻译成马。敬请关注!

© . All rights reserved.