Free-form Flows: Make Any Architecture a Normalizing Flow

模型在生成稳定分子的速度超过之前模型两个数量级。见第5节。

目标是,本文介绍的方法将允许从业者将更多时间用于将领域知识整合到他们的模型中,并允许通过最大限度似然估计解决更多问题。

是否可以将FFF解释为VAE。在附录A.2中,我们提供了一个论点,即它可以,但它具有非常灵活的后验分布

与最近基于样条spline-based的和基于ODE的正规化流相竞争的性能

打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片

此作者之前的论文:

自由形态流:使任何架构成为正规化流Normalizing Flow

摘要

正规化流Normalizing Flow是直接最大化似然的生成模型。以前,正规化流的设计在很大程度上受到分析可逆性需求的限制。我们通过一种训练过程克服了这一限制,该过程使用变量变换公式梯度的高效估计器。这使得任何保持维度的神经网络都可以通过最大似然训练作为生成模型。我们的方法允许将重点放在精确地调整归纳偏差以适应手头任务。具体来说,我们在分子生成基准测试中取得了优异的结果,利用E(n)-等变网络大大改善了采样速度。此外,我们的方法在逆问题基准测试中具有竞争力,同时采用现成的ResNet架构。我们在 https://github.com/vislearn/FFF 上发布了我们的代码。

1 引言

生成模型已经在各种应用中积极展示了它们的实用性,成功地扩展到高维数据分布的场景,从图像合成到分子生成(Rombach等人,2022; Hoogeboom等人,2022)。正规化流(Dinh等人,2015; Rezende和Mohamed,2015)推动了这一进展,特别是在科学领域,使从业者能够直接优化数据似然,从而促进了学习复杂数据分布的统计严谨方法。限制正规化流Normalizing Flow的其他生成模型(特别是扩散模型)的表达力和受欢迎程度的主要因素是,它们的表达力受到架构约束的极大限制,即确保双射性和计算雅可比行列式的必要性

在这项工作中,我们提出了一种方法,使正规化流摆脱了传统的架构限制,因此引入了一类灵活的新的最大限度似然模型。对于模型构建者来说,这将重点从满足可逆性要求转移到结合最佳归纳偏差以解决手头问题。我们的目标是,本文介绍的方法将允许从业者将更多时间用于将领域知识整合到他们的模型中,并允许通过最大限度似然估计解决更多问题。

关键的方法论创新是将最近提出的一种用于训练自编码器的方法(Sorrenson等人,2024)适应到保持维度的模型上。诀窍是通过编码器和解码器雅可比的一对向量-雅可比和雅可比-向量乘积来估计编码器雅可比行列式的梯度,这些乘积在标准自动微分软件库中很容易获得。我们展示了在全维背景下,许多困扰瓶颈自编码器模型解释的理论困难消失了,优化可以被解释为正规化流Normalizing Flow训练的放松,这在原始解上是紧密的。

在分子生成中,旋转等变性已经被证明是一个关键的归纳偏差,我们的方法优于传统的正规化流,并且比以往的方法快一个数量级以上的速度生成有效样本。此外,基于模拟的推断(SBI simulation-based inference (SBI))的实验强调了模型的多功能性。我们发现,我们的训练方法在最小化微调要求的情况下取得了竞争性能。

总结我们的贡献如下:

  • 我们通过引入最大限度似然训练来去除正规化流Normalizing Flow的所有架构约束。我们称我们的模型为自由形态流(FFF free-form flow),见图1和第3节。

  • 我们证明了训练在重建损失最小的情况下与传统正规化流优化具有相同的最小值,见第4节。

  • 我们在逆问题和分子生成基准测试上展示了最小化微调的竞争性能,超越了基于ODE的模型。扩散模型相比,我们的模型生成稳定分子的速度超过两个数量级。见第5节。

2 相关工作

正规化流通常依赖于专门的架构,这些架构是可逆的,并且具有易于管理的雅可比行列式(见第3.1节)。见Papamakarios等人(2021);Kobyzev等人(2021)的概述。

