• Xception实现动物识别(TensorFlow)


    1.项目数据及源码

    可在github下载:

    https://github.com/chenshunpeng/Animal-recognition-based-on-xception

    2.任务介绍

    数据结构为:

    data
    ├── cat(文件夹含1000张图像)
    │
    ├── chook(文件夹含1000张图像)
    │
    ├── dog(文件夹含1000张图像)
    │   
    └── horse(文件夹含1000张图像)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    需要把数据分成训练集train和验证集val,对train数据集进行训练,达到给定val数据集中的一张猫 / 狗的图片,识别其是猫还是狗的目的

    3.数据处理

    3.1.数据预处理

    设置GPU环境进行训练:

    import tensorflow as tf
    
    gpus = tf.config.list_physical_devices("GPU")
    
    if gpus:
        tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
        tf.config.set_visible_devices([gpus[0]],"GPU")
    
    # 打印显卡信息,确认GPU可用
    print(gpus)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    输出:

    [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
    
    • 1

    导入图片数据:

    import matplotlib.pyplot as plt
    # 支持中文
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    
    import os,PIL
    
    # 设置随机种子尽可能使结果可以重现
    import numpy as np
    np.random.seed(1)
    
    # 设置随机种子尽可能使结果可以重现
    import tensorflow as tf
    tf.random.set_seed(1)
    
    import pathlib
    
    data_dir = "./data"
    
    data_dir = pathlib.Path(data_dir)
    
    image_count = len(list(data_dir.glob('*/*')))
    
    print("图片总数为:",image_count)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    输出:

     图片总数为: 4000
    
    • 1

    之后初始化参数,并使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

    函数原型:

    tf.keras.preprocessing.image_dataset_from_directory(
        directory,
        labels="inferred",
        label_mode="int",
        class_names=None,
        color_mode="rgb",
        batch_size=32,
        image_size=(256, 256),
        shuffle=True,
        seed=None,
        validation_split=None,
        subset=None,
        interpolation="bilinear",
        follow_links=False,
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    官网介绍:tf.keras.utils.image_dataset_from_directory

    代码:

    batch_size = 4
    img_height = 299
    img_width  = 299
    
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    输出:

    Found 4000 files belonging to 4 classes.
    Using 3200 files for training.
    
    • 1
    • 2

    同理配置验证集:

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    Found 4000 files belonging to 4 classes.
    Using 800 files for validation.
    
    • 1
    • 2

    我们可以通过class_names输出数据集的标签,标签将按字母顺序对应于目录名称

    class_names = train_ds.class_names
    print(class_names)
    
    • 1
    • 2

    输出:

    ['cat', 'chook', 'dog', 'horse']
    
    • 1

    查看batch的数据类型:

    for image_batch, labels_batch in train_ds:
        print(image_batch.shape)
        print(labels_batch.shape)
        break
    
    • 1
    • 2
    • 3
    • 4

    输出:

    (4, 299, 299, 3)
    (4,)
    
    • 1
    • 2

    3.2.可视化数据

    plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
    plt.suptitle("数据展示")
    
    num = -1
    for images, labels in train_ds.take(2):
        for i in range(4):
            num = num + 1
            ax = plt.subplot(2, 4, num + 1)  
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            plt.savefig('pic1.jpg', dpi=600) #指定分辨率保存
            plt.axis("off")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    输出:

    请添加图片描述

    3.3.配置数据集

    shuffle() : 打乱数据,详细可参考:数据集shuffle方法中buffer_size的理解

    prefetch() :预取数据,加速运行,详细可参考:Better performance with the tf.data API

    cache() :将数据集缓存到内存当中,加速运行

    AUTOTUNE = tf.data.AUTOTUNE
    
    train_ds = (
        train_ds.cache()
        .shuffle(1000)
    #     .map(train_preprocessing)    # 这里可以设置预处理函数
    #     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
        .prefetch(buffer_size=AUTOTUNE)
    )
    
    val_ds = (
        val_ds.cache()
        .shuffle(1000)
    #     .map(val_preprocessing)    # 这里可以设置预处理函数
    #     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
        .prefetch(buffer_size=AUTOTUNE)
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    4.网络设计

    4.1.Xception简单介绍

    详细可看:知乎

    论文地址:Xception: Deep Learning with Depthwise Separable Convolutions

    工程代码:https://github.com/keras-team/keras-applications/blob/master/keras_applications/xception.py

    Xception是Google2016年10月提出的,时间在Google家的MobileNet v1之后,MobileNet v2之前。其吸纳了ResNet、Inception、MobileNet v1的设计思想,直接以Inception v3为模子,将里面的基本Inception module的卷积替换为使用 Depthwise Separable Convolution,又外加了残差连接

    Xception 的结构基于ResNet,整个网络被分为了三个部分:EntryMiddleExit

    • Entry 部分主要是用来不断下采样,减小空间维度
    • Middle 部分则是不断学习关联关系,优化特征,其有8个部分;所有的普通卷积和可分离卷积后面都接了BN,不过图中没有给出
    • 最终Exit部分则是汇总、整理特征,最后交由FC来进行表达

    网络的整个流程如下图,Xception架构有36个卷积层作为网络特征提取的基础,这36个卷积层被分为14个模块,除了第一个和最后一个,其他每一个模块都使用了残差连接

    在这里插入图片描述

    简而言之,Xception架构是一个深度可分离卷积层的线性叠加,这个架构易于修改,仅使用30-40行代码就可以完成

    4.2.设计网络模型

    #====================================#
    #     Xception的网络部分
    #====================================#
    from tensorflow.keras.preprocessing import image
    
    from tensorflow.keras.models import Model
    from tensorflow.keras import layers
    from tensorflow.keras.layers import Dense,Input,BatchNormalization,Activation,Conv2D,SeparableConv2D,MaxPooling2D
    from tensorflow.keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D
    from tensorflow.keras import backend as K
    from tensorflow.keras.applications.imagenet_utils import decode_predictions
    
    
    def Xception(input_shape = [299,299,3],classes=1000):
    
        img_input = Input(shape=input_shape)
        
        #=================#
        #   Entry flow
        #=================#
        #  block1
        # 299,299,3 -> 149,149,64
        x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input)
        x = BatchNormalization(name='block1_conv1_bn')(x)
        x = Activation('relu', name='block1_conv1_act')(x)
        x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
        x = BatchNormalization(name='block1_conv2_bn')(x)
        x = Activation('relu', name='block1_conv2_act')(x)
    
    
        # block2
        # 149,149,64 -> 75,75,128
        residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
        residual = BatchNormalization()(residual)
    
        x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
        x = BatchNormalization(name='block2_sepconv1_bn')(x)
        x = Activation('relu', name='block2_sepconv2_act')(x)
        x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
        x = BatchNormalization(name='block2_sepconv2_bn')(x)
    
        x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
        x = layers.add([x, residual])
    
        # block3
        # 75,75,128 -> 38,38,256
        residual = Conv2D(256, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
        residual = BatchNormalization()(residual)
    
        x = Activation('relu', name='block3_sepconv1_act')(x)
        x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
        x = BatchNormalization(name='block3_sepconv1_bn')(x)
        x = Activation('relu', name='block3_sepconv2_act')(x)
        x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
        x = BatchNormalization(name='block3_sepconv2_bn')(x)
    
        x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
        x = layers.add([x, residual])
    
        # block4
        # 38,38,256 -> 19,19,728
        residual = Conv2D(728, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
        residual = BatchNormalization()(residual)
    
        x = Activation('relu', name='block4_sepconv1_act')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
        x = BatchNormalization(name='block4_sepconv1_bn')(x)
        x = Activation('relu', name='block4_sepconv2_act')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
        x = BatchNormalization(name='block4_sepconv2_bn')(x)
    
        x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
        x = layers.add([x, residual])
    
        #=================#
        # Middle flow
        #=================#
        # block5--block12
        # 19,19,728 -> 19,19,728
        for i in range(8):
            residual = x
            prefix = 'block' + str(i + 5)
    
            x = Activation('relu', name=prefix + '_sepconv1_act')(x)
            x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
            x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
            x = Activation('relu', name=prefix + '_sepconv2_act')(x)
            x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
            x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
            x = Activation('relu', name=prefix + '_sepconv3_act')(x)
            x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
            x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
    
            x = layers.add([x, residual])
    
        #=================#
        #    Exit flow
        #=================#
        # block13
        # 19,19,728 -> 10,10,1024
        residual = Conv2D(1024, (1, 1), strides=(2, 2),
                          padding='same', use_bias=False)(x)
        residual = BatchNormalization()(residual)
    
        x = Activation('relu', name='block13_sepconv1_act')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
        x = BatchNormalization(name='block13_sepconv1_bn')(x)
        x = Activation('relu', name='block13_sepconv2_act')(x)
        x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
        x = BatchNormalization(name='block13_sepconv2_bn')(x)
    
        x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
        x = layers.add([x, residual])
    
        # block14
        # 10,10,1024 -> 10,10,2048
        x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
        x = BatchNormalization(name='block14_sepconv1_bn')(x)
        x = Activation('relu', name='block14_sepconv1_act')(x)
    
        x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
        x = BatchNormalization(name='block14_sepconv2_bn')(x)
        x = Activation('relu', name='block14_sepconv2_act')(x)
    
        x = GlobalAveragePooling2D(name='avg_pool')(x)
        x = Dense(classes, activation='softmax', name='predictions')(x)
    
        inputs = img_input
    
        model = Model(inputs, x, name='xception')
    
        return model
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132

    打印模型信息:

    model = Xception()
    # 打印模型信息
    model.summary()
    
    • 1
    • 2
    • 3

    输出:

    Model: "xception"
    __________________________________________________________________________________________________
     Layer (type)                   Output Shape         Param #     Connected to                     
    ==================================================================================================
     input_1 (InputLayer)           [(None, 299, 299, 3  0           []                               
                                    )]                                                                
                                                                                                      
     block1_conv1 (Conv2D)          (None, 149, 149, 32  864         ['input_1[0][0]']                
                                    )                                                                 
                                                                                                      
     block1_conv1_bn (BatchNormaliz  (None, 149, 149, 32  128        ['block1_conv1[0][0]']           
     ation)                         )                                                                 
                                                                                                      
     block1_conv1_act (Activation)  (None, 149, 149, 32  0           ['block1_conv1_bn[0][0]']        
                                    )                                                                 
                                                                                                      
     block1_conv2 (Conv2D)          (None, 147, 147, 64  18432       ['block1_conv1_act[0][0]']       
                                    )                                                                 
                                                                                                      
     block1_conv2_bn (BatchNormaliz  (None, 147, 147, 64  256        ['block1_conv2[0][0]']           
     ation)                         )                                                                 
                                                                                                      
     block1_conv2_act (Activation)  (None, 147, 147, 64  0           ['block1_conv2_bn[0][0]']        
                                    )                                                                 
                                                                                                      
     block2_sepconv1 (SeparableConv  (None, 147, 147, 12  8768       ['block1_conv2_act[0][0]']       
     2D)                            8)                                                                
                                                                                                      
     block2_sepconv1_bn (BatchNorma  (None, 147, 147, 12  512        ['block2_sepconv1[0][0]']        
     lization)                      8)                                                                
                                                                                                      
     block2_sepconv2_act (Activatio  (None, 147, 147, 12  0          ['block2_sepconv1_bn[0][0]']     
     n)                             8)                                                                
                                                                                                      
     block2_sepconv2 (SeparableConv  (None, 147, 147, 12  17536      ['block2_sepconv2_act[0][0]']    
     2D)                            8)                                                                
                                                                                                      
     block2_sepconv2_bn (BatchNorma  (None, 147, 147, 12  512        ['block2_sepconv2[0][0]']        
     lization)                      8)                                                                
                                                                                                      
     conv2d (Conv2D)                (None, 74, 74, 128)  8192        ['block1_conv2_act[0][0]']       
                                                                                                      
     block2_pool (MaxPooling2D)     (None, 74, 74, 128)  0           ['block2_sepconv2_bn[0][0]']     
                                                                                                      
     batch_normalization (BatchNorm  (None, 74, 74, 128)  512        ['conv2d[0][0]']                 
     alization)                                                                                       
                                                                                                      
     add (Add)                      (None, 74, 74, 128)  0           ['block2_pool[0][0]',            
                                                                      'batch_normalization[0][0]']    
                                                                                                      
     block3_sepconv1_act (Activatio  (None, 74, 74, 128)  0          ['add[0][0]']                    
     n)                                                                                               
                                                                                                      
     block3_sepconv1 (SeparableConv  (None, 74, 74, 256)  33920      ['block3_sepconv1_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block3_sepconv1_bn (BatchNorma  (None, 74, 74, 256)  1024       ['block3_sepconv1[0][0]']        
     lization)                                                                                        
                                                                                                      
     block3_sepconv2_act (Activatio  (None, 74, 74, 256)  0          ['block3_sepconv1_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block3_sepconv2 (SeparableConv  (None, 74, 74, 256)  67840      ['block3_sepconv2_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block3_sepconv2_bn (BatchNorma  (None, 74, 74, 256)  1024       ['block3_sepconv2[0][0]']        
     lization)                                                                                        
                                                                                                      
     conv2d_1 (Conv2D)              (None, 37, 37, 256)  32768       ['add[0][0]']                    
                                                                                                      
     block3_pool (MaxPooling2D)     (None, 37, 37, 256)  0           ['block3_sepconv2_bn[0][0]']     
                                                                                                      
     batch_normalization_1 (BatchNo  (None, 37, 37, 256)  1024       ['conv2d_1[0][0]']               
     rmalization)                                                                                     
                                                                                                      
     add_1 (Add)                    (None, 37, 37, 256)  0           ['block3_pool[0][0]',            
                                                                      'batch_normalization_1[0][0]']  
                                                                                                      
     block4_sepconv1_act (Activatio  (None, 37, 37, 256)  0          ['add_1[0][0]']                  
     n)                                                                                               
                                                                                                      
     block4_sepconv1 (SeparableConv  (None, 37, 37, 728)  188672     ['block4_sepconv1_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block4_sepconv1_bn (BatchNorma  (None, 37, 37, 728)  2912       ['block4_sepconv1[0][0]']        
     lization)                                                                                        
                                                                                                      
     block4_sepconv2_act (Activatio  (None, 37, 37, 728)  0          ['block4_sepconv1_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block4_sepconv2 (SeparableConv  (None, 37, 37, 728)  536536     ['block4_sepconv2_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block4_sepconv2_bn (BatchNorma  (None, 37, 37, 728)  2912       ['block4_sepconv2[0][0]']        
     lization)                                                                                        
                                                                                                      
     conv2d_2 (Conv2D)              (None, 19, 19, 728)  186368      ['add_1[0][0]']                  
                                                                                                      
     block4_pool (MaxPooling2D)     (None, 19, 19, 728)  0           ['block4_sepconv2_bn[0][0]']     
                                                                                                      
     batch_normalization_2 (BatchNo  (None, 19, 19, 728)  2912       ['conv2d_2[0][0]']               
     rmalization)                                                                                     
                                                                                                      
     add_2 (Add)                    (None, 19, 19, 728)  0           ['block4_pool[0][0]',            
                                                                      'batch_normalization_2[0][0]']  
                                                                                                      
     block5_sepconv1_act (Activatio  (None, 19, 19, 728)  0          ['add_2[0][0]']                  
     n)                                                                                               
                                                                                                      
     block5_sepconv1 (SeparableConv  (None, 19, 19, 728)  536536     ['block5_sepconv1_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block5_sepconv1_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block5_sepconv1[0][0]']        
     lization)                                                                                        
                                                                                                      
     block5_sepconv2_act (Activatio  (None, 19, 19, 728)  0          ['block5_sepconv1_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block5_sepconv2 (SeparableConv  (None, 19, 19, 728)  536536     ['block5_sepconv2_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block5_sepconv2_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block5_sepconv2[0][0]']        
     lization)                                                                                        
                                                                                                      
     block5_sepconv3_act (Activatio  (None, 19, 19, 728)  0          ['block5_sepconv2_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block5_sepconv3 (SeparableConv  (None, 19, 19, 728)  536536     ['block5_sepconv3_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block5_sepconv3_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block5_sepconv3[0][0]']        
     lization)                                                                                        
                                                                                                      
     add_3 (Add)                    (None, 19, 19, 728)  0           ['block5_sepconv3_bn[0][0]',     
                                                                      'add_2[0][0]']                  
                                                                                                      
     block6_sepconv1_act (Activatio  (None, 19, 19, 728)  0          ['add_3[0][0]']                  
     n)                                                                                               
                                                                                                      
     block6_sepconv1 (SeparableConv  (None, 19, 19, 728)  536536     ['block6_sepconv1_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block6_sepconv1_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block6_sepconv1[0][0]']        
     lization)                                                                                        
                                                                                                      
     block6_sepconv2_act (Activatio  (None, 19, 19, 728)  0          ['block6_sepconv1_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block6_sepconv2 (SeparableConv  (None, 19, 19, 728)  536536     ['block6_sepconv2_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block6_sepconv2_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block6_sepconv2[0][0]']        
     lization)                                                                                        
                                                                                                      
     block6_sepconv3_act (Activatio  (None, 19, 19, 728)  0          ['block6_sepconv2_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block6_sepconv3 (SeparableConv  (None, 19, 19, 728)  536536     ['block6_sepconv3_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block6_sepconv3_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block6_sepconv3[0][0]']        
     lization)                                                                                        
                                                                                                      
     add_4 (Add)                    (None, 19, 19, 728)  0           ['block6_sepconv3_bn[0][0]',     
                                                                      'add_3[0][0]']                  
                                                                                                      
     block7_sepconv1_act (Activatio  (None, 19, 19, 728)  0          ['add_4[0][0]']                  
     n)                                                                                               
                                                                                                      
     block7_sepconv1 (SeparableConv  (None, 19, 19, 728)  536536     ['block7_sepconv1_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block7_sepconv1_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block7_sepconv1[0][0]']        
     lization)                                                                                        
                                                                                                      
     block7_sepconv2_act (Activatio  (None, 19, 19, 728)  0          ['block7_sepconv1_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block7_sepconv2 (SeparableConv  (None, 19, 19, 728)  536536     ['block7_sepconv2_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block7_sepconv2_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block7_sepconv2[0][0]']        
     lization)                                                                                        
                                                                                                      
     block7_sepconv3_act (Activatio  (None, 19, 19, 728)  0          ['block7_sepconv2_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block7_sepconv3 (SeparableConv  (None, 19, 19, 728)  536536     ['block7_sepconv3_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block7_sepconv3_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block7_sepconv3[0][0]']        
     lization)                                                                                        
                                                                                                      
     add_5 (Add)                    (None, 19, 19, 728)  0           ['block7_sepconv3_bn[0][0]',     
                                                                      'add_4[0][0]']                  
                                                                                                      
     block8_sepconv1_act (Activatio  (None, 19, 19, 728)  0          ['add_5[0][0]']                  
     n)                                                                                               
                                                                                                      
     block8_sepconv1 (SeparableConv  (None, 19, 19, 728)  536536     ['block8_sepconv1_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block8_sepconv1_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block8_sepconv1[0][0]']        
     lization)                                                                                        
                                                                                                      
     block8_sepconv2_act (Activatio  (None, 19, 19, 728)  0          ['block8_sepconv1_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block8_sepconv2 (SeparableConv  (None, 19, 19, 728)  536536     ['block8_sepconv2_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block8_sepconv2_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block8_sepconv2[0][0]']        
     lization)                                                                                        
                                                                                                      
     block8_sepconv3_act (Activatio  (None, 19, 19, 728)  0          ['block8_sepconv2_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block8_sepconv3 (SeparableConv  (None, 19, 19, 728)  536536     ['block8_sepconv3_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block8_sepconv3_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block8_sepconv3[0][0]']        
     lization)                                                                                        
                                                                                                      
     add_6 (Add)                    (None, 19, 19, 728)  0           ['block8_sepconv3_bn[0][0]',     
                                                                      'add_5[0][0]']                  
                                                                                                      
     block9_sepconv1_act (Activatio  (None, 19, 19, 728)  0          ['add_6[0][0]']                  
     n)                                                                                               
                                                                                                      
     block9_sepconv1 (SeparableConv  (None, 19, 19, 728)  536536     ['block9_sepconv1_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block9_sepconv1_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block9_sepconv1[0][0]']        
     lization)                                                                                        
                                                                                                      
     block9_sepconv2_act (Activatio  (None, 19, 19, 728)  0          ['block9_sepconv1_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block9_sepconv2 (SeparableConv  (None, 19, 19, 728)  536536     ['block9_sepconv2_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block9_sepconv2_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block9_sepconv2[0][0]']        
     lization)                                                                                        
                                                                                                      
     block9_sepconv3_act (Activatio  (None, 19, 19, 728)  0          ['block9_sepconv2_bn[0][0]']     
     n)                                                                                               
                                                                                                      
     block9_sepconv3 (SeparableConv  (None, 19, 19, 728)  536536     ['block9_sepconv3_act[0][0]']    
     2D)                                                                                              
                                                                                                      
     block9_sepconv3_bn (BatchNorma  (None, 19, 19, 728)  2912       ['block9_sepconv3[0][0]']        
     lization)                                                                                        
                                                                                                      
     add_7 (Add)                    (None, 19, 19, 728)  0           ['block9_sepconv3_bn[0][0]',     
                                                                      'add_6[0][0]']                  
                                                                                                      
     block10_sepconv1_act (Activati  (None, 19, 19, 728)  0          ['add_7[0][0]']                  
     on)                                                                                              
                                                                                                      
     block10_sepconv1 (SeparableCon  (None, 19, 19, 728)  536536     ['block10_sepconv1_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block10_sepconv1_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block10_sepconv1[0][0]']       
     alization)                                                                                       
                                                                                                      
     block10_sepconv2_act (Activati  (None, 19, 19, 728)  0          ['block10_sepconv1_bn[0][0]']    
     on)                                                                                              
                                                                                                      
     block10_sepconv2 (SeparableCon  (None, 19, 19, 728)  536536     ['block10_sepconv2_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block10_sepconv2_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block10_sepconv2[0][0]']       
     alization)                                                                                       
                                                                                                      
     block10_sepconv3_act (Activati  (None, 19, 19, 728)  0          ['block10_sepconv2_bn[0][0]']    
     on)                                                                                              
                                                                                                      
     block10_sepconv3 (SeparableCon  (None, 19, 19, 728)  536536     ['block10_sepconv3_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block10_sepconv3_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block10_sepconv3[0][0]']       
     alization)                                                                                       
                                                                                                      
     add_8 (Add)                    (None, 19, 19, 728)  0           ['block10_sepconv3_bn[0][0]',    
                                                                      'add_7[0][0]']                  
                                                                                                      
     block11_sepconv1_act (Activati  (None, 19, 19, 728)  0          ['add_8[0][0]']                  
     on)                                                                                              
                                                                                                      
     block11_sepconv1 (SeparableCon  (None, 19, 19, 728)  536536     ['block11_sepconv1_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block11_sepconv1_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block11_sepconv1[0][0]']       
     alization)                                                                                       
                                                                                                      
     block11_sepconv2_act (Activati  (None, 19, 19, 728)  0          ['block11_sepconv1_bn[0][0]']    
     on)                                                                                              
                                                                                                      
     block11_sepconv2 (SeparableCon  (None, 19, 19, 728)  536536     ['block11_sepconv2_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block11_sepconv2_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block11_sepconv2[0][0]']       
     alization)                                                                                       
                                                                                                      
     block11_sepconv3_act (Activati  (None, 19, 19, 728)  0          ['block11_sepconv2_bn[0][0]']    
     on)                                                                                              
                                                                                                      
     block11_sepconv3 (SeparableCon  (None, 19, 19, 728)  536536     ['block11_sepconv3_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block11_sepconv3_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block11_sepconv3[0][0]']       
     alization)                                                                                       
                                                                                                      
     add_9 (Add)                    (None, 19, 19, 728)  0           ['block11_sepconv3_bn[0][0]',    
                                                                      'add_8[0][0]']                  
                                                                                                      
     block12_sepconv1_act (Activati  (None, 19, 19, 728)  0          ['add_9[0][0]']                  
     on)                                                                                              
                                                                                                      
     block12_sepconv1 (SeparableCon  (None, 19, 19, 728)  536536     ['block12_sepconv1_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block12_sepconv1_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block12_sepconv1[0][0]']       
     alization)                                                                                       
                                                                                                      
     block12_sepconv2_act (Activati  (None, 19, 19, 728)  0          ['block12_sepconv1_bn[0][0]']    
     on)                                                                                              
                                                                                                      
     block12_sepconv2 (SeparableCon  (None, 19, 19, 728)  536536     ['block12_sepconv2_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block12_sepconv2_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block12_sepconv2[0][0]']       
     alization)                                                                                       
                                                                                                      
     block12_sepconv3_act (Activati  (None, 19, 19, 728)  0          ['block12_sepconv2_bn[0][0]']    
     on)                                                                                              
                                                                                                      
     block12_sepconv3 (SeparableCon  (None, 19, 19, 728)  536536     ['block12_sepconv3_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block12_sepconv3_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block12_sepconv3[0][0]']       
     alization)                                                                                       
                                                                                                      
     add_10 (Add)                   (None, 19, 19, 728)  0           ['block12_sepconv3_bn[0][0]',    
                                                                      'add_9[0][0]']                  
                                                                                                      
     block13_sepconv1_act (Activati  (None, 19, 19, 728)  0          ['add_10[0][0]']                 
     on)                                                                                              
                                                                                                      
     block13_sepconv1 (SeparableCon  (None, 19, 19, 728)  536536     ['block13_sepconv1_act[0][0]']   
     v2D)                                                                                             
                                                                                                      
     block13_sepconv1_bn (BatchNorm  (None, 19, 19, 728)  2912       ['block13_sepconv1[0][0]']       
     alization)                                                                                       
                                                                                                      
     block13_sepconv2_act (Activati  (None, 19, 19, 728)  0          ['block13_sepconv1_bn[0][0]']    
     on)                                                                                              
                                                                                                      
     block13_sepconv2 (SeparableCon  (None, 19, 19, 1024  752024     ['block13_sepconv2_act[0][0]']   
     v2D)                           )                                                                 
                                                                                                      
     block13_sepconv2_bn (BatchNorm  (None, 19, 19, 1024  4096       ['block13_sepconv2[0][0]']       
     alization)                     )                                                                 
                                                                                                      
     conv2d_3 (Conv2D)              (None, 10, 10, 1024  745472      ['add_10[0][0]']                 
                                    )                                                                 
                                                                                                      
     block13_pool (MaxPooling2D)    (None, 10, 10, 1024  0           ['block13_sepconv2_bn[0][0]']    
                                    )                                                                 
                                                                                                      
     batch_normalization_3 (BatchNo  (None, 10, 10, 1024  4096       ['conv2d_3[0][0]']               
     rmalization)                   )                                                                 
                                                                                                      
     add_11 (Add)                   (None, 10, 10, 1024  0           ['block13_pool[0][0]',           
                                    )                                 'batch_normalization_3[0][0]']  
                                                                                                      
     block14_sepconv1 (SeparableCon  (None, 10, 10, 1536  1582080    ['add_11[0][0]']                 
     v2D)                           )                                                                 
                                                                                                      
     block14_sepconv1_bn (BatchNorm  (None, 10, 10, 1536  6144       ['block14_sepconv1[0][0]']       
     alization)                     )                                                                 
                                                                                                      
     block14_sepconv1_act (Activati  (None, 10, 10, 1536  0          ['block14_sepconv1_bn[0][0]']    
     on)                            )                                                                 
                                                                                                      
     block14_sepconv2 (SeparableCon  (None, 10, 10, 2048  3159552    ['block14_sepconv1_act[0][0]']   
     v2D)                           )                                                                 
                                                                                                      
     block14_sepconv2_bn (BatchNorm  (None, 10, 10, 2048  8192       ['block14_sepconv2[0][0]']       
     alization)                     )                                                                 
                                                                                                      
     block14_sepconv2_act (Activati  (None, 10, 10, 2048  0          ['block14_sepconv2_bn[0][0]']    
     on)                            )                                                                 
                                                                                                      
     avg_pool (GlobalAveragePooling  (None, 2048)        0           ['block14_sepconv2_act[0][0]']   
     2D)                                                                                              
                                                                                                      
     predictions (Dense)            (None, 1000)         2049000     ['avg_pool[0][0]']               
                                                                                                      
    ==================================================================================================
    Total params: 22,910,480
    Trainable params: 22,855,952
    Non-trainable params: 54,528
    __________________________________________________________________________________________________
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404

    设置动态学习率

    # 设置初始学习率
    initial_learning_rate = 1e-4
    
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate, 
            decay_steps=300,      # 敲黑板!!!这里是指 steps,不是指epochs
            decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
            staircase=True)
    
    # 将指数衰减学习率送入优化器
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    模型的编译

    • 损失函数(loss):用于衡量模型在训练期间的准确率,这里用sparse_categorical_crossentropy,原理与categorical_crossentropy(多类交叉熵损失 )一样,不过真实值采用的整数编码(例如第0个类用数字0表示,第3个类用数字3表示,官方可看:tf.keras.losses.SparseCategoricalCrossentropy
    • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新,这里是Adam(官方可看:tf.keras.optimizers.Adam
    • 评价函数(metrics):用于监控训练和测试步骤,本次使用accuracy,即被正确分类的图像的比率(官方可看:tf.keras.metrics.Accuracy
    model.compile(optimizer=optimizer,
                  loss     ='sparse_categorical_crossentropy',
                  metrics  =['accuracy'])
    
    • 1
    • 2
    • 3

    训练模型

    epochs = 20
    
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    训练结果:

    Epoch 1/20
    800/800 [==============================] - 464s 564ms/step - loss: 1.4314 - accuracy: 0.4584 - val_loss: 1.0577 - val_accuracy: 0.5475
    Epoch 2/20
    800/800 [==============================] - 447s 559ms/step - loss: 0.9087 - accuracy: 0.6228 - val_loss: 0.8191 - val_accuracy: 0.6612
    Epoch 3/20
    800/800 [==============================] - 446s 558ms/step - loss: 0.6728 - accuracy: 0.7403 - val_loss: 0.8190 - val_accuracy: 0.6687
    Epoch 4/20
    800/800 [==============================] - 447s 559ms/step - loss: 0.3362 - accuracy: 0.8841 - val_loss: 0.8249 - val_accuracy: 0.6913
    Epoch 5/20
    800/800 [==============================] - 447s 559ms/step - loss: 0.1415 - accuracy: 0.9566 - val_loss: 0.9374 - val_accuracy: 0.6975
    Epoch 6/20
    800/800 [==============================] - 446s 558ms/step - loss: 0.0840 - accuracy: 0.9809 - val_loss: 1.2619 - val_accuracy: 0.6737
    Epoch 7/20
    800/800 [==============================] - 447s 558ms/step - loss: 0.0574 - accuracy: 0.9862 - val_loss: 0.7897 - val_accuracy: 0.7738
    Epoch 8/20
    800/800 [==============================] - 446s 558ms/step - loss: 0.0369 - accuracy: 0.9912 - val_loss: 0.8976 - val_accuracy: 0.7350
    Epoch 9/20
    800/800 [==============================] - 446s 557ms/step - loss: 0.0276 - accuracy: 0.9966 - val_loss: 0.7896 - val_accuracy: 0.7725
    Epoch 10/20
    800/800 [==============================] - 446s 558ms/step - loss: 0.0223 - accuracy: 0.9969 - val_loss: 0.7084 - val_accuracy: 0.7812
    Epoch 11/20
    800/800 [==============================] - 446s 558ms/step - loss: 0.0108 - accuracy: 0.9978 - val_loss: 0.8445 - val_accuracy: 0.7588
    Epoch 12/20
    800/800 [==============================] - 446s 557ms/step - loss: 0.0102 - accuracy: 0.9975 - val_loss: 0.7577 - val_accuracy: 0.7850
    Epoch 13/20
    800/800 [==============================] - 446s 558ms/step - loss: 0.0062 - accuracy: 0.9991 - val_loss: 0.7447 - val_accuracy: 0.7837
    Epoch 14/20
    800/800 [==============================] - 445s 557ms/step - loss: 0.0034 - accuracy: 0.9987 - val_loss: 1.0870 - val_accuracy: 0.7063
    Epoch 15/20
    800/800 [==============================] - 445s 557ms/step - loss: 0.0100 - accuracy: 0.9978 - val_loss: 0.8212 - val_accuracy: 0.7725
    Epoch 16/20
    800/800 [==============================] - 446s 557ms/step - loss: 0.0089 - accuracy: 0.9981 - val_loss: 0.8604 - val_accuracy: 0.7688
    Epoch 17/20
    800/800 [==============================] - 446s 557ms/step - loss: 0.0068 - accuracy: 0.9984 - val_loss: 0.7941 - val_accuracy: 0.7887
    Epoch 18/20
    800/800 [==============================] - 446s 557ms/step - loss: 0.0037 - accuracy: 0.9994 - val_loss: 0.9039 - val_accuracy: 0.7650
    Epoch 19/20
    800/800 [==============================] - 446s 557ms/step - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.8278 - val_accuracy: 0.7812
    Epoch 20/20
    800/800 [==============================] - 446s 557ms/step - loss: 6.7889e-04 - accuracy: 1.0000 - val_loss: 0.8216 - val_accuracy: 0.7812
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    5.模型评估

    5.1.准确率评估

    Accuracy与Loss图

    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    epochs_range = range(epochs)
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    请添加图片描述

    5.2.绘制混淆矩阵

    confusion_matrix()介绍可看:sklearn.metrics.confusion_matrix

    Seaborn:基于 Matplotlib 核心库进行了更高阶的 API 封装,其优势在配色更加舒服、以及图形元素的样式更加细腻

    定义一个绘制混淆矩阵图的函数plot_cm

    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import pandas as pd
    
    # 定义一个绘制混淆矩阵图的函数
    def plot_cm(labels, predictions):
        
        # 生成混淆矩阵
        conf_numpy = confusion_matrix(labels, predictions)
        # 将矩阵转化为 DataFrame
        conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
        
        plt.figure(figsize=(8,7))
        
        sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
        
        plt.title('混淆矩阵',fontsize=15)
        plt.ylabel('真实值',fontsize=14)
        plt.xlabel('预测值',fontsize=14)
        plt.savefig('pic3.jpg', dpi=600) #指定分辨率保存
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    输出:

    请添加图片描述

    保存模型:

    # 保存模型
    model.save('model/model.h5')
    # 加载模型
    new_model = tf.keras.models.load_model('model/model.h5')
    
    • 1
    • 2
    • 3
    • 4

    5.3.进行预测

    plt.figure(figsize=(15, 7))  # 图形的宽为15高为7
    plt.suptitle("预测结果展示")
    
    num = -1
    for images, labels in val_ds.take(2):
        for i in range(4):
            num = num + 1
            plt.subplots_adjust(left=None, bottom=None, right=None, top=None , wspace=0.2, hspace=0.2)
            if num >= 8:
                break
            ax = plt.subplot(2, 4, num + 1)  
            
            # 显示图片
            plt.imshow(images[i].numpy().astype("uint8"))
            
            # 需要给图片增加一个维度
            img_array = tf.expand_dims(images[i], 0) 
            
            # 使用模型预测图片中的人物
            predictions = model.predict(img_array)
            plt.title("True value: {}\npredictive value: {}".format(class_names[labels[i]],class_names[np.argmax(predictions)]))
            plt.savefig('pic4.jpg', dpi=400) #指定分辨率保存
            plt.axis("off")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    结果:

    请添加图片描述

  • 相关阅读:
    知云文献翻译跨页内容选中翻译操作
    Java 网络编程 —— 非阻塞式编程
    ubuntu 20.04 安装软件踩坑
    网络安全:个人信息保护,企业信息安全,国家网络安全的重要性
    Java面试之JVM篇(offer 拿来吧你)
    Linux常用锁
    SQL-存储过程、流程控制、游标
    基金的全面介绍,看这一篇就够了
    Java笔记八(instanceof,类型转换,static详解,抽象类,接口,内部类以及异常)
    Android 进入 Activity 时禁止弹出输入法
  • 原文地址:https://blog.csdn.net/qq_45550375/article/details/126455124