• 深度学习零基础学习之路——第三章 数据可视化TensorBoard和TorchVision的介绍


    Python深度学习入门

    第一章 Python深度学习入门之环境软件配置
    第二章 Python深度学习入门之数据处理Dataset的使用
    第三章 数据可视化TensorBoard和TochVision的使用
    第四章 UNet-Family中Unet、Unet++和Unet3+的简介
    第五章 个人数据集的制作



    前言

      我们通过上一节的学习知道了如何使用DataSet类读取文件夹中的图片数据,并对数据进行整理排序。接下来我们就需要将这些数据进行可视化显示,方便对数据进行分析和对比。我在学习之前对于数据的可视化会使用Python的画图库matlibplot,通过这个库我们可以绘制很多图形,比如折线图、柱状图等等。但是这个画图库太麻烦了,需要自己编写代码,今天我们要聊的TensorBoard不需要编写代码就可以显示数据,并支持多种类型的数据,比如离散数据、图片数据等等。


    一、TensorBoard是什么?

      Tensorboard是TensorFlow提供的一组可视化工具,可以帮助开发者方便的理解、调试、优化TensorFlow 程序。为后续的学习提供了巨大的方便。

      说到调试,我们一般的调试都是通过Pycharm在某行代码处打个断点,然后一步一步执行,并关注Pycharm的变量池,找到问题点。关键是有的时候我们的程序并没有错误,只是训练结果并不是很满意,想优化程序,但是在运行过程中我们不知道每一步得到的结果是什么,变量池的数据也不是很直观。这时TensorBoard就可以解决以上问题。TensorBoard可以展示我们程序的训练过程结果,更直观的关注代码的效果。
    在这里插入图片描述

    二、TensorBoard的使用

    1、引入库

    from torch.utils.tensorboard import SummaryWriter
    from PIL import Image
    import numpy as np
    import os
    

    2、读入数据

      我们利用上一章下载的蚂蚁和蜜蜂的数据作为数据集,使用TensorBoard进行显示。

    image_path = 'H:\\learn\\hymenoptera_data\\train\\ants_image' # 此处为数据集的文件夹路径
    image_name = os.listdir(image_path)
    

    该处我们从文件夹中获取到了蚂蚁数据集的图片名称,因为我们要显示图片,所以我们要将每一张图片放入TensorBoard中。

    3、加载数据

      首先我们需要通过SummaryWriter创建一个存储数据的文件夹,然后再向这个文件夹中添加需要显示的数据和图片。

    writer = SummaryWriter('logs')		# 在当前目录下创建logs文件夹,存储需要显示的数据
    #  因为我们的数据都存储在一个列表中,因此我们需要遍历列表将每张图片给加载进去,这里我们测试10张图片
    for i in range(10):
    	image = Image.open(os.path.join(image_path, image_name[i])) # 读取图片
    	image_array = np.array(image)	# 将image图片的类型转换成TensorBoard需要的类型
    	writer.add_image('test', image_array, i, dataformats='HWC')
    
    # 我们再测试一下显示折线图
    for i in range(100):
    	writer.add_scalar("y = 2x",3 * i, i)
    
    writer.close() # 最后记得关闭writer
    

      然后运行代码,可以看到程序运行完后会终止。然后我们打开Terminal,进入到pytorch环境运行以下代码

    (pytorch) PS H:> tensorboard --logdir=‘logs’ --port=‘6006’

      这个 --port=‘6006’ 这条指令可以修改界面显示的端口号,默认是6006。

    在这里插入图片描述
      出现上面的输出,就说明TensorBoard启动成功,我们点击链接或者复制用浏览器打开即可显示数据内容。

    在这里插入图片描述

    4、方法详解

    接下来我们来解析上面代码用到的方法:

    1. add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats=“CHW”)
      tag:该参数是设置图表的标题title,区别于其他图表。
      imge_tensor: 该参数是设置需要显示的图片,并且要注意的是图片的尺寸格式要求,一般是tensor类型的图片。
      global_step:该参数类似于设置每一张图片的角标,就是图片演化步数。
      walltime:这个参数用的比较少,我们用默认值就好了。
      dataformats:设置数据图片格式,默认是’CHW’,意思是(C:通道数,H:高,W:宽)。
    2. add_scalar(self,tag,scalar_value,global_step=None,walltime=None,new_style=False,double_precision=False)
      tag:和add_image()方法的tag是一个意思,设置图表的标题。
      scalar_value:是图表每一步的值,可以看成是折线图的Y值。
      global_step:是图表训练的步数,可以看成是折线图的X值。
      walltime、new_style、double_prcision:这几个参数一般也用不着,使用默认值即可。

      我们通过上面TensorBoard显示的图来理解一下这几个参数的含义。

    在这里插入图片描述

    三、torchvision的使用

      torchvision是Pytorch的一个图形库,在后续深度学习的学习过程中相当重要,因此我们需要着重探究torchvision下的包。

    • torchvision.transforms:这个包主要用于一些图形变换,例如图片格式的转换、图片的正规化等等,这个用的比较多。
    • torchvision.dataset: 这个包用于加载数据集的,并且可以通过这个包去下载比较常用的数据集,比如CIFAR10、MNIST等等
    • torchvision.models: 该模块包含了一些已经定义好的模型,例如:AlexNet, VGG, ResNet 和 Densenet等等

    1、torchvision.transforms的使用

      我们通过查看transforms的源代码可以看到它提供了很多图形变换的类,我们可以通过Pycharm的Structure来查看他所有的方法。
    在这里插入图片描述

    • 其中ToTensor类是比较重要的,这个类是用来将一张‘PIL Image’类型或者‘numpy.ndarray’类型的图片转成tensor类型的图片,在后续的图片处理过程中很多地方都需要tensor类型的图片,例如上面使用TensorBoard显示图片的时候就需要tensor类型的照片。
    • Compose类顾名思义就是组合变换,他可以对一张图片进行多种变换,看一下源代码举的例子:
    Example:
            >>> transforms.Compose([
            >>>     transforms.CenterCrop(10),					#首先对图片进行中心裁剪
            >>>     transforms.PILToTensor(),					#然后将图片转换成tensor类型的
            >>>     transforms.ConvertImageDtype(torch.float),	#再对图片进行类型转换
            >>> ])
    
    • Normalize类是对一张tensor类型的图片进行正则化的,因此输入图片的类型必须是tensor类型的
    • Resize类是对输入图片进行尺寸的修改,这里好像是对输入图片没有要求。

      torchvision包中常用的就是上面这几个类,在后续的学习中我们会经常用到,因此我们要熟练掌握,也可以经常去看看torchvision的官网:https://pytorch.org/vision/stable/index.html

    2、torchvision.dataset的使用

      torchvision.dataset包我们主要是用来下载数据集的,在之前的学习中,我们都是通过数据集链接在浏览器中下载的,比如第二章中的蚂蚁和蜜蜂的数据集。今天通过dataset包我们就可以下载比较常用的数据集,比如CIFAR10、MNIST等等。
    在这里插入图片描述

    class CIFAR10(VisionDataset):
        """`CIFAR10 `_ Dataset.
    
        Args:
            root (string): Root directory of dataset where directory
                ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
            train (bool, optional): If True, creates dataset from training set, otherwise
                creates from test set.
            transform (callable, optional): A function/transform that takes in an PIL image
                and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
                target and transforms it.
            download (bool, optional): If true, downloads the dataset from the internet and
                puts it in root directory. If dataset is already downloaded, it is not
                downloaded again.
    
        """
        def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
        )
    

    通过上面的源码我们可以知道下载CIFAR10数据集需要传入5个参数:

    1. root:这个参数就是数据集下载到本地的哪个文件夹下,就是存放数据集的地址路径。
    2. train:这个参数是用来选择下载的是训练集(True)还是测试集(False),默认是训练集。
    3. transform:这个参数是将下载的数据集根据transform进行图形变换,就是我们上面学的torchvision.transform类。
    4. target_transform:这个参数用的比较少,一般使用默认值,我觉得应该这个参数的意思是是否需要在target上进行图形变换。
    5. download:这个参数是用来决定是否从网上下载,如果已下载了就不会再次下载了。一般我们会将它设置为True。

      接下来我们就通过代码来测试一下通过torchvision.dataset类从网上下载CIFAR10数据集。

    在这里插入图片描述
      通过上面的截图我们可以看到数据集已经下载下来了,在同一目录下创建了一个cifar10文件夹,虽然下载速度比较慢,但是这样下载这些数据集比较方便。我们也可以通过这种方式去下载其他的数据集。

    总结

      以上的TensorBoard和TorchVision是Pytorch中比较重要的库,在后续深度学习的学习过程中会经常用到这些知识点。因此我们还是需要反复的去使用这些库函数和方法,并且要多去看官方文档,官方文档才是最具权威性的学习资料。然后大家有什么问题可以在下方评论留言,大家一起学习进步!!

  • 相关阅读:
    Spring框架系列(7) - Spring IOC实现原理详解之IOC初始化流程
    基于ARM的字符串拷贝实验(嵌入式系统)
    智能四向穿梭车机器人系统库架一体解决方案|四向车密集型智能自动化立体库立体货架供应
    2022!影响百万用户金融信用评分,Equifax被告上法庭,罪魁祸首——『数据漂移』!
    C语言和Java中RSA算法一些注意事项
    深入剖析多层双向LSTM的输入输出
    Spring Security认证源码解析(示意图)
    实在智能AI+RPA:引领数字化转型的超自动化智能体
    asp.net core 生命周期
    WebGIS外包开发流程
  • 原文地址:https://blog.csdn.net/Monkey_King_GL/article/details/127014473