一起來感受下贝叶斯流网络玩转离散数据的原理(四)

技术

极市导读

本文主要解析了BFN对离散数据建模的原理。首先介绍了离散数据的表示方式,然后讨论了输入分布、发送者分布、输出分布和接收者分布的建模方法。接着详细解释了对离散数据加噪的方法,并介绍了疯狂采样的概念。最后,讨论了贝叶斯更新函数的意义和计算方法。

加入极市CV技术交流群,走在计算机视觉的最前沿

BTNs是怎么玩转生成即压缩的?详解结合贝叶斯统计和深度学习的生成模型 — Bayesian Flow Networks(一)

结合贝叶斯推断的去噪生成模型?详解BFN在连续型数据场景下的实现— Bayesian Flow Networks(二)

详解贝叶斯流网络在离散化数据场景下的实现—Bayesian Flow Networks(三)

前言

本文主要解析 BFN 对离散数据建模的原理。章节的设置与本系列前几篇文章类似,首先说明下离散数据的表示方式 、然后对 BFN 数学框架 进行详细的解析并给出对应的伪代码 ,接着对实验结果 进行分析,最后是发散性话题 :谈谈 CW 对于 BFN 在一些方面的理解。

这篇文章是本系列最难啃的一篇,涉及的数学推理比较多,其中许多部分,作者在 paper 中仅给出了简略的过程,秉持着不无聊的风格,CW 在 paper 内容的基础上给出了更为详细的推导与分析。另外,对于作者直接引用而没有推导过程的一些公式,CW 也力所能及地给出了对应的证明过程(算是很贴心了叭~),为了不打扰到各位客官观赏正文的兴致,这些证明过程被置于附录部分,感兴趣的不妨去瞄瞄~

ps: 本系列的几篇文章多少存在一定程度上的联系,特别是离不开首篇文章所描述的数据压缩游戏的背景与基本数学框架,建议各位观众先对首篇文章有一定程度的理解后再来看后面几篇。

离散数据的表示方式

离散数据可谓是大家的老朋友了,无论是在 paper 中还是实际生活中,它都很常見,比如:文本字符、物体类别 等。在这篇 paper 中,作者将离散数据表示为类别 ,它是 维向量每一维的值代表某个类别(从1开始)

假设将数据划分为 类, 且 表示从 1 到 的整数, 那么数据就表示为:

输入分布:数据分布的“先天”信念

输入分布可谓是一种“偏见”,因为毫无道理地就认为数据会属于某种特定的分布,所以 CW 才说它是数据分布的先天信念。它虽然一开始毫无道理,但后来是会通过观测样本而得到关于真实数据的经验性信息来进行自我纠正的(所以,别方哈)。

既然将数据视作类别,那么自然地就会想到分类分布(Categorical Distribution),而数据在每一维都代表一个类别,对应于一个分类分布,于是整个输入分布就建模为多个分类分布的联合分布

记输入分布的参数 , 其中每一维 是一个 维向量, 维数与类别数相等, 其中每个维度表示属于某个类别的概率。 表示 维的 概率单纯形(probability simplex), 之所以是 维, 是因为 个类别的概率之和为 1 , 知道了其中的 个概率就能顺势推出剩下的 1 个。

参数 的取值范围是 , 表示概率, 是 维向量, 其中每一维的参数 都表示类别数为 的分类分布, 且相互独立。 于是, 输入分布的概率质量函数则表示为 在数据真实类别上的联合概率:

先验设置

如今我们处在一个讲究人性化的社会,先验当然就选择设为“公平分配”的均匀分布 啦:

其中 是 维向量, 每一维上的值都等于 , 每个类别都拥有相等的概率, 十分公平。

其实,作者之所以选择均匀分布这种简单的先验,是因为这在数学上会更简单,并且他认为网络能够很容易学到数据的真实先验(毕竟是他自己出版的故事,得有信念)。

对网络输入进行缩放

另外,由于输入分布的参数是概率值,如果直接输入给 BFN,那么它永远只能接收到非负值,因此在将其喂给 BFN 前,会先乘以2再减去1,从而缩放至

输出分布:新数据的缔造者

输出分布是生成新数据的直接依据,可谓是新数据的缔造者,下面来看看对于离散数据它是如何建模的。

将输入分布的参数 和时间变量 喂给 BFN, 它会输出:, 其中每个 都是 维向量, 与 个类别相对应。

注意,每一维对应的输出都由输入分布参数在所有维度 (而非仅仅是单个对应维度)上的变量( 而非)经 BFN “处理”后所获得。因此,输出分布会拥有关联上下文信息的能力 ,这点 CW 其实在本系列的首篇文章中已经讲过。