一类工作通过将简单层(耦合块)串联起来构建可逆架构,这些层很容易逆转,并且具有三角形的雅可比矩阵,这使得计算行列式变得容易(Dinh等人,2015)。通过堆叠许多层及其通用性已经在理论上得到了确认(Huang等人,2020;Teshima等人,2020;Koehler等人,2021;Draxler等人,2022,2023)。已经提出了许多耦合块的选择,如MAF(Papamakarios等人,2017)、RealNVP(Dinh等人,2017)、Glow(Kingma和Dhariwal,2018)、神经样条流(Durkan等人,2019),见Kobyzev等人(2021)的概述。与分析可逆性不同,我们的模型依赖于重建损失来强制近似可逆性。

另一线工作通过使用ResNet结构并限制每个残差层的Lipschitz常数来确保可逆性(Behrmann等人,2019;Chen等人,2019)。类似地,神经ODE(Chen等人,2018;Grathwohl等人,2019)采用ResNets的连续极限,保证在温和条件下的可逆性。这些模型在训练期间需要评估多个步骤,因此变得相当昂贵。此外,雅可比行列式必须估计,增加了开销。像这些方法一样,我们必须估计雅可比行列式的梯度,但可以更有效地做到这一点。流匹配Lipman等人(2023);Liu等人(2023);Albergo和Vanden-Eijnden(2023)提高了这些连续正规化流的训练速度和质量,但仍然涉及昂贵的多步骤采样过程。从构造上讲,我们的方法由单一模型评估组成,我们对架构没有限制,除了由任务手头指示的归纳偏差。

两个有趣的方法(Gresele等人,2020;Keller等人,2021)计算或估计雅可比行列式的梯度,但严格限制在仅包含纯方阵权重矩阵且没有残差块的架构中。我们除了保持维度外没有架构限制。中间激活和权重矩阵可以有任何维度,允许任何网络拓扑。

3 方法

3.1 正规化流

正规化流(Rezende和Mohamed,2015)是一类生成模型,它们学习一个可逆函数fθ(x) : RD → RD,将来自给定数据分布q(x)的样本x映射到潜在代码z。目标是让z遵循一个简单的目标分布,通常是多变量标准正态分布。

从生成模型pθ(x)中获得的样本通过将简单目标分布p(z)的样本通过学习到的函数的逆映射来获得:

(下图截图)

这需要一个可处理的逆函数。传统上,这是通过可逆层(如耦合块)(Dinh等人,2015)或以其他方式限制函数类来实现的。我们通过一个简单的重建损失来替换这个约束,并学习第二个函数gϕ ≈ f−1 θ 作为确切逆函数的近似

需要一个可处理的雅可比行列式的行列式,以考虑密度的变化。因此,模型似然的值由可逆函数的变量变换公式给出:

pθ(x) = p(Z = fθ(x))|Jθ(x)|。1

打开网易新闻 查看精彩图片

这里,Jθ(x)表示fθ在x处的雅可比矩阵,|·|表示其行列式的绝对值。

正规化流通过最小化真实分布和学习分布之间的Kullback-Leibler (KL) 散度来训练。这等价于最大化训练数据的似然:

DKL(q(x)∥pθ(x)) = Ex∼q(x)[log q(x) − log pθ(x)] = Ex[− log p(fθ(x)) − log |Jθ(x)|] + const。

通过等式(1),这需要在x处计算fθ的雅可比矩阵Jθ(x)的行列式。如果我们想准确计算这个值,我们需要计算完整的雅可比矩阵,这需要通过fθ进行D次反向传播,这对于大多数现代应用来说是禁止的。因此,正规化流文献的大部分内容都涉及到构建可逆架构,这些架构具有表现力并允许更有效地计算雅可比行列式的行列式。我们通过一个技巧来绕过这个问题,这个技巧允许我们有效地估计梯度∇θ log |Jθ(x)|,注意到这个量足以进行梯度下降

3.2 梯度技巧

本节的结果是对Caterini等人(2021)和Sorrenson等人(2024)的结果的改编。

