前段时间拜读了微信AI在OCR领域的一篇技术报告,其中有一个文本识别方向挺有意思的实践是在训练识别网络的时候,利用CNN+BLSTM提取文本行的序列特征,同时采用muti-head的结构,在训练时,以CTC为主,Attention Decoder和ACE辅助训练。在预测时,考虑到速度和性能,只采用CTC进行解码预测,具体的网络结构如下。
![](http://dingyue.ws.126.net/2022/0303/c24ecafbj00r8632r001ed200oz00nwg00oz00nw.jpg)
图1 muti-head ctc+attetion+ace识别网络
笔者最近正在进行复现该技术报告的相关工作,在此之前,因为对ACE loss的了解并不深入,特地通读了一下ACE loss的原论文。所以该篇博文,主要针对ACE loss的具体原理做一个解读。由于笔者的能力有限,如果有解读不到位或者是错误的地方,还望各位多多指正。
论文的原文链接如下:
https://arxiv.org/abs/1904.08364
GIT链接如下:
https://github.com/summerlvsong/Aggregation-Cross-Entropy
ACE loss 简介
目前主流的OCR pipeline基本上都是采用检测+识别两步走的方式完成的,导致现有文本识别任务中必须解决针对不定长文本行的识别问题。 不定长文本行的文本识别,本质上是一个序列识别的问题。
针对序列识别的问题,两种业界比较主流的做法是从ASR任务中借鉴来的 CRNN+CTC方式,或者是从机器翻译任务中借鉴来的seq2seq+attention方式。
CRNN+CTC只能解决1-D的序列识别问题,在长文本识别,中文识别任务中表现出来了不错的效果。同时,得益于CTC计算中的前向后向递推迭代计算方式,使得其在运行效率上也有不俗的表现。但当文本行的形变较大时,CTC的效果就会受到比较大的影响。
Seq2Seq+attention的识别方式,原则上能够解决2-D的序列识别问题,但受限于RNN网络在长序列识别中的局限性,以及seq2seq的串行机制,导致这种方式,在长序列文本识别和运行效率上的表现并不十分尽如人意。
ACE loss就是为了弥补这两种方式的设计缺陷提出的,全称是Aggregation Cross-Entropy聚合交叉熵。文章中描述ACE不仅能够解决2-D文本的识别问题,还在时间复杂度和空间复杂度上优于CTC loss。
![](http://dingyue.ws.126.net/2022/0303/a7f9c05ej00r8632r001wd200u000e0g00u000e0.jpg)
图2 三种算法的对比
具体实现细节
那么这个ACE loss 究竟有什么优势,又是怎么实现的呢?
我们不难知道,网络经过CNN+BLSTM提取特征+softmax之后,会得到一个 的后验概率矩阵,其中 是需要识别的字符集合长度,T是序列的长度,如下图所示。
![](http://dingyue.ws.126.net/2022/0303/461a2b3cj00r8632s0012d200dt00nqg00dt00nq.jpg)
图3 网络输出的后验概率矩阵
定义文本行的annotation为 , 为 的长度, 第 个位置的字符记为 ,文本行图片记为 ,训练集记为 。那么某一个满足输入图片为 ,网络权重为 ,解码序列为 的序列概率可以表示为:
![](http://dingyue.ws.126.net/2022/0303/b229dd38j00r8632w0006d200pm0031g00g2001w.jpg)
不难得知,所有序列识别网络的目的,都是使得目标序列的整体概率最大,因此序列识别网络的loss,基本都是基于这个原则设计出来的。同时,为了降低loss计算的复杂度,我们一般会把序列概率中的乘法运算,换成加法运算,因此就有:
![](http://dingyue.ws.126.net/2022/0303/9ebed293j00r8632w000cd200r7002xg00g2001p.jpg)
那么该网络的整体loss,可以设计成如下形式:
![](http://dingyue.ws.126.net/2022/0303/07ee862fj00r8632w000dd200om005rg00g2003r.jpg)
LOSS的形式虽然简单,但想要从网络encodding的后验概率矩阵中提取出满足序列输出结果为 的所有链路的集合并不是一件容易的事,主要矛盾在于序列长度 和目标输出 的长度 在大多数情况下是不一样的,当 和 的差别越大,符合条件的链路就会越多。在CTC loss 中,作者巧妙提出了一种动态规划的思路来求解这一问题的精确解,而ACE作者的思路则有些清奇,既然无法快速的求解这一问题的精确解,那么就用一种快速的方法来求解这一问题的近似解。
具体做法是,通过监督 长度的序列内,每个字符的出现次数,与 label 中每个字符出现次数的关系,来估计 的值,其本质是忽略了序列的位置属性,将一个排列问题抽象成组合问题。
![](http://dingyue.ws.126.net/2022/0303/cc82c9dbj00r86330000ed200nu006cg00g20049.jpg)
注:这里的 。
举个例子来理解一下,我们以序列长度 , 标注为cocacola的文本行实例为例,如下图:
![](http://dingyue.ws.126.net/2022/0303/32e969f5j00r86331000bd2007q008hg007q008h.jpg)
图4
cocacola文本实例的长度 为8,其中共有4个字符出现,其中 出现2次, 出现3次, 出现1次, 出现2次。序列的长度 为10,要满足序列输出结果为cocacola,则必然会有 10-8=2 个节点会被预测为空(这里有点像CTC的Blank)。我们用 来表示 字符在 中的出现次数,不难得到: ,用 表示 字符在 中的出现频率,则 ,其中 表示空字符。
同时,我们从CRNN网络encoding的后验矩阵中,求取每一个字符 在 时间内出现次数的平均概率,如下图:
![](http://dingyue.ws.126.net/2022/0303/6c5bab87j00r86338000id200hg006sg00hg006s.jpg)
图5
其中 为 时刻 字符的出现概率, 为 在长度为 的序列中出现的平均概率:
![](http://dingyue.ws.126.net/2022/0303/68fa2c37j00r8633b0002d200gs0027g00ee001v.jpg)
到这里问题就变成了一个回归问题,使得 和 label 的分布尽可能地接近。
原文中作者提到了两种思路:(1)MSE:也就是L2范数 (2)交叉熵
(1)MSE:
该方法下损失函数 可以转化为这种形式:
![](http://dingyue.ws.126.net/2022/0303/6f83f739j00r8633d000ed200qm003ag00g2001z.jpg)
特别地,我们可以得到 空字符 次数 。
在反向传播地过程中,我们考虑对于每一个输入图片 ,我们求取 对于 时刻预测为k字符的节点 的偏导:
![](http://dingyue.ws.126.net/2022/0303/9a3e502bj00r8633f0009d200nz002yg00g2001y.jpg)
因为encoding输出的时候的激活函数是softmax,所以我们考虑softmax的输出形式又能够得到某一时刻,节点 与该时刻其他节点 的关系。
![](http://dingyue.ws.126.net/2022/0303/c059413fj00r8633f0003d200i0002jg00g20029.jpg)
![](http://dingyue.ws.126.net/2022/0303/a9fb44d5j00r8633g0009d200n1002yg00g20021.jpg)
其中,当 时, ,当 时, 。
推导至此,我们可以得到此时的 相较于后验概率矩阵中每一个节点的偏导可以表示成如下的形式:
![](http://dingyue.ws.126.net/2022/0303/d468ead8j00r8633h000ld200mj009ng00e90063.jpg)
从上式的结果形式来看,基于L2范数的ACE损失(这种情况下是不是应该叫AL2损失,哈哈^~^), 会在训练过程中显现出严重的梯度消失问题。
作者解释,在训练的开始阶段, ,在识别字典的集合比较大的情况下, 非常大,这样就会导致 的值在训练的开始阶段比较小,尽管 的值是合理的,但两者在数量级上可能依然会有一定的悬殊,进而造成梯度消失的问题,使得训练无法收敛。
(2)交叉熵
该方法下损失函数 可以转化为这种形式:
![](http://dingyue.ws.126.net/2022/0303/542138e9j00r8633j000ed200q5003cg00g20021.jpg)
同样的我们不难得到:
![](http://dingyue.ws.126.net/2022/0303/586d5defj00r8633j000cd200ml003ag00g2002b.jpg)
根据公式(5)(8)(9)(12),可以得到(13)(14)(15):
![](http://dingyue.ws.126.net/2022/0303/0ea66691j00r8633k000id200m70097g00d9005h.jpg)
![](http://dingyue.ws.126.net/2022/0303/3f22dfc9j00r8633k0009d200n9003ag00eb0020.jpg)
因为 和 的是同一个数量级,并且可以认为在训练的开始阶段他俩是近似相等的,都接近 ,因此梯度消失的问题,得到了很好的缓解,ACE loss由此诞生,先聚合次数,再进行交叉熵,这就是聚合交叉熵。
更进一步地,作者还在2-D序列预测的问题上进行了拓展。这也不难理解,无非就是将长度为 的序列问题,拓展为形状为 的序列预测问题,于是乎公式(11)可以作如下拓展,这里就不展开描写了。
![](http://dingyue.ws.126.net/2022/0303/6d5e3dffj00r8633n000bd200jk0065g00dz004d.jpg)
总结
最后,小小的总结一下。
(1)ACE loss在时间复杂度和空间复杂度上要优于CTC loss, ACE loss的时间复杂度为 借助GPU的并行计算能力,时间复杂度可以来到可怕的 。而借助前向后向推理方式的CTC loss的时间复杂度为 。
(2)ACE loss支持2-D序列的预测
(3)因为多分类任务中softmax激活函数的存在,导致MSE loss会产生梯度消失问题,因此作者采用交叉熵作为loss
(4)本质上,ACE是一种对于序列问题求解的弱监督loss(毕竟不是直接求解序列的整体概率,而是通过序列中字符的出现次数来进行监督),所以在某些情况下可能会产生收敛困难的问题,看到不少的实践者都表明复现效果不理想,因此联合别的LOSS一起训练会取得更好的效果。
来源:知乎
作者:zzzzhkzzzz
|深延科技|
深延科技成立于2018年1月,中关村高新技术企业,是拥有全球领先人工智能技术的企业AI服务专家。以计算机视觉、自然语言处理和数据挖掘核心技术为基础,公司推出四款平台产品——深延智能数据标注平台、深延AI开发平台、深延自动化机器学习平台、深延AI开放平台,为企业提供数据处理、模型构建和训练、隐私计算、行业算法和解决方案等一站式AI平台服务。
热门跟贴