由于作者将数据表示为类别, 因此这里就要将网络的输出转化为类别所对应的概率, 也就是要将它搞到 范围内。条件反射, 很容易使我们想到使用 softmax, 于是:

picture.image

也就是说, 将网络在每一维上输出的 维向量都经过 softmax, 使其对应于类别数为 的分类分布。

面对已知数据 , 就将经过 softmax 处理后的输出在每一维上取出数据在对应维度上的类别所对应的概率, 最后连乘起来所得到的联合概率分布, 就是输出分布。

特别地,对于只有两个类别的数据,那么网络在每个维度上就只输出一个值(而非2维向量),并且以 sigmoid 替代 softmax 。于是, 在由 。于是,在由 计算出其中一个类别的概率后,再根据 1 减去这个概率就能得到剩下那个类别所对应的概率,即:

picture.image

另外,作者还进一步谈到:虽然在面对类别数 的情况,也可以由其中的 类概率推断出剩下的那一类所对应的概率, 但这种做法潜在地要求网络做更复杂的推理(有点像使用排除法), 可能会导致网络学习(收玫)得更慢和更困难, 所以作者还是采取了 分类的做法。

对于这点, CW 想了下, 若网络在 维中的每个维度上仅输出 维向量, 那么一种实现 分类的方式可能是需要将这 个输出加起来先经过 sigmoid 做二分类(看是否属于剩下的那一类), 然后再将这 个值喂给 softmax 做多分类, 也是比较麻烦~

发送者分布:数据的信使

发送者分布是用于构造观测样本的,而观测样本能够将真实数据的信息传递出去(从而才能是先验进行后验更新),因此 CW 将发送者分布视作数据的信使。

如何对离散数据加噪

经过本系列前几篇文章的“洗礼”,大家已经知道发送者分布是要在原始数据上加噪,而这里的数据是离散型的类别,那么如何对这种表示为类别的数据进行加噪呢?

由上可知,输入、输出分布都建模为分类分布,借此灵感,在这里我们可以将原始数据分布脑补为一个非常 "hard" 的分布——这个分布在原始数据所对应的类别上,概率为1;而在其余类别上,概率则为0。至于加噪的手段,就是将此概率分布变 "soft"——从真实类别上“挖”走一部分概率值平摊到其余类别上

picture.image

其中 可理解为一个平滑因子, 其越小则分布变得越平滑, 以至于为 0 时则变为均匀分布。 表示Kronecker delta function(克罗内克德尔塔函数), 当 时其值为1, 否则为 0 ; 也就是只有在数据所对应的那个类别上为 1 , 其余情况均为 0 。 则表示 个类别中的某一个。

容易看出, 式对于任意类别都是大于 0 的, 且 , 于是此处又可定义一个分类分布(以 为条件):

至此,我们得到了一个“含有噪声”(平滑)的分类分布。那么,噪声样本(观测样本)该如何构造呢?

如果直接从以上分布中采样,那么很可能会抽到其它不是原数据所对应的类别,特别是当 的时候(趋向于均匀分布)。并且,类别是以整数来表示的,这些整数并不能反映类别之间的关系。于是,当采样到非数据所对应的类别时,观测样本就不能反映出原始数据的信息,这就不符合游戏规则了;另外,就算采样到了数据的真实类别,但这时观测样本却变得与原始数据一模一样,这就不是噪声样本了。

所以,接下来要想办法构造出一种观测样本,使其既能与原始数据有关联(最好能通过某些参数控制相关程度),同时又不能与后者一模一样。

疯狂采样

既然从以上平滑的分类分布中采样一次不可行,那么不妨试试采样多次会出现什么情况。

作者引入了这样一个场景: 对以上平滑的分类分布进行 次独立的采样 , 那么每个类别都会有被采样到的机率, 记 为第 类总共被采样到的次数, 于是采样结果就表示为一个 维向量 (paper 中是 觉得不一定每个类别都会被采样到, 你们觉得呢)。

为了便于书写(偷懒),接下来暂时省略维度上标:(d) (我跟作者学的),也就是接下来的推导都是限定在单个维度上进行讨论的。

相信统计学玩得6的朋友们一眼就看出,随机变量 是服从 多项分布(Multinomial Distribution) 的:

picture.image

以上第2步是代入了多项分布的概率质量函数(具体推导过程见附录1), paper 中少了 次幂(见 paper 的式 (134), 作者你能不能走点心..), 最后1步是将式 代入。在 paper 中, 的下标为 , 但其实应该是 (CW 不禁怀疑作者是不是喝多了)。

根据 大数定律(Law of Large Numbers), 有:

