torchivison图形库:
1.torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
2.torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
3.torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
4.torchvision.utils: 其他的一些有用的方法。
- train_transforms = transforms.Compose([
- transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
- # transforms.RandomHorizontalFlip(), # 随机水平翻转
- transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
- transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
- mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
- ])
-
- test_transform = transforms.Compose([
- transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
- transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
- transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
- mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
- ])
-
- total_data = datasets.ImageFolder("D:/P7/49-data/",transform=train_transforms)
- print(total_data)