• Tensorflow入门实战 T06-Vgg16 明星识别


    目录

    1、前言

    2、 完整代码

    3、运行过程+结果

    4、遇到的问题

    5、小结


    1、前言

    这周主要是使用VGG16模型,完成明星照片识别。

    2、 完整代码

    1. from keras.utils import losses_utils
    2. from tensorflow import keras
    3. from keras import layers, models
    4. import os, PIL, pathlib
    5. import matplotlib.pyplot as plt
    6. import tensorflow as tf
    7. import numpy as np
    8. from keras.callbacks import ModelCheckpoint, EarlyStopping
    9. gpus = tf.config.list_physical_devices("GPU")
    10. if gpus:
    11. gpu0 = gpus[0] # 如果有多个GPU,仅使用第0个GPU
    12. tf.config.experimental.set_memory_growth(gpu0, True) # 设置GPU显存用量按需使用
    13. tf.config.set_visible_devices([gpu0], "GPU")
    14. # 导入数据
    15. data_dir = "/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data"
    16. data_dir = pathlib.Path(data_dir)
    17. # 查看数据
    18. image_count = len(list(data_dir.glob('*/*.jpg')))
    19. print("图片总数为:",image_count) # 1800
    20. roses = list(data_dir.glob('Jennifer Lawrence/*.jpg'))
    21. img = PIL.Image.open(str(roses[0]))
    22. # img.show() # 查看图片
    23. # 数据预处理
    24. # 1、加载数据
    25. batch_size = 32
    26. img_height = 224
    27. img_width = 224
    28. print('data_dir======>',data_dir)
    29. """
    30. 关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
    31. """
    32. train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    33. data_dir,
    34. validation_split=0.1,
    35. subset="training",
    36. label_mode="categorical",
    37. seed=123,
    38. image_size=(img_height, img_width),
    39. batch_size=batch_size)
    40. """
    41. 关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
    42. """
    43. val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    44. data_dir,
    45. validation_split=0.1,
    46. subset="validation",
    47. label_mode="categorical",
    48. seed=123,
    49. image_size=(img_height, img_width),
    50. batch_size=batch_size)
    51. class_names = train_ds.class_names
    52. print(class_names)
    53. # 可视化数据
    54. plt.figure(figsize=(20, 10))
    55. for images, labels in train_ds.take(1):
    56. for i in range(20):
    57. ax = plt.subplot(5, 10, i + 1)
    58. plt.imshow(images[i].numpy().astype("uint8"))
    59. plt.title(class_names[np.argmax(labels[i])])
    60. plt.axis("off")
    61. plt.show()
    62. # 再次检查数据
    63. for image_batch, labels_batch in train_ds:
    64. print(image_batch.shape) # (32, 224, 224, 3)
    65. print(labels_batch.shape) # (32, 17)
    66. break
    67. # 配置数据集
    68. AUTOTUNE = tf.data.AUTOTUNE
    69. train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    70. val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    71. # 构建CNN网络
    72. """
    73. 关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995
    74. layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
    75. 关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
    76. """
    77. model = models.Sequential([
    78. keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),
    79. layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3
    80. layers.AveragePooling2D((2, 2)), # 池化层1,2*2采样
    81. layers.Conv2D(32, (3, 3), activation='relu'), # 卷积层2,卷积核3*3
    82. layers.AveragePooling2D((2, 2)), # 池化层2,2*2采样
    83. layers.Dropout(0.5),
    84. layers.Conv2D(64, (3, 3), activation='relu'), # 卷积层3,卷积核3*3
    85. layers.AveragePooling2D((2, 2)),
    86. layers.Dropout(0.5),
    87. layers.Conv2D(128, (3, 3), activation='relu'), # 卷积层3,卷积核3*3
    88. layers.Dropout(0.5),
    89. layers.Flatten(), # Flatten层,连接卷积层与全连接层
    90. layers.Dense(128, activation='relu'), # 全连接层,特征进一步提取
    91. layers.Dense(len(class_names)) # 输出层,输出预期结果
    92. ])
    93. # model.summary() # 打印网络结构
    94. # 训练模型
    95. # 1、设置动态学习率
    96. # 设置初始学习率
    97. initial_learning_rate = 1e-4
    98. lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    99. initial_learning_rate,
    100. decay_steps=60, # 敲黑板!!!这里是指 steps,不是指epochs
    101. decay_rate=0.96, # lr经过一次衰减就会变成 decay_rate*lr
    102. staircase=True)
    103. # 将指数衰减学习率送入优化器
    104. optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    105. model.compile(optimizer=optimizer,
    106. loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    107. metrics=['accuracy'])
    108. # 损失函数
    109. # 调用方式1:
    110. model.compile(optimizer="adam",
    111. loss='categorical_crossentropy',
    112. metrics=['accuracy'])
    113. # 调用方式2:
    114. # model.compile(optimizer="adam",
    115. # loss=tf.keras.losses.CategoricalCrossentropy(),
    116. # metrics=['accuracy'])
    117. # sparse_categorical_crossentropy(稀疏性多分类的对数损失函数)
    118. # 调用方式1:
    119. model.compile(optimizer="adam",
    120. loss='categorical_crossentropy',
    121. metrics=['accuracy'])
    122. # ↑↑↑↑这里出现报错,需要将 sparse_categorical_crossentropy 改成→ categorical_crossentropy↑↑
    123. # 调用方式2:
    124. # model.compile(optimizer="adam",
    125. # loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    126. # metrics=['accuracy'])
    127. # 函数原型
    128. tf.keras.losses.SparseCategoricalCrossentropy(
    129. from_logits=False,
    130. reduction=losses_utils.ReductionV2.AUTO,
    131. name='sparse_categorical_crossentropy'
    132. )
    133. epochs = 100
    134. # 保存最佳模型参数
    135. checkpointer = ModelCheckpoint('best_model.h5',
    136. monitor='val_accuracy',
    137. verbose=1,
    138. save_best_only=True,
    139. save_weights_only=True)
    140. # 设置早停
    141. earlystopper = EarlyStopping(monitor='val_accuracy',
    142. min_delta=0.001,
    143. patience=20,
    144. verbose=1)
    145. # 网络模型训练
    146. history = model.fit(train_ds,
    147. validation_data=val_ds,
    148. epochs=epochs,
    149. callbacks=[checkpointer, earlystopper])
    150. # 模型评估
    151. acc = history.history['accuracy']
    152. val_acc = history.history['val_accuracy']
    153. loss = history.history['loss']
    154. val_loss = history.history['val_loss']
    155. epochs_range = range(len(loss))
    156. plt.figure(figsize=(12, 4))
    157. plt.subplot(1, 2, 1)
    158. plt.plot(epochs_range, acc, label='Training Accuracy')
    159. plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    160. plt.legend(loc='lower right')
    161. plt.title('Training and Validation Accuracy')
    162. plt.subplot(1, 2, 2)
    163. plt.plot(epochs_range, loss, label='Training Loss')
    164. plt.plot(epochs_range, val_loss, label='Validation Loss')
    165. plt.legend(loc='upper right')
    166. plt.title('Training and Validation Loss')
    167. plt.show()
    168. # 指定图片进行预测
    169. # 加载效果最好的模型权重
    170. model.load_weights('best_model.h5')
    171. from PIL import Image
    172. import numpy as np
    173. img = Image.open("/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p06_vgg16/data/Jennifer Lawrence/003_963a3627.jpg") #这里选择你需要预测的图片
    174. image = tf.image.resize(img, [img_height, img_width])
    175. img_array = tf.expand_dims(image, 0)
    176. predictions = model.predict(img_array) # 这里选用你已经训练好的模型
    177. print("预测结果为:",class_names[np.argmax(predictions)])

    3、运行过程+结果

    【查看图片】

    【模型运行过程---第21epoch就早停了】

    【训练精度、损失-----显然结果很很差】

    4、遇到的问题

    ① 在运行代码的时候遇到报错:

    错误:Graph execution error: Detected at node 'sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits' defined at (most recent call last):

    出现这个问题来自我们使用的损失函数。

    1. model.compile(optimizer="adam",
    2. loss='sparse_categorical_crossentropy',
    3. metrics=['accuracy'])

    解决办法:

    将损失函数里面的loss='sparse_categorical_crossentropy' 改成 'categorical_crossentropy',即可解决报错问题。

    关于sparse_categorical_crossentropy和categorical_crossentropy的更多细节,详细参考这篇博文:交叉熵损失_多分类交叉熵损失函数-CSDN博客

    5、小结

    原始模型,跑出来效果很差很差!!!

    (1)将原来的Adam优化器换成SGD优化器,效果如下:

    (2)后续再补充,最近在写结课论文,有些忙。

  • 相关阅读:
    用HTML+CSS做一个漂亮简单的个人网页——樱木花道篮球3个页面 学生个人网页设计作品 学生个人网页模板 简单个人主页
    YOLO目标检测——交通标志数据集+已标注voc和yolo格式标签下载分享
    用两个栈实现一个队列
    2652. 倍数求和
    基于Springboot高校毕业生招聘管理信息系统-计算机毕设 附源码 28393
    干货!解决IDEA中项目出现cannot resolve method ‘XXXXX(java.lang.String)’问题
    nginx 发布vue项目 页面刷新出现404问题
    持安科技入选数说安全《2023中国网络安全市场年度报告》
    【vue】0到1的常规vue3项目起步
    2022债市波动分析
  • 原文地址:https://blog.csdn.net/Miss_liangrm/article/details/139946653