也就是说, 当试验次数足够大时, 频率会趋近于真实概率。于是, 当试验次数足够大时, 根据采样结果 就能推断出原始数据 , 因为原始数据类别所对应的次数一定是最多的,这源于原数据类别所对应的概率最大(根据 式, 可知 时最大)。同时, 当 时, 由于各类别被抽到的概率都差不多, 因此也需要更多的试验次数来甄别出数据真实的类别, 甚至于 , 可谓是“疯狂采样”(踩踩踩)

于是, 作者将 和 结合, 定义出一个有限值精度: , 当 占优势时, , 此时采样结果会变得几乎无法反映原始数据信息(因为各类别的概率相差无几, 所以在试验后被采样到的次数都差不多); 反之, 当 占优势时, , 试验次数已经大到足以甄别出数据的真实类别与其它类别, 尽管它们之间的差异很小(因为有大数定律的保障), 此时采样结果比较能反映出原始数据的信息。

至于为何 是平方, 作者没有解释, 看完 paper 后, 感觉这只是为了后续推导出一个好看且易用的形式(看完本章你也能 feel 到)。

进一步, 设离散概率 , where 0<p_k<1,=""k=""{1,=""k}0<p\_k<1, \forall="" k="" \in\{1,="" k\}="" 。若="" c=""=""multi(m,=""p)c="" \sim="" \operatorname{multi}(m,="" p),="" 当="" m=""=""m="" \rightarrow="" \infty="" 时,="" 根据="" 拉普拉斯中心极限定理,="" 有:<="" p="">

其中 是 的单位阵。注意, 这里对应 paper 中的式 (137), 但在 paper 中, 分母漏了 , 不过影响不大, 只是协方差相差一个系数而已。 为了迎合作者的故事, 接下来就沿用 paper 的式 (137), 即:

picture.image

观测样本的诞生与采样

话说, 以上疯狂采样的场景到底有什么用呢? 虽然采样结果能反映出原始数据信息(能从采样结果推断出原始数据类别), 但通过上述分析, 要做到稳定的话, 采样次数 几乎要趋于 , 所以直接以采样结果作为观测样本是不切实际的。然而, 采样结果又确确实实能够反映出原始数据的信息, 因此还是得将以上场景利用起来, 接下来 CW 就带大家一起来看看作者是怎么做的。

首先定义一个变量 , 然后将类别 对应的观测样本定义为:

现在来分析下以上式子的实际含义: 相当于将 次采样次数平摊到各类别上, 也就是平均每个类别会被采样到的次数, 于是 就衡量了当前类别被采样到的次数与平均值相比如何。 同时, 由于 是非负的, 因此 就与 呈正相关。并且, 当 增加(即 占优势)时, 采样结果会更加集中于数据真实的类别, 也就是 会(远)大于 。由此可知, 当 是正值且变得越大时, 它与原始数据的关系就越紧密、越能反映出原始数据的信息。

嗯,这么看来,这个观测样本还是能通过“及格线”的 —— 能根据它的值反映出其与原始数据的相关程度。但是,关键问题还没有解决:它该如何计算?要是真的从以上定义式去计算的话,那么又变得不切实际了,因为还是得进行几乎无数次的疯狂采样。

这是,你们是否不禁 yy:要是能够知道观测样本 所服从的分布,并且这个分布的概率密度函数是有解析形式的,使得我们能够直接从中采样,那么就万事大吉了!

做人就是要敢想才有前(钱)途, 并且, 冷静下来, 你会发现这个想法并非天(不)马(切)行(实)空 (际), 因为前面我们已经推导出 是服从正态分布的, 而 与 存在解析形式的数学关系, 而且是线性关系,所以前者应该也是服从正态分布的(根据正态分布的性质)。 不过,光吹水没用, talk is cheap, CW show you the math!

根据 的定义式, 可得:

接着,利用 随机变量的变量变换定理(Change of Variables):

picture.image

(关于一维随机变量的分布变换的具体推导过程请见附2)

接下来, 我们要利用一个对数恒等式(贼香 ) : (具体推导过程详见附3), 并将其套在前面的 上, 于是:

picture.image

以上第一步利用了对数恒等式并且仅展开到第一阶,并且 ,是满足约束的。

同时,当 时,有:

picture.image

将式 (iii), (iv) 代入式 (ii), 得:

picture.image

OMG!! 我们真的推导出观测样本所服从的分布了(yy 成真)!但是请注意,这是有约束条件的,即:。

如今, 我们将前面丢掉的维度上标 拿回来。同时, 由于以上的 仅是某个类别所对应的观测样本, 而总共有 个类别, 因此每个数据维度上都有 种采样结果(各维度独立进行疯狂采样)。对应地, 每个数据维度上的观测样本都是 维向量: , 于是:

picture.image

其中, 1 是值全为 1 的 维向量, 是 的单位阵, 而 是原始数据的某个维度映射到 个类别的 one-hot 向量, 其中的第 个分量为:

picture.image

