• LeNet-5网络结构详解和minist手写数字识别项目实践


    参考论文:Gradient-Based Learning Applied to Document Recognition

    组成

    ​ 网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。

    输入层:32×32(输入层不计入网络层数)

    卷积层1:6个5×5的卷积核,步长为1——>(6,28,28)

    池化层1:MaxPooling,(2,2)——>(6,14,14)

    卷积层2:16个5×5的卷积核,步长为1——>(16,10,10)

    池化层2:MaxPooling,(2,2)——>(16,5,5)

    全连接层1:120,——>16×5×5×120+120

    全连接层2:84,——>120×84+84

    全连接层3:10,——>84*10+10

    在这里插入图片描述

    特点

    1. 第一次将卷积神经网络应用于实际操作中,是通过梯度下降训练卷积神经网络的鼻祖算法之一;
    2. 奠定了卷积神经网络的基本结构,即卷积、非线性激活函数、池化、全连接;
    3. 使用局部感受野,权值共享,池化(下采样)来实现图像的平移,缩放和形变的不变性,其中卷积层用
    4. 来识别图像里的空间模式,如线条和物体局部特征,最大池化层则用来降低卷积层对位置的敏感性;

    lenet5-minist手写数字识别项目实践

    import tensorflow  as tf
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Input, Dense, Activation, Conv2D, MaxPooling2D, Flatten
    from tensorflow.keras.datasets import mnist
    
    # 加载mnist数据
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    print(type(x_train)) # 
    print(x_train.shape) # (60000, 28, 28)
    print(x_test.shape) # (10000, 28, 28)
    print(y_train.shape)#(60000,)
    print(y_test.shape)#(10000,)
    
    # 数据类型转换,uint8 to float32,(60000, 28, 28) to (60000, 28, 28, 1)
    x_train = x_train.reshape(-1, 28, 28, 1)
    print(x_train.dtype) # uint8
    x_train = x_train.astype('float32')
    print(x_train.dtype) # float32
    print(x_train.shape)#uint8类型的(60000, 28, 28)变成float32的(60000, 28, 28, 1)
    y_train = y_train.astype('float32')
    x_test = x_test.reshape(-1, 28, 28, 1)
    x_test = x_test.astype('float32')
    y_test = y_test.astype('float32')
    print(y_train)
    
    # x数据归一化,255为像素最大值
    x_train /= 255
    x_test /= 255
    
    #label为0~9共10个类别,将其转换成one-hot编码
    from tensorflow.python.keras.utils.np_utils import to_categorical
    y_train_new = to_categorical(num_classes=10, y=y_train)
    print(y_train_new)
    y_test_new = to_categorical(num_classes=10, y=y_test)
    
    
    # LeNet_5网络结构搭建
    def LeNet_5():
        model = Sequential()
        model.add(Conv2D(filters=6, kernel_size=(5, 5), padding='valid', activation='tanh', input_shape=[28, 28, 1]))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Conv2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Flatten())
        model.add(Dense(120, activation='tanh'))
        model.add(Dense(84, activation='tanh'))
        model.add(Dense(10, activation='softmax'))
        return model
    
    
    # LeNet_5模型训练和编译
    def train_model(xtr,ytr,batch_size,epochs,n_val):
        model = LeNet_5()
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        tensorboard = tf.keras.callbacks.TensorBoard(histogram_freq=1)
        model.fit(xtr, ytr, batch_size=batch_size, epochs=epochs, validation_split=n_val, shuffle=True,callbacks=[tensorboard])
        #shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
        return model,tensorboard
    
    
    model,tensorboard = train_model(x_train, y_train_new,64,20,0.2)
    tf.saved_model.save(model,'LeNet_5-1')
    # 返回测试集损失函数值和准确率
    loss1, accuracy1 = model.evaluate(x_train, y_train_new)
    loss2, accuracy2 = model.evaluate(x_test, y_test_new)
    print(loss1, accuracy1)#0.027162624523043633 0.9929999709129333
    print(loss2, accuracy2)#0.061963751912117004 0.9837999939918518
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    tensorboard --logdir=C:\Users\ThinkStation\Desktop\logs\train
    
    • 1

    在这里插入图片描述
    在这里插入图片描述
    模型保存
    在这里插入图片描述

  • 相关阅读:
    代码随想录第39天 | ● 198.打家劫舍 ● 213.打家劫舍II ● 337.打家劫舍III
    利用地质年代图谱精准判读文献中的地质时间
    低资源场景下的命名实体识别
    MySQL导出csv数据文件
    JS使用工具函数
    Python 在Word中查找并高亮指定文本
    浅谈原型链
    2022年最新甘肃建筑施工焊工(建筑特种作业)模拟题库及答案解析
    【Android】固件结构
    2022暑期复习-Day6
  • 原文地址:https://blog.csdn.net/xiaofeixia002X/article/details/128206542