• 基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(二)之知识收集阶段


    前面学习了模型的构建与训练过程,然而在实验过程中,博主依旧对数据集与模型之间的关系有些疑惑,首先是论文说这是一个混合数据集,但事实上博主在实验时是将三个数据集分开的,那么在数据读取时是如何混合的呢,是每个epoch使用同一个数据集,下一个epoch再换数据集,还是再epoch中随机取数据集中的一部分。
    此外,教师模型总共有三个,其模型构造是完全相同的,不同之处在于三个教师模型是在不同的数据集训练得到的,即其权重参数是固定的,那么在训练过程中,从代码来看,原始的教师网络权重是不改变的,那么说如何更新学生网络呢?带着这些疑问,开始今天的学习。

    数据集加载

    首先需要明确的是数据集加载时是将三个数据集进行了合并,只不过会按照三个数据集进行区别,即生成list形式。train_loader的相关参数设置如下:

    在这里插入图片描述

    模型训练

    模型的训练分为两个阶段,分别是知识收集阶段与知识检验阶段,即knowlwdge collect(kc)knowledge exam(ke)两阶段。

    在这里插入图片描述

    在开始前,需要声明必须要将batch-size设置为3以上,否则会无法加载数据集,当batch=6时,可以看到图像实际上并不是均分的,事实上每次划分都是随机从不同数据集中获取,那么也就存在某个批次中某一种数据集没有被挑选的情况,这种情况是存在的,但可以通过验证方式排除掉,只在三种数据都有的情况下进行训练。

    在这里插入图片描述

    首先是知识收集阶段:
    声明损失函数,这里的损失函数有两个,分别是L1损失与通过VGG网络计算的软损失(SCRLoss)

    criterion_l1, criterion_scr, _ = criterions
    
    • 1

    在这里插入图片描述

    模型开启traineval,关于两者的区别:

    model.train()的作用是启用 Batch NormalizationDropout。在train模式,Dropout层会按照设定的参数p设置保留激活单元的概率,如keep_prob=0.8,Batch Normalization层会继续计算数据的meanvar并进行更新。
    model.eval()的作用是不启用 Batch NormalizationDropout。在eval模式下,Dropout层会让所有的激活单元都通过,而Batch Normalization层会停止计算和更新meanvar,直接使用在训练阶段已经学出的meanvar值。在使用model.eval()时就是将模型切换到测试模式,在这里,模型就不会像在训练模式下一样去更新权重。
    但是需要注意的是model.eval()不会影响各层的梯度计算行为,即会和训练模式一样进行梯度计算和存储,只是不进行反向传播。

    model.train()#  model开启train
    ckt_modules.train()
    for teacher_network in teacher_networks:#为教师网络开启eval()
    	teacher_network.eval()
    
    • 1
    • 2
    • 3
    • 4

    随后便进入核心代码模块了:这里包含模型运算,特征映射,损失计算等过程
    这里我们对应论文的创新点来看代码。
    首先是进度条加载,这里是对数据集加载train_load的封装

    pBar = tqdm(train_loader, desc='Training')
    
    • 1

    遍历数据,判断数据是否为空,这里曾经困扰过博主一段时间,因为每次遍历时target_image都为空,只要将batch-size设置为3以上即可。

    for target_images, input_images in pBar:
    	if target_images is None: continue
    	target_images = target_images.cuda()
    	input_images = [images.cuda() for images in input_images]
    	preds_from_teachers = []
    
    • 1
    • 2
    • 3
    • 4
    • 5

    可以看到,此时已经将输入图像,目标图像转换为tensor格式,其中input_imageslist形式,每张图像为torch.Size([1, 3, 224, 224])

    在这里插入图片描述

    而target_images为完全为tensor格式,shape为torch.Size([3, 3, 224, 224])

    在这里插入图片描述

    简要描述知识收集阶段

    teacher_networks即为教师网络列表,单个的教师网络模型与学生网络是相同的,将数据输入教师网络时,由于需要使用教师网络的中间特征,因此return_feat为True,最终的输出结果为预测结果图与中间特征图,预测结果图会作为 “真值” 来训练学生网络,并计算软损失,中间特征图会与学生网络进行映射到同一特征域来进行特征转移,并将教师网络的预测结果与学生网络的预测结果求SCRLoss。

    preds_from_teachers = []
    features_from_each_teachers = []
    with torch.no_grad():
    for i in range(len(teacher_networks)):
    	preds, features = teacher_networks[i](input_images[i], return_feat=True)
    	preds_from_teachers.append(preds)
    	features_from_each_teachers.append(features)		
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    随后将图像输入教师模型,教师模型不更新权重,只是用模型输出的特征来帮助学生网络来训练,称为软损失。核心代码如下:

    preds, features = teacher_networks[i](input_images[i], return_feat=True)
    
    • 1

    将图像 i 输入对应的教师网络 i,这里的i指的是教师网络的索引,这里博主开始曾经有过疑惑,此时的batch_size为3,刚好与教师网络数量对应,因此可以使用该网络,那如果batch_size为6,9时呢,后面的岂不是都无法输入模型了吗,随后博主将batch_size改为6,发现此时的input_image依旧是list形式,但每个list中的内容已经发生了改变,可以看到其是按照不同的数据集类型做了区分,这就是为何input_image要使用listtarget_imagetensor的原因了。现在之前的疑惑也就消失了。
    在这里插入图片描述

    随后获得输出结果pred,即预测结果,也就是恢复后的图像。可以看到其与输入图像的维度是一致的,对于第一个网络的第一组输入图像,都为:torch.Size([3, 3, 224, 224])

    在这里插入图片描述
    而返回的中间特征图像如图所示,可以看到输出的不同大小的特征图,总共有4组,即4组不同大小的特征图,每组3张图像,通道数,宽高则不相同。
    第 1 组数据集(教师网络)的中间特征图:

    在这里插入图片描述
    第 3 组数据集(教师网络)的中间特征图:

    在这里插入图片描述

    随后经过三个网络模型的运算,将结果加入列表:

    preds_from_teachers.append(preds)
    features_from_each_teachers.append(features)
    
    • 1
    • 2

    在这里插入图片描述
    在这里插入图片描述

    随后将教师网络的预测值转换为tensor格式,因为在最终学生网络的输出是tensor

    preds_from_teachers = torch.cat(preds_from_teachers)
    
    • 1

    原本list变为tensor
    在这里插入图片描述
    接下来这段是对feature按照特征图大小进行分组,现在的特征图是按照数据集划分为3组,为方便后面做特征映射,将其按照特征图大小分为四组。

    for layer in range(len(features_from_each_teachers[0])):
    	features_from_teachers.append([features_from_each_teachers[i][layer] for i in range(len(teacher_networks))])
    
    • 1
    • 2

    在这里插入图片描述

    随后便是将输入图像输入学生网络输出结果与中间特征图,这里的输入图像实际上是将input_image进行拼接,原本的input是按照数据集分类为3个list,此时直接将3个list拼接为tensor的形式:

    这里以batch-size=6为例。

    在这里插入图片描述

    preds_from_student, features_from_student = model(torch.cat(input_images), return_feat=True)
    
    • 1

    由于博主将batch设置为6会报显存溢出,因此这里改为4,可以看到中间特征图依旧是四组,不过每组的第一个值由6变为了4,其余都没有改变。
    可以看到list为4组,代表4组不同尺度特征图,每组里面又有一个list,每个list中包含不同数据集(教师网络的特征图)分别是2,1,1。

    在这里插入图片描述
    同理输出结果也是由6变4。

    在这里插入图片描述

    CKT模块(特征转移)

    随后便是中间特征图映射了,其过程其实也很简单,即将教师网络特征如与学生网络特征图同时输入CKT模型中,并获得输出结果,将输出结果做损失即可。
    在这里插入图片描述

    PFE_loss, PFV_loss = 0., 0.
    for i, (s_features, t_features) in enumerate(zip(features_from_student, features_from_teachers)):
    	t_proj_features, t_recons_features, s_proj_features = ckt_modules[i](t_features, s_features)
    	PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))
    	PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))
    
    • 1
    • 2
    • 3
    • 4
    • 5

    可以看到输入的教师网络特征与学生网络特征是相同格式的,都是以list的形式区分不同大小的特征图

    在这里插入图片描述
    输入值:
    经过遍历后,学生网络的特征图分为四组,分别对应不同尺度的特征图,但没有区分数据集,因为本身学生网络就是不区分数据集的。
    在这里插入图片描述
    而教师网络却是list形式,每个数据集分别对应2,1,1个图像数量
    在这里插入图片描述
    CKT网络定义:

    class CKTModule(nn.Module):
        def __init__(self, channel_t, channel_s, channel_h, n_teachers):
            super().__init__()
            self.teacher_projectors = TeacherProjectors(channel_t, channel_h, n_teachers)
            self.student_projector = StudentProjector(channel_s, channel_h)
        def forward(self, teacher_features, student_feature):
            teacher_projected_feature, teacher_reconstructed_feature = self.teacher_projectors(teacher_features)
            student_projected_feature = self.student_projector(student_feature)
            return teacher_projected_feature, teacher_reconstructed_feature, student_projected_feature
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    具体结构如下,CKT模块共有4个,即对应不同尺度的特征图,注意功能便是进行一系列的特征映射与转换。

    CKTModule(
        (teacher_projectors): TeacherProjectors(
          (PFPs): ModuleList(
            (0): Sequential(
              (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU(inplace=True)
              (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
            (1): Sequential(
              (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU(inplace=True)
              (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
            (2): Sequential(
              (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU(inplace=True)
              (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
          )
          (IPFPs): ModuleList(
            (0): Sequential(
              (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU(inplace=True)
              (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
            (1): Sequential(
              (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU(inplace=True)
              (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
            (2): Sequential(
              (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ReLU(inplace=True)
              (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            )
          )
        )
        (student_projector): StudentProjector(
          (PFP): Sequential(
            (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): ReLU(inplace=True)
            (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          )
        )
      )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    特征转移实际上也是通过损失函数来进行的,即通过一个网学习特征,从而达到特征转移的效果。
    最终获得三个结果,分别是教师网络结构特征,教师网络重构特征,学生网络结构特征。核心代码如下:

    teacher_projected_feature, teacher_reconstructed_feature = self.teacher_projectors(teacher_features)
    student_projected_feature = self.student_projector(student_feature)
    return teacher_projected_feature, teacher_reconstructed_feature, student_projected_feature
    
    • 1
    • 2
    • 3

    输入值:教师网络的特征图与学生网络的特征图

    在这里插入图片描述

    进入之后需要转换:

    在这里插入图片描述

    输出值:
    与输入值一样,学生网络结构特征的输出值为tensor形式

    在这里插入图片描述
    而教师网络特征与教师网络重构特征的输出值依旧为list形式。

    在这里插入图片描述

    在这里插入图片描述

    随后求特征损失与重构损失即可,这里需要注意的是,尽管学生网络的训练过程中是不区分数据集的,但在计算损失的过程中,特征图,输出值还是能够一一对应的,因为将list拼接为tensor后的顺序是没错的。

    PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))
    PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))
    
    • 1
    • 2

    在这里插入图片描述

    在这里插入图片描述

    最终求总损失与SCR损失即可,值得注意的是SCR损失需要使用VGG网络做特征变换后再计算。
    L1损失较为简单,输入为学生网络预测值与教师网络预测值

    T_loss = criterion_l1(preds_from_student, preds_from_teachers)
    SCR_loss = 0.1 * criterion_scr(preds_from_student, target_images, torch.cat(input_images))
    
    • 1
    • 2

    关于criterion_l1函数,其实际上是首先使用VGG网络进行特征变换,其输入数据分别是学生网络预测值,目标图像以及输入图像。
    SCRLoss定义如下:根据在forward中的代码可知,其首先将输入值分别输入VGG网络进行特征变换,随后在将输出值计算L1损失。
    其中,detch方法是返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_gradfalse,得到的这个tensor永远不需要计算其梯度,不具有grad。即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
    这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()tensor就会停止,不能再继续向前进行传播。
    最终乘以对应的权重,返回最后的损失。

    class SCRLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.vgg = Vgg19().cuda()
            self.l1 = nn.L1Loss()
            self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
        def forward(self, a, p, n):
            a_vgg, p_vgg, n_vgg = self.vgg(a), self.vgg(p), self.vgg(n)
            loss = 0
            d_ap, d_an = 0, 0
            for i in range(len(a_vgg)):
                d_ap = self.l1(a_vgg[i], p_vgg[i].detach())
                d_an = self.l1(a_vgg[i], n_vgg[i].detach())
                contrastive = d_ap / (d_an + 1e-7)
                loss += self.weights[i] * contrastive
            return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    可以看到最后的损失值是Tensor形式的。
    中间生成的特征则有5个,除了4个中间特征图外,还有一个是最终输出结果,即恢复到224x224的,只是通道维度是64而非3。

    在这里插入图片描述

    在这里插入图片描述

    至此,知识收集阶段便完成了。接下来便是知识测试阶段。

    关于损失函数中VGG网络的问题

    博主在刚开始时并不太明白为何要使用一个VGG网络来对特征图做提取后再计算损失,后来了解到这种损失称为感知损失(Perceptual Loss)。

    感知损失就是通过一个固定的网络(通常使用预训练的VGG16或者VGG19),分别以真实图像(Ground Truth)、网络生成结果(Prediciton)作为其输入,得到对应的输出特征:feature_gt、feature_pre,然后使用feature_gt与feature_pre构造损失(通常为L2损失),逼近真实图像与网络生成结果之间的深层信息,也就是感知信息,相比普通的L2损失而言,可以增强输出特征的细节信息。
    即:此处的固定网络视为一个函数f,feature_gt=f(Ground Truth),feature_pre=f(Prediciton) ,我们的目的是最小化feature_gt与feature_pre之间的差异,即最小化feature_gt、feature_pre构成的感知损失。

    那么Perceptual Loss如何构造?

    • 设置固定网络(如ImageNet上预训练好的VGG16),该网络参数固定,不进行更新;
    • 以真实图像(Ground Truth)、网络生成结果(Prediciton)作为其输入,得到对应的输出特征:feature_gtfeature_pre
    • 使用feature_gtfeature_pre构造损失,如L1损失
  • 相关阅读:
    27岁,准备转行做网络安全渗透,完全零基础,有前途吗?
    Learning Git Branch 题解(基础、高级、Git远程仓库)
    阻抗与导纳的理解
    SpringCloud&Gateway理论与实践
    基于weixin小程序乡村旅游系统的设计
    985大学新增专业,考数据结构+自然语言处理!中央民族大学新增语言信息安全...
    Oracle11g在红帽Linux上的安装教程
    Kafka知识点总结
    Linux系列之查找jar包安装目录
    【ROS进阶篇】第四讲 ROS中的重名问题(节点、话题与参数)
  • 原文地址:https://blog.csdn.net/pengxiang1998/article/details/133466166