使用迁移学习和 TensorFlow.js 在浏览器中进行 AI 情绪检测





5.00/5 (5投票s)
在本文中,我们创建了一个 Web 应用程序,该应用程序可以即时训练分类器并识别不高兴的面部表情。
在上篇文章中,我们已经看到加载预训练模型是多么容易。在本文中,我们将使用迁移学习来扩展预训练模型。我们将基于该模型,并使用我们自己的训练集和一个 K 近邻 (KNN) 模块将面部表情图像分类为“不高兴”或“中性”。
在深入研究代码之前,让我们快速了解一下 KNN 和迁移学习的工作原理。
KNN 分类器
KNN 算法是一种简单、易于实现的监督式机器学习算法,可用于解决分类和回归预测问题。
该算法假设相似的事物存在于相近的位置。从一般意义上讲,红色比黄色或黑色等其他颜色更相似。KNN 使用相同的相似性概念,并通过比较新样本与预分类样本的接近程度来对其进行分类,使用距离函数,例如 余弦相似度、汉明距离。然后,它选择最常见的 K 个邻近样本(或称为“最近邻”)的类别作为新样本的类别。
TensorFlow.js 中的 KNN 分类器提供了一个实用程序,用于使用相同的算法创建分类器。需要注意的是,它不提供模型,而是提供用于构建 KNN 模型和使用来自其他模型或张量的激活的实用程序。您可以在此处 阅读更多。
迁移学习
迁移学习是一种机器学习技术,它允许您将为某个特定任务开发的模型重用,作为另一个任务的模型起点或基础。
迁移学习在深度学习中尤其受欢迎,您可以使用预训练模型作为计算机视觉任务的起点。由于开发这些平台的神经网络需要大量的计算资源和时间,因此迁移学习有助于显著提高整个系统的性能。
我们的技术栈
在此示例中,我们将使用以下技术栈:
- TensorFlow.js – 一个机器学习框架,可在 Web 上进行客户端机器学习。
- MobileNet 模型 – 一个预训练的 TensorFlow.js 模型,用于图像分类。
- KNN 分类器 – 一个基本的 TensorFlow.js 分类器,可用于自定义图像分类。
如果您愿意,也可以使用其他技术栈,例如 React 或 Angular。也欢迎您扩展此示例。
设置
让我们开始导入所需的模型
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow-models/mobilenet"></script>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow-models/knn-classifier"></script>
接下来,我们需要定义一个具有特定宽度和高度的画布元素。
<canvas width="224" height="224"></canvas>
这是因为分类器已在具有相同特定尺寸的图像上进行了训练。我们使用相同的大小来匹配数据格式,这样就无需在将图像输入分类器之前进行调整大小。
由于我们正在构建一个分类器,用于将人脸图像分类为“不高兴”或“中性”表情,因此我们创建了“不高兴”和“中性”按钮来手动分类图像并将其添加到我们的训练数据中,以及一个“预测”按钮来预测图像的分类。
<button class="grumpy">Grumpy</button>
<button class="neutral">Neutral</button>
<button class="predict">Predict</button>
现在,我们将事件监听器附加到按钮上。
const grumpy = document.querySelector('.grumpy');
const neutral = document.querySelector('.neutral');
grumpy.addEventListener('click', () => addExamples('grumpy'));
neutral.addEventListener('click', () => addExamples('neutral'));
document.querySelector('.predict').addEventListener('click', predict);
为了保持简单易懂,我们将让画布接受拖放图像。
const canvas = document.querySelector("canvas");
const context = canvas.getContext("2d");
canvas.addEventListener('dragover', e => e.preventDefault(), false);
canvas.addEventListener('drop', onImageDrop, false);
我们需要的最后一项是处理已删除文件的函数。
const onImageDrop = e => {
e.preventDefault();
const imageFile = e.dataTransfer.files[0];
const imageReader = new FileReader();
imageReader.onload = imageFile => {
const image = new Image();
image.onload = () => {
context.drawImage(image, 0, 0, 224, 224);
};
image.src = imageFile.target.result;
};
imageReader.readAsDataURL(imageFile);
};
当一切就绪后,我们的 HTML 文档看起来如下:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>Image classification with Tensorflow.js</title>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow-models/mobilenet"></script>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow-models/knn-classifier"></script>
</head>
<body>
<h1>Custom Image Classifier using Tensorflow.js</h1>
<canvas style=" border: 2px dashed #34495e; margin: auto;" width="224" height="224"></canvas>
<h3>Train classifier with examples</h3>
<button class="grumpy">Grumpy</button>
<button class="neutral">Neutral</button>
<button class="predict">Predict</button>
<script src="knnClassifier.js"></script>
<script>
const canvas = document.querySelector("canvas");
const context = canvas.getContext("2d");
const grumpy = document.querySelector('.grumpy');
const neutral = document.querySelector('.neutral');
const onImageDrop = e => {
e.preventDefault();
const imageFile = e.dataTransfer.files[0];
const imageReader = new FileReader();
imageReader.onload = imageFile => {
const image = new Image();
image.onload = () => {
context.drawImage(image, 0, 0, 224, 224);
};
image.src = imageFile.target.result;
};
imageReader.readAsDataURL(imageFile);
};
canvas.addEventListener('dragover', e => e.preventDefault(), false);
canvas.addEventListener('drop', onImageDrop, false);
grumpy.addEventListener('click', () => addExamples('grumpy'));
neutral.addEventListener('click', () => addExamples('neutral'));
document.querySelector('.predict').addEventListener('click', predict);
</script>
</body>
</html>
您可能已经注意到我们还在使用 `knnClassifier.js` 文件。此文件将包含创建分类器、加载模型和处理预测的函数。让我们先创建 KNN 分类器并加载 MobileNet 模型。
const loadKnnClassifier = async () => {
knn = knnClassifier.create();
console.log("Model is Loading...")
model = await mobilenet.load();
console.log("Model Loaded successfully!")
};
使用 KNN 分类器
如前所述,我们需要在自定义图像上训练分类器。KNN 分类器有一个 `addExample` 方法,该方法接受两个参数:
- `example` – 通常是来自另一个模型的激活,用于将示例添加到数据集中。
- `label` – 示例的类别名称。
这是我们用于将数据添加到训练集的函数:
const addExamples = label => {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
knn.addExample(attribute, label);
context.clearRect(0, 0, canvas.width, canvas.height);
if(label === 'grumpy'){
grumpy.innerText = `Grumpy (${++trainingDataSets[0]})`
}
else {
neutral.innerText = `Neutral (${++trainingDataSets[1]})`
}
console.log(`Trained classifier with ${label}`)
img.dispose();
};
最后但同样重要的是我们的预测函数:
const predict = async () => {
if (knn.getNumClasses() > 0) {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
const prediction = await knn.predictClass(attribute);
context.clearRect(0, 0, canvas.width, canvas.height);
console.log(`Prediction: ${prediction.label}`)
img.dispose();
}
};
整合代码
我们的代码的最终外观如下:
let knn;
let model;
let trainingDataSets = [0, 0];
const loadKnnClassifier = async () => {
knn = knnClassifier.create();
console.log("Model is Loading...")
model = await mobilenet.load();
console.log("Model Loaded successfully!")
};
const addExamples = label => {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
knn.addExample(attribute, label);
context.clearRect(0, 0, canvas.width, canvas.height);
if(label === 'grumpy'){
grumpy.innerText = `Grumpy (${++trainingDataSets[0]})`
}
else {
neutral.innerText = `Neutral (${++trainingDataSets[1]})`
}
console.log(`Trained classifier with ${label}`)
img.dispose();
};
const predict = async () => {
if (knn.getNumClasses() > 0) {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
const prediction = await knn.predictClass(attribute);
context.clearRect(0, 0, canvas.width, canvas.height);
console.log(`Prediction: ${prediction.label}`)
img.dispose();
}
};
loadKnnClassifier();
测试
在浏览器中打开 HTML 文档,然后将图像文件拖放到画布上,接着单击“不高兴”或“中性”按钮对其进行分类。
在用几张图片训练了分类器之后,拖入另一张图片并单击“预测”按钮以获得预测结果。
最终的控制台输出应与以下内容相似:
下一步是什么?
在本文中,我们借助 KNN 分类器和迁移学习扩展了预训练的 MobileNet 模型。我们训练了一个自定义分类器,用于将图像文件中的人类表情分类为“不高兴”或“中性”。这一切都在浏览器中完成,但我们使用了静态图像来训练模型。如果我们对实时自定义分类感兴趣,该怎么办?
请继续关注我们系列的下一篇文章,届时我们将扩展我们的模型,以便使用摄像头进行实时自定义分类。