• RK3568笔记四:基于TensorFlow花卉图像分类部署


    若该文为原创文章,转载请注明原文出处。

    基于正点原子的ATK-DLRK3568部署测试。

    花卉图像分类任务,使用使用 tf.keras.Sequential 模型,简单构建模型,然后转换成 RKNN 模型部署到ATK-DLRK3568板子上。

    在 PC 使用 Windows 系统安装 tensorflow,并创建虚拟环境进行训练,然后切换到VM下的RK3568环境,使用rknn-toolkit2把模型转成rknn模型部署到RK3568板子上测试。

    一、介绍

           TensorFlow 是一个基于数据流编程(dataflow programming)的符号数学系统,被广泛应用于机器学习(machine learning)算法的编程实现,其前身是谷歌的神经网络算法库 DistBelief。

    使用 tf.keras.Sequential 模型对花卉图像进行分类。

    二、环境搭建

    1、创建虚拟环境

     conda create -n tensorflow_env python=3.8 -y

    2、激活环境

    conda activate tensorflow_env

    3、安装环境

    1. pip install numpy
    2. pip install tensorflow
    3. pip install pillow

    三、训练

    1、下载数据集

    https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

    数据集不好下载,自行处理。

    2、训练

    tensorflow_classification.py

    1. import numpy as np
    2. import tensorflow as tf
    3. from tensorflow import keras
    4. from tensorflow.keras import layers
    5. from tensorflow.keras.models import Sequential
    6. # 获取
    7. import pathlib
    8. #dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    9. #data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    10. data_dir = './flower_photos'
    11. data_dir = pathlib.Path(data_dir)
    12. batch_size = 32
    13. img_height = 180
    14. img_width = 180
    15. # 划分数据
    16. train_ds = tf.keras.utils.image_dataset_from_directory(
    17. data_dir,
    18. validation_split=0.2,
    19. subset="training",
    20. seed=123,
    21. image_size=(img_height, img_width),
    22. batch_size=batch_size)
    23. val_ds = tf.keras.utils.image_dataset_from_directory(
    24. data_dir,
    25. validation_split=0.2,
    26. subset="validation",
    27. seed=123,
    28. image_size=(img_height, img_width),
    29. batch_size=batch_size)
    30. class_names = train_ds.class_names
    31. #print(class_names)
    32. # 处理数据
    33. normalization_layer = layers.Rescaling(1./255)
    34. train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
    35. val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
    36. num_classes = len(class_names)
    37. data_augmentation = keras.Sequential(
    38. [
    39. layers.RandomFlip("horizontal",
    40. input_shape=(img_height,
    41. img_width,
    42. 3)),
    43. layers.RandomRotation(0.1),
    44. layers.RandomZoom(0.1),
    45. ]
    46. )
    47. model = Sequential([
    48. data_augmentation,
    49. layers.Conv2D(16, 3, padding='same', activation='relu'),
    50. layers.MaxPooling2D(),
    51. layers.Conv2D(32, 3, padding='same', activation='relu'),
    52. layers.MaxPooling2D(),
    53. layers.Conv2D(64, 3, padding='same', activation='relu'),
    54. layers.MaxPooling2D(),
    55. layers.Dropout(0.2),
    56. layers.Flatten(),
    57. layers.Dense(128, activation='relu'),
    58. layers.Dense(num_classes, name="outputs")
    59. ])
    60. model.compile(optimizer='adam',
    61. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    62. metrics=['accuracy'])
    63. model.summary()
    64. # 训练模型
    65. epochs=15
    66. history = model.fit(
    67. train_ds,
    68. validation_data=val_ds,
    69. epochs=epochs,
    70. )
    71. # 测试模型
    72. #sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
    73. #sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)
    74. sunflower_path = './test_180.jpg'
    75. img = tf.keras.utils.load_img(
    76. sunflower_path, target_size=(img_height, img_width)
    77. )
    78. img_array = tf.keras.utils.img_to_array(img)
    79. img_array = tf.expand_dims(img_array, 0) # Create a batch
    80. predictions = model.predict(img_array)
    81. score = tf.nn.softmax(predictions[0])
    82. print(
    83. "This image most likely belongs to {} with a {:.2f} percent confidence."
    84. .format(class_names[np.argmax(score)], 100 * np.max(score))
    85. )
    86. # Convert the model.
    87. converter = tf.lite.TFLiteConverter.from_keras_model(model)
    88. tflite_model = converter.convert()
    89. # Save the model.
    90. with open('model.tflite', 'wb') as f:
    91. f.write(tflite_model)

    代码有点需要注意,代码屏蔽了下载的功能,所以需要预先下载数据集,如果没有下载数据集,就需要把下载的代码开启。

    1. #dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    2. #data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)

    执行下面命令开始训练:

    python tensorflow_classification.py

    等待一会,会生成model.tflite模型文件。

    四、RKNN模型转换

    转换代码通过下面代码:

    rknn_transfer.py

    1. import numpy as np
    2. import cv2
    3. from rknn.api import RKNN
    4. import tensorflow as tf
    5. img_height = 180
    6. img_width = 180
    7. IMG_PATH = 'test.jpg'
    8. class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    9. if __name__ == '__main__':
    10. # Create RKNN object
    11. #rknn = RKNN(verbose='Debug')
    12. rknn = RKNN()
    13. # Pre-process config
    14. print('--> Config model')
    15. rknn.config(mean_values=[0, 0, 0], std_values=[255, 255, 255], target_platform='rk3568')
    16. print('done')
    17. # Load model
    18. print('--> Loading model')
    19. ret = rknn.load_tflite(model='model.tflite')
    20. if ret != 0:
    21. print('Load model failed!')
    22. exit(ret)
    23. print('done')
    24. # Build model
    25. print('--> Building model')
    26. ret = rknn.build(do_quantization=False)
    27. #ret = rknn.build(do_quantization=True,dataset='./dataset.txt')
    28. if ret != 0:
    29. print('Build model failed!')
    30. exit(ret)
    31. print('done')
    32. # Export rknn model
    33. print('--> Export rknn model')
    34. ret = rknn.export_rknn('./model.rknn')
    35. if ret != 0:
    36. print('Export rknn model failed!')
    37. exit(ret)
    38. print('done')
    39. #Init runtime environment
    40. print('--> Init runtime environment')
    41. ret = rknn.init_runtime()
    42. # if ret != 0:
    43. # print('Init runtime environment failed!')
    44. # exit(ret)
    45. print('done')
    46. img = cv2.imread(IMG_PATH)
    47. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    48. img = cv2.resize(img,(180,180))
    49. img = np.expand_dims(img, 0)
    50. #print('--> Accuracy analysis')
    51. #rknn.accuracy_analysis(inputs=['./test.jpg'])
    52. #print('done')
    53. print('--> Running model')
    54. outputs = rknn.inference(inputs=[img])
    55. print(outputs)
    56. outputs = tf.nn.softmax(outputs)
    57. print(outputs)
    58. print(
    59. "This image most likely belongs to {} with a {:.2f} percent confidence."
    60. .format(class_names[np.argmax(outputs)], 100 * np.max(outputs))
    61. )
    62. #print("图像预测是:", class_names[np.argmax(outputs)])
    63. print('--> done')
    64. rknn.release()

    运行后会生成RKNN模型

    五、部署

    把rknnlite_inference.py和图片,及模型model.rknn拷贝到开发板上,终端运行即可。

    rknnlite_inference.py源码:

    1. import numpy as np
    2. import cv2
    3. from rknnlite.api import RKNNLite
    4. IMG_PATH = 'test.jpg'
    5. RKNN_MODEL = 'model.rknn'
    6. img_height = 180
    7. img_width = 180
    8. class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    9. # Create RKNN object
    10. rknn_lite = RKNNLite()
    11. # load RKNN model
    12. print('--> Load RKNN model')
    13. ret = rknn_lite.load_rknn(RKNN_MODEL)
    14. if ret != 0:
    15. print('Load RKNN model failed')
    16. exit(ret)
    17. print('done')
    18. # Init runtime environment
    19. print('--> Init runtime environment')
    20. ret = rknn_lite.init_runtime()
    21. if ret != 0:
    22. print('Init runtime environment failed!')
    23. exit(ret)
    24. print('done')
    25. # load image
    26. img = cv2.imread(IMG_PATH)
    27. img = cv2.resize(img,(180,180))
    28. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    29. img = np.expand_dims(img, 0)
    30. # runing model
    31. print('--> Running model')
    32. outputs = rknn_lite.inference(inputs=[img])
    33. print("result: ", outputs)
    34. print(
    35. "This image most likely belongs to {}."
    36. .format(class_names[np.argmax(outputs)])
    37. )
    38. rknn_lite.release()

    终端中执行:python rknnlite_inference.py

    结果识别为sunflowers。

    如有侵权,或需要完整代码,请及时联系博主。

  • 相关阅读:
    束测后台实操文档2-OpenWrt
    JS 数据结构:队列
    离散小波变换(概念与应用)
    操作系统 day10(调度的概念、层次、七状态模型)
    跟着cherno手搓游戏引擎【27】升级2DRenderer(添加旋转)
    阿里P8高级专家,耗时多年整理SpringBoot指南文档
    函数的参数
    java-后端调用第三方接口返回图片流给前端
    基于Arduino开发板的太阳能灯光控制器
    java通过minio下载pdf附件
  • 原文地址:https://blog.csdn.net/weixin_38807927/article/details/133951001