这里,我们推导了如何有效地估计方程(2)中最大似然损失的梯度,即使架构不提供一种有效的方式来计算变量变换项log |Jθ(x)|。我们通过估计log |Jθ(x)|的梯度来避免这个计算,通过一对向量-雅可比和雅可比-向量乘积来估计,这些乘积在标准自动微分软件库中很容易获得。

梯度通过迹估计器 Gradient via trace estimator

定理3.1。让fθ : RD → RD是一个由θ参数化的C1可逆函数。那么,对于所有x ∈ RD:

∇θi log |Jθ(x)| = tr((∇θiJθ(x))(Jθ(x))−1)。

打开网易新闻 查看精彩图片

证明是通过直接应用雅可比公式,见附录A.1。这本身并不是一个简化,因为等式(3)的右侧现在涉及到计算雅可比矩阵及其逆矩阵。然而,我们可以通过Hutchinson迹估计器来估计它(这里我们为了简单省略了对x的依赖):

tr((∇θiJθ)J−1 θ ) = Ev[vT(∇θiJθ)J−1 θ v] ≈ 1/K ∑k=1K vT k (∇θiJθ)J−1 θ vk。

现在我们需要计算的就是点积vT(∇θiJθ)和J−1 θ v,其中随机向量v ∈ RD必须具有单位协方差。

通过函数逆矩阵求逆矩阵 Matrix inverse via function inverse

为了计算J−1 θ v,我们注意到,当fθ是可逆的时,fθ的雅可比矩阵的逆矩阵是逆函数f−1 θ的雅可比矩阵:

J−1 θ (x) = (∇xfθ(x))−1 = ∇zf −1 θ (z = fθ(x))。

这意味着J−1 θ v只是一个与向量v的雅可比矩阵f−1 θ的点积。这个雅可比-向量乘积可以通过前向自动微分很容易地获得

使用stop-gradient

打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片

4 理论

在本节中,我们提供了三个定理,这些定理强调了我们方法的有效性。首先,我们展示了使用精确逆的Lf^-1优化是数据分布和生成分布之间扩散散度的一个界限。其次,我们展示了在什么条件下,放松的Lg(使用非精确逆的损失)的梯度等于Lf^-1的梯度。最后,也是最重要的,我们展示了Lf^-1的解是最大似然解,其中pθ(x) = q(x)。此外,Lf^-1的每个临界点也是Lg的临界点,这意味着优化Lg在实践中等同于优化Lf^-1,除了一些额外的临界点,我们认为这些临界点在实践中并不重要。请参考附录A以获取本节结果的详细推导和证明。

4.1 损失推导

除了前几节中给出的直观发展,Lf^-1(等式(7))可以严格地推导为数据的噪声版本和模型的噪声版本之间的KL散度的界限,称为扩散KL散度(Zhang等人,2020)。这个界限是一种证据下界(ELBO)的形式,如VAEs(Kingma和Welling,2014)中所使用的。

打开网易新闻 查看精彩图片

由于上述推导类似于ELBO,我们可以问是否可以将FFF解释为VAE。在附录A.2中,我们提供了一个论点,即它可以,但它具有非常灵活的后验分布,与VAE后验中通常使用的简单分布(如高斯分布)形成对比。因此,它不会受到典型的VAE失败模式的影响,例如糟糕的重建和过度正则化。

请注意,上述定理陈述的是解码器分布pϕ(x)的结果,而不是用于激励损失函数的编码器分布pθ(x)。虽然这似乎起初是反生产的,但实际上,优化pϕ(x)以匹配数据分布比优化pθ(x)以匹配数据分布更有用,因为pϕ(x)是我们用于从数据生成的模型。无论如何,可以简单地证明D' ≥ DKL(q(x) || pθ(x)),其中D'具有与Lf^-1相同的梯度(见附录A.2),因此,随着优化的增加,编码器和解码器模型都将变得更接近数据分布。

打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片

4.3 临界点

