• Pytorch中DataLoader的使用方法


    在Pytorch中,torch.utils.data中的Dataset与DataLoader是处理数据集的两个函数,用来处理加载数据集。通常情况下,使用的关键在于构建dataset类。

    一:dataset类构建。

    在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。

    1. class dataset:
    2. def __init__(self,...):
    3. ...
    4. def __len__(self,...):
    5. return n
    6. def __getitem__(self,item):
    7. return data[item]

    正常情况下,该数据集是要继承Pytorch中Dataset类的,但实际操作中,即使不继承,数据集类构建后仍可以用Dataloader()加载的。

    在dataset类中,__len__(self)返回数据集中数据个数,__getitem__(self,item)表示每次返回第item条数据。

    二:DataLoader使用

    在构建dataset类后,即可使用DataLoader加载。DataLoader中常用参数如下:

    1.dataset:需要载入的数据集,如前面构造的dataset类。

    2.batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个batch进行训练。

    3.shuffle:是否在打乱数据集样本顺序。True为打乱,False反之。

    4.drop_last:是否舍去最后一个batch的数据(很多情况下数据总数N与batch size不整除,导致最后一个batch不为batch size)。True为舍去,False反之。

    三:举例

    兔兔以指标为1,数据个数为100的数据为例。

    1. import torch
    2. from torch.utils.data import DataLoader
    3. class dataset:
    4. def __init__(self):
    5. self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
    6. self.y=(torch.sin(self.x)+1)/2
    7. def __len__(self):
    8. return 100
    9. def __getitem__(self, item):
    10. return self.x[item],self.y[item]
    11. data=DataLoader(dataset(),batch_size=10,shuffle=True)
    12. for batch in data:
    13. print(batch)

    当然,利用这个数据集可以进行简单的神经网络训练。

    1. from torch import nn
    2. data=DataLoader(dataset(),batch_size=10,shuffle=True)
    3. bp=nn.Sequential(nn.Linear(1,5),
    4. nn.Sigmoid(),
    5. nn.Linear(5,1),
    6. nn.Sigmoid())
    7. optim=torch.optim.Adam(params=bp.parameters())
    8. Loss=nn.MSELoss()
    9. for epoch in range(10):
    10. print('the {} epoch'.format(epoch))
    11. for batch in data:
    12. yp=bp(batch[0])
    13. loss=Loss(yp,batch[1])
    14. optim.zero_grad()
    15. loss.backward()
    16. optim.step()
  • 相关阅读:
    Illustrator 2022 for mac (AI 2022中文版)
    acwing算法提高之图论--拓扑排序
    Docker 容器的数据卷的使用
    Springboot基于web的游泳馆信息管理系统 毕业设计-附源码281444
    常见开源协议介绍
    【JAVA】-- setBorder
    斗地主案例及一些实现规则
    萌新训练赛(1)
    【SA8295P 源码分析 (一)】54 - /ifs/bin/startupmgr 程序工作流程分析 及 script.c 介绍
    Kubernetes 基于 helm 部署高可用 harbor
  • 原文地址:https://blog.csdn.net/weixin_60737527/article/details/126754254