最终,发送者分布就定义为:

picture.image

其中, , 其中 维中的每个分量都是 维的 one-hot 向量; 1 是值全为1的 维向量; 是 的单位阵。

多说几句: 为何要定义 这个变量以及 的形式为何是那样的, 对此作者并没有解释, 从前面的推导过程来看, 这两者的定义貌似仅仅是为了能够利用起对数恒等式, 从而最终顺利推导出一个解析形式。

接收者分布:对于消息的“暴力猜测”

消息接收者 Bob(别忘了这篇 paper 是以数据压缩游戏作为背景来论述的) 由于不知道原始数据长什么样子,因此只能暴力猜测——将所有类别都以输出分布作为概率加权起来,这就是接收者分布。有点像:我虽然不知道准确值,但我考虑了所有可能性,于是“命中率”理应高一些。

在 维数据的每个维度上,接收者分布是发送者分布在输出分布上的期望,而所有单个维度的联合分布就是接受者分布,即:

picture.image

贝叶斯更新函数:贝叶斯定理的“表演者”

贝叶斯更新函数就是根据贝叶斯定理计算后验概率,下面就一起来观看它的“演出”~

先验、似然、后验 傻傻分不清楚?

要使用贝叶斯定理计算后验概率, 关键是要先理清 先验、似然 以及 后验, 也就是搞清楚它们三个到底是谁。为了便于分析, 我们先在单个数据维度上进行讨论, 并且暂时忽略维度上标 。

假设我们现在要得到第 步的参数(输入分布的参数) , 它由第 步的参数 经过后验更新而来。对于数据 , 其对应的先验是 , 这是输入分布的定义; 至于观测结果, 则是前面引入的疯狂采样场景中的采样结果 。于是, 似然则是 ; 从而, 后验则为:

picture.image

非常丝(顺)滑(利)!经过一番推导,现在可以用观测样本来计算后验了。注意,以上的 代表自然指数。

贝叶斯更新函数的意义

以上是 (它是 维向量, 注意前面我们忽略了数据维度 ) 的第 个分量所对应的后验更新。于是, 对于第 个分量, 贝叶斯更新函数就定义为:

picture.image

现在,我们考虑回数据维度,即: 。于是,贝叶斯更新函数就是:

picture.image

注意, 以上 。也就是说, 这个贝叶斯更新函数相当于让 在每个数据维度 里, 对每个类别 都做归一化(在类别这个维度里) , 归一化因子是对应数据维度里, 所有 之和(是不是有点 softmax 的味道 ) 。

有了这个贝叶斯更新函数, 进行后验更新可就方便多了。由于我们是在 (从而精度 为有限值)的约束下玩的, 因此原本需要在多项分布 上进行几乎无数次采样得到结果 , 进而才能利用这个观测结果进行后验更新; 然而,现在却可以直接从发送者分布中采样出观测样本, 然后按照 式来计算就可以实现后验更新了 ,真香~

贝叶斯更新分布:考虑得更为周到的贝叶斯更新函数

贝叶斯更新分布考虑到了观测样本的所有可能性 ,它是贝叶斯更新函数在发送者分布上的期望,即:

picture.image

精度可加性:助力后验更新实现“跳步”更新

精度可加性是后验更新所拥有的性质。简单来说, 就是: 若 由 经过后验更新而来,而 又由 经过后验更新而来, 那么 可直接通过 和 经过后验更新而得到。

现在,我们从贝叶斯更新函数入手来证明精度可加性。

假设 ;, 其中 分别表示贝叶斯更新函数和观测样本。于是:

picture.image

对比式 (v), 我们知道以上就是贝叶斯更新函数的形式, 也就是说, 若想要直接由 经过贝叶斯更新函数计算出 , 那么所用到的观测样本就是 。但是, 其中的精度参数还不可知 (所以 暂时用 ? 来表示)。根据贝叶斯更新函数的定义, 可以知道其中的精度参数就是观测样本(在发送者分布中)的精度, 于是, 这里要求的 就是观测样本 所对应的精度。

由于:

picture.image

因此 相当于是从精度为 的发送者分布中采样出来的观测样本。如今,谜底揭开:

picture.image

以上是仅考虑到单个观测样本的贝叶斯更新函数,其所对应的贝叶斯更新分布(考虑了观测样本的所有可能性)则是:

picture.image

而这其实考虑到了 的所有可能性,也就是对其边缘化,所以,以上结果实质上是个边缘分布,即:

picture.image

accuracy schedule 的设置:炼丹的真实写照

在本系列前面的文章中, 讲过 accuracy schedule 即 的表达式是遵循“输入分布的期望熵随时间线性减小"这一原则而推导出来的。然而, 在本文的离散数据场景下, 期望熵 并没有解析形式, 于是作者就靠直觉脑洞出一个他心里认为是合理的形式:

