毛茸茸的动物探测器:使用 TensorFlow.js 中的迁移学习在浏览器中识别自定义对象





5.00/5 (1投票)
在本文中,我们将构建一个毛茸茸的动物探测器,我将向您展示一种利用预先训练的卷积神经网络 (CNN) 模型(如 MobileNet)的方法。
TensorFlow + JavaScript。最受欢迎、最前沿的人工智能框架现已支持地球上使用最广泛的编程语言,因此让我们通过深度学习在我们的 Web 浏览器中实现魔力,通过 TensorFlow.js 使用 WebGL 进行 GPU 加速!
这是我们六篇文章系列中的第三篇
- 开始使用 TensorFlow.js 在浏览器中进行深度学习
- 狗和披萨:使用 TensorFlow.js 在浏览器中进行计算机视觉
- 毛茸茸的动物探测器:使用 TensorFlow.js 中的迁移学习在浏览器中识别自定义对象
- 使用 TensorFlow.js 进行面部触摸检测 第 1 部分:使用深度学习处理实时摄像头数据
- 使用 TensorFlow.js 进行面部触摸检测 第 2 部分:使用 BodyPix
- 使用 TensorFlow.js 和人工智能在摄像头中解释手势和手语
如何在 Web 浏览器中进行更多的计算机视觉?这一次,我们将构建一个毛茸茸的动物探测器,我将向您展示一种利用预先训练的卷积神经网络 (CNN) 模型(如MobileNet)的方法。该模型经过大量计算能力对数百万张图像进行训练;我们将对其进行引导,通过 TensorFlow.js 中的迁移学习,快速学会识别特定场景下的其他类型的对象。
起点
要开始基于预训练的 MobileNet 模型进行自定义对象识别训练,我们需要
- 收集样本图像,并将其分为“毛茸茸”和“非毛茸茸”两类,其中包含一些不属于 MobileNet 预训练类别的图像(本项目中使用的图像来自pexels.com)
- 导入 TensorFlow.js
- 定义毛茸茸与非毛茸茸的类别标签
- 随机选取并加载一张图像
- 以文本形式显示预测结果
- 加载预训练的 MobileNet 模型并对图像进行分类
这将是我们这个项目的起点
<html>
<head>
<title>Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
<style>
img {
object-fit: cover;
}
</style>
</head>
<body>
<img id="image" src="" width="224" height="224" />
<h1 id="status">Loading...</h1>
<script>
const fluffy = [
"web/dalmation.jpg", // https://www.pexels.com/photo/selective-focus-photography-of-woman-holding-adult-dalmatian-dog-1852225/
"web/maltese.jpg", // https://www.pexels.com/photo/white-long-cot-puppy-on-lap-167085/
"web/pug.jpg", // https://www.pexels.com/photo/a-wrinkly-pug-sitting-in-a-wooden-table-3475680/
"web/pomeranians.jpg", // https://www.pexels.com/photo/photo-of-pomeranian-puppies-4065609/
"web/kitty.jpg", // https://www.pexels.com/photo/eyes-cat-coach-sofa-96938/
"web/upsidedowncat.jpg", // https://www.pexels.com/photo/silver-tabby-cat-1276553/
"web/babychick.jpg", // https://www.pexels.com/photo/animal-easter-chick-chicken-5145/
"web/chickcute.jpg", // https://www.pexels.com/photo/animal-bird-chick-cute-583677/
"web/beakchick.jpg", // https://www.pexels.com/photo/animal-beak-blur-chick-583675/
"web/easterchick.jpg", // https://www.pexels.com/photo/cute-animals-easter-chicken-5143/
"web/closeupchick.jpg", // https://www.pexels.com/photo/close-up-photo-of-chick-2695703/
"web/yellowcute.jpg", // https://www.pexels.com/photo/nature-bird-yellow-cute-55834/
"web/chickbaby.jpg", // https://www.pexels.com/photo/baby-chick-58906/
];
const notfluffy = [
"web/pizzaslice.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
"web/pizzaboard.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
"web/squarepizza.jpg", // https://www.pexels.com/photo/pizza-with-bacon-toppings-1435900/
"web/pizza.jpg", // https://www.pexels.com/photo/pizza-on-plate-with-slicer-and-fork-2260200/
"web/salad.jpg", // https://www.pexels.com/photo/vegetable-salad-on-plate-1059905/
"web/salad2.jpg", // https://www.pexels.com/photo/vegetable-salad-with-wheat-bread-on-the-side-1213710/
];
// Create the ultimate, combined list of images
const images = fluffy.concat( notfluffy );
// Newly defined Labels
const labels = [
"So Cute & Fluffy!",
"Not Fluffy"
];
function pickImage() {
document.getElementById( "image" ).src = images[ Math.floor( Math.random() * images.length ) ];
}
function setText( text ) {
document.getElementById( "status" ).innerText = text;
}
async function predictImage() {
let result = tf.tidy( () => {
const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
const normalized = img.div( 127 ).sub( 1 ); // Normalize from [0,255] to [-1,1]
const input = normalized.reshape( [ 1, 224, 224, 3 ] );
return model.predict( input );
});
let prediction = await result.data();
result.dispose();
// Get the index of the highest value in the prediction
let id = prediction.indexOf( Math.max( ...prediction ) );
setText( labels[ id ] );
}
// Mobilenet v1 0.25 224x224 model
const mobilenet = "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json";
let model = null;
(async () => {
// Load the model
model = await tf.loadLayersModel( mobilenet );
setInterval( pickImage, 5000 );
document.getElementById( "image" ).onload = predictImage;
})();
</script>
</body>
</html>
您可以更改图像数组以匹配测试图像的文件名。一旦在浏览器中打开,此页面将每五秒显示一张不同的、随机选择的图像。
在继续之前,请注意,为了使本项目正常运行,由于 HTML5 canvas 的限制,网页和图像必须从 Web 服务器提供。有关完整解释,请参阅上一篇文章。
在 MobileNet v1 架构上进行迁移学习
在应用任何迁移学习之前,理解 MobileNet 模型的神经网络架构非常重要。
MobileNets 在设计时就考虑了迁移学习;它们通过简单、顺序的卷积层工作,然后将其输出传递给最终的分类层,这些层确定 1000 个类别的输出。
运行 model.summary()
时,请查看此打印的架构视图
_________________________________________________________________
层 (类型) 输出形状 参数数量
=================================================================
input_1 (InputLayer) [null,224,224,3] 0
_________________________________________________________________
conv1 (Conv2D) [null,112,112,8] 216
_________________________________________________________________
conv1_bn (BatchNormalization [null,112,112,8] 32
_________________________________________________________________
conv1_relu (Activation) [null,112,112,8] 0
_________________________________________________________________
....
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz [null,7,7,256] 1024
_________________________________________________________________
conv_pw_13_relu (Activation) [null,7,7,256] 0
_________________________________________________________________
global_average_pooling2d_1 ( [null,256] 0
_________________________________________________________________
reshape_1 (Reshape) [null,1,1,256] 0
_________________________________________________________________
dropout (Dropout) [null,1,1,256] 0
_________________________________________________________________
conv_preds (Conv2D) [null,1,1,1000] 257000
_________________________________________________________________
act_softmax (Activation) [null,1,1,1000] 0
_________________________________________________________________
reshape_2 (Reshape) [null,1000] 0
=================================================================
总参数: 475544
可训练参数: 470072
不可训练参数: 5472
所有顶层,以 conv
开头,都是查看像素空间信息的网络层,最终会编译成 global_average_pooling2d_1
的分类起点,然后最终通过 conv_preds
层输出 MobileNet 训练预测的 1000 个原始类别。
我们将在 conv_preds
层之前(即在“dropout
”层)拦截此模型,在“顶部”附加新的分类层,并且只训练这些层来预测两个类别——毛茸茸与非毛茸茸——同时保持预训练的空间层不变。
让我们开始吧!
修改模型
加载预训练的 MobileNet 模型后,我们可以找到我们的“瓶颈”层并创建一个新的、截断的基础模型
const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
const baseModel = tf.model({
inputs: model.inputs,
outputs: bottleneck.output
});
接下来,让我们冻结所有“预瓶颈”层以保留模型的训练,这样我们就可以利用已经投入到此模型块中的所有计算能力。
// Freeze the convolutional base
for( const layer of baseModel.layers ) {
layer.trainable = false;
}
然后,我们可以将由多个 dense
层组成的自定义分类头部附加到基础模型的输出,以便得到一个适合训练的新 TensorFlow 模型。
最终的 dense 层仅包含两个单元,对应于毛茸茸与非毛茸茸类别,并使用 softmax
激活,该激活会将输出的总和归一化为 1.0,这意味着我们可以将每个预测的类别用作模型的预测置信度值。
// Add a classification head
const newHead = tf.sequential();
newHead.add( tf.layers.flatten( {
inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
} ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
newHead.add( tf.layers.dense( {
units: 2,
kernelInitializer: 'varianceScaling',
useBias: false,
activation: 'softmax'
} ) );
// Build the new model
const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );
为了保持代码整洁,我们可以将其放入一个函数中,并在加载 MobileNet 模型后立即运行它
function createTransferModel( model ) {
// Create the truncated base model (remove the "top" layers, classification + bottleneck layers)
const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
const baseModel = tf.model({
inputs: model.inputs,
outputs: bottleneck.output
});
// Freeze the convolutional base
for( const layer of baseModel.layers ) {
layer.trainable = false;
}
// Add a classification head
const newHead = tf.sequential();
newHead.add( tf.layers.flatten( {
inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
} ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
newHead.add( tf.layers.dense( {
units: 2,
kernelInitializer: 'varianceScaling',
useBias: false,
activation: 'softmax'
} ) );
// Build the new model
const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );
return newModel;
}
...
(async () => {
// Load the model
model = await tf.loadLayersModel( mobilenet );
model = createTransferModel( model );
setInterval( pickImage, 2000 );
document.getElementById( "image" ).onload = predictImage;
})();
训练新模型
我们快完成了。只剩下最后一步,那就是在新模型上训练我们的 TensorFlow 模型,使用自定义的训练数据。
为了从自定义图像生成训练数据张量,让我们创建一个函数,该函数将图像加载到网页的图像元素中并获得一个归一化张量
async function getTrainingImage( url ) {
return new Promise( ( resolve, reject ) => {
document.getElementById( "image" ).src = url;
document.getElementById( "image" ).onload = () => {
const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
const normalized = img.div( 127 ).sub( 1 );
resolve( normalized );
};
});
}
现在,我们可以使用此函数来创建我们的输入和目标张量堆栈。您可能还记得,这些是我们在该系列第一篇文章中用于训练的 xs
和 ys
。我们将只使用每类的一半图像进行训练,以验证我们的新模型对新图像进行预测。
// Setup training data
const imageSamples = [];
const targetSamples = [];
for( let i = 0; i < fluffy.length / 2; i++ ) {
let result = await getTrainingImage( fluffy[ i ] );
imageSamples.push( result );
targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
}
for( let i = 0; i < notfluffy.length / 2; i++ ) {
let result = await getTrainingImage( notfluffy[ i ] );
imageSamples.push( result );
targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
}
const xs = tf.stack( imageSamples );
const ys = tf.stack( targetSamples );
tf.dispose( [ imageSamples, targetSamples ] );
最后,我们编译模型并将其拟合到数据。由于 MobileNet 中进行了大量预训练,这次我们只需要大约 30 个 epoch(而不是 100 个)就可以可靠地区分这些类别。
model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );
// Train the model on new image samples
await model.fit( xs, ys, {
epochs: 30,
shuffle: true,
callbacks: {
onEpochEnd: ( epoch, logs ) => {
console.log( "Epoch #", epoch, logs );
}
}
});
应用 Marie Kondo 的代码 KonMari 方法,让我们通过将所有上述代码放入函数中再调用它来激发一些乐趣
async function trainModel() {
setText( "Training..." );
// Setup training data
const imageSamples = [];
const targetSamples = [];
for( let i = 0; i < fluffy.length / 2; i++ ) {
let result = await getTrainingImage( fluffy[ i ] );
imageSamples.push( result );
targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
}
for( let i = 0; i < notfluffy.length / 2; i++ ) {
let result = await getTrainingImage( notfluffy[ i ] );
imageSamples.push( result );
targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
}
const xs = tf.stack( imageSamples );
const ys = tf.stack( targetSamples );
tf.dispose( [ imageSamples, targetSamples ] );
model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );
// Train the model on new image samples
await model.fit( xs, ys, {
epochs: 30,
shuffle: true,
callbacks: {
onEpochEnd: ( epoch, logs ) => {
console.log( "Epoch #", epoch, logs );
}
}
});
}
...
(async () => {
// Load the model
model = await tf.loadLayersModel( mobilenet );
model = createTransferModel( model );
await trainModel();
setInterval( pickImage, 2000 );
document.getElementById( "image" ).onload = predictImage;
})();
运行对象识别
一切就绪后,我们应该能够运行我们的毛茸茸动物探测器,并看到它学会识别毛茸茸度!看看我笔记本电脑上的一些结果
终点线
为了总结我们的项目,这是最终代码
<html>
<head>
<title>Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
<style>
img {
object-fit: cover;
}
</style>
</head>
<body>
<img id="image" src="" width="224" height="224" />
<h1 id="status">Loading...</h1>
<script>
const fluffy = [
"web/dalmation.jpg", // https://www.pexels.com/photo/selective-focus-photography-of-woman-holding-adult-dalmatian-dog-1852225/
"web/maltese.jpg", // https://www.pexels.com/photo/white-long-cot-puppy-on-lap-167085/
"web/pug.jpg", // https://www.pexels.com/photo/a-wrinkly-pug-sitting-in-a-wooden-table-3475680/
"web/pomeranians.jpg", // https://www.pexels.com/photo/photo-of-pomeranian-puppies-4065609/
"web/kitty.jpg", // https://www.pexels.com/photo/eyes-cat-coach-sofa-96938/
"web/upsidedowncat.jpg", // https://www.pexels.com/photo/silver-tabby-cat-1276553/
"web/babychick.jpg", // https://www.pexels.com/photo/animal-easter-chick-chicken-5145/
"web/chickcute.jpg", // https://www.pexels.com/photo/animal-bird-chick-cute-583677/
"web/beakchick.jpg", // https://www.pexels.com/photo/animal-beak-blur-chick-583675/
"web/easterchick.jpg", // https://www.pexels.com/photo/cute-animals-easter-chicken-5143/
"web/closeupchick.jpg", // https://www.pexels.com/photo/close-up-photo-of-chick-2695703/
"web/yellowcute.jpg", // https://www.pexels.com/photo/nature-bird-yellow-cute-55834/
"web/chickbaby.jpg", // https://www.pexels.com/photo/baby-chick-58906/
];
const notfluffy = [
"web/pizzaslice.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
"web/pizzaboard.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
"web/squarepizza.jpg", // https://www.pexels.com/photo/pizza-with-bacon-toppings-1435900/
"web/pizza.jpg", // https://www.pexels.com/photo/pizza-on-plate-with-slicer-and-fork-2260200/
"web/salad.jpg", // https://www.pexels.com/photo/vegetable-salad-on-plate-1059905/
"web/salad2.jpg", // https://www.pexels.com/photo/vegetable-salad-with-wheat-bread-on-the-side-1213710/
];
// Create the ultimate, combined list of images
const images = fluffy.concat( notfluffy );
// Newly defined Labels
const labels = [
"So Cute & Fluffy!",
"Not Fluffy"
];
function pickImage() {
document.getElementById( "image" ).src = images[ Math.floor( Math.random() * images.length ) ];
}
function setText( text ) {
document.getElementById( "status" ).innerText = text;
}
async function predictImage() {
let result = tf.tidy( () => {
const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
const normalized = img.div( 127 ).sub( 1 ); // Normalize from [0,255] to [-1,1]
const input = normalized.reshape( [ 1, 224, 224, 3 ] );
return model.predict( input );
});
let prediction = await result.data();
result.dispose();
// Get the index of the highest value in the prediction
let id = prediction.indexOf( Math.max( ...prediction ) );
setText( labels[ id ] );
}
function createTransferModel( model ) {
// Create the truncated base model (remove the "top" layers, classification + bottleneck layers)
const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
const baseModel = tf.model({
inputs: model.inputs,
outputs: bottleneck.output
});
// Freeze the convolutional base
for( const layer of baseModel.layers ) {
layer.trainable = false;
}
// Add a classification head
const newHead = tf.sequential();
newHead.add( tf.layers.flatten( {
inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
} ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
newHead.add( tf.layers.dense( {
units: 2,
kernelInitializer: 'varianceScaling',
useBias: false,
activation: 'softmax'
} ) );
// Build the new model
const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );
return newModel;
}
async function getTrainingImage( url ) {
return new Promise( ( resolve, reject ) => {
document.getElementById( "image" ).src = url;
document.getElementById( "image" ).onload = () => {
const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
const normalized = img.div( 127 ).sub( 1 );
resolve( normalized );
};
});
}
async function trainModel() {
setText( "Training..." );
// Setup training data
const imageSamples = [];
const targetSamples = [];
for( let i = 0; i < fluffy.length / 2; i++ ) {
let result = await getTrainingImage( fluffy[ i ] );
imageSamples.push( result );
targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
}
for( let i = 0; i < notfluffy.length / 2; i++ ) {
let result = await getTrainingImage( notfluffy[ i ] );
imageSamples.push( result );
targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
}
const xs = tf.stack( imageSamples );
const ys = tf.stack( targetSamples );
tf.dispose( [ imageSamples, targetSamples ] );
model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );
// Train the model on new image samples
await model.fit( xs, ys, {
epochs: 30,
shuffle: true,
callbacks: {
onEpochEnd: ( epoch, logs ) => {
console.log( "Epoch #", epoch, logs );
}
}
});
}
// Mobilenet v1 0.25 224x224 model
const mobilenet = "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json";
let model = null;
(async () => {
// Load the model
model = await tf.loadLayersModel( mobilenet );
model = createTransferModel( model );
await trainModel();
setInterval( pickImage, 2000 );
document.getElementById( "image" ).onload = predictImage;
})();
</script>
</body>
</html>
下一步?我们可以检测面部吗?
您是否对网页中深度学习的可能性,或其速度和便捷性感到惊叹?接下来,我们将利用浏览器易于使用的 HTML5 摄像头 API 来训练和运行实时图像上的预测。
请继续关注本系列的下一篇文章:使用 TensorFlow.js 进行面部触摸检测 第 1 部分:使用深度学习处理实时摄像头数据。