• torchvision.transforms 数据预处理:ToTensor()


    ToTensor() 是pytorch中的数据预处理函数,包含在 torchvision.transforms 模块下。一般用于处理图像数据,所以其处理对象是 PIL Image 和 numpy.ndarray 。

    1、ToTensor() 函数的作用

    必须要声明不能只看函数名,就以为 ToTensor() 只是将图像转为 tensor,其实它的功能不止于此

    看一下 ToTensor() 函数的源码:

    class ToTensor:
        """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
    
        Converts a PIL Image or numpy.ndarray (H x W x C) in the range
        [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
        if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
        or if the numpy.ndarray has dtype = np.uint8
    
        In the other cases, tensors are returned without scaling.
    
        .. note::
            Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
            transforming target image masks. See the `references`_ for implementing the transforms for image masks.
    
        .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
        """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    大意是:

    (1)将 PIL Image 或 numpy.ndarray 转为 tensor

    (2)如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 格式数据类型是 np.uint8 ,则将 [0, 255] 的数据转为 [0.0, 1.0] ,也就是说将所有数据除以 255 进行归一化。

    (3)将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。

    2、读取图像时 PIL 和 opencv 的选择

    在自己建立 dataset 迭代器时,一般操作是检索数据集图像的路径,然后使用 PIL 库或 opencv库读取图片路径。

    2.1 使用 PIL

    import numpy as np
    from PIL import Image
    
    filePath="Dataset/FFHQ/00000.png"
    img1=Image.open(filePath)
    print(f"img1 = {img1}")    
    # img1 = 
    
    img2 = np.array(img1)
    print(f"img2 = {img2}")
    
    """
    img2 = [[[  0 130 146]
      [  0 128 144]
      [  0 125 141]
      ...
      [133 162 164]
      [133 157 159]
      [134 157 163]]]
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    可以看到,使用 PIL.Image 读取的图像是一种 PIL 类,mode=RGB,要想获得图像的像素值还需要将其转为 np.array 格式。

    而 opencv 可以直接将图像读取为 np.array 格式,因此首选 opencv 。

    2.2 使用 opencv

    import cv2
    
    filePath="Dataset/FFHQ/00000.png"
    img=cv2.imread(filePath)
    print(f"img.shape = {img.shape}")     # img.shape = (128, 128, 3)
    print(f"img = {img}")     # img.dtype = uint8
    
    """
    img = [[[146 130   0]
      [144 128   0]
      [141 125   0]
      ...
      [164 162 133]
      [159 157 133]
      [163 157 134]]]
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    仔细对比PIL 和 opencv 的输出结果可以发现,PIL默认输出的图片格式为 RGB,而opencv输出的是BGR格式。

    使用opencv读取的图像是[H,W,C]大小的,数据格式是 np.uint8 ,经过 ToTensor() 会进行归一化。而其他的数据类型(如 np.int8)经过 ToTensor() 数值不变,不进行归一化,后面会详细讲述。并且经过ToTensor()后图像格式变为 [C,H,W]。

    3、ToTensor() 的使用

    3.1 关键知识点

    不管是使用 PLT还是opencv,最终得到都是 np.array类型。因此:

    ToTensor() 是将 np.array 的数据 转为 tensor 格式

    这里一定要明确几个点:

    (1)np.array 整型的默认数据类型为 np.int32,经过 ToTensor() 后数值不变,不进行归一化。
    (2)np.array 浮点型的默认数据类型为 np.float64,经过 ToTensor() 后数值不变,不进行归一化。
    (3)opencv 读取的图像格式为 np.array,其数据类型为 np.uint8
        经过 ToTensor() 后数值由 [0,255] 变为 [0,1],通过将每个数据除以255进行归一化。
    (4)经过 ToTensor() 后,HWC 的图像格式变为 CHW 的 tensor 格式。
    (5)np.uint8 和 np.int8 不一样,uint8是无符号整型,数值都是正数。
    (6)ToTensor() 可以处理任意 shape 的 np.array,并不只是三通道的图像数据。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    3.2 代码示例

    下面通过代码熟悉 ToTensor() 的使用,分别看一下 np.uint8 和 非 np.uint8 类型的 np.array 经过 ToTensor() 之后的输出。

    (1) np.uint8 类型

    import numpy as np
    from torchvision import transforms
    
    data = np.array([
        [0, 5, 10, 20, 0],
        [255, 125, 180, 255, 196]
    ], dtype=np.uint8)
    
    tensor = transforms.ToTensor()(data)
    print(tensor)
    """
    tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],
             [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    (2)非 np.uint8 类型

    import numpy as np
    from torchvision import transforms
    
    data = np.array([
        [0, 5, 10, 20, 0],
        [255, 125, 180, 255, 196]
    ])      # data.dtype = int32
    
    tensor = transforms.ToTensor()(data)
    print(tensor)
    """
    tensor([[[  0,   5,  10,  20,   0],
             [255, 125, 180, 255, 196]]], dtype=torch.int32)
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
  • 相关阅读:
    LeetCode每日一练 —— 160. 相交链表
    领英-如何合并或注销重复领英帐号及利用领英高效开发客户技巧
    探究WPF中文字模糊的问题:TextOptions的用法
    【前端】NodeJS:包管理工具
    【Python机器学习】零基础掌握ShrunkCovariance协方差估计
    关于hash表的一些练习题
    代码随想录刷题|完全背包理论基础 LeetCode 518. 零钱兑换II 377. 组合总和 Ⅳ
    DecorView和android.R.id.content的关系
    zookeeper节点数据类型介绍及集群搭建
    从零开始搭建仿抖音短视频APP-构建后端项目
  • 原文地址:https://blog.csdn.net/qq_43799400/article/details/127785104