而 就是最终时刻 的 accuracy schedule, 作者说这个值需要根据不同的场景凭经验去设置(Em.. 是炼丹的真实写照)。进一步, 我们就可以得到精度的表达式为:

觉得, 作者之所以将 设置为时间变量 的二次函数, 是因为这样求导后所得到的精度 就成为时间变量的一次函数,从而精度会随时间线性递增,于是在一定程度上可认为数据信息也是随时间线性地流入输入分布中(通过 的后验更新),进而对应的信息摘则线性减少。

贝叶斯流分布:通过 softmax 实现后验更新

贝叶斯流分布是直接由最初的先验 经过贝叶斯更新分布就可往后实现任意步的后验更新,即:

picture.image

通过前面章节的内容,我们知道这是精度可加性为其提供了可能性 ,于是从最初时刻到当前时刻的所有精度累加起来就得到了 。将式 (vi) 代入,得到:

picture.image

WOW!!! 这也太惊喜了叭 以上这个结果告诉我们, 在每一步中, 只要从发送者分布 中采样出观测样本 , 然后将其输入 softmax 函数, 就可以实现 的后验更新了!

以上结果同时也向我们反映出后验更新的“靠谱性”:发送者分布充当了 softmax 的 logits 源头,于是当精度提高时, 由于观测样本会更“接近”于原始数据, 因此 logits 的分布会更集中于原始数据, 从而 会变得趋向于 (原始数据类别映射为 one-hot 形式的向量), 进而 BFN 通过接收 也会获得越来越多关于原始数据的信息, 最终它预测的分布(即输出分布)也越来越靠谱。(简直不要太香 )

贝叶斯流的收敛轨迹

picture.image

以上是在二分类数据上测试的结果,显示了输入分布参数根据贝叶斯流分布进行后验更新的轨迹,起始先验设为0.5,也就是对两个类别五五分;最终收敛至真实的类别,从而概率值变为1。

可以看到,轨迹很长且波动很频繁,也就是 paper 中说的很 noisy,说明收敛得慢。这是因为数据是离散的等距点,而输入分布参数是连续的概率值,不容易集中至某个具体的值上(比如形成如 "0.908... vs 0.092..." 会比形成如 "1 vs 0" 这样的概率分布来得容易),因此无法轻易地快速“靠近”数据对应的类别,这算是离散值与连续的概率流之间的 gap

重构损失

老套路, 重构损失就是利用输出分布充当似然函数, 然后在最终时刻 计算负对数似然。由于考虑到 各种取值的可能性, 因此还要在贝叶斯流分布上计算期望:

与 BFN 在建模连续数据和离散化数据时一样,这个重构损失也是不参与训练的

离散时间的损失函数

在本系列的首篇文章中,CW 已经向大家推导出(第 iii 步的)离散时间的损失函数是:

picture.image

于是,代入我们在前面章节中推导出来的发送者分布和接收者分布的概率密度公式,就有:

picture.image

其中, 多了一项期望 是把 KL 散度中发送者分布的期望给并进去了。另外, 是 BFN (对应在单个数据维度)的输出。同时, 因为在离散时间下 。所以, 代入 的公式, 可得:

连续时间的损失函数

在连续时间的情况下,发送者分布和接收者分布之间的 KL 散度拥有以下通用形式:

picture.image

(关于的含义请回顾本系列首篇文章,这里不再重复阐述)

并且,最终推导出( ttt 时刻下的)连续时间的损失函数是:

picture.image

但我们前面推导出发送者分布的形式是 , 若要变身为 这种样子, 那么最方便快捷的方式, 莫过于将 乘以 再加 (其实不加1也可, 只不过加1后所得均值的形式会更简便), 这样之后所服从的分布就是 , 从而:

记 , 现在来把目光聚焦在接收者分布上:

picture.image

上面之所以能够把均值中 的单位向量给“消失掉”, 是因为接收者分布是混合高斯分布, 展开来其均值中的单位向量实际是: , 所以 中的 相当于摊分给了各个高斯分布。另外, 关于 的证明详见本系列的首篇文章, 以上最后一步是将狄拉克函数中的常数提取出来。

于是, 式 (vii) 的 正好可以和以上的 对上(因为 ), 而 则对应为 , 从而:

picture.image

将以上结果连同 代入式 (viii), 得到:

picture.image

你或许会疑惑:损失函数不应该是发送者分布和接收者分布之间的 KL 散度么? 怎么变成了后来引入的变量 对应的两个分布之间的 KL 散度了? 搞一大堆公式想坑我是不! ?

不敢不敢 之所以够胆这么搞, 是因为KL 散度对于随机变量的仿射变换具有等价性 , 而 是妥妥的仿射变换, 所以各位客观无须恐方。

