Customization of latent space in semi-supervised Variational AutoEncoder

半监督变分自动编码器中潜在空间的定制

https://www.sciencedirect.com/science/article/pii/S0167865523003288

https://github.com/an-seunghwan/EXoN_official

摘要

我们提出了一种新颖的变分自编码器 (VAE) 半监督学习方法,该方法通过我们的可解释编码器网络 (EXoN) 生成定制的潜在空间。定制涉及手动设计插值和结构约束,例如接近性,从而增强潜在空间的解释性。为了提高分类性能,我们引入了一种新的半监督分类方法,称为 SCI(软标签一致性插值)。将分类损失和 Kullback-Leibler 散度结合起来对于构建可解释的潜在空间至关重要。此外,生成样本的可变性由一个主动潜在子空间决定,该子空间有效地捕捉了独特的特征。我们使用 MNIST、SVHN 和 CIFAR-10 数据集进行了实验,结果表明,我们的方法生成了可解释的潜在空间,同时显著减少了分析潜在空间内表示模式所需的工作量。

关键词:变分自编码器 可解释的潜在空间 半监督 定制化

1. 引言

变分自编码器 (VAE) [1] 旨在实现两个主要目标:构建有意义的潜在空间并准确重建原始观测结果。在 VAE 框架中,神经网络通常用于参数化概率模型,充当观测域和潜在空间之间的映射。然而,由于不存在封闭形式的概率模型,直接应用最大似然估计并不合适。相反,通常采用变分贝叶斯方法来最大化证据下界 (ELBO)。这种方法允许在贝叶斯框架内有效估计模型的参数和潜在表示。

已经提出了许多技术用于 VAE 框架中的(半)监督学习 [2–5]。尽管它们取得了进展,但现有的半监督 VAE 模型仍然遇到实际限制。[2–4] 引入了额外的离散潜在空间来表示标签。然而,这些模型缺乏通过不同标签图像之间的插值生成新图像的定量度量。此外,使用他们提出的方法对潜在空间施加结构约束(例如,强制标记观测值之间的潜在特征接近性)是具有挑战性的。

为了增强潜在空间的解释性,我们提出了一种新颖的半监督 VAE 模型,该模型结合了插值和结构约束的手动设计。我们采用高斯混合模型作为先验和后验潜在分布 [6,7]。我们方法的关键概念是构建一个允许定制潜在空间的编码器,在保留接近信息的同时,建立每个混合成分与特定标签之间的对应关系。换句话说,混合分布的每个质心都被训练为一个标识符,代表一个特定的标签,而基于标签的手动分解潜在空间为用户提供了定制的、可解释的表示。与 [5] 中的方法不同,我们训练的潜在空间在不同的训练过程中保持一致。这意味着我们可以在不使用整个观测值探索潜在空间的情况下,始终从潜在空间中获得关于插值图像的有意义信息。

此外,我们引入了一种新颖的度量方法来评估潜在变量的表示能力,使我们能够检查特定潜在子空间的重要性。我们的研究结果表明,我们 VAE 模型中的编码器选择性地激活了潜在空间的一个子集,有效地捕捉了生成样本的独特特征。这种激活通过与编码器相关的后验方差进行量化。受基因生物学术语的启发,我们将我们的编码器命名为 EXoN(可解释编码器网络)。

2. 相关工作

零样本学习中的设计潜在空间

在零样本学习中,潜在空间的设计在表示多模态观测和促进准确的零样本学习中起着关键作用。[8,9] 旨在从不同模态学习联合潜在变量,并构建对齐的潜在空间。为了实现对齐,这些方法利用接近性度量,如最大均值差异、平方损失互信息和 Wasserstein 距离,以确保不同模态的潜在分布之间的一致性。我们提出的方法的一个显著区别在于,即使在经历不同的训练过程后,潜在空间结构与手动预设计的一致性仍然保持不变。这一独特特性使我们能够在不需要使用完整观测值广泛探索潜在空间的情况下,获得关于插值图像的有意义信息。

