65.9K
CodeProject 正在变化。 阅读更多。
Home

使用 Python + Keras 进行深度学习(第 3 章):ResNet

starIconstarIconstarIconstarIconstarIcon

5.00/5 (4投票s)

2018年6月19日

CPOL

3分钟阅读

viewsIcon

31295

这是介绍 Python 和 Keras 框架中深度学习编码的文章系列的第三篇文章。

引言

本文不会介绍深度学习。 假设您已经了解了深度学习的基础知识,并且对 Python 编码略知一二。 本文的主要目的是向您介绍 Keras 框架的基础知识,并将其与另一个已知库一起使用以进行快速实验并得出初步结论。

背景

本文展示了 ResNet 架构,该架构由 Microsoft 推出,并在 2015 年赢得了 ILSVRC (ImageNet 大规模视觉识别挑战赛)。 您可以在这里查看论文。

关键概念是增加层数,引入残差连接(使用恒等层)。 此层直接进入下一层,从而改善学习过程。

我们将进行与前几章相同的实验。 我将不会展示加载 CIFAR-100 数据集、设置实验和下载 python 库的部分。 所有这些都与前一章相同。

Using the Code

Keras 拥有此架构,但默认情况下,图像大小必须大于 187 像素,所以我们将定义一个更小的架构。

def CustomResNet50(include_top=True, input_tensor=None, 
                   input_shape=(32,32,3), pooling=None, classes=100):
    if input_tensor is None:
        img_input = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_tensor):
            img_input = Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor
    if K.image_data_format() == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    x = ZeroPadding2D(padding=(2, 2), name='conv1_pad')(img_input)
    
    x = resnet50.conv_block(x, 3, [32, 32, 64], stage=2, block='a')
    x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='b')
    x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='c')

    x = resnet50.conv_block(x, 3, [64, 64, 256], stage=3, block='a', strides=(1, 1))
    x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='b')
    x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='c')

    x = resnet50.conv_block(x, 3, [128, 128, 512], stage=4, block='a')
    x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='b')
    x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='c')
    x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='d')

    x = resnet50.conv_block(x, 3, [256, 256, 1024], stage=5, block='a')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='b')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='c')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='d')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='e')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='f')

    x = resnet50.conv_block(x, 3, [512, 512, 2048], stage=6, block='a')
    x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='b')
    x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='c')

    x = AveragePooling2D((1, 1), name='avg_pool')(x)

    if include_top:
        x = Flatten()(x)
        x = Dense(classes, activation='softmax', name='fc1000')(x)
    else:
        if pooling == 'avg':
            x = GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = GlobalMaxPooling2D()(x)

    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = get_source_inputs(input_tensor)
    else:
        inputs = img_input
    # Create model.
    model = Model(inputs, x, name='resnet50')

    return model

与之前的文章一样,我们使用相同的参数进行编译

def create_custom_resnet50():
  model = CustomResNet50(include_top=True, input_tensor=None, 
                         input_shape=(32,32,3), pooling=None, classes=100)
  
  return model
  
custom_resnet50_model = create_custom_resnet50()
custom_resnet50_model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['acc', 'mse'])

完成后,我们可以看到创建的模型的摘要。