伪代码

个人感觉作者写得最好的就是伪代码部分了(CW 也不知道这属于赞赏还是..),因为 paper 中秀出来的伪代码和推导出来的公式是能够完美对应上的。所以,如果你们看下面伪代码时有感到疑惑的地方,都可以回顾下前文内容,一定能够找到对应的答案。

网络输出及预测

这部分是 BFN 接收输入分布参数 并作预测 (即输出分布)。预测结果由 BFN 先输出拥有上下文信息的向量 ,然后经过后处理形成概率分布。在二分类的情况下,后处理使用 sigmoid 函数; 在多分类的情况则使用 softmax。

picture.image

离散时间的训练过程

picture.image

前三行标绿处表示使用贝叶斯流分布进行后验更新的过程,最后两个标绿处之所以要再次从发送者分布中采样观测样本,是因为最终在发送者分布上使用了蒙特卡洛采样来近似计算发送者分布和接收者分布之间的 KL 散度

连续时间的训练过程

picture.image

可以看到,在连续时间的情况下,由于损失函数拥有解析形式,因此不需要再次在发送者分布中采样(因为没有使用蒙特卡洛采样),这与离散时间的情况不同。

采样生成

picture.image

采样生成的过程大致可以概括为:先设置先验为均匀分布 (见第一行标绿处),然后在每个时间步都根据从发送者分布中采样出的观测样本来更新先验,直至预设的总时间步后,最终从输出分布中采样出对应的类别(见倒数第二行)。

这里特别说明下, 第二个标绿处表示的是从输出分布中采样, 这个采样结果 是 维向量, 其中每一维都是 1 K 的整数, 则代表其在每一维上都映射为 维的 one-hot 向量, 因此总共是 维。

另外, 这里的后验更新不能使用贝叶斯流分布, 因为在每个时间步生成的数据样本 (对应到发送者分布中的 ) 都是不同的,所以只能使用贝叶斯更新函数/分布(见最后四个有颜色标注处)。

实验结果

对于离散数据的建模效果,作者选了两个数据集来进行分析,分别是:动态二值化(dynamically binarized) 的 MNIST(28x28 分辨率的二值图像) 和 text8(每个文本序列样本的长度为 256 个字符,字符由27个“字母”组成,空格以下划线代替,也被当作一个“字母”)。作者在实验中都使用连续时间的损失函数进行训练 ,离散时间的损失函数仅用作测试评估(evaluated for testing only)。

为了简单叙述(其实是想偷懒)但又不失一般性,CW 仅就 MNIST 的实验结果来和大家吹吹水~

MNIST(dynamically binarized)

picture.image

作者使用的并非原始的 MNIST 数据集,而是动态二值化的版本,它是从 二值化(binarized) MINST 的基础上发展而来的。二值化就是将原本 MNIST 图像的像素灰度值(0~255)变为二值(0 or 1) ,至于二值化的策略则有许多种,常见的是先将像素值归一化至 内的连续值作为 伯努利分布 的概率,然后从其中进行采样,从而采样结果就是二值化的像素值 0 或 1。动态二值化的原理类似,只不过在每个 batch 都使用不同的二值化策略

对于网络结构和其它的实验设置,CW 就不在这里阐述了,各位客官看 paper 即可。这里提一点,就是作者使用了衰减率为0.9999的 EMA: Exponential Moving Average(指数滑动平均) 模型来做测试评估和采样生成

下面来看一些实验结果:

  • 离散时间损失 vs 连续时间损失

picture.image

上表使用了离散时间的损失函数和连续时间的损失函数对测试集做评估。其中,计算离散时间的损失函数需要事先指定总的步数,所以 "n-steps" 标有实际数字的,就代表使用离散时间的损失函数;从而,剩下标有 " " 的一项就代表使用的是连续时间的损失函数。

作者对测试集进行2000次遍历,每次遍历都抽取出一张图,然后分别计算每组(对应以上表中的每列)对应的 loss,(每组)最终记录在表格中的结果就是这2000次 loss 的平均值。

由上表可知,loss 值随着步数的增加而减少,且连续时间的 loss 值最小(连续时间下的总步数可看作是无穷) 。特别说明下,这里的 loss 还包含了最后一步的重构损失 (记住,它并没有参与训练过程,如前文所述)。另外,由于数据集的图像分辨率是28x28=784,也就是说一张图片的像素个数总共有784个,因此 784 就相当于是自回归模型在生成一张图片时所需要的采样步数。

通过以上实验结果,我们可以认为:由于网络是使用连续时间的损失函数进行训练的,因此在评估时,连续时间的 loss 值最低;且离散时间的损失函数在步数越多(也就是越趋近于连续时间)的情况下,其对应的 loss 值也越低。

