• 人工智能学习:CIFAR-10数据分类识别-VGG网络(5)


    这里尝试采用VGG网络对CIFAR-10数据集进行分类识别。

    1 导入需要的模块

    import numpy as np
    
    import tensorflow as tf
    from tensorflow import keras
    from keras import models, layers
    
    import matplotlib.pyplot as plt
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    2 载入CIFAR-10数据集

    # load CIFAR-10 dataset
    (train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
    # train_images: 50000*32*32*3, train_labels: 50000*1, test_images: 10000*32*32*3, test_labels: 10000*1
    
    # change data shape & types
    train_input = train_images/255.0
    test_input = test_images/255.0
    train_output = train_labels
    test_output = test_labels
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    3 构建神经网络

    首先,定义构建模型函数

    def build_model():
        model = models.Sequential()
        
        # 1st layer, input shape (32,32,3)
        model.add(layers.Conv2D(64, (3,3), padding='same', input_shape=(32,32,3)))
        model.add(layers.Activation('relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.Dropout(0.3))
        
        # 2nd layer, input shape (32,32,64)
        model.add(layers.Conv2D(64, (3,3), padding='same'))
        model.add(layers.Activation('relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same'))
    
        # 3rd layer, (16,16,64)
        model.add(layers.Conv2D(128, (3,3), padding='same'))
        model.add(layers.Activation('relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.Dropout(0.4))
                  
        # 4th layer, (16,16,128)
        model.add(layers.Conv2D(128, (3,3), padding='same'))
        model.add(layers.Activation('relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.MaxPooling2D(pool_size=(2,2)))
        
        # 5th layer, (8,8,128)
        model.add(layers.Conv2D(256, (3, 3), padding='same'))
        model.add(layers.Activation('relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.Dropout(0.4))
    
        # 6th layer, (8,8,256)
        model.add(layers.Flatten())
        model.add(layers.Dense(512))
        model.add(layers.Activation('relu'))
        model.add(layers.BatchNormalization())
        
        #7th layer, 512
        model.add(layers.Dropout(0.5))
        model.add(layers.Dense(10))
        model.add(layers.Activation('softmax'))
        
            
        model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
        
        return model
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    这里构建5层卷积层,加上2层全连接层。调用函数

    # build model
    network = build_model()
    
    # show network summary
    network.summary()
    
    • 1
    • 2
    • 3
    • 4
    • 5

    显示结果如下

    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    conv2d (Conv2D)              (None, 32, 32, 64)        1792      
    _________________________________________________________________
    activation (Activation)      (None, 32, 32, 64)        0         
    _________________________________________________________________
    batch_normalization (BatchNo (None, 32, 32, 64)        256       
    _________________________________________________________________
    dropout (Dropout)            (None, 32, 32, 64)        0         
    _________________________________________________________________
    conv2d_1 (Conv2D)            (None, 32, 32, 64)        36928     
    _________________________________________________________________
    activation_1 (Activation)    (None, 32, 32, 64)        0         
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 32, 32, 64)        256       
    _________________________________________________________________
    max_pooling2d (MaxPooling2D) (None, 16, 16, 64)        0         
    _________________________________________________________________
    conv2d_2 (Conv2D)            (None, 16, 16, 128)       73856     
    _________________________________________________________________
    activation_2 (Activation)    (None, 16, 16, 128)       0         
    _________________________________________________________________
    batch_normalization_2 (Batch (None, 16, 16, 128)       512       
    _________________________________________________________________
    dropout_1 (Dropout)          (None, 16, 16, 128)       0         
    _________________________________________________________________
    conv2d_3 (Conv2D)            (None, 16, 16, 128)       147584    
    _________________________________________________________________
    activation_3 (Activation)    (None, 16, 16, 128)       0         
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 16, 16, 128)       512       
    _________________________________________________________________
    max_pooling2d_1 (MaxPooling2 (None, 8, 8, 128)         0         
    _________________________________________________________________
    conv2d_4 (Conv2D)            (None, 8, 8, 256)         295168    
    _________________________________________________________________
    activation_4 (Activation)    (None, 8, 8, 256)         0         
    _________________________________________________________________
    batch_normalization_4 (Batch (None, 8, 8, 256)         1024      
    _________________________________________________________________
    dropout_2 (Dropout)          (None, 8, 8, 256)         0         
    _________________________________________________________________
    flatten (Flatten)            (None, 16384)             0         
    _________________________________________________________________
    dense (Dense)                (None, 512)               8389120   
    _________________________________________________________________
    activation_5 (Activation)    (None, 512)               0         
    _________________________________________________________________
    batch_normalization_5 (Batch (None, 512)               2048      
    _________________________________________________________________
    dropout_3 (Dropout)          (None, 512)               0         
    _________________________________________________________________
    dense_1 (Dense)              (None, 10)                5130      
    _________________________________________________________________
    activation_6 (Activation)    (None, 10)                0         
    =================================================================
    Total params: 8,954,186
    Trainable params: 8,951,882
    Non-trainable params: 2,304
    _________________________________________________________________
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62

    4 训练模型

    调用函数训练

    # train model
    history = network.fit(train_input, train_output, epochs=30, batch_size=256, validation_split=0.1)
    
    • 1
    • 2

    训练30次,batch_size为256,训练结果显示如下

    Epoch 1/30
    176/176 [==============================] - 15s 62ms/step - loss: 1.7285 - sparse_categorical_accuracy: 0.4502 - val_loss: 4.3244 - val_sparse_categorical_accuracy: 0.1566
    Epoch 2/30
    176/176 [==============================] - 9s 53ms/step - loss: 1.0967 - sparse_categorical_accuracy: 0.6184 - val_loss: 4.2188 - val_sparse_categorical_accuracy: 0.2258
    Epoch 3/30
    176/176 [==============================] - 9s 53ms/step - loss: 0.8392 - sparse_categorical_accuracy: 0.7048 - val_loss: 1.5962 - val_sparse_categorical_accuracy: 0.5262
    Epoch 4/30
    176/176 [==============================] - 9s 53ms/step - loss: 0.6949 - sparse_categorical_accuracy: 0.7549 - val_loss: 0.7939 - val_sparse_categorical_accuracy: 0.7430
    Epoch 5/30
    176/176 [==============================] - 9s 53ms/step - loss: 0.5994 - sparse_categorical_accuracy: 0.7880 - val_loss: 0.9743 - val_sparse_categorical_accuracy: 0.7168
    Epoch 6/30
    176/176 [==============================] - 9s 54ms/step - loss: 0.5187 - sparse_categorical_accuracy: 0.8173 - val_loss: 0.8175 - val_sparse_categorical_accuracy: 0.7520
    Epoch 7/30
    176/176 [==============================] - 9s 54ms/step - loss: 0.4544 - sparse_categorical_accuracy: 0.8398 - val_loss: 0.8302 - val_sparse_categorical_accuracy: 0.7544
    Epoch 8/30
    176/176 [==============================] - 9s 54ms/step - loss: 0.3901 - sparse_categorical_accuracy: 0.8629 - val_loss: 0.7184 - val_sparse_categorical_accuracy: 0.7934
    Epoch 9/30
    176/176 [==============================] - 9s 54ms/step - loss: 0.3390 - sparse_categorical_accuracy: 0.8794 - val_loss: 0.8141 - val_sparse_categorical_accuracy: 0.7962
    Epoch 10/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.2886 - sparse_categorical_accuracy: 0.8964 - val_loss: 0.9829 - val_sparse_categorical_accuracy: 0.7804
    Epoch 11/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.2630 - sparse_categorical_accuracy: 0.9075 - val_loss: 0.7088 - val_sparse_categorical_accuracy: 0.8034
    Epoch 12/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.2362 - sparse_categorical_accuracy: 0.9164 - val_loss: 0.5813 - val_sparse_categorical_accuracy: 0.8336
    Epoch 13/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.2086 - sparse_categorical_accuracy: 0.9269 - val_loss: 0.7702 - val_sparse_categorical_accuracy: 0.8014
    Epoch 14/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.1860 - sparse_categorical_accuracy: 0.9345 - val_loss: 0.7444 - val_sparse_categorical_accuracy: 0.8254
    Epoch 15/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.1748 - sparse_categorical_accuracy: 0.9398 - val_loss: 0.7130 - val_sparse_categorical_accuracy: 0.8184
    Epoch 16/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.1582 - sparse_categorical_accuracy: 0.9443 - val_loss: 0.7712 - val_sparse_categorical_accuracy: 0.8226
    Epoch 17/30
    176/176 [==============================] - 10s 55ms/step - loss: 0.1459 - sparse_categorical_accuracy: 0.9488 - val_loss: 0.8808 - val_sparse_categorical_accuracy: 0.8086
    Epoch 18/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.1329 - sparse_categorical_accuracy: 0.9530 - val_loss: 0.7062 - val_sparse_categorical_accuracy: 0.8340
    Epoch 19/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.1323 - sparse_categorical_accuracy: 0.9538 - val_loss: 0.6216 - val_sparse_categorical_accuracy: 0.8380
    Epoch 20/30
    176/176 [==============================] - 10s 55ms/step - loss: 0.1243 - sparse_categorical_accuracy: 0.9575 - val_loss: 0.6749 - val_sparse_categorical_accuracy: 0.8334
    Epoch 21/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.1206 - sparse_categorical_accuracy: 0.9586 - val_loss: 0.7408 - val_sparse_categorical_accuracy: 0.8268
    Epoch 22/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.1105 - sparse_categorical_accuracy: 0.9615 - val_loss: 0.7999 - val_sparse_categorical_accuracy: 0.8314
    Epoch 23/30
    176/176 [==============================] - 10s 55ms/step - loss: 0.1064 - sparse_categorical_accuracy: 0.9633 - val_loss: 0.6867 - val_sparse_categorical_accuracy: 0.8396
    Epoch 24/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.0974 - sparse_categorical_accuracy: 0.9655 - val_loss: 0.6695 - val_sparse_categorical_accuracy: 0.8422
    Epoch 25/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.0908 - sparse_categorical_accuracy: 0.9687 - val_loss: 0.7222 - val_sparse_categorical_accuracy: 0.8306
    Epoch 26/30
    176/176 [==============================] - 10s 54ms/step - loss: 0.0907 - sparse_categorical_accuracy: 0.9689 - val_loss: 0.6841 - val_sparse_categorical_accuracy: 0.8384
    Epoch 27/30
    176/176 [==============================] - 10s 55ms/step - loss: 0.0866 - sparse_categorical_accuracy: 0.9696 - val_loss: 0.8356 - val_sparse_categorical_accuracy: 0.8286
    Epoch 28/30
    176/176 [==============================] - 10s 55ms/step - loss: 0.0898 - sparse_categorical_accuracy: 0.9690 - val_loss: 0.6899 - val_sparse_categorical_accuracy: 0.8392
    Epoch 29/30
    176/176 [==============================] - 10s 55ms/step - loss: 0.0867 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.7572 - val_sparse_categorical_accuracy: 0.8338
    Epoch 30/30
    176/176 [==============================] - 10s 55ms/step - loss: 0.0796 - sparse_categorical_accuracy: 0.9728 - val_loss: 0.7699 - val_sparse_categorical_accuracy: 0.8336
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60

    经过训练,得到0.9728的训练精度和0.8336的测试精度。对训练过程进行绘制,如下

    # plot train history
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    acc = history.history['sparse_categorical_accuracy']
    val_acc = history.history['val_sparse_categorical_accuracy']
    
    plt.figure(figsize=(10,3))
    
    plt.subplot(1,2,1)
    plt.plot(loss, color='blue', label='train')
    plt.plot(val_loss, color='red', label='test')
    plt.ylabel('loss')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(acc, color='blue', label='train')
    plt.plot(val_acc, color='red', label='test')
    plt.ylabel('accuracy')
    plt.legend()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    显示如下的结果
    在这里插入图片描述
    训练和测试的准确度和损失函数,都随着训练次数的增加,逐渐优化。显示在泛化能力上要好于之前的模型。

    5 测试训练模型

    # evaluate model
    network.evaluate(test_input, test_output, verbose=2)
    
    • 1
    • 2

    显示在测试集上的准确度和损失函数

    313/313 - 1s - loss: 0.7931 - sparse_categorical_accuracy: 0.8315
    [0.7930535078048706, 0.8314999938011169]
    
    • 1
    • 2

    结果和训练给出的性能指标接近。绘制测试集前100张图片的测试结果

    # predict on test data
    predict_output = network.predict(test_input)
    
    # lines and columns of subplots
    m = 10
    n = 10
    num = m*n
    
    # labels of category
    labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    
    # figure size
    plt.figure(figsize=(15,15))
    
    # plot first 100 pictures and results in test images
    for i in range(num):
        plt.subplot(m,n,i+1)
        
        type_index = np.argmax(predict_output[i]);
        label = labels[type_index]
        
        clr = 'black' if type_index == test_labels[i] else 'red'
                     
        plt.imshow(test_images[i])
        #plt.axis('off')
        plt.xticks([])
        plt.yticks([])
        
        plt.xlabel(label, color=clr)
    
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    最后图片显示结果如下

    在这里插入图片描述

    红色表示错误的识别结果,基本上和测试给出的准确率指标相当。相比之下,这个类型的神经网络具有较好的识别性能和泛化能力。

    参考链接:https://blog.csdn.net/Mind_programmonkey/article/details/121049217

  • 相关阅读:
    电子学会青少年软件编程 Python编程等级考试三级真题解析(判断题)2020年12月
    矿物结构和构造的区别
    合并K个升序链表
    Himall商城LinqHelper帮助类(1)
    电脑重装系统后usbcleaner怎么格式化u盘
    Java的平台无关性
    贰[2],QT异常处理
    Maven
    【建议收藏】逻辑回归面试题,机器学习干货、重点。
    Sentinel介绍与使用 收藏起来
  • 原文地址:https://blog.csdn.net/mbdong/article/details/127724152