半监督 VAE 学习方法 在 [2] 中,标签被视为离散潜在变量,类似于 [3],与标签相关的后验分布作为分类器。然而,在推导 ELBO 时,分类器仅通过未标记数据集进行训练。为了解决这个问题,[2] 引入了一个额外的分类损失,利用标记数据集来增强分类器和后验推断的训练。[4] 提出了平滑 ELBO,它直接将标签信息集成到 ELBO 中。此外,[4] 通过引入基于数据空间中最佳插值的分类损失来提高半监督分类性能。另一方面,[5] 将潜在变量和标签的潜在空间集成在一起,将标签分类器纳入高斯先验分布的混合中。为了增强标签的后验推断,[5] 引入了 [2] 的额外监督分类损失。

半监督分类 已经提出了各种正则化技术来提高半监督分类的性能。-模型 [10] 利用暗知识的概念,通过最小化从同一输入获得的不同输出向量之间的平方差来促进预测一致性。虚拟对抗训练 (VAT) [11] 通过在虚拟对抗方向上分配相同的标签分布来训练模型,该方向最大化输出分布之间的差异。为了防止过度拟合模型预测,MixMatch [12] 和 PLCB [13] 利用通过数据增强和混合策略 [14] 获得的伪标签。

主动潜在维度 [15,16] 引入了统计方法来识别捕捉观测重要信息并在大数据生成中起关键作用的主动潜在维度。[15] 提出的统计方法源自对编码器分布参数的经验分析。另一方面,[16] 提出的统计方法特别适用于解码器的均值向量遵循仿射映射的情况。

3. 提议

3.1.模型假设

3.2. EXoN:半监督VAE

在我们的半监督VAE模型中,惩罚函数的应用采用了与[2]相似的理念。然而,规范化目标函数的推导源自于(2),这与现有研究不同。

3.3. SCI:软标签一致性插值

分类器对给定数据点的标签条件分布进行局部近似。如果分类器仅在小型标记数据集上进行训练,数据空间将包含许多空白区域。因此,分类器在这些空白区域的预测性能严重依赖于分类器的基本假设。因此,虽然(3)的推导在数学上是可行的,但它不能保证半监督分类的最先进性能。然而,我们可以通过使用大量未标记数据点进行线性插值来填补这些空白区域,从而减少分类器的方差。 因此,为了增强半监督分类的性能,我们引入了一种名为软标签一致性插值(SCI)损失的新损失函数。这种损失利用了线性插值和伪标记技术[12,13]。

SCI损失由三部分组成:(1)一个插值的新图像与一对未标记图像,(2)一对未标记图像的伪标签,以及(3)交叉熵的凸组合。

一致性插值意味着从图像到伪标签的线性映射 的存在。众所周知,这种混合策略可以改善泛化误差 [14]。有趣的是,在我们 VAE 模型中对 (⋅; ∗) 的估计也被其他现有的半监督学习方法 [4,13] 使用,以下算法提供了一个通用框架来估计 VAE 模型中的 (⋅; ∗)。设 () 是通过第 步训练 VAE 获得的 的估计值,(+1) 是以下最优插值问题的解 [4]

3.4. 活跃的潜在子空间

其中 [] = {1, … , }。我们的数值研究发现,该子空间代表了生成样本的信息特征,并且该子空间可以有效地用于生成高质量图像(见第 4.2 节)。这些结果与 VQ-VAE [20] 的结果一致,即图像生成过程主要依赖于近似确定性的编码映射,并且映射的细化可以提高生成图像的质量。

4. 实验

由于各种超参数的计算问题,我们选择了一些最能代表我们主张的离散值。值得注意的是,所有比较模型具有相同的潜在维度大小。为了可重复性,代码已公开发布,并可在 https://github.com/an-seunghwan/EXoN_official 获取。

