• tensorflow跑手写体实验


    目录

    1、环境条件

    2、代码实现

    3、总结


    1、环境条件

    1. pycharm编译器
    2. python3.0环境
    3. tensorflow2.0依赖
    4. matplotlib依赖(用于画图)

    2、代码实现

    1. import tensorflow as tf
    2. from tensorflow.keras.datasets import mnist
    3. from tensorflow.keras.preprocessing import image
    4. import numpy as np
    5. import matplotlib.pyplot as plt
    6. # 加载并预处理 MNIST 数据集
    7. (x_train, y_train), (x_test, y_test) = mnist.load_data()
    8. x_train, x_test = x_train / 255.0, x_test / 255.0
    9. print(x_train)
    10. print(x_test)
    11. # 构建 LeNet-5 模型
    12. model = tf.keras.models.Sequential([
    13. tf.keras.layers.Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)),
    14. tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    15. tf.keras.layers.Conv2D(64, kernel_size=(5, 5), activation='relu'),
    16. tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    17. tf.keras.layers.Flatten(),
    18. tf.keras.layers.Dense(120, activation='relu'),
    19. tf.keras.layers.Dense(84, activation='relu'),
    20. tf.keras.layers.Dense(10, activation='softmax')
    21. ])
    22. model.compile(optimizer='adam',
    23. loss='sparse_categorical_crossentropy',
    24. metrics=['accuracy'])
    25. # 重塑数据以适应模型
    26. x_train = x_train.reshape(-1, 28, 28, 1)
    27. x_test = x_test.reshape(-1, 28, 28, 1)
    28. # 训练模型
    29. model.fit(x_train, y_train, epochs=5)
    30. # 评估模型
    31. test_loss, test_acc = model.evaluate(x_test, y_test)
    32. print(f'测试准确率: {test_acc}')
    33. # 保存模型
    34. model.save('lenet-5_model.h5')
    35. print('模型已保存至 lenet-5_model.h5')
    36. # 加载模型
    37. loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
    38. print('模型已加载')
    39. # 加载并预处理本地图片
    40. def load_and_preprocess_image(image_path):
    41. img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
    42. img_array = image.img_to_array(img)
    43. img_array = img_array / 255.0 # 归一化
    44. img_array = np.expand_dims(img_array, axis=0) # 添加批次维度
    45. return img_array
    46. # 预测本地图片
    47. image_path = '4.png' # 替换为你的本地图片路径
    48. img_array = load_and_preprocess_image(image_path)
    49. # 使用加载的模型进行预测
    50. predictions = loaded_model.predict(img_array)
    51. predicted_label = np.argmax(predictions)
    52. # 打印预测结果
    53. print(f'预测结果: {predicted_label}')
    54. # 显示图片
    55. plt.imshow(img_array[0, :, :, 0], cmap='gray')
    56. plt.title(f'预测结果: {predicted_label}')
    57. plt.show()

            解释:image_path为本地图片路径,通过model.save()方法实现模型的保存功能,下次预测使用的时候直接使用训练好的模型即可。下面将给出可直接预测的代码:

    1. import tensorflow as tf
    2. from tensorflow.keras.preprocessing import image
    3. import numpy as np
    4. import matplotlib.pyplot as plt
    5. from matplotlib.font_manager import FontProperties
    6. # 加载模型
    7. loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
    8. print('模型已加载')
    9. # 加载并预处理本地图片
    10. def load_and_preprocess_image(image_path):
    11. img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
    12. img_array = image.img_to_array(img)
    13. img_array = img_array / 255.0 # 归一化
    14. img_array = np.expand_dims(img_array, axis=0) # 添加批次维度
    15. return img_array
    16. # 预测本地图片
    17. image_path = '7.png' # 替换为你的本地图片路径
    18. img_array = load_and_preprocess_image(image_path)
    19. # 使用加载的模型进行预测
    20. predictions = loaded_model.predict(img_array)
    21. predicted_label = np.argmax(predictions)
    22. # 打印预测结果
    23. print(f'预测结果: {predicted_label}')
    24. # 设置支持中文的字体
    25. font_path = "C:/Windows/Fonts/simhei.ttf" # 替换为你的字体路径,例如 SimHei.ttf
    26. font_prop = FontProperties(fname=font_path)
    27. # 显示图片
    28. plt.imshow(img_array[0, :, :, 0], cmap='gray')
    29. plt.title(f'预测结果: {predicted_label}', fontproperties=font_prop)
    30. plt.show()

    3、总结

            使用tensorflow完成手写体图片的识别功能,其主要难点在安装依赖环境,其他的都是比较简单的事情。

    学习之所以会想睡觉,是因为那是梦开始的地方。
    ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)
                                                                                                            ------不写代码不会凸的小刘

  • 相关阅读:
    JavaEE:网络初识
    Avalonia开发(一)环境搭建
    LNMP架构
    Leetcode581. 最短无序连续子数组
    如果你会玩这4个自媒体运营工具,副业收入6000+很轻松
    VSC/SMC(十五)——基于模糊逼近的积分滑模控制
    jupyter使用教程及python语法基础
    go: 如何编写一个正确的udp服务端
    【学习笔记】正则表达式及其在VS Code,Word中查找替换的应用
    shell入门第6课 环境变量
  • 原文地址:https://blog.csdn.net/qq_40834643/article/details/140102171