什么是Wasserstein距离

AI领域中,“距离”和“相似性”是理解数据和模型行为的关键概念。在众多衡量分布之间差异的方法中,Wasserstein距离(也称为地球移动距离,英文:Earth Mover’s Distance, EMD)脱颖而出,为我们提供了一个更直观、更稳定的度量标准。它在人工智能,特别是生成对抗网络(GAN)等领域发挥了重要作用。

一、 什么是Wasserstein距离?——从“搬土”说起

想象一下你有两堆沙子:一堆是你实际观察到的数据(真实数据分布),另一堆是你的AI模型生成的数据(生成数据分布)。这两堆沙子的形状、位置和大小可能各不相同。现在,你的任务是把第一堆沙子(模型生成的沙子)重新塑造成第二堆沙子(真实沙子)。你需要雇佣一台推土机来完成这项工作。

Wasserstein距离衡量的就是完成这项“搬土”任务所需的最小“工作量”。 这里的“工作量”通常定义为:你移动了多少沙子,乘以这些沙子平均移动了多远的距离。 如果两堆沙子完全相同,那么不需要移动任何沙子,工作量就是0。如果它们完全不相干,或者形状差异很大,那么就需要做更多的“功”。

这个形象的比喻就是**地球移动距离(Earth Mover’s Distance)**这个名字的由来,它是在1781年由Gaspard Monge首次提出的一个关于最优传输(Optimal Transport)的问题概念。 直到后来,列昂尼德·瓦瑟施泰因(Leonid Vaseršteǐn)等人的研究才将其应用于概率分布的比较中,并最终以他的名字命名。

二、 为什么Wasserstein距离如此特别?——与其他“距离”的区别

在计算机科学和机器学习中,我们还有其他衡量两个概率分布之间差异的方法,其中最常见的是KL散度(Kullback-Leibler Divergence)JS散度(Jensen-Shannon Divergence)。 那么,相较于它们,Wasserstein距离有什么优势呢?

  1. 对重叠度不敏感,提供有意义的梯度信息

    • 想象两堆沙子,如果它们之间完全没有重叠(比如一堆沙子全部在左边,另一堆全部在右边),那么KL散度或JS散度可能会给出无限大或常数的值,这使得我们无法判断哪堆沙子更“靠近”另一堆,也就无法知道应该如何调整模型去“搬动”沙子以缩小距离。 这在机器学习算法中可能导致梯度消失,模型无法有效学习。
    • Wasserstein距离则不同。即使两堆沙子完全没有重叠,它也能根据沙子需要移动的距离给出有意义的数值。 比如,两堆沙子相距10米的工作量,显然比相距100米的工作量要小。这个数值提供了一个平滑的、可以有效优化的梯度信息,使得模型能够明确知道“往哪个方向努力”才能让生成的沙子更像真实的沙子。
    • 你可以把它理解为:KL/JS散度可能只关心两堆沙子“是不是不一样”,但Wasserstein距离更能衡量它们“在哪里不一样,以及不一样到什么程度”
  2. 考虑了“路径”和“成本”

    • KL散度和JS散度更多地关注两个分布在每个点上的概率差异。
    • Wasserstein距离则着眼于如何最优地将一个分布中的“质量”(比如沙子)转换到另一个分布中。它不仅仅测量差异的总量,还测量消除这种差异所需的“成本”或“工作量”,这个成本与移动的“距离”以及“质量”有关。
  3. 几何直观性

    • Wasserstein距离与物理直觉高度吻合,即“搬土工程”的比喻。这使得即使是非专业人士也能更容易地理解其内在含义。

三、 Wasserstein距离在AI中的应用

Wasserstein距离之所以在AI领域受到关注,很大程度上归功于其在**生成对抗网络(GAN)**中的应用。

1. 生成对抗网络(GANs)的稳定性提升:
传统的GANs在训练时经常会遇到模式崩溃(mode collapse)和训练不稳定等问题。这部分原因在于其损失函数(通常基于JS散度)在两个分布重叠度很低时会梯度消失。
2017年提出的**Wasserstein GAN (WGAN)**就是为了解决这个问题。 WGAN将原本的损失函数替换为Wasserstein距离,使得判别器(Critic)能够为生成器(Generator)提供更有意义的梯度信号,即使真实数据分布和生成数据分布之间重叠很小。 这使得WGAN的训练更加稳定,生成的样本质量更高,多样性也更好。它能更好地衡量生成图像与真实图像分布之间的距离(或差异)。

2. 图像处理与计算机视觉:
Wasserstein距离在图像处理中被用于衡量两幅图像之间的差异。 相比于传统的像素级比较,它能更好地考虑图像的结构信息和空间关系。 例如,在图像检索中,它可以用来寻找与查询图像最相似的图像,即使图像有变形或噪声。 此外,它还在图像生成、风格迁移等任务中发挥作用。

3. 数据漂移检测:
在机器学习模型部署之后,输入数据的分布可能会随时间发生变化,这被称为“数据漂移”(Data Drift),可能导致模型性能下降。 Wasserstein距离可以用来有效地衡量新数据分布与训练数据分布之间的差异,从而检测数据漂移。 相比于KL散度,Wasserstein距离在检测出复杂数据分布或大型数据集的结构变化时,表现更具鲁棒性。

4. 其他应用:
除了上述领域,Wasserstein距离还在自然语言处理、计算生物学(如比较细胞计数数据集的持久图)和地球物理学逆问题等领域有所应用。 它甚至被用于集成信息理论中,以计算概念和概念结构之间的差异。

四、 展望未来

尽管Wasserstein距离有其计算成本相对较高(尤其是在高维数据上)的缺点, 但是它在机器学习,特别是生成模型和数据分析中的独特优势,使得它成为了一个不可或缺的工具。随着计算资源的进步和新算法的开发,相信Wasserstein距离的应用将更加广泛和深入,为AI领域带来更多创新和突破。