在真实数据上运行 AI 时尚分类





5.00/5 (4投票s)
在本文中,我们使用手机摄像头拍摄的真实图像来评估VGG19。
引言
像DeepFashion这样的数据集的可用性为时尚行业开辟了新的可能性。在本系列文章中,我们将展示一个由人工智能驱动的深度学习系统,通过帮助我们更好地理解客户需求,从而彻底改变时尚设计行业。
在这个项目中,我们将使用
- Jupyter Notebook 作为 IDE
- 库
- DeepFashion 数据集的自定义子集——相对较小,以减少计算和内存开销
我们假设您熟悉深度学习的概念,以及Jupyter Notebook和TensorFlow。如果您是Jupyter Notebook的新手,请从本教程开始。欢迎您下载项目代码。
在上一篇文章中,我们训练了VGG16模型并评估了它在测试图像集上的性能。在本文中,我们将评估我们训练好的网络在一些测试图像以及用相机拍摄的图像上的表现,以验证模型在检测可能包含多个服装类别的图像中的真实服装时的鲁棒性。
在测试图像上评估
让我们将一张来自“牛仔裤”类别的图像输入网络,看看网络能否正确分类该服装。请注意,选定的图像将很难分类,因为它将包含不止一种服装类型:例如,牛仔裤和上衣。图像将被读取并使用preprocess_input
进行处理,该函数会调整图像大小并对其进行缩放以适应训练网络的输入。
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
img_path = r'C:\Users\abdul\Desktop\ContentLab\P2\DeepFashion\Test\Jeans\img_00000052.jpg'
img = image.load_img(img_path, target_size=(224,224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
plt.imshow(img)
选择图像后,我们将通过模型并获取输出(预测)。
def get_class_string_from_index(index):
for class_string, class_index in test_generator.class_indices.items():
if class_index == index:
return class_string
Predicted_Class=np.argmax(c, axis = 1)
print('Predicted_Class is:', Predicted_Class) #Get the rounded value of the predicted class
true_index = 5
# print('true_label is:', true_labels) #Get the rounded value of the predicted class
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(Predicted_Class))
如上图所示,模型已成功将类别识别为“牛仔裤”。
计算错误分类图像的数量
让我们进一步研究模型在检测服装类别方面的鲁棒性。为此,我们将创建一个函数,该函数将从测试集中选择一个随机图像批次,并将其输入模型以预测其类别,然后计算错误分类图像的数量。
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224), batch_size=3, class_mode='categorical')
X_test, y_test = next(test_generator)
X_test=X_test/255
preds = full_model.predict(X_test)
pred_labels = np.argmax(preds, axis=1)
true_labels = np.argmax(y_test, axis=1)
print (pred_labels)
print (true_labels)
如您在上文所见,我们定义了批量大小为三,以避免计算机内存问题。这意味着网络将仅选择三张图像并对它们进行分类,以计算这三张图像中错误分类的数量。您可以根据需要增加批量大小。
现在,让我们计算错误分类图像的数量。
mispred_img = X_test[pred_labels!=true_labels]
mispred_true = true_labels[pred_labels!=true_labels]
mispred_pred = pred_labels[pred_labels!=true_labels]
print ('number of misclassified images:', mispred_img.shape[0])
如果找到错误分类的图像,让我们使用此函数绘制它们
def plot_img_results(array, true, pred, i, n=1):
# plot the image and the target for sample i
ncols = 3
nrows = n/ncols + 1
fig = plt.figure( figsize=(ncols*2, nrows*2), dpi=100)
for j in range(n):
index = j+i
plt.subplot(nrows,ncols, j+1)
plt.imshow(array[index])
plt.title('true: {} pred: {}'.format(true[index], pred[index]))
plt.axis('off')
plot_img_results(mispred_img, mispred_true, mispred_pred, 0, len(mispred_img))
要查看每个类号指代哪个类,请运行以下命令
Classes[13]
使用特定数据集评估模型
现在,我们将创建一个函数,该函数将从任何数据集(例如训练、测试或验证)中选择任何图像,并在图像下方显示“真实与预测类别”的结果。为了使结果更易于解释,我们将显示类别名称(例如,“牛仔裤”),而不是类号(例如,“5”)。
def get_class_string_from_index(index):
for class_string, class_index in test_generator.class_indices.items():
if class_index == index:
return class_string
test_generator = test_datagen.flow_from_directory(test_dir, target_size=(224, 224), batch_size=7, class_mode='categorical')
X_test, y_test = next(test_generator)
X_test=X_test/255
image = X_test[2]
true_index = np.argmax(y_test(2)])
plt.imshow(image)
plt.axis('off')
plt.show()
# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = full_model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index))
使用相机图像评估模型
在这一部分,我们将研究模型在相机拍摄的图像上的性能。我们拍摄了12张衣服放在床上的图像,以及穿着不同类型衣服的个体的图像,并让训练好的模型对它们进行分类。为了增加趣味性,我们选择了男装(因为大多数训练图像都是女装)。这些衣服没有进行分类。我们只是将它们输入网络,让它找出这些衣服属于哪个类别。
网络在高质量图像(高对比度、未翻转的图像)上表现良好。有些图像被分配了正确的类别,有些被分配了相似的类别,而另一些则被错误标记。
提高网络性能
正如我们在前面几节中所示,网络的性能相当不错。但是,它可以得到改进。这与数据有关吗?是的,它是:原始的DeepFashion数据集非常庞大,而我们只使用了其中很小的一部分。
让我们使用数据增强来增加网络训练数据的量。这可能会提高网络在各种类型和不同质量的新图像上的测试性能。数据增强的目标是提高网络的泛化能力。通过在增强图像上训练网络来实现此目标,这些增强图像可以覆盖训练网络在测试真实图像时可能遇到的所有图像排列。
在Keras中,数据增强易于实现。您可以简单地在ImageDataGenerator
函数中添加所需的增强操作类型:旋转、缩放、平移、翻转等。我们实现了增强的DataLoad
函数将如下所示
from tensorflow.keras.preprocessing.image import ImageDataGenerator
batch_size = 3
def DataLoad(shape, preprocessing):
'''Create the training and validation datasets for
a given image shape.
'''
imgdatagen = ImageDataGenerator(
preprocessing_function = preprocessing,
rotation_range=10, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.15, z oom_range=0.1,
channel_shift_range=10., horizontal_flip=True,
validation_split = 0.1,
)
height, width = shape
train_dataset = imgdatagen.flow_from_directory(
os.getcwd(),
target_size = (height, width),
classes = ['Blazer', 'Blouse', 'Cardigan', 'Dress', 'Jacket',
'Jeans', 'Jumpsuit', 'Romper', 'Shorts', 'Skirts', 'Sweater', 'Sweatpants', 'Tank', 'Tee', 'Top'],
batch_size = batch_size,
subset = 'training',
)
val_dataset = imgdatagen.flow_from_directory(
os.getcwd(),
target_size = (height, width),
classes = ['Blazer', 'Blouse', 'Cardigan', 'Dress', 'Jacket',
'Jeans', 'Jumpsuit', 'Romper', 'Shorts', 'Skirts', 'Sweater', 'Sweatpants', 'Tank', 'Tee', 'Top'],
batch_size = batch_size,
subset = 'validation'
)
return train_dataset, val_dataset
以下代码显示了ImageDataGenerator
如何增强图像,并附带一些示例。
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
%matplotlib inline
def plotImages(images_arr):
fig, axes = plt.subplots(1, 10, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
gen = ImageDataGenerator(rotation_range=10, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.15, zoom_range=0.1,
channel_shift_range=10., horizontal_flip=True)
现在,我们可以读取任何图像并显示它,以及它的增强衍生物。
image_path = r'C:\Users\abdul\Desktop\ContentLab\P2\DeepFashion\Train\Blouse\img_00000003.jpg'
image = np.expand_dims(plt.imread(image_path),0)
plt.imshow(image[0])
上述图像的增强图像显示在下方。
aug_iter = gen.flow(image)
aug_images = [next(aug_iter)[0].astype(np.uint8) for i in range(10)]
plotImages(aug_images)
后续步骤
在下一篇文章中,我们将向您展示如何构建一个用于时尚设计生成的生成对抗网络 (GAN)。敬请关注!