SLOT STRUCTURED WORLD MODELS
插槽结构世界模型
https://arxiv.org/pdf/2402.03326
摘要
感知单个物体及其相互作用并进行推理的能力,是构建智能人工系统需要实现的目标。最先进的方法使用前馈编码器提取物体嵌入,并使用潜在图神经网络对这些物体嵌入之间的相互作用进行建模。然而,前馈编码器无法提取以物体为中心的表示,也无法分离外观相似的多个物体。为解决这些问题,我们引入了槽结构世界模型(SSWM),这是一类将以物体为中心的编码器(基于槽注意力)与基于潜在图的动力学模型相结合的世界模型。我们在具有简单物理相互作用规则的 Spriteworld 基准测试中评估了我们的方法,其中槽结构世界模型在一系列具有动作条件物体相互作用的(多步)预测任务上始终优于基线。重现论文实验的所有代码可从https://github.com/JonathanCollu/Slot-Structured-World-Models获取。
1 引言
区分视觉场景的各个组成部分并对其相互作用进行推理的能力是人类认知的关键方面(Spelke & Kinzler,2007)。它使我们能够对环境建立扎实的理解,因此也被认为是人工智能系统的重要要求(Battaglia 等人,2018)。理想情况下,我们需要这样的模型:接收原始图像,灵活表示场景中的物体,并能预测动作对单个物体及其相互作用的影响(Ha & Schmidhuber,2018;Moerland 等人,2023)。确实有几篇论文解决了这一挑战(Van Steenkiste 等人,2018;Kipf 等人,2019;Watters 等人,2019b)。一个特别成功的想法是提取场景中的物体,并使用图神经网络(GNN)(Wu 等人,2020)对物体之间的成对相互作用进行建模,该方法作为 “对比学习结构化世界模型”(C - SWM)(Kipf 等人,2019)取得了最先进的结果。
然而,正如 C - SWM 的作者所指出的,当前基于 GNN 的方法仍然存在挑战(Kipf 等人,2019)。该方法使用前馈编码器将场景嵌入为潜在 GNN 模型的固定嵌入集,这种方法有几个局限性。首先,固定的前馈编码器无法区分外观(近似)相同的多个物体:它们会在相同的特征图中被检测到(此问题的图示见图 1)。此外,可发现物体的数量在架构中是固定的,因此在推理时无法变化。
为克服上述限制,我们建议改为将基于潜在 GNN 的过渡模型与以物体为中心的编码器(Engelcke 等人,2019;Greff 等人,2019;Burgess 等人,2019;Locatello 等人,2020;Biza 等人,2023)结合使用。此类编码器学习返回一组嵌入,每个嵌入表示场景中单个物体的信息。一种成功的方法是槽注意力(SA)(Locatello 等人,2020),它利用物体特定槽之间的(迭代)竞争注意力机制,迫使单个物体进入不同的槽。这些方法重复应用相同的编码器来提取每个物体(共享信息),由于竞争注意力,可以分离相似物体,并且可以在推理时调整初始化槽的数量。因此,像槽注意力这样的以物体为中心的编码器提供了我们期望在基于 GNN 的动力学模型中用于下游任务的确切特征。
因此,本文提出了一种新型动力学模型,该模型嵌入了以对象为中心的编码器和基于 GNN 的世界模型。具体而言,我们使用槽注意力(Slot Attention)为潜在 GNN 动力学模型(受 C-SWM 启发)生成嵌入,称之为槽结构世界模型(SSWM)。SSWM 的高层架构如图 2 所示。需要注意的是,将以对象为中心的编码器与基于 GNN 的动力学模型相结合的思路具有普适性,人们可以轻松地将槽注意力替换为任何其他以对象为中心的嵌入方法(只要该方法能生成绑定单个对象的特征向量集合)。
为了测试我们的方法,我们扩展了著名的 Spriteworld 基准(Watters 等人,2019a)—— 该基准最初为 COBRA(Watters 等人,2019b)中的以对象为中心学习而设计 —— 加入了物理交互规则。在原始 Spriteworld 中,当一个对象推入另一个对象时会开始重叠,而在我们的新交互式 Spriteworld 环境中,第二个对象会被推开。这确保了对象之间真正发生交互,同时也需要更丰富的表示(因为对象的精确形状开始变得重要,而在之前的设置中仅需位置和速度即可)。对一系列交互式 Spriteworld 任务的实验评估表明,SSWM 在 1 步、5 步和 10 步预测准确性上持续优于最先进的基线模型 C-SWM。定量评估显示,尽管在较长预测范围内细微偏差会逐渐累积,但 SSWM 确实能做出准确预测。
总之,本文提出的 SSWM 是首个具备以下特性的学习动力学模型:
能够从原始像素输入中分离单个对象,并对其(基于动作条件的)交互进行推理
能够区分外观相似的多个对象
在(多步)预测任务中,数值表现优于最先进的以对象为中心的动力学模型 C-SWM
2 背景
Kipf 等人在 2019 年提出的 C-SWM 架构由三个不同的子模块组成:对象提取器、对象编码器和关系转换模型。
3 相关工作 以对象为中心的表示学习
该领域的若干研究提出了在对象发现等任务中成功的无监督方法。这些方法的一些示例包括 Engelcke 等人(2019)、Greff 等人(2019)、Burgess 等人(2019)、Locatello 等人(2020),它们学习将原始图像表示为一组与单个对象绑定的潜在向量。与上述其他使用多个编码 - 解码步骤的工作不同,槽位注意力模型(Slot Attention,Locatello 等人,2020)借助一个简单而有效的迭代注意力模块,仅需执行一次编码步骤。这一特性使得槽位注意力模型在计算效率上远高于其前身。由于这一原因及其他因素(如数据和训练效率),我们在本工作中选择了该模型。出于同样的原因,这一方法已被更多近期工作进一步扩展。特别是,如 Chang 等人(2022)、Jia 等人(2022)提出了扩展方法以改进槽位优化过程,而最新的不变槽位注意力模型(Invariant Slot Attention,Biza 等人,2023)则引入了一种方法,使槽位注意力的表示对位置、比例和旋转等特征具有不变性。
基于对象的环境模型
鉴于许多环境具有结构化性质,其中多个实体(智能体和对象)相互作用,学习此类环境的鲁棒且准确的模型需要特定的架构偏差。基于这一前提,许多先前的方法(如 Sukhbaatar 等人,2016;Chang 等人,2016;Battaglia 等人,2016;Watters 等人,2017;Hoshen,2017;Wang 等人,2018;Van Steenkiste 等人,2018;Kipf 等人,2018;Sanchez-Gonzalez 等人,2018;Xu 等人,2019;Kipf 等人,2019)采用了基于图神经网络的解决方案来构建结构化世界模型。在这些方法中,C-SWM 因其能够同时学习状态表示和转移模型且不依赖任何基于像素的损失而脱颖而出。然而,该方法的编码器无法分离具有相同外观的对象的信息。该领域的另一相关方法是 COBRA(Watters 等人,2019b),其与我们的方法类似,采用预训练的以对象为中心的模型(MONet,Burgess 等人,2019)来获得结构化表示。与 C-SWM 和本工作不同,COBRA 不建模对象之间的关系,因为其转移模型未实现为图神经网络。
4 环境
Spriteworld 环境(Watters 等人,2019a)由 COBRA(Watters 等人,2019b)引入,是一个视觉基准任务,其中不同形状和颜色的对象可以四处移动。然而,当其中两个对象被推向彼此时,它们会开始重叠(而非相互推挤)。这种物理对象交互的缺失使得 Spriteworld 的动态仅能通过对象位置进行编码(Watters 等人,2019b)。为了在需要更丰富对象表示的更复杂环境中挑战我们的模型,我们因此用简单物理规则扩展了 Spriteworld,形成交互式 Spriteworld(可在https://github.com/JonathanCollu/Interactive-Spriteworld获取),其中具身智能体可以移动对象(图 3)。
该环境包含一个黑色背景的正方形竞技场和五个对象。在每个 episode 开始时,每个对象的形状从三种形状(圆形、正方形和三角形)的离散均匀分布中采样,位置从覆盖整个空间的连续分布中采样,(纯色)颜色从沿三个通道(HSV)的连续分布中采样。智能体由一个较小尺寸的白色精灵表示。该智能体可以采取四种可能的动作之一(向下、向上、向左或向右移动),并通过与其他对象碰撞来移动它们。
图 3 的第一行展示了交互式 Spriteworld 的一些示例观测。我们可以清楚地观察到具有相似或相同外观的对象场景,我们预计 C-SWM 编码器会在此类场景中遇到困难(见图 1)。第二行说明了实现的简单物理规则:如果移动的精灵 / 对象 A 撞击静止对象 B,B 会获得 A 的运动,依此类推直到链条停止。为了对这些动态进行建模,学习到的转移模型需要包含每个对象的位置、形状和大小信息的表示,以确定两个对象接触的确切时刻。
5方法论
SA 设计
由于转换模型的损失是通过成对比较预测的槽向量与下一个状态编码得到的槽向量来计算的,我们更倾向于保持它们的顺序不变。这可以避免基于相似性度量的计算成本高昂的排序。因此,我们学习固定的槽初始化,而不是从(学习到的)正态分布中采样它们。然而,需要注意的是,这种选择在我们的实验中固定了推理时的槽数量。
6 实验 6.1 指标
我们采用与 Kipf 等人(2019)相同的指标,即:k 级命中数(H@k)和平均倒数排名(MRR)。这些指标允许在潜在空间中直接评估,而无需训练单独的解码器。
- k 级命中数
:我们首先根据每个预测对象向量与整个数据集中真实编码对象向量的距离对其进行排序(类似于 Kipf 等人(2019)的方法)。当推断向量的排名不超过 k 时,每个对象向量的 H@k 得分为 1,否则为 0。我们报告整个数据集中 “1 级命中” 的百分比。
- 平均倒数排名
:MRR 是上述排名的汇总,定义为评估数据集中所有 n 个样本的倒数排名的平均值:MRR = 1/N × ∑(n=1 到 N) 1/rankn,其中 rankn 是第 n 个样本的排名。
请注意,潜在空间评估指标确实存在一些边缘情况,我们将在附录 A.2 中进一步报告。
所有结果均取自四次独立重复实验的平均值,超参数见附录表 4。我们使用带有默认前馈编码器的 C-SWM 作为基线,并增加了算法 1 中描述的迭代 GNN 模块。如原始论文所述,我们将每个对象的特征图数量从 1 个增加到 4 个,以建模位置信息和其他相关信息。训练数据集包含 6×10⁴个样本,其中三分之一通过环境中随机策略收集,其余部分通过人类演示收集,以获取随机动作选择无法产生的多样化复杂交互序列。测试集分为三个按难度升序排列的不相交子集,每个子集包含 10³ 个未见样本,分为 10 个 episode,每个 episode 有 100 个时间步。第一组包含 10 条 “无碰撞” 轨迹,其中智能体是唯一移动的精灵;第二组包含多个步骤序列(每个 episode),其中智能体一次携带一个精灵;最后一组包含复杂轨迹,其中智能体可以一次或链式携带多个精灵。这种划分旨在突出两类模型的优势和局限性,并帮助在不同设置下识别和解释它们的行为。
6.3 定量分析
表 1 显示了 SSWM 和 C-SWM 在交互式 Spriteworld 中三种预测任务的定量性能指标,分为 1 步、5 步和 10 步预测范围。首先,我们观察到在所有研究的设置中,SSWM 的性能均优于 C-SWM。第一列显示 1 步预测准确率,在交互式 Spriteworld 的未见实例上达到接近 100% 的分数。有趣的是,在动作仅影响智能体的场景(测试 1)中的准确率几乎与智能体移动一个对象(测试 2)或多个对象(测试 3)的场景中的准确率相当。这表明 SSWM 能够准确地对对象交互进行泛化。
当我们将预测范围增加到 5 步(第二列)和 10 步(第三列)时,我们看到 SSWM 和 C-SWM 的预测准确率都开始下降。这种现象可能是由于误差累积导致的,这是展开多步预测模型时的一个众所周知的现象(Talvitie,2014)。我们还看到 SSWM 开始更明显地优于 C-SWM,尤其是在最长的预测范围内。这表明 SSWM 在对象交互方面具有更好的泛化性能。有趣的是,C-SWM 确实设法在训练数据上最小化了损失函数(底行),但这样做时没有编码对象形状等不可或缺的信息(我们将在下一节讨论这个主题)。这并不奇怪,因为 C-SWM 的对比目标并没有明确地引导编码器朝这个方向发展,并且网络可以找到许多唯一识别对象的解决方案。相比之下,SSWM 似乎更好地编码了所有对象属性,因此能够实现更准确的长距离预测。
6.4 定性分析
我们还希望对模型的预测进行定性评估,理想情况下在原始像素空间中(以提高可解释性)。因此,我们使用槽位注意力预训练阶段获得的解码器来展示 SSWM 所做的预测。对于这些可视化,我们使用四次重复实验中获得的最佳模型。请注意,动态模型从未明确针对这些像素空间重建进行训练,而仅在潜在空间中接受损失。当然,对于 C-SWM,我们没有此解码器,而是展示对象提取器生成的掩码以评估嵌入的质量。
SSWM 预测
图 4 展示了 SSWM 在三种测试场景中的预测示例:仅智能体移动(左上 3×4 块)、智能体移动单个对象(右上 3×4 块)或智能体移动多个对象(底部 3×4 块)。在每个块中,三行分别展示了 1 步(顶行)、5 步(中间行)和 10 步(底行)预测的示例。每个设置旁边的四张图像依次显示起始状态、预测的下一状态、真实的下一状态以及两者之间的误差。
当仅智能体移动时(图的左上块),我们看到模型预测是准确的,预测的下一状态与观察到的下一状态之间差异很小。与前一段的定量结果一致,随着步数的增加(由于误差累积),预测的准确性确实会降低。右上块显示了智能体在移动中携带一个精灵的情况。同样,我们看到模型如何准确捕捉智能体及其接触对象的预期移动,但偏差会随着时间累积。对于 10 步预测,模型略微误判了智能体与对象(绿色三角形)之间的接触点,因此在预测中该对象最终位置过高。
图的底部块展示了一个场景,其中智能体移动多个对象,推动蓝色三角形,而蓝色三角形本身又推动紫色三角形。我们可以清楚地看到,SSWM 可以对这种多对象交互进行建模,因为紫色三角形确实被预测为以与所选动作一致的方式移动。我们再次看到一些轻微的预测误差,这与之前的观察一致,并且会随着时间累积。与定量结果一致,当多个对象交互时(主要是因为更多对象可能被误预测),误差累积得更快。
C-SWM 表示
我们假设基线(C-SWM)表现出的泛化能力不足是由于所学嵌入的质量较差。为了验证这一假设,图 7 显示了 C-SWM 编码器输出的每个槽位获得的特征图。请注意,每个槽位由四个特征图组成,我们将它们上下绘制。显然,交互式 Spriteworld 任务需要编码来表示对象的确切形状(以确定接触点)。然而,从可视化中我们可以清楚地看到,C-SWM 的表示既不表示对象的确切形状,甚至也不单独表示它们。这表明 C-SWM 的 CNN 对象提取器没有学习到隔离对象并包含此任务所需特征的过滤器。这一观察结果与表 1 中报告的定量结果一起,证实了 C-SWM 学习了一种确实优化了公式 3 中的目标但未编码所有相关信息的解决方案这一假设。相比之下,SSWM 的槽位注意力编码器所学的特征确实隔离了对象并捕捉了它们的确切形状,如图 1 的底行所示。这很可能解释了表 1 中 SSWM 更好的定量性能。
7 结论
本文介绍了槽位结构化世界模型(Slot Structured World Models,SSWM),这是一个简单且灵活的框架,它将以对象为中心的编码器与基于图神经网络(GNN)的潜在转移模型相结合。通过学习更具信息量的表示,完整的 SSWM 架构优于最先进的以对象为中心的动态模型 C-SWM,同时还能够在场景中区分具有相似外观的多个对象。定性分析表明,C-SWM 过度拟合训练数据,未学习有意义的潜在表示,而 SSWM 通过槽位注意力学习了更好的表示,显著提升了预测性能。
未来工作有几个方向。首先,我们目前学习固定数量的槽位初始化,因为这使我们能够直接构建成对的潜在损失。然而,像槽位注意力这样的以对象为中心的编码器通过训练槽位初始化分布,自然允许潜在槽位数量(可发现对象)的变化。当然,我们随后需要基于相似性度量构建成对的潜在对象损失,这可能会使优化更加不稳定。另一个解决方案是端到端地训练整个架构。这也可能改善定性的像素空间预测:我们目前完全依赖于 SA 自动编码学习的表示,但帧间变化的动态损失可能会更强调智能体和对象的行为。第三个方向是提高多步预测性能,例如,我们可以在多步目标上进行训练(Abbeel & Ng, 2004)。最后,在下游决策任务(即基于模型的强化学习)中测试这些模型也将是有趣的。
附录 A.
每次迭代的最终更新通过门控循环单元(GRU)处理,随后连接带残差连接的多层感知机(MLP)获得。
为了执行对象发现,槽位会被解码为图像和掩码,然后组合并求和以生成单张图像。因此,训练目标是最小化输入图像与重建图像之间的均方误差。
A.2 额外实验
本节展示了一些额外实验,这些实验最初旨在解决多步预测问题。因此,我们首先将 SSWM 在长时预测中的困难归因于槽位注意力(Slot Attention)的嵌入。我们认为,若表示中所有特征相互纠缠(如槽位注意力的特征),可能会限制潜在转移模型的准确性(和泛化能力)。尽管解纠缠(disentanglement)显然是理想特性,但后续实验表明,多步预测问题的根源并非在此。此外,这些实验还揭示了我们实验中涉及的指标可能导致误判的一些极限情况。
A.2.1 SSWM
本小节的结果通过将槽位注意力替换为 DISA(?)获得,该模型可确保位置、比例、形状和纹理等特征的解纠缠。
表 2 中的结果可能给人一种错觉,即 DISA 提供的解纠缠表示足以解决多步预测问题。但遗憾的是,图 A.2.1 的可视化示例显示,其与定量指标存在显著差异。需要说明的是,使用该编码器时我们也观察到一些定性表现良好的例子,但我们更倾向于展示槽位注意力的实验,因为其定量结果与像素空间的观测更一致。
如第 6.3 节简要提到的,用于评估模型在潜在空间中性能的指标可能远不能代表转移模型的真实质量。这可以解释为:转移模型的训练目标是最小化其预测与潜在空间中真实下一状态的距离,若模型在此空间中泛化良好,评估时很可能获得最优的 H@1 和 MRR 分数。然而,即便预测与目标的距离在相关指标中可忽略不计(在紧凑的潜在空间中,不同状态间的距离数量级较小),微小扰动也可能导致像素空间中的严重误预测,如图 A.2.2 所示。
A.2.2 C-SWM
我们使用表 4 中的相同配置对基线模型进行了复现,每个对象仅设置 2 个特征图,结果呈现出类似现象。
原文链接: https://arxiv.org/pdf/2402.03326
热门跟贴