博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Batch Normalization原理与使用过程
阅读量:5997 次
发布时间:2019-06-20

本文共 3666 字,大约阅读时间需要 12 分钟。

  •  阅读《Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising》时,开始接触一些深度学习的知识
  • - []
  • - []
  • - []
  • - []
作者:张俊林,新浪微博AI Lab担任资深算法专家Batch Normalization(简称BN)自从提出之后,因为效果特别好,很快被作为深度学习的标准工具应用在了各种场合。BN大法虽然好,但是也存在一些局限和问题, 诸如当BatchSize太小时效果不佳、对RNN等动态网络无法有效应用BN等。针对BN的问题,最近两年又陆续有基于BN思想的很多改进Normalization模型被提出。 BN是深度学习进展中里程碑式的工作之一,无论是希望深入了解深度学习,还是在实践中解决实际问题,BN及一系列改进Normalization工作都是绕不开的重要环节。
  • 2018-NIPS-How Does Batch Normalization Help Optimization : 论文链接:https://arxiv.org/pdf/1805.11604v3.pdf
摘要:批归一化(BatchNorm)是一种广泛采用的技术,用于更快速、更稳定地训练深度神经网络(DNN)。尽管应用广泛,但 BatchNorm 有效的确切原因我们尚不清楚。 人们普遍认为,这种效果源于在训练过程中控制层输入分布的变化来减少所谓的「内部协方差偏移」。本文证明这种层输入分布稳定性与 BatchNorm 的成功几乎没有关系。 相反,我们发现 BatchNorm 会对训练过程产生更重要的影响:它使优化解空间更加平滑了。这种平滑使梯度更具可预测性和稳定性,从而使训练过程更快。

 

Batch Normalization是由google提出的一种训练优化方法。参考论文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift

个人觉得BN层的作用是加快网络学习速率,论文中提及其它的优点都是这个优点的副产品。
网上对BN解释详细的不多,大多从原理上解释,没有说出实际使用的过程,这里从what, why, how三个角度去解释BN。

What is BN

Normalization是数据标准化(归一化,规范化),Batch 可以理解为批量,加起来就是批量标准化。

先说Batch是怎么确定的。在CNN中,Batch就是训练网络所设定的图片数量batch_size。

Normalization过程,引用论文中的解释:

这里写图片描述
输入:输入数据x1..xm(这些数据是准备进入激活函数的数据)
计算过程中可以看到,
1.求数据均值;
2.求数据方差;
3.数据进行标准化(个人认为称作正态化也可以)
4.训练参数γ,β
5.输出y通过γ与β的线性变换得到原来的数值
在训练的正向传播中,不会改变当前输出,只记录下γ与β

在反向传播的时候,根据求得的γ与β通过链式求导方式,求出学习速率以至改变权值

这里写图片描述

Why is BN

解决的问题是梯度消失与梯度爆炸。

关于梯度消失,以sigmoid函数为例子,sigmoid函数使得输出在[0,1]之间。
这里写图片描述
事实上x到了一定大小,经过sigmoid函数的输出范围就很小了,参考下图
这里写图片描述
如果输入很大,其对应的斜率就很小,我们知道,其斜率(梯度)在反向传播中是权值学习速率。所以就会出现如下的问题,
这里写图片描述
在深度网络中,如果网络的激活输出很大,其梯度就很小,学习速率就很慢。假设每层学习梯度都小于最大值0.25,网络有n层,因为链式求导的原因,第一层的梯度小于0.25的n次方,所以学习速率就慢,对于最后一层只需对自身求导1次,梯度就大,学习速率就快。
这会造成的影响是在一个很大的深度网络中,浅层基本不学习,权值变化小,后面几层一直在学习,结果就是,后面几层基本可以表示整个网络,失去了深度的意义。

关于梯度爆炸,根据链式求导法,

第一层偏移量的梯度=激活层斜率1x权值1x激活层斜率2x…激活层斜率(n-1)x权值(n-1)x激活层斜率n
假如激活层斜率均为最大值0.25,所有层的权值为100,这样梯度就会指数增加。

How to use BN

先解释一下对于图片卷积是如何使用BN层。

