码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 【Batch Normalization 在CNN中的实现细节】


    目录

    • 1. BN在MLP中的实现步骤
    • 2. BN在CNN中的实现细节
      • 2.1 训练过程
      • 2.2 前向推断过程

    整天说Batch Norm,CNN的论文里离不开Batch Norm。BN可以使每层输入数据分布相对稳定,加速模型训练时的收敛速度。但BN操作在CNN中具体是如何实现的呢?

    1. BN在MLP中的实现步骤

    首先快速回顾下BN在MLP中是怎样的,步骤如下图:
    在这里插入图片描述
    图片来源:BN原论文

    一句话概括就是对于每个特征,求一个batch求均值和方差。然后该特征减去均值除以标准差,再进行一个可以学习参数的线性缩放(即乘以γ加上β)即可。本质上是对输入进行线性缩放,使得每层的输入数值分布相对稳定。

    在MLP中更具体的实现可以参考BN原论文,也可以看这篇讲解——Batch Normalization原理与实战,写得非常清楚。

    2. BN在CNN中的实现细节

    其实本质上还是对于每层进行线性缩放,使得每层的数据分布相对稳定。但由于MLP中的一层是一个1D的向量,而CNN的每一层是一个3D的tensor,所以具体实现细节还是有差异的。
    在这里插入图片描述
    在这里插入图片描述
    整体过程还是和上面的步骤一致:

    2.1 训练过程

    在这里插入图片描述
    只是我们不在对每个特征,求一个batch的均值和方差;而是对每个feature map,求一个batch的均值和方差。

    也可以这么表述:MLP是对在特征维度(即每个神经元)求均值和方差,而CNN是在通道(channel)维度(即每个feature map)求均值和方差。

    表述得更清楚一些:

    在这里插入图片描述
    上图中红框圈出的是CNN的一层,形状是(C, H, W),其中,C是通道个数(即feature map的数量),H和W是feature map的长和宽。我们可以这么理解:CNN的一层是一个3D的矩阵,是由C层2D的feature map组成的。

    若一个batch有N张图片,则红框圈出的一层的形状为(N, C, H, W),即共有NxCxHxW个数值。

    在训练时,BN做的就是在该batch中,对该层每一个feature map的所有数据(即第二个维度C,共有C个feature map),求均值和方差。即对NxHxW个数据求均值μ和方差σ。【一个feature map有HxW个数据,一个batch中有N张图片】即对第i个feature map,有与之对应的μ_i和σ_i。

    因为一层有C个feature map,所以就共有C个μ_i和σ_i。i∈[1, C]

    其余过程就和MLP一样了:第i个feature map每个元素减去对应的均值μ_i除以标准差,然后再进行一个可以学习参数的线性缩放(即乘以γ加上β)即可。

    2.2 前向推断过程

    和MLP一样,在训练过程中,每个batch都会求出C个均值和方差(因为每个feature map都会有一个均值和方差)。

    而在前向推断过程,则对于所有的batch的每个feature map的均值和方差做平均,论文中是求出所有均值的期望和方差的无偏估计,如下:
    在这里插入图片描述
    得到前向推断时每一层feature map的均值μ_i_test和方差σ_i_test。然后第i个feature map每个元素减去对应的均值μ_i_test除以标准差,然后再进行一个的线性缩放(即乘以γ加上β)即可。其中γ和β是训练过程已经训练好的。

    总而言之:MLP是对在特征维度(即每个神经元)对一个batch求均值和方差,而CNN是在通道(channel)维度(即每个feature map)对一个batch求均值和方差。

    END:)

    总觉得文字还是无法完全把我的想法表述清楚,最好的学习方法还是去看原论文,然后自己多琢磨。

    参考:
    1.【原论文】 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    2. Batch Normalization原理与实战

    3. 【李沐老师课程】批量归一化【动手学深度学习v2】

  • 相关阅读:
    leetcode刷题(124)——64. 最小路径和
    Linux(openssl):解析证书
    灵魂一问:一个Java文件的执行全部过程你确定都清楚吗?
    Windows软件架构概念
    封装包头基本信息-TCP/UDP头-IP包头-帧头
    私域增长 | 私域会员:9大连锁行业15个案例集锦
    Ajax解析
    springboot:异步注解@Async的前世今生
    面试常问的dubbo的spi机制到底是什么?
    C#,《数值算法:科学计算的艺术,Numerical Recipes: The Art of Scientific Computing》
  • 原文地址:https://blog.csdn.net/qq_44166630/article/details/127266651
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号