基于这个结果,作者推测:若直接使用离散时间的损失函数来进行训练,那么在测试集上评估时,离散时间的 loss 值应该会更低。

  • 输入分布 vs 输出分布

picture.image

以上是对两张测试集图像做评估时,输入、输出分布的可视化效果。这是在 区间内均匀划分为20步,每步展示出对应的可视化效果图。

通过上面的结果,我们可以看出,一开始,输入分布的先验设置为均匀分布,然后通过多次后验更新而逐渐接近于真实分布;而输出分布一开始则更加贴近于训练集中每个像素的边缘先验 ,这个先验是根据训练集中的所有图像的像素值的频率来估计的,它反映了图像的一些统计特征,如:亮度、对比度、边缘 等,这点从它前期预测的是“多个数字的叠加”这个现象就可以看出。这也说明 BFN 能够纠正输入分布的先验且产生趋向于从训练数据中学到的经验性先验

从最终效果来看,输出分布的噪声程度更低 ,并且输出分布能够随着数据信息的不断流入(通过 BFN 接收输入分布的参数),从预测多个数字的叠加逐渐变为预测某个确定的数字 ,这说明 BFN 能够利用上下文信息来解决歧义(ambiguity)和噪声

你或许会感慨:怎么都是在表扬输出分布,输入分布就这么不堪么..

也不是,其实输入分布也很关键,因为输出分布之所以能有如此表现,也是靠输入分布的参数 给它传递信息的,只不过它有 BFN 的加持,利用了神经网络的优势——整合上下文信息与建模多维变量之间的复杂关系,才能够这么秀。只能说,输入分布主打辅助,不过辅助打得再好也没有主攻亮眼,世道就是如此~

  • loss 随时间的变化情况

picture.image

以上是在测试集上, 连续时间的损失函数随时间变化的情况。可以发现, loss 并非均匀地下降, 作者猜测重新缩放 可能会取得更好的效果。另外, 根据最开始给出的那张实验表, 我们知道重构损失 , 这是相对比较高的, 最直接的改善方式是增加 , 但作者发现这样么做会导致收玫慢且最终性能也不好, 说明前面凭直觉设计出来的 accruacy schedule( ) 对于离散数据来说是次优的。

吹吹水

前面都是“正经”内容 —— 关于 BFN 建模离散数据的原理与实验结果,是 CW 根据 paper 中的内容做的详细解析。而在这部分,CW 就不想那么正经了 —— 稍微发散一下,吹吹水(谈谈一些理解)~

收敛慢

首先想说的是,如前文“贝叶斯流分布”那一章所述,在训练期间,输入分布的参数轨迹会收敛得比较慢,那么进一步可能会导致 BFN 收敛慢,毕竟其要通过接收输入分布的参数来获取(关于数据分布的)信息。但是,这种现象(输入分布的参数轨迹收敛慢)在建模连续数据时通常是不存在的(见 paper 中的 Figure 4)。

So.. 造成这两种差别的原因到底在哪里呢?

除了在前文中谈到的“离散数据与连续的概率值之间的 gap”之外,这里进一步脑补下其它可能的因素。

先从数据本身出发来进行分析。离散数据表示为 的类别(此处忽略维度), 类别只能出现在有限的数据点中 , 且从数值上来看是体现不出它们之间的联系的(你不能说 1 和 2 比 1 和 3 更“亲密”)。在将其用于后验更新时, 需转换成 one-hot 形式的向量 , 它具有稀疏性, 信息密度低 , 不太容易使先验(输入分布的参数)快速地学习到真值(进而成为真实数据的经验性先验), 于是需要经历较多次数的后验更新, 从而造成输入分布的参数轨迹“漫长”。从微观上理解,信息密度低, 相当于供先验学习的样本少, 于是学得慢。

另外,若数据中各类别分布得不均匀,那么就容易导致某个类别学得比较慢 。可能更“打击”人的是,基于传统统计学的后验更新在高维空间中本身能力有限 ,相对不那么容易快速地学习到真值(这应该也是作者的小心思:用神经网络来学多一个输出分布)。

总的来说,导致输入分布的参数轨迹收敛慢的(可能)因素有:

  • 离散数据具有稀疏性,信息密度低;
  • 离散数据与连续的概率值之间存在 gap;
  • 传统统计学在高维空间中的建模能力有限;
  • 数据集中类别分布不均匀

面临的挑战

接下来谈谈 CW 认为 BFN 这种玩法可能面临的一些 challenge.