custom_resnet50_model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 36, 36, 3)    0           input_1[0][0]                    
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 18, 18, 32)   128         conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 18, 18, 32)   0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_1[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 18, 18, 32)   0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_2[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 18, 18, 64)   256         conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 18, 18, 64)   256         res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, 18, 18, 64)   0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 18, 18, 64)   0           add_1[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 18, 18, 32)   2080        activation_3[0][0]               
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 18, 18, 32)   0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_4[0][0]               
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 18, 18, 32)   0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_5[0][0]               
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_2 (Add)                     (None, 18, 18, 64)   0           bn2b_branch2c[0][0]              
                                                                 activation_3[0][0]               
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 18, 18, 64)   0           add_2[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 18, 18, 32)   2080        activation_6[0][0]               
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 18, 18, 32)   0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_7[0][0]               
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 18, 18, 32)   0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_8[0][0]               
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_3 (Add)                     (None, 18, 18, 64)   0           bn2c_branch2c[0][0]              
                                                                 activation_6[0][0]               
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 18, 18, 64)   0           add_3[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 18, 18, 64)   4160        activation_9[0][0]               
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 18, 18, 64)   0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 18, 18, 64)   0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_11[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 18, 18, 256)  16640       activation_9[0][0]               
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 18, 18, 256)  1024        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_4 (Add)                     (None, 18, 18, 256)  0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 18, 18, 256)  0           add_4[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 18, 18, 64)   16448       activation_12[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 18, 18, 64)   0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_13[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 18, 18, 64)   0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_14[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, 18, 18, 256)  0           bn3b_branch2c[0][0]              
                                                                 activation_12[0][0]              
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 18, 18, 256)  0           add_5[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 18, 18, 64)   16448       activation_15[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 18, 18, 64)   0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_16[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 18, 18, 64)   0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_17[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, 18, 18, 256)  0           bn3c_branch2c[0][0]              
                                                                 activation_15[0][0]              
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 18, 18, 256)  0           add_6[0][0]                      
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 9, 9, 128)    32896       activation_18[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 9, 9, 128)    0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_19[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 9, 9, 128)    0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_20[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 9, 9, 512)    131584      activation_18[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 9, 9, 512)    2048        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_7 (Add)                     (None, 9, 9, 512)    0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 9, 9, 512)    0           add_7[0][0]                      
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_21[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 9, 9, 128)    0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_22[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 9, 9, 128)    0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_23[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_8 (Add)                     (None, 9, 9, 512)    0           bn4b_branch2c[0][0]              
                                                                 activation_21[0][0]              
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 9, 9, 512)    0           add_8[0][0]                      
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_24[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 9, 9, 128)    0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_25[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 9, 9, 128)    0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_26[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, 9, 9, 512)    0           bn4c_branch2c[0][0]              
                                                                 activation_24[0][0]              