4.1. MNIST 数据集

我们使用了 MNIST 数据集 [21] 来考虑我们 VAE 模型中的二维潜在空间。数值被缩放到 -1 到 1 的范围内。编码器返回 10 成分高斯混合分布参数、混合概率、均值向量和对角协方差元素。因此,编码器将  映射到 ((0, 1)。特别是混合概率由编码器中的分类器生成。Gumbel-Max 技巧 [22] 用于采样离散潜在变量。编码器、解码器和分类器的详细网络架构在附录 A.4.2 中描述。定制的 () 如图 1(a) 所示。每个标签从 3 点钟方向逆时针分配给混合成分。请注意,图 1(a) 展示了一个概念中心的示例。混合分布的第 个成分对应于 MNIST 数据集中数字 (− 1) 的分布,其中 = 1, …, 10(有关详细的预设计设置,请参见附录 A.4.1)。

我们的模型与 [2,4,5] 进行了比较,模拟结果显示,我们的拟合模型在 59,900 个未标记和 100 个标记图像上实现了有竞争力的分类性能,错误率为 3.23%(有关比较结果,请参见附录 A.2)。实现细节在附录 A.4.1 中描述。

4.1.1. 正则化的效果

首先,研究了 (7) 中调谐参数 在拟合潜在空间中的作用。图 1(b) 的上部面板显示了来自 (|; , ) 的样本,其中 是测试数据集中的一个观测值,下部面板显示了从潜在空间上的网格点生成的图像。上部面板表明,较大的 通过间接增加目标函数中 KL-散度项的权重,将 (|; , ) 正则化为 ()。下部面板显示,每个生成的图像完全匹配预设计潜在空间上定义的标签。此外,确认了生成的图像根据我们对概念中心的排列自然地插值在预设计潜在空间上。这些结果表明,所提出的 VAE 生成了带有标签的可解释潜在空间。有关各种 值的附加评估结果(负平均单尺度结构相似性 (negative SSIM) [23]、分类错误率和 KL-散度),请参见附录 A.2.1。

4.1.2. 生成图像的多样性

为了在大 下最大化 (7),(8) 应接近零,并且 | 和 | 之间的互信息对于所有类别 也应接近零。 和 的条件独立性意味着 (; ) 对于每个 不依赖于 ,因为假设 | ∼ ((; ), ⋅)。因此,如图 1(b) 的底部所示,当 较大时,特定标签的潜在混合成分无法捕捉属于相应标签的观测的复杂模式(见图 2)。

4.1.3. 定制插值

图 3 显示了从先验结构上的插值重建的两组图像。潜在变量 A 和 B 分别从标签 0 和 1 的混合成分中采样。假设从点 A 到 B 的插值路径产生带有标签 0 和 1 的插值图像。然而,图 3 的左侧面板显示了在插值路径中间的标签既不是 0 也不是 1 的重建图像。这意味着在用 Parted-VAE [5] 训练之前,插值路径是不可预测的。然而,由于我们的潜在空间可以手动设计,我们的模型插值路径仅由带有标签 0 和 1 的插值图像组成(见图 3 的右侧面板)。

此外,我们可以通过控制混合成分之间的接近度来预先确定插值路径和插值的分辨率(见下一小节)。

4.1.4. 控制接近度

我们研究了根据先验混合成分接近度的各种预设计,插值图像的模式。我们使用仅带有 0 和 1 标签的 MNIST 数据集子集,因此使用 2 成分高斯混合分布。所有高斯成分都具有对角协方差矩阵;它们的对角元素均为 4。我们将位置参数设置为 (−, 0), (, 0),这决定了接近度。我们将两个位置参数之间的距离设置为 8、16、24 和 32。图像从 2 维潜在空间上从 (−10, 0) 到 (10, 0) 的等间距线段的 11 个点生成。图 4 显示,如果 () 的两个位置参数彼此相距较远,插值图像的变化会更慢,潜在空间有效地适应了预设计的特征。

4.2. SVHN 和 CIFAR-10 数据集

我们将我们的 VAE 模型应用于 SVHN [24] 和 CIFAR-10 [25] 数据集,这些图像有十个标签。对于 SVHN 数据集,我们仅使用官方训练集中的 73,257 张图像,并使用每类 100 张标记图像评估我们的模型。对于 CIFAR-10 数据集,使用每类 400 张标记图像和 46,000 张未标记图像。对于这两个数据集,图像的所有值都被缩放到 (−1, 1) 的范围内。使用 256 维潜在空间和 10 混合高斯分布。先验混合分布的每个成分具有单独的均值向量,所有成分共享相同的协方差(有关详细的预设计先验、超参数和实现设置,请参见附录 A.4.1)。Gumbel-Max 技巧 [22] 用于采样离散潜在变量。编码器、解码器和分类器的网络架构如附录 A.4.3 所示。

4.2.1. 比较

表 1 提供了我们提出的模型在 SVHN 和 CIFAR-10 数据集上半监督分类性能及其图像生成能力的全面定量比较。由于所考虑的分类模型不是生成性的,因此不适用于 Inception Score [26]。SVHN 和 CIFAR-10 数据集上的测试错误率表明,EXoN 结合 SCI 损失在半监督分类中实现了有竞争力的性能。虽然 EXoN 在 CIFAR-10 数据集上可能没有达到最高的 Inception Score,但我们的模型在恢复和插值性能方面优于其他模型,如图 5 和图 2 所示(有关我们模型生成的其他图像,请参见附录 A.3.3)。

4.2.2. 活跃的潜在子空间

5. 结论和局限性

在这项研究中,我们提出了一种使用给定标签数据定制VAE模型潜在空间的通用方法。以往的研究,如半监督VAE、迁移学习和零样本学习,仅在表面上通过结构损失函数描述了潜在空间的定制。相比之下,本文解释了所提出的方法是如何通过调整混合模型中KL散度的局部近似性能来构建潜在空间的。特别是,VAE中先验分布的多模态特性在有意义地解释潜在空间中发挥了重要作用。

同时,通过从潜在变量进行图像插值和恢复实验,我们能够确认所提出的模型与现有模型相比能够生成更加真实的图像。特别是,能够根据用户偏好手动设计潜在空间的能力是其他模型所不具备的显著优势。此外,从我们的模型中获得的理论结果部分解释了SCI损失的合理性,SCI损失已被用来提高现有模型的预测能力。

我们推测,与潜在空间中的特征向量一致对齐的增强数据可以增强用于分类任务的其他VAE模型的预测性能。

然而,我们的模型并没有提供在缺乏标签数据时潜在空间的可解释构造方法,并且在分析可以被多个因素划分的更复杂形式的数据时仍存在局限性。解码器中的概率分布学习以及为解释潜在空间自动生成标签预计将成为未来模型扩展的主要研究课题。

我们研究的另一个局限性在于缺乏通过视觉感知实验对我们一致性插值假设的经验验证。未来的研究可以考虑进行视觉感知研究,以提供关于我们方法中一致性插值假设的有效性和局限性的经验证据。

本研究中使用的数据集可通过以下链接访问:

  • MNIST: https://www.tensorflow.org/datasets/catalog/mnist

  • SVHN: http://ufldl.stanford.edu/housenumbers

  • CIFAR-10: https://www.tensorflow.org/datasets/catalog/cifar10

此外,GitHub 仓库 https://github.com/an-seunghwan/EXoN_official 包含了详细代码,用于复制本文中提出的模型和分析,使用上述数据集,并附有全面的使用说明。

https://www.sciencedirect.com/science/article/pii/S0167865523003288

https://github.com/an-seunghwan/EXoN_official