毫无疑问,相比于许多其它生成模型,BFN 的一大亮点在于解释性高 ,毕竟贝叶斯爸爸理论的自洽性还是很牛逼的。但同时也源于这点,导致 BFN 依赖于对概率流的精确估计 ,若根据贝叶斯统计而“造出来”的输入分布不靠谱,那么 BFN 吐出来的输出分布也就变得是来搞笑的.. 因为关于数据分布的信息是由输入分布的参数给到 BFN 的,BFN 是在此基础上进一步整合上下文信息与处理高维变量的复杂关系(也可认为是在对输入分布做校正),所以输入分布本身的靠谱与否很是关键 。总的来说,这个问题紧密关系着 BFN 的生成质量

还有关于生成多样性 的问题,在这方面 CW 暂时未看出 BFN 的优势。如果说 diffusion models 可以通过调节噪声水平来制造多样性,那么相应地,BFN 可以通过调节 来实现,因为这个变量本质上就是在控制噪声水平。

但如前面“实验结果”那一章所看到的, 的设置是个麻烦问题 ,目前作者也是凭直觉去搞的,于是导致了 loss 下降不均以及最后的重构损失较高的现象,重构得不好则最终会体现在模型的生成质量上(不逼真、严重偏离数据分布等)。

Anything u wanna talk, pls let me know ~

End.. NO!

本来计划这篇文章是该系列的终篇,奈何上周从北京浪了一圈回来后,惊喜(恐)地发现作者居然开源了!搞得我是既激动又无奈。激动在于开源了嘛我自己也可以玩一把过过瘾;无奈在于.. 嗯,你们懂的。

无意外的话,CW 会于下一篇文章对 BFN 的源码实现进行解析 ,这将会是不无聊的风格~

附1:多项分布的概率质量函数

注:这部分对应 paper 中的式。

多项分布(Multinomial Distribution) 是 二项分布(Binomial Distribution) 的推广,后者在每次随机试验中只会出现两种结果,而前者则可以有多种。这两者之间的区别类似于 “多分类 vs 二分类”、“Softmax vs Sigmoid” 的赶脚~

假设如今有多项分布 为随机试验次数, 每次随机试验都可能有 种互斥 的结果: , 将它们出现(发生)的次数 记为随机变量:, 这个多项分布代表的事件就是: 在 次相互独立 的随机试验中, 这 种结果出现的次数分别为: 。于是有: 。

记这 种结果在每次 随机试验中发生的概率分别为: 。那么, 如果我们要求在 次随机试验中按照某种特定顺序 使这些结果分别出现 次(比如:前 次的结果都是 , 接下来的 次结果都是 , 则这样发生的概率就是:

然而, 并非一定出现在最开始的前 次试验, 也不一定非得在接下来的 次试验中出现。也就是说,在 次随机试验中,这些结果出现的顺序并不要求是固定的(咱们很民主,没人强求你), 只要让最终的结局是 出现了 次即可。

由此可知, 满足 这个多项分布的情况有很多种, 而其中每种发生的概率都是 : 。于是, 现在的关键在于统计出所有可能的情况

( CW 知道你们最讨厌“易得”、“易证”、“显然可得”这种话术,所以不敢说)

这是个组合问题。在 次试验中选 次作为 发生的“时间点”、然后在剩下的 次中选 次作为 的、接着再于剩下的 次中选 次作为 的... 以此类推, 所有可能的情况数就是:

picture.image

综上, 由于每种情况发生的概率都是 ,因此目标事件发生的概率就是:

所以说, paper 中的式 (134), (135) 存在错误, 连乘的每项应该对应加上 次幂:

picture.image

(不小心逛了下 Arxiv,发现最新一版的 paper,于11月27日发布的,已经改正了这个问题,浅浅表扬下作者~)

并且, 式 (135) 的下标应为 而非 高度怀疑作者大大在写 paper 前喝多了..

附2:一维随机变量的分布变换

注:这部分对应 paper 中的式(144)

记随机变量 的概率密度函数为 , 分布函数为 。现有另一随机变量 , 且函数 严格单调 , 其反函数 存在(也是严格单调递增)。

记 的概率密度函数为 , 分布函数为 , 我们需要求证的是:

假设函数 单调递增, 则 的分布函数 为:

由于 严格单调递增, 因此:

于是:

概率密度函数即分布函数的导数 ,所以:

picture.image

当单调递减时,由于:

picture.image

因此,最终会得出:

综上, 无论 单调递增还是递减, 都有:

彩蛋: 对于多维随机变量: 的情况, 则对应变为 雅可比行列式(Jacobian) 的绝对值 , 即有:

picture.image

附3:对数恒等式

注:这部分对应 paper 中的式(149)

这个证明其实用到了 在 处的泰勒展开(没错,泰勒就是 yyds~),也就是 麦克劳林公式。

picture.image

其中 代表 阶导数。

现在, 对 进行麦克劳林展开:

picture.image

paper 中的式 (149) 就是仅估计到二阶无穷小所对应的结果。

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论