__________________________________________________________________________________________________
activation_27 (Activation)      (None, 9, 9, 512)    0           add_9[0][0]                      
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_27[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_28 (Activation)      (None, 9, 9, 128)    0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_28[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_29 (Activation)      (None, 9, 9, 128)    0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_29[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, 9, 9, 512)    0           bn4d_branch2c[0][0]              
                                                                 activation_27[0][0]              
__________________________________________________________________________________________________
activation_30 (Activation)      (None, 9, 9, 512)    0           add_10[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 5, 5, 256)    131328      activation_30[0][0]              
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_31 (Activation)      (None, 5, 5, 256)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_31[0][0]              
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_32 (Activation)      (None, 5, 5, 256)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_32[0][0]              
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 5, 5, 1024)   525312      activation_30[0][0]              
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 5, 5, 1024)   4096        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_11 (Add)                    (None, 5, 5, 1024)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_33 (Activation)      (None, 5, 5, 1024)   0           add_11[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_33[0][0]              
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_34 (Activation)      (None, 5, 5, 256)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_34[0][0]              
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_35 (Activation)      (None, 5, 5, 256)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_35[0][0]              
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, 5, 5, 1024)   0           bn5b_branch2c[0][0]              
                                                                 activation_33[0][0]              
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 5, 5, 1024)   0           add_12[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_36[0][0]              
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_37 (Activation)      (None, 5, 5, 256)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_37[0][0]              
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_38 (Activation)      (None, 5, 5, 256)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_38[0][0]              
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_13 (Add)                    (None, 5, 5, 1024)   0           bn5c_branch2c[0][0]              
                                                                 activation_36[0][0]              
__________________________________________________________________________________________________
activation_39 (Activation)      (None, 5, 5, 1024)   0           add_13[0][0]                     
__________________________________________________________________________________________________
res5d_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_39[0][0]              
__________________________________________________________________________________________________
bn5d_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_40 (Activation)      (None, 5, 5, 256)    0           bn5d_branch2a[0][0]              
__________________________________________________________________________________________________
res5d_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_40[0][0]              
__________________________________________________________________________________________________
bn5d_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_41 (Activation)      (None, 5, 5, 256)    0           bn5d_branch2b[0][0]              
__________________________________________________________________________________________________
res5d_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_41[0][0]              
__________________________________________________________________________________________________
bn5d_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5d_branch2c[0][0]             
__________________________________________________________________________________________________
add_14 (Add)                    (None, 5, 5, 1024)   0           bn5d_branch2c[0][0]              
                                                                 activation_39[0][0]              
__________________________________________________________________________________________________
activation_42 (Activation)      (None, 5, 5, 1024)   0           add_14[0][0]                     
__________________________________________________________________________________________________
res5e_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_42[0][0]              
__________________________________________________________________________________________________
bn5e_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_43 (Activation)      (None, 5, 5, 256)    0           bn5e_branch2a[0][0]              
__________________________________________________________________________________________________
res5e_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_43[0][0]              
__________________________________________________________________________________________________
bn5e_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_44 (Activation)      (None, 5, 5, 256)    0           bn5e_branch2b[0][0]              
__________________________________________________________________________________________________
res5e_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_44[0][0]              
__________________________________________________________________________________________________
bn5e_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5e_branch2c[0][0]             
__________________________________________________________________________________________________
add_15 (Add)                    (None, 5, 5, 1024)   0           bn5e_branch2c[0][0]              
                                                                 activation_42[0][0]              
__________________________________________________________________________________________________
activation_45 (Activation)      (None, 5, 5, 1024)   0           add_15[0][0]                     
__________________________________________________________________________________________________
res5f_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_45[0][0]              
__________________________________________________________________________________________________
bn5f_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_46 (Activation)      (None, 5, 5, 256)    0           bn5f_branch2a[0][0]              
__________________________________________________________________________________________________
res5f_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_46[0][0]              
__________________________________________________________________________________________________
bn5f_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_47 (Activation)      (None, 5, 5, 256)    0           bn5f_branch2b[0][0]              
__________________________________________________________________________________________________
res5f_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_47[0][0]              
__________________________________________________________________________________________________
bn5f_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5f_branch2c[0][0]             
__________________________________________________________________________________________________
add_16 (Add)                    (None, 5, 5, 1024)   0           bn5f_branch2c[0][0]              
                                                                 activation_45[0][0]              
__________________________________________________________________________________________________
activation_48 (Activation)      (None, 5, 5, 1024)   0           add_16[0][0]                     
__________________________________________________________________________________________________
res6a_branch2a (Conv2D)         (None, 3, 3, 512)    524800      activation_48[0][0]              
__________________________________________________________________________________________________
bn6a_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 3, 3, 512)    0           bn6a_branch2a[0][0]              
__________________________________________________________________________________________________
res6a_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_49[0][0]              
__________________________________________________________________________________________________
bn6a_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_50 (Activation)      (None, 3, 3, 512)    0           bn6a_branch2b[0][0]              
__________________________________________________________________________________________________
res6a_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_50[0][0]              
__________________________________________________________________________________________________
res6a_branch1 (Conv2D)          (None, 3, 3, 2048)   2099200     activation_48[0][0]              
__________________________________________________________________________________________________
bn6a_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6a_branch2c[0][0]             
__________________________________________________________________________________________________
bn6a_branch1 (BatchNormalizatio (None, 3, 3, 2048)   8192        res6a_branch1[0][0]              
__________________________________________________________________________________________________
add_17 (Add)                    (None, 3, 3, 2048)   0           bn6a_branch2c[0][0]              
                                                                 bn6a_branch1[0][0]               
__________________________________________________________________________________________________
activation_51 (Activation)      (None, 3, 3, 2048)   0           add_17[0][0]                     
__________________________________________________________________________________________________
res6b_branch2a (Conv2D)         (None, 3, 3, 512)    1049088     activation_51[0][0]              
__________________________________________________________________________________________________
bn6b_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_52 (Activation)      (None, 3, 3, 512)    0           bn6b_branch2a[0][0]              
__________________________________________________________________________________________________
res6b_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_52[0][0]              
__________________________________________________________________________________________________
bn6b_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_53 (Activation)      (None, 3, 3, 512)    0           bn6b_branch2b[0][0]              
__________________________________________________________________________________________________
res6b_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_53[0][0]              
__________________________________________________________________________________________________
bn6b_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6b_branch2c[0][0]             
__________________________________________________________________________________________________
add_18 (Add)                    (None, 3, 3, 2048)   0           bn6b_branch2c[0][0]              
                                                                 activation_51[0][0]              
__________________________________________________________________________________________________
activation_54 (Activation)      (None, 3, 3, 2048)   0           add_18[0][0]                     
__________________________________________________________________________________________________
res6c_branch2a (Conv2D)         (None, 3, 3, 512)    1049088     activation_54[0][0]              
__________________________________________________________________________________________________
bn6c_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_55 (Activation)      (None, 3, 3, 512)    0           bn6c_branch2a[0][0]              
__________________________________________________________________________________________________
res6c_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_55[0][0]              
__________________________________________________________________________________________________
bn6c_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_56 (Activation)      (None, 3, 3, 512)    0           bn6c_branch2b[0][0]              
__________________________________________________________________________________________________
res6c_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_56[0][0]              
__________________________________________________________________________________________________
bn6c_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6c_branch2c[0][0]             
__________________________________________________________________________________________________
add_19 (Add)                    (None, 3, 3, 2048)   0           bn6c_branch2c[0][0]              
                                                                 activation_54[0][0]              
__________________________________________________________________________________________________
activation_57 (Activation)      (None, 3, 3, 2048)   0           add_19[0][0]                     
__________________________________________________________________________________________________
avg_pool (AveragePooling2D)     (None, 3, 3, 2048)   0           activation_57[0][0]              
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 18432)        0           avg_pool[0][0]                   
__________________________________________________________________________________________________
fc1000 (Dense)                  (None, 100)          1843300     flatten_1[0][0]                  
==================================================================================================
Total params: 25,461,700
Trainable params: 25,407,812
Non-trainable params: 53,888
__________________________________________________________________________________________________