以下定理陈述了我们的主要结果:优化Lg(等式(7))几乎等同于优化Lf^-1,并且Lf^-1的解是最大似然解,其中pθ(x) = q(x)。请注意,这是一个关于函数级别的结果:如果我们说f是Lf^-1的临界点,我们的意思是,向f添加任何无限小的偏差δf不会改变Lf^-1。这些最优解可能不在通过我们选择的网络架构下的梯度下降可达到的函数集中,并且特定的神经网络实现可能会引入未在定理中捕获的局部最小值。

打开网易新闻 查看精彩图片

定理 4.3. 假设fθ和gϕ是C1的,并且假设fθ是全局可逆的。假设q(x)是有限的并且在任何地方都有支撑。那么对于任何β > 0,Lf^-1的临界点(在函数级别上)是这样的:

  1. 对于所有z,gϕ(z) = f^-1_θ(z),

  2. 对于所有x,pθ(x) = q(x),

  3. 所有临界点都是全局最小值。

此外,Lf^-1的每个最小值是Lg的临界点。如果重建损失是最小的,Lg没有额外的临界点。

请注意,如果重建损失不是最小的,即fθ和gϕ不是全局可逆的,Lg可能会有额外的临界点。一个例子是当fθ和gϕ都是零函数,且q(x)的均值为零时。我们可以通过确保β足够大以不容许高重建损失来避免这样的解决方案。在附录B.4中,我们提供了如何在实践中选择β的指导。

图2提供了一个启发性的例子。在这里,数据和潜在空间是一维的,f和g是每个只有一个参数的简单线性函数。因此,我们可以在2D图中可视化梯度景观。我们看到原点处的额外临界点是一个鞍点:既有汇聚的梯度也有发散的梯度。在随机梯度下降中,我们不太可能收敛到鞍点,因为以确定性方式收敛到它的点集在参数空间中的测度为零。因此,在本例中,Lg将收敛到与Lf^-1相同的解。

此外,它具有更平滑的梯度景观(在 a = 0 时没有发散的梯度)。虽然这在这个简单的例子中可能不重要,但在更高维度的情况下,其中相邻区域的雅可比矩阵可能是不一致的(如果特征值有不同的符号),能够穿越雅可比矩阵奇异的区域而不必克服过大的梯度障碍是有益的。

打开网易新闻 查看精彩图片

5 实验

在本节中,我们展示了自由形态流(FFF)的实际能力。我们主要将性能与基于构造可逆的架构的正规化流进行比较。首先,在逆问题基准测试中,我们展示了使用自由形态架构提供了与最近基于样条spline-based的和基于ODE的正规化流相竞争的性能。尽管超参数调整很少,但这一成就展示了FFF易于适应新任务。其次,在两个分子生成基准测试中,我们展示了可以在正规化流中使用专门的网络。特别是,我们采用了等变图神经网络E(n)-GNN(Satorras等人,2021b)。这个E(n)-FFF在似然方面优于基于ODE的等变正规化流,并且生成稳定分子的速度明显快于扩散模型。

5.1 基于模拟的推断

生成模型的一个流行应用是解决逆问题。这里的目标是从观测中估计隐藏参数。由于逆问题通常是模糊的,由生成模型表示的概率分布是一个合适的解决方案。从贝叶斯角度来看,这个概率分布是给定观测的参数的后验。我们通过条件生成模型学习这个后验

特别是,我们关注基于模拟的推断(SBI,Radev等人(2022,2021);Bieringer等人(2021)),我们想要预测模拟的参数。训练数据是参数和从模拟中生成的输出对。我们在Lueckmann等人(2021)提出的基准上训练FFF模型,该基准由十个不同难度的逆问题组成,每个问题有三个不同的模拟预算(即训练集大小)。模型通过分类器2样本测试(C2ST)(Lopez-Paz和Oquab,2017;Friedman,2003)进行评估,其中训练一个分类器来区分训练生成模型的样本和真实的参数后验。然后,模型性能报告为分类器准确性,其中0.5表示与真实后验无法区分的分布。我们在十个不同的观测上平均这种准确性。在图3中,我们报告了我们模型的C2ST,并将其与基于神经样条流(Durkan等人,2019)和SBI的流匹配(Wildberger等人,2023)的基线进行了比较。我们的方法表现出竞争力,特别是在低模拟预算的范围内超过了现有方法。关于超参数调整,我们发现一个简单的全连接架构,带有跳跃连接,在数据集上有效,只需对更大的数据集进行微小修改以增加容量。我们确定了足够大的重建权重β,以便训练变得稳定。我们在附录C.1中提供了所有数据集和更多的训练细节。

