65.9K
CodeProject 正在变化。 阅读更多。
Home

微调 VGG16 以对服装进行分类

starIconstarIconstarIconstarIconstarIcon

5.00/5 (3投票s)

2021 年 3 月 17 日

CPOL

2分钟阅读

viewsIcon

12858

在本文中,我们将向您展示如何训练 VGG19 来识别人们穿的衣服。

引言

DeepFashion 这样的数据集的可用性为时尚行业开辟了新的可能性。 在本系列文章中,我们将展示一个由 AI 驱动的 深度学习 系统,该系统可以通过帮助我们更好地了解客户的需求来彻底改变时装设计行业。

在这个项目中,我们将使用

我们假设您熟悉深度学习的概念,以及 Jupyter Notebook 和 TensorFlow。 如果您不熟悉 Jupyter Notebook,请从 本教程 开始。 欢迎下载 项目代码

上一篇文章 中,我们向您展示了如何加载 DeepFashion 数据集,以及如何重组 VGG16 模型以适应我们的服装分类任务。 在本文中,我们将训练 VGG16 对 15 个不同的服装类别进行分类并评估模型性能。

训练 VGG16

VGG16 的迁移学习从冻结模型权重开始,这些权重是通过在诸如 ImageNet 这样的巨大数据集上训练模型获得的。 这些学习到的权重和过滤器为网络提供了强大的特征提取能力,这将有助于我们在训练它对服装类别进行分类时提高其性能。 因此,只有全连接 (FC) 层将被训练,同时保持模型的特征提取部分几乎冻结(通过设置非常低的 learning rate,比如 0.001)。 让我们将特征提取层设置为 False 来冻结它们

for layer in conv_model.layers:
    layer.trainable = False

现在,我们可以编译我们的模型,同时选择学习率 (0.001) 和优化器 (Adamax)

full_model.compile(loss='categorical_crossentropy',
                  optimizer=keras.optimizers.Adamax(lr=0.001),
                  metrics=['acc'])

编译后,我们可以使用 fit_generator 函数开始模型训练,因为我们使用了 ImageDataGenerator 来加载我们的数据。 我们将分别使用标记为 train_datasetval_dataset 的数据来训练和验证我们的网络。 我们将训练三个 epoch,但可以根据网络性能增加这个数字。

history = full_model.fit_generator(
    train_dataset, 
    validation_data = val_dataset,
    workers=0,
    epochs=3,
)

运行以上代码将产生以下输出

现在,为了绘制网络的学习曲线和损失曲线,让我们添加 plot_history 函数

def plot_history(history, yrange):
    '''Plot loss and accuracy as a function of the epoch,
    for the training and validation datasets.
    '''
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    # Get number of epochs
    epochs = range(len(acc))

    # Plot training and validation accuracy per epoch
    plt.plot(epochs, acc)
    plt.plot(epochs, val_acc)
    plt.title('Training and validation accuracy')
    plt.ylim(yrange)
    
    # Plot training and validation loss per epoch
    plt.figure()

    plt.plot(epochs, loss)
    plt.plot(epochs, val_loss)
    plt.title('Training and validation loss')
    
    plt.show()
    
plot_history(history, yrange=(0.9,1))

此函数将生成这两个图

在新的图像上评估 VGG16

我们的网络在训练期间表现良好。 因此,它在测试它之前没有见过的衣服图像时也应该表现良好,对吧? 我们将在我们的测试图像集上对其进行测试。

首先,让我们加载测试集,然后使用 model.evaluate 函数将测试图像传递给模型以衡量网络精度。

from tensorflow.keras.preprocessing.image import ImageDataGenerator
test_dir=r'C:\Users\abdul\Desktop\ContentLab\P2\DeepFashion\Test'
test_datagen = ImageDataGenerator()

test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224), batch_size=3, class_mode='categorical')
# X_test, y_test = next(test_generator)

Testresults = full_model.evaluate(test_generator)
print("test loss, test acc:", Testresults)

好吧,很明显我们的网络训练有素。 没有过度拟合:它在测试集上达到了 92% 的准确率。

后续步骤

下一篇文章 中,我们将使用手机摄像头拍摄的真实图像来评估 VGG19。 敬请关注!

© . All rights reserved.