然后,下一步是训练模型。

crn50 = custom_resnet50_model.fit(x=x_train, y=y_train, batch_size=32, 
        epochs=10, verbose=1, validation_data=(x_test, y_test), shuffle=True)

Train on 50000 samples, validate on 10000 samples
Epoch 1/10
 50000/50000 [==============================] - 441s 9ms/step - loss: 4.5655 - acc: 0.0817 
  - mean_squared_error: 0.0101 - val_loss: 4.2085 - val_acc: 0.1228 - val_mean_squared_error: 0.0099
Epoch 2/10
 50000/50000 [==============================] - 434s 9ms/step - loss: 4.1448 - acc: 0.1348 
  - mean_squared_error: 0.0098 - val_loss: 4.2032 - val_acc: 0.1236 - val_mean_squared_error: 0.0099
Epoch 3/10
 50000/50000 [==============================] - 433s 9ms/step - loss: 4.2682 - acc: 0.1146 
  - mean_squared_error: 0.0099 - val_loss: 4.3306 - val_acc: 0.1066 - val_mean_squared_error: 0.0100
Epoch 4/10
 50000/50000 [==============================] - 434s 9ms/step - loss: 4.1581 - acc: 0.1340 
  - mean_squared_error: 0.0098 - val_loss: 4.1405 - val_acc: 0.1384 - val_mean_squared_error: 0.0098
Epoch 5/10
 50000/50000 [==============================] - 431s 9ms/step - loss: 3.9395 - acc: 0.1653 
  - mean_squared_error: 0.0096 - val_loss: 3.8838 - val_acc: 0.1718 - val_mean_squared_error: 0.0095
Epoch 6/10
 50000/50000 [==============================] - 432s 9ms/step - loss: 3.9598 - acc: 0.1698 
  - mean_squared_error: 0.0096 - val_loss: 4.0047 - val_acc: 0.1608 - val_mean_squared_error: 0.0096
Epoch 7/10
 50000/50000 [==============================] - 433s 9ms/step - loss: 3.8715 - acc: 0.1797 
  - mean_squared_error: 0.0095 - val_loss: 4.2620 - val_acc: 0.1184 - val_mean_squared_error: 0.0099
Epoch 8/10
 50000/50000 [==============================] - 434s 9ms/step - loss: 3.9661 - acc: 0.1666 
  - mean_squared_error: 0.0096 - val_loss: 3.8181 - val_acc: 0.1898 - val_mean_squared_error: 0.0095
Epoch 9/10
 50000/50000 [==============================] - 434s 9ms/step - loss: 3.8110 - acc: 0.1901 
  - mean_squared_error: 0.0095 - val_loss: 3.7521 - val_acc: 0.1966 - val_mean_squared_error: 0.0094
Epoch 10/10
 50000/50000 [==============================] - 432s 9ms/step - loss: 3.7247 - acc: 0.2048 
  - mean_squared_error: 0.0094 - val_loss: 3.8206 - val_acc: 0.1929 - val_mean_squared_error: 0.0095

让我们以图形方式查看训练和测试结果的指标(当然,使用 matplotlib 库)。

plt.figure(0)
plt.plot(crn50.history['acc'],'r')
plt.plot(crn50.history['val_acc'],'g')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Accuracy")
plt.title("Training Accuracy vs Validation Accuracy")
plt.legend(['train','validation'])
 
plt.figure(1)
plt.plot(crn50.history['loss'],'r')
plt.plot(crn50.history['val_loss'],'g')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Loss")
plt.title("Training Loss vs Validation Loss")
plt.legend(['train','validation'])
 
plt.show()

