• 【PyTorch】Torchvision


    三、Torchvision

    PyTorch官网https://pytorch.org

    1、Dataset

    数据集描述:https://www.cs.toronto.edu/~kriz/cifar.html

    数据集使用说明:

    CIFAR10数据集https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10

    参数说明:

    • root:数据集存放位置
    • train:True(训练集)、False(测试集)
    • transform:变化
    • target_transform:target变化
    • download:是否下载

    基本使用:

    import torchvision
    
    train_set = torchvision.datasets.CIFAR10(root="../data", train=True, download=True)
    test_set = torchvision.datasets.CIFAR10(root="../data", train=False, download=True)
    
    print(test_set[0])
    print(test_set.classes)
    
    img, target = test_set[0]
    print(img)
    print(target)
    print(test_set.classes[target])
    img.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    Files already downloaded and verified
    Files already downloaded and verified
    (<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F0220>, 3)
    ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    <PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F00D0>
    3
    cat
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    转为Tensor类型: 并使用TensorBoard显示

    import torchvision
    from torch.utils.tensorboard import SummaryWriter
    
    dataset_transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    train_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=True, download=True)
    test_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=False, download=True)
    
    writer = SummaryWriter("logs")
    for i in range(10):
        img, target = test_set[i]
        writer.add_image("test_set", img, i)
    
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    2、DataLoader

    介绍:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

    参数说明:

    • batch_size:每批要加载多少个样品(默认:1)
    • shuffle:True(重新洗牌),(默认:False)
    • num_workers:使用多少个子进程来加载数据,(默认:0 表示主进程)
    • drop_last:是否舍去最后(除不尽的)

    2.1 test_data

    import torchvision
    from torch.utils.data import DataLoader
    
    # 准备测试集
    test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
    
    # 测试集第一张图片及target
    img, target = test_data[0]
    print(img.shape)
    print(target)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    torch.Size([3, 32, 32]) # 3通道 32 * 32
    3
    
    • 1
    • 2

    2.2 test_loader

    import torchvision
    from torch.utils.data import DataLoader
    
    # 准备测试集
    test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
    
    test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
    
    # 测试集第一张图片及target
    # img, target = test_data[0]
    # print(img.shape)
    # print(target)
    
    # test_loader
    for data in test_loader:
        imgs, targets = data
        print(imgs.shape)
        print(targets)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    torch.Size([4, 3, 32, 32]) # 4张 3通道 32 * 32
    tensor([1, 2, 0, 8]) # 4张图片的target糅合在一起
    ...
    ...
    
    • 1
    • 2
    • 3
    • 4

    注意:target[1, 2, 0, 8]并不是按序采样,而是随机的!

    2.3 drop_last

    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    # 准备测试集
    test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
    
    test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
    
    # batch_size=64
    writer = SummaryWriter("logs")
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images("test_data", imgs, step)
        step += 1
    
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    注意:最后一次采样只有16张图像,这是因为参数drop_last=False

    当不满足每一次都取一定值的图片时,可以显示真实剩下的或者直接舍去(drop_last=True)。

    当我们设置为drop_last=True时,就会舍去最后一组采样:

    2.4 shuffle

    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    # 准备测试集
    test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
    
    test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
    
    # shuffle=False
    writer = SummaryWriter("logs")
    
    for epoch in range(2):
        step = 0
        for data in test_loader:
            imgs, targets = data
            writer.add_images("Epoch:{}".format(epoch), imgs, step)
            step += 1
    
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    注意:两者采样完全相同,如果想要 “洗牌”,应设置shuffle=True

  • 相关阅读:
    单调队列代码模板
    一文带你学透Java Servlet(建议收藏)
    Java|学习|多线程
    【计算机视觉 | CNN】Image Model Blocks的常见算法介绍合集(一)
    java计算机毕业设计学生课堂互动教学系统源码+mysql数据库+lw文档+系统+调试部署
    Prometheus基于Consul的 Redis 多实例监控方案
    计算机毕业设计Python+Django的高校学生项目校内申报平台系统(源码+系统+mysql数据库+Lw文档)
    C#目录和文件管理
    计算机网络基础 ARP协议 详详解----看完我的总结你就不用看别人的了!
    【Go】excelize库实现excel导入导出封装(一),自定义导出样式、隔行背景色、自适应行高、动态导出指定列、动态更改表头
  • 原文地址:https://blog.csdn.net/m0_70885101/article/details/127897302