打开网易新闻 查看精彩图片

前段时间拜读了微信AI在OCR领域的一篇技术报告,其中有一个文本识别方向挺有意思的实践是在训练识别网络的时候,利用CNN+BLSTM提取文本行的序列特征,同时采用muti-head的结构,在训练时,以CTC为主,Attention Decoder和ACE辅助训练。在预测时,考虑到速度和性能,只采用CTC进行解码预测,具体的网络结构如下。

打开网易新闻 查看精彩图片

图1 muti-head ctc+attetion+ace识别网络

笔者最近正在进行复现该技术报告的相关工作,在此之前,因为对ACE loss的了解并不深入,特地通读了一下ACE loss的原论文。所以该篇博文,主要针对ACE loss的具体原理做一个解读。由于笔者的能力有限,如果有解读不到位或者是错误的地方,还望各位多多指正。

论文的原文链接如下:

https://arxiv.org/abs/1904.08364

GIT链接如下:

https://github.com/summerlvsong/Aggregation-Cross-Entropy

ACE loss 简介

目前主流的OCR pipeline基本上都是采用检测+识别两步走的方式完成的,导致现有文本识别任务中必须解决针对不定长文本行的识别问题。 不定长文本行的文本识别,本质上是一个序列识别的问题。

针对序列识别的问题,两种业界比较主流的做法是从ASR任务中借鉴来的 CRNN+CTC方式,或者是从机器翻译任务中借鉴来的seq2seq+attention方式。

CRNN+CTC只能解决1-D的序列识别问题,在长文本识别,中文识别任务中表现出来了不错的效果。同时,得益于CTC计算中的前向后向递推迭代计算方式,使得其在运行效率上也有不俗的表现。但当文本行的形变较大时,CTC的效果就会受到比较大的影响。

Seq2Seq+attention的识别方式,原则上能够解决2-D的序列识别问题,但受限于RNN网络在长序列识别中的局限性,以及seq2seq的串行机制,导致这种方式,在长序列文本识别和运行效率上的表现并不十分尽如人意。

ACE loss就是为了弥补这两种方式的设计缺陷提出的,全称是Aggregation Cross-Entropy聚合交叉熵。文章中描述ACE不仅能够解决2-D文本的识别问题,还在时间复杂度和空间复杂度上优于CTC loss。

打开网易新闻 查看精彩图片

图2 三种算法的对比

具体实现细节

那么这个ACE loss 究竟有什么优势,又是怎么实现的呢?

我们不难知道,网络经过CNN+BLSTM提取特征+softmax之后,会得到一个 的后验概率矩阵,其中 是需要识别的字符集合长度,T是序列的长度,如下图所示。

打开网易新闻 查看精彩图片

图3 网络输出的后验概率矩阵

定义文本行的annotation为 , 为 的长度, 第 个位置的字符记为 ,文本行图片记为 ,训练集记为 。那么某一个满足输入图片为 ,网络权重为 ,解码序列为 的序列概率可以表示为:

打开网易新闻 查看精彩图片

不难得知,所有序列识别网络的目的,都是使得目标序列的整体概率最大,因此序列识别网络的loss,基本都是基于这个原则设计出来的。同时,为了降低loss计算的复杂度,我们一般会把序列概率中的乘法运算,换成加法运算,因此就有:

打开网易新闻 查看精彩图片

那么该网络的整体loss,可以设计成如下形式:

打开网易新闻 查看精彩图片

LOSS的形式虽然简单,但想要从网络encodding的后验概率矩阵中提取出满足序列输出结果为 的所有链路的集合并不是一件容易的事,主要矛盾在于序列长度 和目标输出 的长度 在大多数情况下是不一样的,当 和 的差别越大,符合条件的链路就会越多。在CTC loss 中,作者巧妙提出了一种动态规划的思路来求解这一问题的精确解,而ACE作者的思路则有些清奇,既然无法快速的求解这一问题的精确解,那么就用一种快速的方法来求解这一问题的近似解。

具体做法是,通过监督 长度的序列内,每个字符的出现次数,与 label 中每个字符出现次数的关系,来估计 的值,其本质是忽略了序列的位置属性,将一个排列问题抽象成组合问题。

打开网易新闻 查看精彩图片

注:这里的 。

举个例子来理解一下,我们以序列长度 , 标注为cocacola的文本行实例为例,如下图:

打开网易新闻 查看精彩图片

图4

cocacola文本实例的长度 为8,其中共有4个字符出现,其中 出现2次, 出现3次, 出现1次, 出现2次。序列的长度 为10,要满足序列输出结果为cocacola,则必然会有 10-8=2 个节点会被预测为空(这里有点像CTC的Blank)。我们用 来表示 字符在 中的出现次数,不难得到: ,用 表示 字符在 中的出现频率,则 ,其中 表示空字符。

同时,我们从CRNN网络encoding的后验矩阵中,求取每一个字符 在 时间内出现次数的平均概率,如下图:

打开网易新闻 查看精彩图片

图5

其中 为 时刻 字符的出现概率, 为 在长度为 的序列中出现的平均概率:

打开网易新闻 查看精彩图片

到这里问题就变成了一个回归问题,使得 和 label 的分布尽可能地接近。

原文中作者提到了两种思路:(1)MSE:也就是L2范数 (2)交叉熵

(1)MSE:

该方法下损失函数 可以转化为这种形式:

打开网易新闻 查看精彩图片

特别地,我们可以得到 空字符 次数 。

在反向传播地过程中,我们考虑对于每一个输入图片 ,我们求取 对于 时刻预测为k字符的节点 的偏导:

打开网易新闻 查看精彩图片

因为encoding输出的时候的激活函数是softmax,所以我们考虑softmax的输出形式又能够得到某一时刻,节点 与该时刻其他节点 的关系。

打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片

其中,当 时, ,当 时, 。

推导至此,我们可以得到此时的 相较于后验概率矩阵中每一个节点的偏导可以表示成如下的形式:

打开网易新闻 查看精彩图片

从上式的结果形式来看,基于L2范数的ACE损失(这种情况下是不是应该叫AL2损失,哈哈^~^), 会在训练过程中显现出严重的梯度消失问题。

作者解释,在训练的开始阶段, ,在识别字典的集合比较大的情况下, 非常大,这样就会导致 的值在训练的开始阶段比较小,尽管 的值是合理的,但两者在数量级上可能依然会有一定的悬殊,进而造成梯度消失的问题,使得训练无法收敛。

(2)交叉熵

该方法下损失函数 可以转化为这种形式:

打开网易新闻 查看精彩图片

同样的我们不难得到:

打开网易新闻 查看精彩图片

根据公式(5)(8)(9)(12),可以得到(13)(14)(15):

打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片

因为 和 的是同一个数量级,并且可以认为在训练的开始阶段他俩是近似相等的,都接近 ,因此梯度消失的问题,得到了很好的缓解,ACE loss由此诞生,先聚合次数,再进行交叉熵,这就是聚合交叉熵。

更进一步地,作者还在2-D序列预测的问题上进行了拓展。这也不难理解,无非就是将长度为 的序列问题,拓展为形状为 的序列预测问题,于是乎公式(11)可以作如下拓展,这里就不展开描写了。

打开网易新闻 查看精彩图片

总结

最后,小小的总结一下。

(1)ACE loss在时间复杂度和空间复杂度上要优于CTC loss, ACE loss的时间复杂度为 借助GPU的并行计算能力,时间复杂度可以来到可怕的 。而借助前向后向推理方式的CTC loss的时间复杂度为 。

(2)ACE loss支持2-D序列的预测

(3)因为多分类任务中softmax激活函数的存在,导致MSE loss会产生梯度消失问题,因此作者采用交叉熵作为loss

(4)本质上,ACE是一种对于序列问题求解的弱监督loss(毕竟不是直接求解序列的整体概率,而是通过序列中字符的出现次数来进行监督),所以在某些情况下可能会产生收敛困难的问题,看到不少的实践者都表明复现效果不理想,因此联合别的LOSS一起训练会取得更好的效果。

来源:知乎

作者:zzzzhkzzzz

深延科技|

打开网易新闻 查看精彩图片

深延科技成立于2018年1月,中关村高新技术企业,是拥有全球领先人工智能技术的企业AI服务专家。以计算机视觉、自然语言处理和数据挖掘核心技术为基础,公司推出四款平台产品——深延智能数据标注平台、深延AI开发平台、深延自动化机器学习平台、深延AI开放平台,为企业提供数据处理、模型构建和训练、隐私计算、行业算法和解决方案等一站式AI平台服务。