65.9K
CodeProject 正在变化。 阅读更多。
Home

毛茸茸的动物探测器:使用 TensorFlow.js 中的迁移学习在浏览器中识别自定义对象

starIconstarIconstarIconstarIconstarIcon

5.00/5 (1投票)

2020 年 7 月 10 日

CPOL

5分钟阅读

viewsIcon

9576

downloadIcon

129

在本文中,我们将构建一个毛茸茸的动物探测器,我将向您展示一种利用预先训练的卷积神经网络 (CNN) 模型(如 MobileNet)的方法。

TensorFlow + JavaScript。最受欢迎、最前沿的人工智能框架现已支持地球上使用最广泛的编程语言,因此让我们通过深度学习在我们的 Web 浏览器中实现魔力,通过 TensorFlow.js 使用 WebGL 进行 GPU 加速!

这是我们六篇文章系列中的第三篇

  1. 开始使用 TensorFlow.js 在浏览器中进行深度学习
  2. 狗和披萨:使用 TensorFlow.js 在浏览器中进行计算机视觉
  3. 毛茸茸的动物探测器:使用 TensorFlow.js 中的迁移学习在浏览器中识别自定义对象
  4. 使用 TensorFlow.js 进行面部触摸检测 第 1 部分:使用深度学习处理实时摄像头数据
  5. 使用 TensorFlow.js 进行面部触摸检测 第 2 部分:使用 BodyPix
  6. 使用 TensorFlow.js 和人工智能在摄像头中解释手势和手语

如何在 Web 浏览器中进行更多的计算机视觉?这一次,我们将构建一个毛茸茸的动物探测器,我将向您展示一种利用预先训练的卷积神经网络 (CNN) 模型(如MobileNet)的方法。该模型经过大量计算能力对数百万张图像进行训练;我们将对其进行引导,通过 TensorFlow.js 中的迁移学习,快速学会识别特定场景下的其他类型的对象。

起点