训练给出了可以接受的结果,并且已经很好地泛化 (0.0119)。

混淆矩阵

一旦我们训练了我们的模型,我们希望在对我们创建的模型的可用性得出任何结论之前,先查看其他指标。 为此,我们将创建混淆矩阵,然后我们将看到 精度召回率F1 分数 指标(参见 维基百科)。

要创建混淆矩阵,我们需要对测试集进行预测,然后,我们可以创建混淆矩阵并显示这些指标。

crn50_pred = custom_resnet50_model.predict(x_test, batch_size=32, verbose=1)
crn50_predicted = np.argmax(crn50_pred, axis=1)

crn50_cm = confusion_matrix(np.argmax(y_test, axis=1), crn50_predicted)

# Visualizing of confusion matrix
crn50_df_cm = pd.DataFrame(crn50_cm, range(100), range(100))
plt.figure(figsize = (20,14))
sn.set(font_scale=1.4) #for label size
sn.heatmap(crn50_df_cm, annot=True, annot_kws={"size": 12}) # font size
plt.show()

下一步是显示指标。

crn50_report = classification_report(np.argmax(y_test, axis=1), crn50_predicted)
print(crn50_report)

             precision    recall  f1-score   support

          0       0.46      0.32      0.38       100
          1       0.25      0.17      0.20       100
          2       0.17      0.09      0.12       100
          3       0.05      0.62      0.09       100
          4       0.18      0.06      0.09       100
          5       0.25      0.05      0.08       100
          6       0.11      0.14      0.12       100
          7       0.15      0.12      0.13       100
          8       0.21      0.20      0.20       100
          9       0.49      0.21      0.29       100
         10       0.11      0.03      0.05       100
         11       0.08      0.05      0.06       100
         12       0.38      0.13      0.19       100
         13       0.23      0.10      0.14       100
         14       0.18      0.05      0.08       100
         15       0.14      0.06      0.08       100
         16       0.19      0.24      0.21       100
         17       0.40      0.19      0.26       100
         18       0.19      0.24      0.21       100
         19       0.20      0.22      0.21       100
         20       0.42      0.31      0.36       100
         21       0.31      0.23      0.26       100
         22       0.35      0.09      0.14       100
         23       0.36      0.37      0.37       100
         24       0.31      0.49      0.38       100
         25       0.17      0.03      0.05       100
         26       0.43      0.06      0.11       100
         27       0.11      0.03      0.05       100
         28       0.31      0.35      0.33       100
         29       0.12      0.10      0.11       100
         30       0.27      0.33      0.30       100
         31       0.11      0.09      0.10       100
         32       0.22      0.20      0.21       100
         33       0.23      0.30      0.26       100
         34       0.17      0.05      0.08       100
         35       0.09      0.02      0.03       100
         36       0.10      0.23      0.14       100
         37       0.15      0.16      0.16       100
         38       0.08      0.24      0.12       100
         39       0.23      0.18      0.20       100
         40       0.26      0.20      0.22       100
         41       0.45      0.49      0.47       100
         42       0.12      0.17      0.14       100
         43       0.11      0.02      0.03       100
         44       0.14      0.09      0.11       100
         45       0.08      0.01      0.02       100
         46       0.07      0.29      0.12       100
         47       0.55      0.18      0.27       100
         48       0.23      0.31      0.26       100
         49       0.27      0.23      0.25       100
         50       0.12      0.05      0.07       100
         51       0.28      0.09      0.14       100
         52       0.47      0.62      0.54       100
         53       0.25      0.13      0.17       100
         54       0.18      0.25      0.21       100
         55       0.00      0.00      0.00       100
         56       0.27      0.27      0.27       100
         57       0.27      0.11      0.16       100
         58       0.15      0.41      0.22       100
         59       0.18      0.10      0.13       100
         60       0.41      0.63      0.50       100
         61       0.33      0.32      0.32       100
         62       0.15      0.07      0.09       100
         63       0.31      0.26      0.28       100
         64       0.11      0.11      0.11       100
         65       0.15      0.11      0.13       100
         66       0.10      0.06      0.08       100
         67       0.15      0.15      0.15       100
         68       0.37      0.66      0.47       100
         69       0.38      0.25      0.30       100
         70       0.21      0.04      0.07       100
         71       0.27      0.54      0.36       100
         72       0.20      0.01      0.02       100
         73       0.30      0.21      0.25       100
         74       0.14      0.15      0.14       100
         75       0.30      0.29      0.29       100
         76       0.40      0.40      0.40       100
         77       0.13      0.14      0.13       100
         78       0.15      0.08      0.10       100
         79       0.14      0.05      0.07       100
         80       0.08      0.05      0.06       100
         81       0.14      0.11      0.12       100
         82       0.37      0.24      0.29       100
         83       0.08      0.02      0.03       100
         84       0.10      0.11      0.10       100
         85       0.23      0.39      0.29       100
         86       0.36      0.21      0.26       100
         87       0.21      0.19      0.20       100
         88       0.05      0.06      0.05       100
         89       0.24      0.18      0.20       100
         90       0.21      0.24      0.22       100
         91       0.33      0.31      0.32       100
         92       0.11      0.11      0.11       100
         93       0.16      0.10      0.12       100
         94       0.38      0.26      0.31       100
         95       0.21      0.50      0.30       100
         96       0.22      0.23      0.22       100
         97       0.10      0.18      0.13       100
         98       0.12      0.02      0.03       100
         99       0.24      0.08      0.12       100

