• 【迁移学习】


    1 迁移学习的思路

    迁移学习的思路是利用预训练模型的卷积部分(卷积基)提取数据集的图片特征,然后重新练最后的全连接部分(分类器),迁移学习的特征提取部分(卷积基)不能发生变化。

    2 迁移学习的步骤

    迁移学习的思路有3步:
    (1)冻结预训练模型的卷积基
    (2)根据问题重新设置分类器,如需要分2类,则out_features=2
    (3)用自己的数据训练设置好的分类器,注意:只优化分类器参数

    3 具体步骤

    torchvision提供了可以加载的预训练模型:

    alexnet
    convnex
    densenet
    efficientnet
    feature_extraction
    googlenet
    inception
    mnasnet
    mobilenet
    mobilenetv2
    mobilenetv3
    regnet
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    使用代码:

    import torchvision
    方法1:
    model = torchvision.models.vgg16(pretrained=True)	# pretrained=True表式仅仅加载网络结构,而不加载网络参数
    # 方法2
    model = models.vgg16(weights= models.VGG16_Weights.DEFAULT)
    print(model)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    输出如下:

    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
    ...
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    冻结卷积基的参数,避免模型参数被破坏,准确率下降

    for param in model.features.parameters():
        param.requires_grad = False
    
    model.classifier[-1].out_features = 4	# 4分类
    
    model = model.to(device)	# 模型传入型芯片,一般GPU上
    loss_fn = nn.CrossEntropyLoss()	#根据具体问题自定义
    optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.0001)	#注意这里只优化分类器参数
    
    ....接训练代码
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
  • 相关阅读:
    Kubernetes---使用端口转发来访问集群中的应用
    JAVA 设计模式篇
    表的内连接
    MySQL之InnoDB的锁类型与锁原理
    springboot中如何在测试环境下进行web环境模拟测试
    FilterRegistrationBean能不能排除指定url
    .Net CLR
    npm包管理工具
    周赛361(模拟、枚举、记忆化搜索、统计子数组数目(前缀和+哈希)、LCA应用题)
    Spring Cloud之服务注册与发现(Eureka)
  • 原文地址:https://blog.csdn.net/m0_46256255/article/details/133184650