torchvision.datasets.ImageFolder(root, transform, target_transform, loader)
参数:
另外,该 API 有以下成员变量:
举例:
数据存储结构如下

import torchvision
import torchvision.transforms as transforms
from torch.utils import data
trans = transforms.Compose([transforms.RandomCrop(224), transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder('../input/data', transform=trans)
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)
print('\n')
train_loader = data.DataLoader(dataset, batch_size=2, shuffle=True)
for (img, label) in train_loader:
print(img.shape)
print(label)
break
