• UNet网络


    UNet

    本博客主要对UNet网络进行讲解,以下为文章目录:

    • UNet 原论文讲解
    • 网络结构
    • 数据集介绍
    • 评价指标
    • 损失计算
    • 代码

    本文参考资料如下:

    • UNet原论文 https://arxiv.org/pdf/1505.04597.pdf
    • U-Net网络结构讲解(语义分割) https://www.bilibili.com/video/BV1Vq4y127fB/
    • 语义分割前言 https://www.bilibili.com/video/BV1ev411P7dR/
    • 转置卷积(transposed convolution) https://www.bilibili.com/video/BV1mh411J7U4
    • U-Net原理分析与代码解读 https://zhuanlan.zhihu.com/p/150579454?utm_medium=social&utm_oi=933364753825456128&utm_psn=1565058304753229824&utm_source=wechat_session
    • 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割) https://www.bilibili.com/video/BV1rq4y1w7xM?p=1&vd_source=61b5ff132eca1d44ecddf022bf9b9def

    UNet 原论文讲解

    UNet原论文于2015年提出,主要影响领域是生物医学影像。

    请添加图片描述

    UNet网络结构

    UNet网络是以Encoder-Decoder为主要形式,它的主要网络结构形式如下:

    在这里插入图片描述

    我们在这里展开讲解一下网络的实现形式,借助如下两个公式:

    1. 输入特征矩阵channel = 卷积核深度;输出特征矩阵channel = 卷积核个数
    2. 输出特征矩阵大小计算公式: N = ( W − F + 2 P ) / S + 1 N=(W-F+2P)/S+1 N=(WF+2P)/S+1。其中输入图片大小 W ∗ W W*W WW,卷积核大小 F ∗ F F*F FF,步长 S S S,补充(padding)像素数 P P P
    conv 3x3, ReLU

    在每一次下采样后都会给一个翻倍卷积核个数的卷积层,使得输出特征矩阵的channel翻倍;或者每次上采样后会给一个缩小卷积核个数的卷积层,使得输出特征矩阵的channel缩小为二分之一。

    输出特征矩阵大小 = (W - 3 + 0)/1 + 1 = W -2 。 故输出特征矩阵的大小-2。(主流实现形式为(W - 3 + 2)/1 + 1 = W, 故特征矩阵大小不变,最终输出大小=最终输入大小)

    可能会加上BN层。

    copy and crop

    进行中心裁剪后使得特征矩阵大小相等后再拼接。(

    max pool 2x2

    MaxPooling下采样层,使得输出特征矩阵的大小变为原先的二分之一,channel不变。但在随后的conv 3x3中channel变为两倍。

    up-conv 2x2

    上采样,这里采用的是转置卷积,线性插值等。经过这一层操作会使得输出矩阵大小变为原来的两倍,channel变为原来的二分之一。

    conv 1x1

    最后的一层卷积层没有激活函数,特征矩阵大小不变,只改变channel。

    UNet 分多批次训练

    对于一个医学图像,他的图片大小往往是较大的,为了能够有一个较好的预测结果,我们往往将原图进行裁剪后分批次进行训练,如下图:

    在这里插入图片描述

    镜像得到边缘缺失

    根据UNet原论文,我们最终要预测的输出图片(下图Fig.2. 黄色框区域)是小于输入图片(下图Fig.2. 蓝色框区域)的,故我们需要对输入图片的像素进行扩充才能使得最终输出图片的大小等于输入图片的大小。我们采用进行镜像方式对原输入图进行扩充,如下所示:

    在这里插入图片描述

    数据集展示

    对于UNet数据集如下。其中图a指的是原图片,是一种灰度图片;图b是进行标注后的图片,不同颜色代表不同细胞;图c是UNet语义分割图片,前景和背景表示细胞和非细胞;但是细胞和细胞之间空隙较小,比较难分割,所以这里采用了图d的权重标注,使得细胞间的权重增加,使得更好地分割细胞。

    在这里插入图片描述

    网络结构

    网络结构见下:

    在这里插入图片描述

    (array)原始图片为image shape=(height, width, 3) dtype=dtype('uint8')label torch.size=(height, width) dtype=dtype('float64')

    (tensor)经过transforms后dataset中图片为image shape=torch.size([3, 480, 480]) dtype=torch.float32label torch.size=([480, 480]) dtype=torch.int64

    (tensor)经过dataloader后最前面加上了一层batch数据。

    (tensor)image经过网络后输出的output shape=torch.size([4, 2, 480, 480]) dtype=torch.float32

    output为两个channel,是为了在计算dice loss时,一个用来计算前景,一个用来计算背景。

    数据集介绍

    我们使用Drive数据集,包含20张测试图,20张验证图,用于分离眼球中血管,属于二分类问题。其中测试集中每一组测试有1张image,1张mask,1张label,我们通过自定义的dataset将mask和label合并,最终得到输出label的Tensor信息为:1表示前景,0表示背景,255表示丢弃。

    评价指标

    评价指标以Global Accuracy,mean Accuracy,mean IoU为主。

    损失计算

    损失计算以cross entropy loss和dice loss为主。将两个loss分别计算后相加。

    cross entropy

    我们用output和target进行cross entropy的loss计算。

    # x 		= 	output: 	[4, 2, 480, 480]
    # target	=    label: 	   [4, 480, 480]
    loss = nn.functional.cross_entropy(x, target, ignore_index=ignore_index, weight=loss_weight)
    
    • 1
    • 2
    • 3

    Dice Loss

    dice loss全称Dice similarity coefficient,用于度量两个集合相似性,使用交并比进行计算,计算公式如下:

    请添加图片描述

    需要注意的是在使用训练集的训练图片进行训练时,我们计算方法如下,我们使用矩阵乘法将对应位置的元素相乘:

    请添加图片描述

    但是在使用训练集的验证图片进行验证,或者使用测试集时,我们则对输出X进行0,1离散分布,再将得到的值和Y相乘。

    代码实现时,通过上面描述,我们意识到在计算Dice Loss前,我们需要先得到预测前景概率矩阵和前景GT标签矩阵,所以我们先要build label,在这里我们使用one-hot编码对label处理,分别生成前景和背景的label标签[N, H, W] -> [N, H, W, C],将label的元素转成float类型后再使用permute方法调整维度[N, H, W, C] -> [N, C, H, W]。于是我们的label变为了label shape=torch.size([4, 2, 480, 480]) dtype=torch.float32,再进行dice损失计算。

    代码

    代码托管在github上https://github.com/yingmuzhi/artificial_intelligence/tree/master/Unet,如果要将自己的数据放在该模型上跑,

    1. 先要将你的数据按照DRIVE中的文件进行排列
    2. 重写my_dataset.py代码,使得输入Image为三通道,label为一通道。其中label中255表示mask(丢弃),1表示前景,0表示背景。
  • 相关阅读:
    Java-反射
    java教材订购系统计算机毕业设计MyBatis+系统+LW文档+源码+调试部署
    【Linux】03_编译原理
    position定位总结+元素选择器+window对象的子对象
    element vue表格单选
    每日一个设计模式之【工厂模式】
    c语言程序范例
    LLM 技术图谱(LLM Tech Map)& Kubernetes (K8s) 与AIGC的结合应用
    【已解决】pycharm 突然每次点击都开新页面,关不掉怎么办?
    C++编程题目2
  • 原文地址:https://blog.csdn.net/qq_43369406/article/details/127354108