这里写图片描述
这是文章卷积神经网络CNN(1)中5x5的图片通过valid卷积得到的3x3特征图(粉红色)。特征图里的值,作为BN的输入,也就是这9个数值通过BN计算并保存γ与β,通过γ与β使得输出与输入不变。假设输入的batch_size为m,那就有m*9个数值,计算这m*9个数据的γ与β并保存。正向传播过程如上述,对于反向传播就是根据求得的γ与β计算梯度。
这里需要着重说明2个细节:
1.网络训练中以batch_size为最小单位不断迭代,很显然,新的batch_size进入网络,机会有新的γ与β,因此,在BN层中,有总图片数/batch_size组γ与β被保存下来。
2.图像卷积的过程中,通常是使用多个卷积核,得到多张特征图,对于多个的卷积核需要保存多个的γ与β。

结合论文中给出的使用过程进行解释

这里写图片描述
输入:待进入激活函数的变量
输出:
1.对于K维的输入,假设每一维包含m个变量,所以需要K个循环。每个循环中按照上面所介绍的方法计算γ与β。这里的K维,在卷积网络中可以看作是卷积核个数,如网络中第n层有64个卷积核,就需要计算64次。
需要注意,在正向传播时,会使用γ与β使得BN层输出与输入一样。
2.在反向传播时利用γ与β求得梯度从而改变训练权值(变量)。
3.通过不断迭代直到训练结束,求得关于不同层的γ与β。如网络有n个BN层,每层根据batch_size决定有多少个变量,设定为m,这里的mini-batcherB指的是特征图大小*batch_size,即m=特征图大小*batch_size,因此,对于batch_size为1,这里的m就是每层特征图的大小。
4.不断遍历训练集中的图片,取出每个batch_size中的γ与β,最后统计每层BN的γ与β各自的和除以图片数量得到平均直,并对其做无偏估计直作为每一层的E[x]与Var[x]。
5.在预测的正向传播时,对测试数据求取γ与β,并使用该层的E[x]与Var[x],通过图中11:所表示的公式计算BN层输出。
注意,在预测时,BN层的输出已经被改变,所以BN层在预测的作用体现在此处

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):    # 通过 autograd 来判断当前模式为训练模式或预测模式。    if not autograd.is_training():        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差。        X_hat = (X - moving_mean) / nd.sqrt(moving_var + eps)    else:        assert len(X.shape) in (2, 4)        if len(X.shape) == 2:            # 使用全连接层的情况,计算特征维上的均值和方差。            mean = X.mean(axis=0)            var = ((X - mean) ** 2).mean(axis=0)        else:            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要            # 保持 X 的形状以便后面可以做广播运算。            mean = X.mean(axis=(0, 2, 3), keepdims=True)            var = ((X - mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)        # 训练模式下用当前的均值和方差做标准化。        X_hat = (X - mean) / nd.sqrt(var + eps)        # 更新移动平均的均值和方差。        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean        moving_var = momentum * moving_var + (1.0 - momentum) * var    Y = gamma * X_hat + beta  # 拉升和偏移。    return Y, moving_mean, moving_var

 

你可能感兴趣的文章
使用VUE时,在html中需要将单参数转化为多参数时
查看>>
当微信小程序遇上TensorFlow:终章
查看>>
Hystrix都停更了,我为什么还要学?
查看>>
nodejs中流(stream)的理解之可读流
查看>>
java面试题总结(开发者必备)
查看>>
Block 形式的通知中心观察者是否需要手动注销
查看>>
CSS进阶——绝对定位元素的宽高是如何定义的
查看>>
认证鉴权与API权限控制在微服务架构中的设计与实现(二)
查看>>
Android从零撸美团(四) - 美团首页布局解析及实现 - Banner+自定义View+SmartRefreshLayout下拉刷新上拉加载更多...
查看>>
Android Paging分页库的学习(一)—— 结合本地数据进行分页加载
查看>>
自己手写一个 SpringMVC 框架
查看>>
在linux中安装mysql并解决中文乱码问题
查看>>
Beego框架的一条神秘日志引发的思考
查看>>
项目开发框架-SSM
查看>>
Kotlin 开发中文周报 —— 100
查看>>
关于为什么学习React Native三点原因
查看>>
Vector、ArrayList、LinkList集合框架的使用与理解
查看>>
Dagger2 知识梳理(2) @Qulifier 和 @Named 解决依赖注入迷失
查看>>
golang和任意一种windows开发语言之间的进程内通信到底难不难吖
查看>>
多Activity切换的生命周期问题
查看>>