avg / total       0.22      0.19      0.19     10000

ROC 曲线

ROC 曲线 被二元分类器使用,因为它是一个很好的工具,可以查看真阳性率与假阳性率的关系。 以下几行显示了多类分类 ROC 曲线的代码。 此代码来自 DloLogy,但您可以转到 Scikit Learn 文档页面。

from sklearn.datasets import make_classification
from sklearn.preprocessing import label_binarize
from scipy import interp
from itertools import cycle

n_classes = 100

from sklearn.metrics import roc_curve, auc

# Plot linewidth.
lw = 2

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], crn50_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), crn50_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average ROC curve and ROC area

# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
mean_tpr /= n_classes

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot all ROC curves
plt.figure(1)
plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["micro"]),
         color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)

colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes-97), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()


# Zoom in view of the upper left corner.
plt.figure(2)
plt.xlim(0, 0.2)
plt.ylim(0.8, 1)
plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["micro"]),
         color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)

colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(10), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()

然后,我们将保存训练历史记录结果以供将来比较以及模型。

#Model
custom_resnet50_model.save(path_base + '/crn50.h5')

#Historical results
with open(path_base + '/crn50_history.txt', 'wb') as file_pi:
  pickle.dump(crn50.history, file_pi)

模型比较

下一步是将先前实验的指标与这些结果进行比较。 我们将比较我们在前几章中看到的 准确率损失 均方误差 以及使用相同参数训练的一些 VGG 模型。

plt.figure(0)
plt.plot(snn.history['val_acc'],'r')
plt.plot(scnn.history['val_acc'],'g')
plt.plot(vgg16.history['val_acc'],'b')
plt.plot(vgg19.history['val_acc'],'y')
plt.plot(vgg16Bis.history['val_acc'],'m')
plt.plot(crn50.history['val_acc'],'gold')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Accuracy")
plt.title("Simple NN Accuracy vs simple CNN Accuracy")
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])

plt.figure(0)
plt.plot(snn.history['val_loss'],'r')
plt.plot(scnn.history['val_loss'],'g')
plt.plot(vgg16.history['val_loss'],'b')
plt.plot(vgg19.history['val_loss'],'y')
plt.plot(vgg16Bis.history['val_loss'],'m')
plt.plot(crn50.history['val_loss'],'gold')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Loss")
plt.title("Simple NN Loss vs simple CNN Loss")
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])

plt.figure(0)
plt.plot(snn.history['val_mean_squared_error'],'r')
plt.plot(scnn.history['val_mean_squared_error'],'g')
plt.plot(vgg16.history['val_mean_squared_error'],'b')
plt.plot(vgg19.history['val_mean_squared_error'],'y')
plt.plot(vgg16Bis.history['val_mean_squared_error'],'m')
plt.plot(crn50.history['val_mean_squared_error'],'gold')
plt.xticks(np.arange(0, 11, 2.0))
plt.rcParams['figure.figsize'] = (8, 6)
plt.xlabel("Num of Epochs")
plt.ylabel("Mean Squared Error")
plt.title("Simple NN MSE vs simple CNN MSE")
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])

结论

正如您所看到的,该架构标志着一个转折点。 这不仅是因为它比之前的架构取得了最好的结果,而且在训练时间方面也是如此,因为它允许以可接受的时间增加层数; 并且在参数数量方面,与 VGG 架构相比,它已经大大减少了。

在下一篇文章中,我们将展示 DenseNet

© . All rights reserved.