65.9K
CodeProject 正在变化。 阅读更多。
Home

训练运行 CycleGAN 以进行移动风格迁移

starIconstarIconstarIconstarIconstarIcon

5.00/5 (2投票s)

2021年6月17日

CPOL

3分钟阅读

viewsIcon

6347

downloadIcon

67

在本文中,我们训练一个带有基于 U-Net 的生成器的 CycleGAN。

引言

在本系列文章中,我们将介绍一个基于 循环一致对抗网络 (CycleGAN) 的移动端图像到图像转换系统。我们将构建一个可以执行非配对图像到图像转换的 CycleGAN,并向您展示一些有趣但学术上深刻的例子。我们还将讨论如何将这样一个用 TensorFlow 和 Keras 构建的训练好的网络转换为 TensorFlow Lite,并将其用作移动设备上的应用程序。

我们假设您熟悉深度学习的概念,以及 Jupyter Notebooks 和 TensorFlow。 欢迎您下载项目代码。

之前的文章中,我们从头开始实现了一个 CycleGAN。 在本文中,我们将在 horse2zebra 数据集上训练和测试该网络,并评估其性能。

训练 CycleGAN

是时候训练我们的 CycleGAN 来执行一些有趣的翻译,例如将马匹转换为斑马,反之亦然。 我们将首先设置一个检查点路径来保存最佳模型

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

首先,我们将训练 20 个 epoch,看看这是否足以获得可接受的结果。 根据获得的结果,我们可能需要增加 epoch 的数量。 即使您的训练结果看起来不错,预测也可能不太准确。 因此,80 到 100 个 epoch 更有可能让您获得完美的翻译,但是,除非您使用的是具有非常高规格的系统或付费的基于云的计算服务(例如 AWSMicrosoft Azure),否则这将需要超过 3 天的训练时间。

EPOCHS = 20
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()


def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

上面的训练循环执行以下操作

  • 获取预测
  • 计算损失
  • 使用反向传播计算梯度
  • 将梯度应用于优化器

在训练期间,网络将从训练集中选择一张随机图像,并将其与其翻译版本一起显示,以便我们可视化每次 epoch 后性能如何变化,如下图所示。

for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

评估 CycleGAN

一旦 CycleGAN 经过训练,我们就可以开始向其输入新图像并评估其在将马匹转换为斑马以及反之亦然方面的性能。

让我们在数据集中的图像上测试我们训练好的 CycleGAN,并可视化其泛化能力。 我们将使用 generate_images 函数,该函数将提取一些图像,将它们传递到训练好的网络中,并显示翻译结果。

def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

现在,您可以选择任何测试图像并可视化翻译结果

for inp in test_horses.take(5):
  generate_images(generator_g, inp)

以下是一些在网络仅经过 20 个 epoch 的训练后获得的示例。 对于如此短的训练,结果非常好。 您可以通过添加更多 epoch 来改进它们。

季节转换 CycleGAN

我们可以将我们设计的网络用于不同的任务,例如白天到夜晚的转换或季节转换。 为了训练我们的网络进行季节转换,我们所需要做的就是将训练数据集更改为 summer2winter

我们在上述数据集上训练了我们的网络 80 个 epoch。看看结果。

后续步骤

在本文中,我们训练了一个带有基于 U-Net 的生成器的 CycleGAN。 在下一篇文章中,我们将向您展示如何实现基于残差的生成器,并在医疗数据集上训练生成的 CycleGAN。 敬请关注!

© . All rights reserved.