• torch 的数据加载 Datasets & DataLoaders


    点赞收藏关注!
    如需要转载,请注明出处!

    torch的模型加载有两种方式:
    Datasets & DataLoaders

    torch本身可以提供两数据加载函数
    torch.utils.data.DataLoader()和torch.utils.data.Dataset()

    其中torch.utils.data 是PyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

    加载函数后可以实现数据集代码与模型训练代码分离,以获得更好的可读性和模块化
    Dataset定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。制作自己的数据集必须要实现三个函数:

    • init()函数在实例化Dataset对象时运行一次
    • len()函数返回数据集中样本的数量
    • getitem()函数的作用是:从给定索引 index,从数据集中加载并返回一个样本并将其转换为张量。
    import torch
    from torch.utils.data import Dataset
    
    class CreateDataset(Dataset):
        def __init__(self, data):
            self.data = data
    
        def __getitem__(self, index):
            # 根据索引获取样本
            return self.data[index]
    
        def __len__(self):
            # 返回数据集大小
            return len(self.data)
    
    # 创建数据集对象
    data = [[255,255,255],[255,245,235],[225,226,227]]
    dataset = CreateDataset(data)
    
    # 根据索引获取样本
    sample = dataset[1]
    print(sample)
    # [255,245,235]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    数据处理模块其他的功能:

    • TensorDataset: 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。
    from torch.utils.data import TensorDataset
    import torch
    from torch.utils.data import DataLoader
     
    a = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99]])
    b = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
    train_ids = TensorDataset(a, b)
    
    for x_train, y_label in train_ids:
        print(x_train, y_label)
    
    
    
    ##############################################################################################
    #tensor([11, 22, 33]) tensor(0)
    #tensor([44, 55, 66]) tensor(1)
    #tensor([77, 88, 99]) tensor(2)
    #tensor([11, 22, 33]) tensor(0)
    #tensor([44, 55, 66]) tensor(1)
    #tensor([77, 88, 99]) tensor(2)
    #tensor([11, 22, 33]) tensor(0)
    #tensor([44, 55, 66]) tensor(1)
    #tensor([77, 88, 99]) tensor(2)
    #tensor([11, 22, 33]) tensor(0)
    #tensor([44, 55, 66]) tensor(1)
    #tensor([77, 88, 99]) tensor(2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。

    • Subset: 数据集的子集类,用于从数据集中选择指定的样本。定义了一个子集的索引列表indices,它可以根据需要进行调整。然后,我们使用Subset类创建了一个名为subset的子集对象,它接受两个参数:原始数据集dataset和子集的索引列表indices。

    indices = [0, 2, 4]  # 子集的索引列表
    subset = Subset(dataset, indices)
    
    • 1
    • 2
    • random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
    import torch
    import torchvision
    # from torch.utils.tensorboard import SummaryWriter
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    # 准备数据集
    from torch import nn
    from torch.utils.data import DataLoader
    
    # 定义训练的设备
    device = torch.device("cuda")
    #读取数据
    data_transform = transforms.Compose([
        transforms.Resize(size=(224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
    ])
    full_dataset = ImageFolder(r'D:\PythonSpace\data\trainTest',transform = data_transform)
    # length 数据集总长度
    full_data_size = len(full_dataset)
    print("总数据集的长度为:{}".format(full_data_size))
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    
    #在这里
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    #在这里
    
    train_data_size = len(train_dataset)
    test_data_size = len(test_dataset)
    # 如果train_data_size=10, 训练数据集的长度为:10
    print("训练数据集的长度为:{}".format(train_data_size))
    print("测试数据集的长度为:{}".format(test_data_size))
    >>>
    总数据集的长度为:100
    训练数据集的长度为:80
    测试数据集的长度为:20
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • ConcatDataset: 将多个数据集连接在一起形成一个更大的数据集。
    #链接两个数据集
    dataset = torch.utils.data.ConcatDataset([celeba_dataset, digiface_dataset]) 
    #导入数据集
    loader = torch.utils.data.DataLoader( 
            dataset=dataset, 
            batch_size=cfg.batch_size, 
            shuffle=True, 
            drop_last=True, 
            num_workers=cfg.n_workers)
         
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • get_worker_info: 获取当前数据加载器所在的进程信息。torch.utils.data.get_worker_info() 在worker进程中返回各种有用的信息(包括worker id、dataset replica、initial seed等),在main进程中返回None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。分片数据集特别有用。

    如有帮助,点赞收藏关注!

  • 相关阅读:
    Python Turtle Graphics 绘制I Love You字符
    MyBatis
    Chai的入门
    退运险业务及系统架构演进史
    安卓Android 架构模式及UI布局设计
    在Jupyter里面安装torch的历程
    ICC2: keepout、spacing_rules、clock_cell_spacing
    Flutter 剪裁(Clip)
    vue3-admin商品管理后台项目(后台布局layout布局开发二)
    版本控制系统:Perforce Helix Core -2023
  • 原文地址:https://blog.csdn.net/weixin_42362399/article/details/134524424