• LMFLOSS:专治解决不平衡医学图像分类的新型混合损失函数 (附代码)


     

    论文地址:https://arxiv.org/pdf/2212.12741.pdf

    代码地址:https://github.com/SanaNazari/LMFLoss

    1.是什么?

    LMFLOSS是一种用于不平衡医学图像分类的混合损失函数。它是由Focal Loss和LDAM Loss的线性组合构成的,旨在更好地处理不平衡数据集。Focal Loss通过强调难以分类的样本来提高模型的性能,而LDAM Loss则考虑了数据集的类别分布来调整权重。

    2.为什么?

    先来简单回顾下,对于类别不均衡问题,以往的方法是如何解决的。大体上主要有两种,即以数据为中心驱动和以算法为中心的解决方案。

    数据策略

    以数据为中心的类别不均衡解决方法主要有两种:过采样欠采样。过采样试图为少数类别生成人工数据点,而欠采样旨在消除多数类别的样本。

    算法策略

    算法层面的策略,特别是在深度学习领域,主要侧重于开发损失函数来应对类不平衡问题,而不是直接操纵数据。一种简单的方式便是为每个类别都设置相应的权重,以便与多数类别相比,少数类别样本的错误分类受到更严重的惩罚。另一种方法是为每个训练样本自适应地设置一个唯一的权重,以便硬样本获得更高的权重。

    作者便提出了一种称为 Large Margin aware Focal (LMF) Loss 的新型损失函数,以缓解医学成像中的类不平衡问题。该损失函数动态地同时考虑硬样本和类分布。

    3.怎么样

    3.1 Focal Loss

    说到类别不均衡的损失函数,不得不提的便是 Focal Loss。对于分类问题,大家常用的便是交叉熵损失 BCE Loss,该损失函数对所有类别均一视同仁,即赋予同等的权重学习。而 Focal Loss 主要就是交叉熵损失改进的,通过引入 \alpha 和 \gamma 两个调节因子来调整样本数量和样本难易程度,以便模型专注于学习少数类。具体公式如下:

    3.2 LDAM Loss

    《 Learning imbalanced datasets with label-distribution-aware margin loss 》 这篇文章中提出了另一项减轻类不平衡问题的工作,称为标签分布感知边距(LDAM)损失。作者建议对少数类引入比多数类更强的正则化,以减少它们的泛化误差。如此一来,损失函数保持了模型学习多数类并强调少数类的能力。LDAM 损失侧重于每个类的最小边际和获得每个类和统一标签测试错误,而不是鼓励大多数类训练样本与决策边界的大边距。换句话说,它只会鼓励少数群体获得相对较大的利润。此外,作者提出了用于获得多个类别 1、2、...、k 的类别相关边距的公式: \gamma _{j} = \frac{C}{n_{j}1/4^{}}.

    这里 j∈1,...,k 表示特定类,n_{j}表示每个类别的样本数,C为固定的常数。现在,让我们定义出一个样本对 (x,y),x 为样本,y为对应的标签,同时给定一个模型 f。考虑下面这个函数映射:x_{y}=f(x)_{y};我们令 u=e^{z_{y}-p_{y}},这里对于每一个类别j∈1,...,k 都有 p_{j}=\frac{C}{n_{j}^{1/4}}。因此,LDAM 损失便可以定义为:

    3.3 LMF Loss

    Focal Loss 创建了一种机制,可以更加强调模型难以分类的样本;通常,来自少数群体的样本将属于这一类。相比之下,LDAM Loss 通过考虑数据集的类别分布来判断权重。我们假设与单独使用每个功能相比,同时利用这两个功能可以产生有效的结果。因此,作者提出的 Large Margin aware Focal (LMF) 损失是 Focal 损失和由两个超参数加权的 LDAM 的线性组合,公式如下:

    这里,α 和 β 是常数,被认为是可以调整的超参数。 因此,本文提出的损失函数在单个框架中联合优化了两个独立的损失函数。通过反复试验,作者发现将相同的权重分配给两个组件会产生良好的结果。

    3.4 代码实现

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed May 24 17:03:06 2023
    4. @author: Sana
    5. """
    6. import torch
    7. import torch.nn as nn
    8. import torch.nn.functional as F
    9. import numpy as np
    10. from ..builder import LOSSES
    11. class FocalLoss(nn.Module):
    12. def __init__(self, alpha, gamma=2):
    13. super().__init__()
    14. self.alpha = alpha
    15. self.gamma = gamma
    16. def forward(self, output, target):
    17. num_classes = output.size(1)
    18. assert len(self.alpha) == num_classes, \
    19. 'Length of weight tensor must match the number of classes'
    20. logp = F.cross_entropy(output, target, self.alpha)
    21. p = torch.exp(-logp)
    22. focal_loss = (1 - p) ** self.gamma * logp
    23. return torch.mean(focal_loss)
    24. class LDAMLoss(nn.Module):
    25. def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
    26. """
    27. max_m: The appropriate value for max_m depends on the specific dataset and the severity of the class imbalance.
    28. You can start with a small value and gradually increase it to observe the impact on the model's performance.
    29. If the model struggles with class separation or experiences underfitting, increasing max_m might help. However,
    30. be cautious not to set it too high, as it can cause overfitting or make the model too conservative.
    31. s: The choice of s depends on the desired scale of the logits and the specific requirements of your problem.
    32. It can be used to adjust the balance between the margin and the original logits. A larger s value amplifies
    33. the impact of the logits and can be useful when dealing with highly imbalanced datasets.
    34. You can experiment with different values of s to find the one that works best for your dataset and model.
    35. """
    36. super(LDAMLoss, self).__init__()
    37. m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
    38. m_list = m_list * (max_m / np.max(m_list))
    39. m_list = torch.cuda.FloatTensor(m_list)
    40. self.m_list = m_list
    41. assert s > 0
    42. self.s = s
    43. self.weight = weight
    44. def forward(self, x, target):
    45. index = torch.zeros_like(x, dtype=torch.uint8)
    46. index.scatter_(1, target.data.view(-1, 1), 1)
    47. index_float = index.type(torch.cuda.FloatTensor)
    48. batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
    49. batch_m = batch_m.view((-1, 1))
    50. x_m = x - batch_m
    51. output = torch.where(index, x_m, x)
    52. return F.cross_entropy(self.s * output, target, weight=self.weight)
    53. @LOSSES.register_module()
    54. class LMFLoss(nn.Module):
    55. def __init__(self, cls_num_list, weight, alpha=1, beta=1, gamma=2, max_m=0.5, s=30):
    56. super().__init__()
    57. self.focal_loss = FocalLoss(weight, gamma)
    58. self.ldam_loss = LDAMLoss(cls_num_list, max_m, weight, s)
    59. self.alpha = alpha
    60. self.beta = beta
    61. def forward(self, output, target):
    62. focal_loss_output = self.focal_loss(output, target)
    63. ldam_loss_output = self.ldam_loss(output, target)
    64. total_loss = self.alpha * focal_loss_output + self.beta * ldam_loss_output
    65. return total_loss

    参考:Focal Loss 后继之秀 | LMFLOSS:专治解决不平衡医学图像分类的新型混合损失函数

  • 相关阅读:
    力扣labuladong——一刷day34
    物联网感知-分布式光纤振动传感主机实现基本原理
    递归排列枚举2(c++)
    element ui修改select选择框背景色和边框色
    计算机网络概述及因特网
    聊聊写代码的20个反面教材
    ubuntu18安装coova chilli精简
    docker发布镜像到阿里云与私服
    Restful API 设计示例
    【错误记录】Android Studio 中最新的 Gradle 配置中设置插件依赖 ( 2023 年 8 月 24 日 | 最新 Gradle 中配置插件依赖的变化 | 增加 Maven 仓库源 )
  • 原文地址:https://blog.csdn.net/PLANTTHESON/article/details/133999512