5.2 分子生成

自由形态正规化流(FFF)对底层网络fθ和gϕ不做任何假设,除了它们保持维度。我们可以利用这种灵活性来处理那些应该在架构中内置明确约束的任务,与那些源于可处理优化需求的约束(如耦合块)相反。

作为一个展示,我们将FFF应用于分子生成。这里的任务是学习大量原子x1, ..., xN ∈ R^n的联合分布。每个生成模型的预测都应该产生每个原子的物理有效位置:x = (x1, ..., xN) ∈ R^(N×n)。

空间中的原子物理系统具有重要的对称性:如果分子在空间中被移动或旋转,其属性不会改变。这意味着一个分子的生成模型应该无论方向和转换如何,都产生相同的概率:

这里,旋转Q ∈ R^n×n通过绕原点旋转或反射每个原子xi ∈ R^n来作用于x,而t ∈ R^n对每个原子应用相同的平移。正式地说,(Q, t)是欧几里得群E(n)的实现。上述等式(8)意味着分布pϕ(x)在欧几里得群E(n)下是不变的

Köhler等人(2020);Toth等人(2020)表明,如果潜在分布p(z)在群G下是不变的,并且一个生成模型gϕ(z)对G是等变的,那么得到分布也对G不变。等变性意味着将任何群作用应用于输入(例如旋转和平移),然后应用gϕ,应该得到与首先应用gϕ然后应用群相同的结果。例如,对于

欧几里得群:

这意味着我们可以通过使正规化流NFs对欧几里得群等变来构造一个对欧几里得群不变的分布,如等式(9)所示。以前的工作已经证明,这种归纳偏差比数据增强更有效,数据增强是在训练时对每个数据点应用随机旋转和平移(Köhler等人,2020;Hoogeboom等人,2022)。

因此,我们选择一个E(n)等变网络作为我们的FFF中的网络fθ(x)和gϕ(z)。我们采用了Satorras等人(2021b)提出的E(n)-GNN。我们将这个模型称为E(n)-自由形态流(E(n)-FFF)。我们在附录C.2中提供了实现细节。

E(n)-GNN也是以前分子正规化流的基础。然而,据我们所知,所有这样的架构的实现都基于神经ODEs,其中流被参数化为微分方程dx/dt = fθ(x(t), t)。在训练期间,可以通过使用修正流或流匹配目标(Liu等人,2023;Lipman等人,2023;Albergo和Vanden-Eijnden,2023)来避免求解ODE。然而,它们仍然有一个缺点,即它们需要在采样时积分ODE。相比之下,我们的模型只需要调用一次fϕ(z)来进行采样。

Boltzmann Generator

我们在学习Boltzmann分布时测试了我们的E(n)-FFF:

q(x) ∝ e^(-βu(x)),

打开网易新闻 查看精彩图片

其中u(x) ∈ R是将原子位置x = (x1, ..., xN)作为输入的能量函数。一个近似q(x)的生成模型pϕ(x)可以用作Boltzmann生成器(Noé等人,2019)。Boltzmann生成器的想法是,如果访问u(x),则即使pϕ(x)与q(x)不同,也可以在训练后重新加权生成器的样本。这对于评估q(x)中的样本以进行下游任务是必要的:重新加权样本允许从生成模型pϕ(x)的样本计算期望值Ex∼q(x)[O(x)] = Ex∼pϕ(x)[ q(x)/pϕ(x)O(x)],如果pϕ(x)和q(x)具有相同的支持。

