• 卷积神经网络——vgg16网络及其python实现


    1、介绍     

            VGG-16网络包括13个卷积层和3个全连接层,网络结构较LeNet-5等网络变得十分复杂,但同时也有不错的效果。VGG16有强大的拟合能力在当时取得了非常的效果,但同时VGG也有部分不足:
    1、巨大参数量导致训练时间过长,调参难度较大;
    2、模型所需内存容量大,VGG的权值文件很大,用到实际应用会比较困难。

    2、结构原理 

    这是经典的vgg网络,输入图片大小为224*224。

    下面这为官方给出的几种VGG结构图。

     

     现在多用的为D模型。

    简单介绍下过程,输入224*224大小的图片,然后用两次64个3*3的卷积核进行全采集,也就是补零采集,保证特征不丢失,得到64*224*224的特征;池化层得到64*112*112;再利用128个3*3的卷积核进行特征采集两次,得到特征112*112*128;池化得到56*56*128大小特征.........反复这样操作,最后卷积完得到7*7*512的特征,然后利用全连接层进行展开,最后得到1000个特征,随后进行概率分类操作。

    3、python实现

            选用的数据集为fashion数据集,具体请另外了解。数据可直接在库中导入,本文用class网络编写神经网络程序。

    1. class VGG16(Model):
    2. def __init__(self):
    3. super(VGG16, self).__init__()
    4. self.c1 = Conv2D(filters=64, kernel_size=(3, 3), padding='same')
    5. self.b1 = BatchNormalization()
    6. self.a1 = Activation('relu')
    7. self.c2 = Conv2D(filters=64, kernel_size=(3, 3), padding='same', )
    8. self.b2 = BatchNormalization()
    9. self.a2 = Activation('relu')
    10. self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    11. self.d1 = Dropout(0.2)
    12. self.c3 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
    13. self.b3 = BatchNormalization()
    14. self.a3 = Activation('relu')
    15. self.c4 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
    16. self.b4 = BatchNormalization()
    17. self.a4 = Activation('relu')
    18. self.p2 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    19. self.d2 = Dropout(0.2)
    20. self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
    21. self.b5 = BatchNormalization()
    22. self.a5 = Activation('relu')
    23. self.c6 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
    24. self.b6 = BatchNormalization()
    25. self.a6 = Activation('relu')
    26. self.c7 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
    27. self.b7 = BatchNormalization()
    28. self.a7 = Activation('relu')
    29. self.p3 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    30. self.d3 = Dropout(0.2)
    31. self.c8 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    32. self.b8 = BatchNormalization()
    33. self.a8 = Activation('relu')
    34. self.c9 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    35. self.b9 = BatchNormalization()
    36. self.a9 = Activation('relu')
    37. self.c10 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    38. self.b10 = BatchNormalization()
    39. self.a10 = Activation('relu')
    40. self.p4 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    41. self.d4 = Dropout(0.2)
    42. self.c11 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    43. self.b11 = BatchNormalization()
    44. self.a11 = Activation('relu')
    45. self.c12 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    46. self.b12 = BatchNormalization()
    47. self.a12 = Activation('relu')
    48. self.c13 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    49. self.b13 = BatchNormalization()
    50. self.a13 = Activation('relu')
    51. self.p5 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    52. self.d5 = Dropout(0.2)
    53. self.flatten = Flatten()
    54. self.f1 = Dense(512, activation='relu')
    55. self.d6 = Dropout(0.2)
    56. self.f2 = Dense(512, activation='relu')
    57. self.d7 = Dropout(0.2)
    58. self.f3 = Dense(10, activation='softmax')
    59. def call(self, x):
    60. x = self.c1(x)
    61. x = self.b1(x)
    62. x = self.a1(x)
    63. x = self.c2(x)
    64. x = self.b2(x)
    65. x = self.a2(x)
    66. x = self.p1(x)
    67. x = self.d1(x)
    68. x = self.c3(x)
    69. x = self.b3(x)
    70. x = self.a3(x)
    71. x = self.c4(x)
    72. x = self.b4(x)
    73. x = self.a4(x)
    74. x = self.p2(x)
    75. x = self.d2(x)
    76. x = self.c5(x)
    77. x = self.b5(x)
    78. x = self.a5(x)
    79. x = self.c6(x)
    80. x = self.b6(x)
    81. x = self.a6(x)
    82. x = self.c7(x)
    83. x = self.b7(x)
    84. x = self.a7(x)
    85. x = self.p3(x)
    86. x = self.d3(x)
    87. x = self.c8(x)
    88. x = self.b8(x)
    89. x = self.a8(x)
    90. x = self.c9(x)
    91. x = self.b9(x)
    92. x = self.a9(x)
    93. x = self.c10(x)
    94. x = self.b10(x)
    95. x = self.a10(x)
    96. x = self.p4(x)
    97. x = self.d4(x)
    98. x = self.c11(x)
    99. x = self.b11(x)
    100. x = self.a11(x)
    101. x = self.c12(x)
    102. x = self.b12(x)
    103. x = self.a12(x)
    104. x = self.c13(x)
    105. x = self.b13(x)
    106. x = self.a13(x)
    107. x = self.p5(x)
    108. x = self.d5(x)
    109. x = self.flatten(x)
    110. x = self.f1(x)
    111. x = self.d6(x)
    112. x = self.f2(x)
    113. x = self.d7(x)
    114. y = self.f3(x)
    115. return y
    116. model = VGG16()

    读取数据

    1. import tensorflow as tf
    2. import os
    3. import numpy as np
    4. from matplotlib import pyplot as plt
    5. from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
    6. from tensorflow.keras import Model
    7. np.set_printoptions(threshold=np.inf)
    8. fashion = tf.keras.datasets.fashion_mnist
    9. (x_train, y_train), (x_test, y_test) = fashion.load_data()
    10. x_train, x_test = x_train / 255.0, x_test / 255.0
    11. x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
    12. x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

    迭代训练

    1. model.compile(optimizer='adam',
    2. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    3. metrics=['sparse_categorical_accuracy'])
    4. cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
    5. save_weights_only=True,
    6. save_best_only=True)
    7. history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1,
    8. callbacks=[cp_callback])

    绘制结果图

    1. acc = history.history['sparse_categorical_accuracy']
    2. val_acc = history.history['val_sparse_categorical_accuracy']
    3. loss = history.history['loss']
    4. val_loss = history.history['val_loss']
    5. plt.subplot(1, 2, 1)
    6. plt.plot(acc, label='Training Accuracy')
    7. plt.plot(val_acc, label='Validation Accuracy')
    8. plt.title('Training and Validation Accuracy')
    9. plt.legend()
    10. plt.subplot(1, 2, 2)
    11. plt.plot(loss, label='Training Loss')
    12. plt.plot(val_loss, label='Validation Loss')
    13. plt.title('Training and Validation Loss')
    14. plt.legend()
    15. plt.show()

     

     虽然不是很稳定,但总的来说准确率还可以。

  • 相关阅读:
    C++知识点总结(6):高精度乘法真题代码
    获取1688店铺所有商品、店铺列表api
    html页面直接使用elementui Plus时间线 + vue3
    Java数字处理类--数字格式化
    鸿蒙开发接口媒体:【@ohos.multimedia.audio (音频管理)】
    Vue组件路由
    基于国产芯片RK1126的智能视频分析网关
    尿检设备“智能之眼”:维视智造推出MV-MC 系列医疗专用相机
    桥接模式
    Log4j “史诗级 ”漏洞背后:项目只有三位赞助者;RISC-V 基金会加速设计 RISC-V GPU;Linux 5.16 将延期发布 | 开源日报
  • 原文地址:https://blog.csdn.net/abc1234abcdefg/article/details/125495965