码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 使用Dataset 和DataLoader 加载数据集


    一、加载数据过程

    PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader 类。 它表示可在数据集上迭代的 Python,并支持

    • 映射式和迭代式的数据集,
    • 自定义数据加载顺序,
    • 自动批次,
    • 单进程和多进程数据加载,
    • 自动内存固定。

    这些选项由 DataLoader 的构造函数参数配置,构造函数的签名如下:

    如下如显示了dataLoader的过程,shuffle将Dataset里的数据打乱,batch_size=2

    二、模型建立流程

    1、准备数据集(Dataset和DataLoader)2、继承Module类设计自己的模型

    3、使用PyTorch APi 构造损失函数和优化器  4、采用前向传播、返向回馈、更新 反复训练。

    三、代码实现

    import torch.nn
    import numpy as np
    from torch.utils.data import Dataset, DataLoader

    class DiabetesDataset(Dataset):
       
    def __init__(self, filepath):
            xy = np.loadtxt(filepath,
    delimiter=',', dtype=np.float32)
           
    self.len = xy.shape[0]
           
    self.x_data = torch.from_numpy(xy[:, :-1])
           
    self.y_data = torch.from_numpy(xy[:, [-1]])

       
    def __getitem__(self, index):
           
    return self.x_data[index], self.y_data[index]

       
    def __len__(self):
           
    return self.len


    dataset = DiabetesDataset(
    'diabetes.csv.gz')
    train_loader = DataLoader(
    dataset=dataset, batch_size=64, shuffle=True, num_workers=2)


    # 继承类Module,自动会实现反向计算图
    class Model(torch.nn.Module):
       
    # 构造方法
       
    def __init__(self):
           
    super(Model, self).__init__()
           
    self.linear1 = torch.nn.Linear(8, 6)
           
    self.linear2 = torch.nn.Linear(6, 4)
           
    self.linear3 = torch.nn.Linear(4, 1)
           
    self.sigmoid = torch.nn.Sigmoid()

       
    def forward(self, x):
            x =
    self.sigmoid(self.linear1(x))
            x =
    self.sigmoid(self.linear2(x))
            x =
    self.sigmoid(self.linear3(x))
           
    return x


    model = Model()

    criterion = torch.nn.BCELoss(
    size_average=True)
    optimizer = torch.optim.SGD(model.parameters(),
    lr=0.1)

    if __name__=='__main__':
       
    for epoch in range(100):
           
    for i, data in enumerate(train_loader, 0):
               
    #1.prepare data
               
    inputs, labels = data
               
    #2.Forward
                
    y_pred = model(inputs)
                loss = criterion(y_pred, labels)
               
    print(epoch, loss.item())
               
    #3.Backward
               
    optimizer.zero_grad()
                loss.backward()
               
    #4.Update
               
    optimizer.step()

    四、运行结果

  • 相关阅读:
    Python学习笔记--类的定义和调用
    实时目标检测新高地之#YOLOv7#更快更强的目标检测器
    从零开始的Django框架入门到实战教程(内含实战实例) - 09 初试Ajax之任务界面(学习笔记)
    JavaScript中if语句优化和部分语法糖小技巧推荐
    【深度学习】用Pytorch完成MNIST手写数字数据集的训练和测试
    RIP小实验配置及缺省路由下发
    Unity/C# 舍入的五种写法
    PointNet++改进策略 :模块改进 | EdgeConv | DGCNN, 动态图卷积在3d任务上应用
    Fiddler配置及使用
    MySQL数据库入门到大牛_04_运算符
  • 原文地址:https://blog.csdn.net/axiaoquan/article/details/127649061
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号