2022.11.28接着开启新的一章
2022.11.29继续学习

import torchvision
train_set = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10", train=False, download=True)
root:数据集的路径train:如果为 True 则创建数据集,如果为 False 则创建测试集download:为 True 则从网络上下载数据集到 root 路径下,如果该路径下已有数据集则不进行下载。print(test_set[0])
(, 3) 。其中 3 代表的是该测试样例对应的映射到整数上的类别print(test_set.classes)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img, target = test_set[0]
print(img) #
print(target) #3
img.show()
由于该数据集的图片是 32×32 像素的,所以不是很清晰:

再查看对应的类别的名称:
print(test_set.classes[target]) # cat
PIL 类型的图片转换成 ToTensor 类型的图片:dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10", train=False, transform=dataset_transform, download=True)
ToTensor 类型的图片显示到 tensorboard 中了:writer = SummaryWriter("log-CIFAR10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set-CIFAR10", img, i)
writer.close()