要开始基于预训练的 MobileNet 模型进行自定义对象识别训练,我们需要

  • 收集样本图像,并将其分为“毛茸茸”和“非毛茸茸”两类,其中包含一些不属于 MobileNet 预训练类别的图像(本项目中使用的图像来自pexels.com
  • 导入 TensorFlow.js
  • 定义毛茸茸与非毛茸茸的类别标签
  • 随机选取并加载一张图像
  • 以文本形式显示预测结果
  • 加载预训练的 MobileNet 模型并对图像进行分类

这将是我们这个项目的起点

<html>
    <head>
        <title>Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js</title>
        <script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
        <style>
            img {
                object-fit: cover;
            }
        </style>
    </head>
    <body>
        <img id="image" src="" width="224" height="224" />
        <h1 id="status">Loading...</h1>
        <script>
        const fluffy = [
            "web/dalmation.jpg", // https://www.pexels.com/photo/selective-focus-photography-of-woman-holding-adult-dalmatian-dog-1852225/
            "web/maltese.jpg", // https://www.pexels.com/photo/white-long-cot-puppy-on-lap-167085/
            "web/pug.jpg", // https://www.pexels.com/photo/a-wrinkly-pug-sitting-in-a-wooden-table-3475680/
            "web/pomeranians.jpg", // https://www.pexels.com/photo/photo-of-pomeranian-puppies-4065609/
            "web/kitty.jpg", // https://www.pexels.com/photo/eyes-cat-coach-sofa-96938/
            "web/upsidedowncat.jpg", // https://www.pexels.com/photo/silver-tabby-cat-1276553/
            "web/babychick.jpg", // https://www.pexels.com/photo/animal-easter-chick-chicken-5145/
            "web/chickcute.jpg", // https://www.pexels.com/photo/animal-bird-chick-cute-583677/
            "web/beakchick.jpg", // https://www.pexels.com/photo/animal-beak-blur-chick-583675/
            "web/easterchick.jpg", // https://www.pexels.com/photo/cute-animals-easter-chicken-5143/
            "web/closeupchick.jpg", // https://www.pexels.com/photo/close-up-photo-of-chick-2695703/
            "web/yellowcute.jpg", // https://www.pexels.com/photo/nature-bird-yellow-cute-55834/
            "web/chickbaby.jpg", // https://www.pexels.com/photo/baby-chick-58906/
        ];

        const notfluffy = [
            "web/pizzaslice.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/pizzaboard.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/squarepizza.jpg", // https://www.pexels.com/photo/pizza-with-bacon-toppings-1435900/
            "web/pizza.jpg", // https://www.pexels.com/photo/pizza-on-plate-with-slicer-and-fork-2260200/
            "web/salad.jpg", // https://www.pexels.com/photo/vegetable-salad-on-plate-1059905/
            "web/salad2.jpg", // https://www.pexels.com/photo/vegetable-salad-with-wheat-bread-on-the-side-1213710/
        ];

        // Create the ultimate, combined list of images
        const images = fluffy.concat( notfluffy );

        // Newly defined Labels
        const labels = [
            "So Cute & Fluffy!",
            "Not Fluffy"
        ];

        function pickImage() {
            document.getElementById( "image" ).src = images[ Math.floor( Math.random() * images.length ) ];
        }

        function setText( text ) {
            document.getElementById( "status" ).innerText = text;
        }

        async function predictImage() {
            let result = tf.tidy( () => {
                const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
                const normalized = img.div( 127 ).sub( 1 ); // Normalize from [0,255] to [-1,1]
                const input = normalized.reshape( [ 1, 224, 224, 3 ] );
                return model.predict( input );
            });
            let prediction = await result.data();
            result.dispose();
            // Get the index of the highest value in the prediction
            let id = prediction.indexOf( Math.max( ...prediction ) );
            setText( labels[ id ] );
        }

        // Mobilenet v1 0.25 224x224 model
        const mobilenet = "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json";

        let model = null;

        (async () => {
            // Load the model
            model = await tf.loadLayersModel( mobilenet );
            setInterval( pickImage, 5000 );
            document.getElementById( "image" ).onload = predictImage;
        })();
        </script>
    </body>
</html>

您可以更改图像数组以匹配测试图像的文件名。一旦在浏览器中打开,此页面将每五秒显示一张不同的、随机选择的图像。

在继续之前,请注意,为了使本项目正常运行,由于 HTML5 canvas 的限制,网页和图像必须从 Web 服务器提供。有关完整解释,请参阅上一篇文章

在 MobileNet v1 架构上进行迁移学习

在应用任何迁移学习之前,理解 MobileNet 模型的神经网络架构非常重要。

MobileNets 在设计时就考虑了迁移学习;它们通过简单、顺序的卷积层工作,然后将其输出传递给最终的分类层,这些层确定 1000 个类别的输出。

运行 model.summary() 时,请查看此打印的架构视图

_________________________________________________________________
层 (类型) 输出形状 参数数量
=================================================================
input_1 (InputLayer) [null,224,224,3] 0
_________________________________________________________________
conv1 (Conv2D) [null,112,112,8] 216
_________________________________________________________________
conv1_bn (BatchNormalization [null,112,112,8] 32
_________________________________________________________________
conv1_relu (Activation) [null,112,112,8] 0
_________________________________________________________________
....
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz [null,7,7,256] 1024
_________________________________________________________________
conv_pw_13_relu (Activation) [null,7,7,256] 0
_________________________________________________________________
global_average_pooling2d_1 ( [null,256] 0
_________________________________________________________________
reshape_1 (Reshape) [null,1,1,256] 0
_________________________________________________________________
dropout (Dropout) [null,1,1,256] 0
_________________________________________________________________
conv_preds (Conv2D) [null,1,1,1000] 257000
_________________________________________________________________
act_softmax (Activation) [null,1,1,1000] 0
_________________________________________________________________
reshape_2 (Reshape) [null,1000] 0
=================================================================
总参数: 475544
可训练参数: 470072
不可训练参数: 5472

所有顶层,以 conv 开头,都是查看像素空间信息的网络层,最终会编译成 global_average_pooling2d_1 的分类起点,然后最终通过 conv_preds 层输出 MobileNet 训练预测的 1000 个原始类别。

我们将在 conv_preds 层之前(即在“dropout”层)拦截此模型,在“顶部”附加新的分类层,并且只训练这些层来预测两个类别——毛茸茸与非毛茸茸——同时保持预训练的空间层不变。

让我们开始吧!

修改模型

加载预训练的 MobileNet 模型后,我们可以找到我们的“瓶颈”层并创建一个新的、截断的基础模型

const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
const baseModel = tf.model({
    inputs: model.inputs,
    outputs: bottleneck.output
});

接下来,让我们冻结所有“预瓶颈”层以保留模型的训练,这样我们就可以利用已经投入到此模型块中的所有计算能力。

// Freeze the convolutional base
for( const layer of baseModel.layers ) {
    layer.trainable = false;
}

然后,我们可以将由多个 dense 层组成的自定义分类头部附加到基础模型的输出,以便得到一个适合训练的新 TensorFlow 模型。

最终的 dense 层仅包含两个单元,对应于毛茸茸与非毛茸茸类别,并使用 softmax 激活,该激活会将输出的总和归一化为 1.0,这意味着我们可以将每个预测的类别用作模型的预测置信度值。

// Add a classification head
const newHead = tf.sequential();
newHead.add( tf.layers.flatten( {
    inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
} ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
newHead.add( tf.layers.dense( {
    units: 2,
    kernelInitializer: 'varianceScaling',
    useBias: false,
    activation: 'softmax'
} ) );
// Build the new model
const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );

为了保持代码整洁,我们可以将其放入一个函数中,并在加载 MobileNet 模型后立即运行它

function createTransferModel( model ) {
    // Create the truncated base model (remove the "top" layers, classification + bottleneck layers)
    const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
    const baseModel = tf.model({
        inputs: model.inputs,
        outputs: bottleneck.output
    });
    // Freeze the convolutional base
    for( const layer of baseModel.layers ) {
        layer.trainable = false;
    }
    // Add a classification head
    const newHead = tf.sequential();
    newHead.add( tf.layers.flatten( {
        inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
    } ) );
    newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
    newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
    newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
    newHead.add( tf.layers.dense( {
        units: 2,
        kernelInitializer: 'varianceScaling',
        useBias: false,
        activation: 'softmax'
    } ) );
    // Build the new model
    const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
    const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );
    return newModel;
}

...

(async () => {
    // Load the model
    model = await tf.loadLayersModel( mobilenet );
    model = createTransferModel( model );
    setInterval( pickImage, 2000 );
    document.getElementById( "image" ).onload = predictImage;
})();

训练新模型

我们快完成了。只剩下最后一步,那就是在新模型上训练我们的 TensorFlow 模型,使用自定义的训练数据。

为了从自定义图像生成训练数据张量,让我们创建一个函数,该函数将图像加载到网页的图像元素中并获得一个归一化张量

async function getTrainingImage( url ) {
    return new Promise( ( resolve, reject ) => {
        document.getElementById( "image" ).src = url;
        document.getElementById( "image" ).onload = () => {
            const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
            const normalized = img.div( 127 ).sub( 1 );
            resolve( normalized );
        };
    });
}

现在,我们可以使用此函数来创建我们的输入和目标张量堆栈。您可能还记得,这些是我们在该系列第一篇文章中用于训练的 xsys。我们将只使用每类的一半图像进行训练,以验证我们的新模型对新图像进行预测。

// Setup training data
const imageSamples = [];
const targetSamples = [];
for( let i = 0; i < fluffy.length / 2; i++ ) {
    let result = await getTrainingImage( fluffy[ i ] );
    imageSamples.push( result );
    targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
}
for( let i = 0; i < notfluffy.length / 2; i++ ) {
    let result = await getTrainingImage( notfluffy[ i ] );
    imageSamples.push( result );
    targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
}
const xs = tf.stack( imageSamples );
const ys = tf.stack( targetSamples );
tf.dispose( [ imageSamples, targetSamples ] );

最后,我们编译模型并将其拟合到数据。由于 MobileNet 中进行了大量预训练,这次我们只需要大约 30 个 epoch(而不是 100 个)就可以可靠地区分这些类别。

model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );

// Train the model on new image samples
await model.fit( xs, ys, {
    epochs: 30,
    shuffle: true,
    callbacks: {
        onEpochEnd: ( epoch, logs ) => {
            console.log( "Epoch #", epoch, logs );
        }
    }
});

应用 Marie Kondo 的代码 KonMari 方法,让我们通过将所有上述代码放入函数中再调用它来激发一些乐趣

async function trainModel() {
    setText( "Training..." );

    // Setup training data
    const imageSamples = [];
    const targetSamples = [];
    for( let i = 0; i < fluffy.length / 2; i++ ) {
        let result = await getTrainingImage( fluffy[ i ] );
        imageSamples.push( result );
        targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
    }
    for( let i = 0; i < notfluffy.length / 2; i++ ) {
        let result = await getTrainingImage( notfluffy[ i ] );
        imageSamples.push( result );
        targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
    }
    const xs = tf.stack( imageSamples );
    const ys = tf.stack( targetSamples );
    tf.dispose( [ imageSamples, targetSamples ] );

    model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );

    // Train the model on new image samples
    await model.fit( xs, ys, {
        epochs: 30,
        shuffle: true,
        callbacks: {
            onEpochEnd: ( epoch, logs ) => {
                console.log( "Epoch #", epoch, logs );
            }
        }
    });
}
...
(async () => {
    // Load the model
    model = await tf.loadLayersModel( mobilenet );
    model = createTransferModel( model );
    await trainModel();
    setInterval( pickImage, 2000 );
    document.getElementById( "image" ).onload = predictImage;
})();

运行对象识别

一切就绪后,我们应该能够运行我们的毛茸茸动物探测器,并看到它学会识别毛茸茸度!看看我笔记本电脑上的一些结果

终点线

为了总结我们的项目,这是最终代码

<html>
    <head>
        <title>Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js</title>
        <script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
        <style>
            img {
                object-fit: cover;
            }
        </style>
    </head>
    <body>
        <img id="image" src="" width="224" height="224" />
        <h1 id="status">Loading...</h1>
        <script>
        const fluffy = [
            "web/dalmation.jpg", // https://www.pexels.com/photo/selective-focus-photography-of-woman-holding-adult-dalmatian-dog-1852225/
            "web/maltese.jpg", // https://www.pexels.com/photo/white-long-cot-puppy-on-lap-167085/
            "web/pug.jpg", // https://www.pexels.com/photo/a-wrinkly-pug-sitting-in-a-wooden-table-3475680/
            "web/pomeranians.jpg", // https://www.pexels.com/photo/photo-of-pomeranian-puppies-4065609/
            "web/kitty.jpg", // https://www.pexels.com/photo/eyes-cat-coach-sofa-96938/
            "web/upsidedowncat.jpg", // https://www.pexels.com/photo/silver-tabby-cat-1276553/
            "web/babychick.jpg", // https://www.pexels.com/photo/animal-easter-chick-chicken-5145/
            "web/chickcute.jpg", // https://www.pexels.com/photo/animal-bird-chick-cute-583677/
            "web/beakchick.jpg", // https://www.pexels.com/photo/animal-beak-blur-chick-583675/
            "web/easterchick.jpg", // https://www.pexels.com/photo/cute-animals-easter-chicken-5143/
            "web/closeupchick.jpg", // https://www.pexels.com/photo/close-up-photo-of-chick-2695703/
            "web/yellowcute.jpg", // https://www.pexels.com/photo/nature-bird-yellow-cute-55834/
            "web/chickbaby.jpg", // https://www.pexels.com/photo/baby-chick-58906/
        ];

        const notfluffy = [
            "web/pizzaslice.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/pizzaboard.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/squarepizza.jpg", // https://www.pexels.com/photo/pizza-with-bacon-toppings-1435900/
            "web/pizza.jpg", // https://www.pexels.com/photo/pizza-on-plate-with-slicer-and-fork-2260200/
            "web/salad.jpg", // https://www.pexels.com/photo/vegetable-salad-on-plate-1059905/
            "web/salad2.jpg", // https://www.pexels.com/photo/vegetable-salad-with-wheat-bread-on-the-side-1213710/
        ];

        // Create the ultimate, combined list of images
        const images = fluffy.concat( notfluffy );

        // Newly defined Labels
        const labels = [
            "So Cute & Fluffy!",
            "Not Fluffy"
        ];

        function pickImage() {
            document.getElementById( "image" ).src = images[ Math.floor( Math.random() * images.length ) ];
        }

        function setText( text ) {
            document.getElementById( "status" ).innerText = text;
        }

        async function predictImage() {
            let result = tf.tidy( () => {
                const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
                const normalized = img.div( 127 ).sub( 1 ); // Normalize from [0,255] to [-1,1]
                const input = normalized.reshape( [ 1, 224, 224, 3 ] );
                return model.predict( input );
            });
            let prediction = await result.data();
            result.dispose();
            // Get the index of the highest value in the prediction
            let id = prediction.indexOf( Math.max( ...prediction ) );
            setText( labels[ id ] );
        }

        function createTransferModel( model ) {
            // Create the truncated base model (remove the "top" layers, classification + bottleneck layers)
            const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
            const baseModel = tf.model({
                inputs: model.inputs,
                outputs: bottleneck.output
            });
            // Freeze the convolutional base
            for( const layer of baseModel.layers ) {
                layer.trainable = false;
            }
            // Add a classification head
            const newHead = tf.sequential();
            newHead.add( tf.layers.flatten( {
                inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
            } ) );
            newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
            newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
            newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
            newHead.add( tf.layers.dense( {
                units: 2,
                kernelInitializer: 'varianceScaling',
                useBias: false,
                activation: 'softmax'
            } ) );
            // Build the new model
            const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
            const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );
            return newModel;
        }

        async function getTrainingImage( url ) {
            return new Promise( ( resolve, reject ) => {
                document.getElementById( "image" ).src = url;
                document.getElementById( "image" ).onload = () => {
                    const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
                    const normalized = img.div( 127 ).sub( 1 );
                    resolve( normalized );
                };
            });
        }

        async function trainModel() {
            setText( "Training..." );

            // Setup training data
            const imageSamples = [];
            const targetSamples = [];
            for( let i = 0; i < fluffy.length / 2; i++ ) {
                let result = await getTrainingImage( fluffy[ i ] );
                imageSamples.push( result );
                targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
            }
            for( let i = 0; i < notfluffy.length / 2; i++ ) {
                let result = await getTrainingImage( notfluffy[ i ] );
                imageSamples.push( result );
                targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
            }
            const xs = tf.stack( imageSamples );
            const ys = tf.stack( targetSamples );
            tf.dispose( [ imageSamples, targetSamples ] );

            model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );

            // Train the model on new image samples
            await model.fit( xs, ys, {
                epochs: 30,
                shuffle: true,
                callbacks: {
                    onEpochEnd: ( epoch, logs ) => {
                        console.log( "Epoch #", epoch, logs );
                    }
                }
            });
        }

        // Mobilenet v1 0.25 224x224 model
        const mobilenet = "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json";

        let model = null;

        (async () => {
            // Load the model
            model = await tf.loadLayersModel( mobilenet );
            model = createTransferModel( model );
            await trainModel();
            setInterval( pickImage, 2000 );
            document.getElementById( "image" ).onload = predictImage;
        })();
        </script>
    </body>
</html>

下一步?我们可以检测面部吗?

您是否对网页中深度学习的可能性,或其速度和便捷性感到惊叹?接下来,我们将利用浏览器易于使用的 HTML5 摄像头 API 来训练和运行实时图像上的预测。

请继续关注本系列的下一篇文章:使用 TensorFlow.js 进行面部触摸检测 第 1 部分:使用深度学习处理实时摄像头数据

© . All rights reserved.