我们在基准任务DW4、LJ13和LJ55上评估了自由形态流(FFF)作为Boltzmann生成器的性能(Köhler等人,2020;Klein等人,2023)。这里,成对势能v(xi, xj)被累加作为总能量u(x):

u(x) = ∑i,j v(xi, xj).

DW4使用双井势vDW,并考虑2D中的四个粒子。LJ13和LJ55都采用13个和55个粒子在3D空间中的Lennard-Jones势vLJ(见附录C.2.3了解更多细节)。我们利用Klein等人(2023)提供的的数据集,该数据集通过MCMC获得了p(x)的样本。

在表1中,我们将我们的模型与(i)基于最大似然训练的等变ODE正规化流E(n)-NF(Satorras等人,2021a),以及(ii)通过最优运输(等变)流匹配训练的两个等变ODE(Klein等人,2023)进行了比较。我们发现我们的模型在负对数似然方面与竞争对手相当或更好。此外,E(n)FFFs的采样速度明显快于竞争对手,因为模型只需要评估一次学习到的网络,而积分ODE则需要多次评估。

打开网易新闻 查看精彩图片

QM9 Molecules

作为第二个分子生成基准测试,我们在生成新分子方面测试了E(3)-FFF的性能。因此,我们在QM9数据集(Ruddigkeit等人,2012;Ramakrishnan等人,2014)上进行训练,该数据集包含不同原子数量的分子,最大分子包含29个原子。生成模型的目标不仅是预测每个分子中原子的位置x = (x1, ..., xN) ∈ R^3,还要预测每个原子的属性hi(原子类型(分类),原子电荷(序数))。

我们再次采用E(3)-GNN(Satorras等人,2021b)。网络作用于坐标xi ∈ R^3的部分对旋转、反射和平移(欧几里得群E(3))是等变的。网络在这些操作下保持原子属性h不变。

我们在图1中展示了我们模型的样本。由于自由形态流只需要一次网络评估来采样,它们在固定时间窗口内生成的稳定分子数量比E(3)-扩散模型(Hoogeboom等人,2022)多两个数量级,

打开网易新闻 查看精彩图片
打开网易新闻 查看精彩图片

6

在这项工作中,我们提出了一种新的范式——自由形态流(FFF),用于规范化流,它使得训练任意保持维度的神经网络以最大似然度成为可能。通过重构损失实现可逆性,并通过高效的替代方法最大化似然度。以前,设计规范化流受到分析可逆性的需要的限制。自由形态流允许实践者专注于数据和适当的归纳偏差。

我们展示了自由形态流是最大似然训练的一个确切放松,只要重构损失最小化,就会收敛到相同的解。我们提供了一种对FFF训练的解释,即将数据的噪声版本和生成分布之间的KL散度的下界最小化。此外,如果fθ和gφ是真正的逆函数,这个界限是紧密的。

在实践中,自由形态流的表现与以前的规范化流相当或更好,通过仅需要单个函数评估来展示快速采样,并且易于调整。我们在附录B中提供了一个实用的指南,用于将它们适应于新问题。

OVERVIEW

The appendix is structured into three parts:

• Appendix A: A restatement and proof of all theoretical claims in the main text, along with some additional

results.

– Appendix A.1: The gradient of the log-determiant can be written as a trace.

– Appendix A.2: A derivation of the loss as a lower bound on a KL divergence.

– Appendix A.3: A bound on the difference between the true gradient of the log-determinant and the

estimator used in this work.

– Appendix A.4: Properties of the critical points of the loss.

– Appendix A.5: Exploration of behavior of the loss in the low β regime, where the solution may not be

globally invertible.

• Appendix B: Practical tips on how to train free-form flows and adapt them to new problems.

– Appendix B.1: Tips on how to set up and initialize the model.

– Appendix B.2: Code for computing the loss function.

– Appendix B.3: Details on how to estimate likelihoods.

– Appendix B.4: Tips on how to tune β.

• Appendix C: Details necessary to reproduce all experimental results in the main text.

– Appendix C.1: Simulation-based inference.

– Appendix C.2: Molecule generation.