• 【模型提分tricks】Adversarial Weight Perturbation(AWP)对抗训练


    在如今AI遍及各行各业的情况下,现在不论是搞科研还是做比赛,一个非常重要的问题就是提升模型的robust,让训练出来的模型能更好的泛化到一个从未见过的测试集上,以此减小线上和线下的gap。

    对抗训练 Adversarial training

    我们知道模型训练是一个ERM(经验风险最小化,Empirical Risk Minimization)的过程,而对抗训练就是为了增强模型的抗干扰能力。

    实现

    经典训练

    for step, batch in enumerate(train_loader):
        inputs, labels = batch
        
        # 将模型的参数梯度初始化为0
        optimizer.zero_grad()
        
        # forward + backward + optimize
        predicts = model(inputs)          # 前向传播计算预测值
        loss = loss_fn(predicts, labels)  # 计算当前损失
        loss.backward()       # 反向传播计算梯度
    	loss.backward()
        optimizer.step()                  # 更新所有参数 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    加入AWP训练

    AWP

    
    class AWP:
        """
        Implements weighted adverserial perturbation
        adapted from: https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook
        """
    
        def __init__(self, model, optimizer, adv_param="weight", adv_lr=1, adv_eps=0.0001):
            self.model = model
            self.optimizer = optimizer
            self.adv_param = adv_param
            self.adv_lr = adv_lr
            self.adv_eps = adv_eps
            self.backup = {}
            self.backup_eps = {}
    
        def attack_backward(self, inputs, labels):
            if self.adv_lr == 0:
                return
            self._save()
            self._attack_step()
    
            y_preds = self.model(inputs)
    
            adv_loss = self.criterion(y_preds, labels)
            self.optimizer.zero_grad()
            return adv_loss
    
        def _attack_step(self):
            e = 1e-6
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None and self.adv_param in name:
                    norm1 = torch.norm(param.grad)
                    norm2 = torch.norm(param.data.detach())
                    if norm1 != 0 and not torch.isnan(norm1):
                        # 在损失函数之前获得梯度
                        r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                        param.data.add_(r_at)
                        param.data = torch.min(
                            torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1]
                        )
    
        def _save(self):
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None and self.adv_param in name:
                    if name not in self.backup:
                        self.backup[name] = param.data.clone()
                        grad_eps = self.adv_eps * param.abs().detach()
                        self.backup_eps[name] = (
                            self.backup[name] - grad_eps,
                            self.backup[name] + grad_eps,
                        )
    
        def _restore(self,):
            for name, param in self.model.named_parameters():
                if name in self.backup:
                    param.data = self.backup[name]
            self.backup = {}
            self.backup_eps = {}
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59

    参数Args:

    • adv_param (str): 要攻击的layer name,一般攻击第一层 或者全部weight参数效果较好

    • adv_lr (float): 攻击步长,这个参数相对难调节,如果只攻击第一层embedding,一般用1比较好,全部参数用0.1比较好。

    • adv_eps (float): 参数扰动最大幅度限制,范围(0~ +∞),一般设置(0,1)之间相对合理一点。

    • start_epoch (int): (0~ +∞)什么时候开始扰动,默认是0,如果效果不好可以调节值模型收敛一半的时候再开始攻击。

    """
        使用AWP的训练过程
    """
    # 初始化AWP
    awp = AWP(model, loss_fn, optimizer, adv_lr=awp_lr, adv_eps=awp_eps)
    
    for step, batch in enumerate(train_loader):
        inputs, labels = batch
        
        # 将模型的参数梯度初始化为0
        optimizer.zero_grad()
        
        # forward + backward + optimize
        predicts = model(inputs)          # 前向传播计算预测值
        loss = loss_fn(predicts, labels)  # 计算当前损失
        loss.backward()       # 反向传播计算梯度
        # 指定从第几个epoch开启awp,一般先让模型学习到一定程度之后
        if awp_start >= epoch:
            loss = awp.attack_backward(inputs, labels)
            loss.backward()
            awp._restore()                    # 恢复到awp之前的model
        optimizer.step()                  # 更新所有参数 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    注意:使用AWP训练时间大概是原来的两倍

    参考

  • 相关阅读:
    一个由登录接口引发的思考
    ti代理商:好的ti代理商有哪些分销
    java计算机毕业设计Vue潍坊学院宿舍管理系统设计与实现MyBatis+系统+LW文档+源码+调试部署
    vscode 源代码不能自动stage change
    【设备树添加节点】
    JS进阶第一篇:手写call apply bind
    HTML期末学生大作业-班级校园我的校园网页设计与实现html+css+javascript
    福建省发改委福州市营商办莅临育润大健康事业部指导视察工作
    GitHub WebHook 使用教程
    电脑硬件——CPU
  • 原文地址:https://blog.csdn.net/u014297502/article/details/126857184