• PyTorch入门学习(十五):现有网络模型的使用及修改


    目录

    一、使用现有网络模型

    二、修改现有网络模型


    一、使用现有网络模型

    PyTorch提供了许多流行的深度学习模型,这些模型在大规模图像数据集上进行了预训练。其中一个著名的模型是VGG16。下面是如何使用VGG16模型的示例代码:

    1. import torchvision
    2. from torch import nn
    3. from torchvision.models import VGG16
    4. # 使用不带预训练权重的VGG16模型
    5. vgg16_false = torchvision.models.vgg16(pretrained=False)
    6. # 使用预训练权重的VGG16模型
    7. vgg16_true = torchvision.models.vgg16(pretrained=True)
    8. print(vgg16_false)
    9. print(vgg16_true)

    在上述代码中,使用torchvision.models.vgg16来加载VGG16模型。通过pretrained参数,我们可以选择是否加载预训练的权重。vgg16_false代表一个不带预训练权重的VGG16模型,而vgg16_true代表一个带有预训练权重的模型。

    二、修改现有网络模型

    一旦加载了现有的网络模型,可以对其进行修改,以满足特定任务的需求。下面是如何修改VGG16模型的示例代码:

    1. import torchvision
    2. from torch import nn
    3. from torchvision.models import VGG16
    4. # 加载带有预训练权重的VGG16模型
    5. vgg16 = torchvision.models.vgg16(pretrained=True)
    6. # 添加一个新的线性层,将输出从1000类修改为10类
    7. vgg16.classifier.add_module('add_linear', nn.Linear(1000, 10))
    8. # 修改VGG16模型的最后一个全连接层
    9. vgg16.classifier[6] = nn.Linear(4096, 10)
    10. print(vgg16)

    在上述代码中,加载了一个带有预训练权重的VGG16模型,并通过add_module方法添加了一个新的线性层,将输出从1000类修改为10类。此外,还演示了如何通过修改模型的索引来改变VGG16模型的最后一个全连接层。

    这种方法可以帮助您快速构建适用于特定任务的模型,而无需从头开始训练整个网络。

    完整代码如下:

    1. import torchvision
    2. from torch import nn
    3. from torchvision.models import VGG16_Weights
    4. # train_data = torchvision.datasets.ImageNet("D:\\Python_Project\\pytorch\\data_image_net",split="train",download=True,transform=torchvision.transforms.ToTensor())
    5. # 错误原因:参数pretrained自0.13起已弃用,将在0.15后删除,要改用“weights”。
    6. vgg16_false = torchvision.models.vgg16(weights=None)
    7. vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
    8. # print(vgg16_true)
    9. # 要想用于 CIFAR10 数据集, 可以在网络下面多加一行,转成10分类的输出,这样输出的结果,跟下面的不一样,位置不一样
    10. # vgg16_true.add_module('add_Linear',nn.Linear(1000,10))
    11. # print(vgg16_true)
    12. vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
    13. # 层级不同
    14. # 如何利用现有的网络,改变结构
    15. print(vgg16_true)
    16. # 上面是添加层,下面是如何修改VGG里面的层内容
    17. print(vgg16_false)
    18. vgg16_false.classifier[6] = nn.Linear(4096,10) # 中括号里的内容,是网络输出结果自带的索引,套进这种格式,就可以直接修改那一层的内容
    19. print(vgg16_false)

    参考资料:

    视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

  • 相关阅读:
    齐活了,Grafana 发布大规模持续性能分析开源数据库 - Phlare
    Python批量保存Excel文件中的图表为图片
    Sa-Token 一个轻量级Java权限认证框架
    【SQLite】三、SQLite 的常用语法
    Web测试如何让IT门外汉更好的入门篇
    独立开发了一款Material3风格的RSS阅读器 - Agr Reader
    Webpack 和 Vite 的区别
    牛客小白月赛73DE
    Acrel-2000系列监控系统在亚运手球比赛馆建设10kV供配电工程中的应用
    用GPT干的18件事,能够真正提高学习生产力,建议收藏
  • 原文地址:https://blog.csdn.net/qq_46179411/article/details/134229358