

自从人工智能开始火热以来,不断涌现出让人们惊叹的新技术。其中,基于Diffusion的技术突破层出不穷,每一次进展都给人们带来新的震撼。无论是Stable Diffusion、DALL·E 2,还是最近风靡的Sora,都离不开Diffusion的贡献。本教程将从经典的DDPM开始,结合代码和数学推导,一直探讨到最近的SOTA,带领您深入了解这一领域。
注: 假设读者知道贝叶斯定理,以及有高等数学基础。
1. 基于直觉的理解
首先,我们要理解为什么要叫做“扩散”模型,以及为什么他的生成效果和稳定性都比其他生成模型优秀。
这里引用一下谭旭老师的文章。正如图中所示,我们把从Z到X的过程当作数据的生成过程。Diffusion从河右岸过来的航线不是可学习的,而是人工设计的,能保证到达河左岸的码头位置,虽然有些随机性,但是符合一个先验分布(一般是高斯分布),这样方便我们在生成数据的时候选择左岸出发的码头位置。因为训练模型的时候要求我们一步步打卡来时经过的浮标,在生成数据的时候,基本上也能遵守这些潜在的浮标位置,一步步打卡到达右岸码头。
这样一步一个脚印的方式极大地提高了生成过程中的稳定性和准确性,其他的生成方式例如GAN就如同训练一个神射手,其精准度和稳定性不言而喻。但是,凡事都有代价,Diffusion的代价就是生成(采样)速度慢,于是近些年很多研究者针对DIffusion采样慢的问题做了很多工作,最经典的例如DDIM,将会在下一篇文章中讲到。
我们将这一过程放大,得到下面这一过程,其中表示符合正态分布的初始噪音,表示我们需要的真实图像,表示生成(采样/降噪)过程,表示训练(增噪)过程。
中的表示这个过程涉及到模型中的参数,结果是由模型估计的。这一过程不涉及模型。这一初始噪音为什么要符合正态分布,这一问题会在下文中得到解决。这里的和只是为了将分布区分开,都表示分布,并没有实质区别。
2. 数学公式推导
首先,我们需要知道高斯分布的一些基础知识。
一个均值为方差为的高斯分布的概率分布函数为:
--------------------------------------------------------------------------------(1)
同时,高斯分布有可加性,即:
------------------------------------------------------------(2)
------------------------------------------------------------(3)
以及: ----------------------------------------------------------------------(4)
2.1 加噪过程
首先,一个基本事实是DDPM将加噪过程定义为一个纯粹的马尔可夫过程,即当前状态只取决于上一个状态,可描述为:
---------------------------------------------------------------(5)
其中,由加噪程度定义,随的增大而增大(一般0.0001 -> 0.02),这是因为刚开始的加噪只需要一点点噪声就可以制造很有信息量的样本对供模型学习,随着训练的进行需要更多的噪声来突出与之前样本对的变化。另外,是一个单位矩阵。
利用高斯分布的性质(公式(4)),我们可以得到:
-------------------------------------------------------------------------------(6)
其中。
让我们接着(6)往下推:
----------------------------------------------------------------------- (7)
这个时候我们需要用到高中学到的数学归纳法,将(7)代入(6)得到:
化简:
---------------------------------------------(8)
再由高斯分布的可加性(公式(2)):
--------(9)
最后式(8)可以改写为:
------------------------------------------------------------------(10)
通过递推,我们显然可以得到:
------------------------------------------------------------(11)
到此,我们可以看出很多信息,之所以把加噪过程设计成式(5)的数学样式,在我看来是为了达到最终式(11)的结果。
由于,T足够大,因此, 。
进一步可以得到,。
终于,小船从真实分布出发,到达了名为高斯分布的彼岸,接下来我们需要初始化一个高斯分布,顺着路上我们设立的浮标,回到真实分布。
2.2 去噪过程
那我们该怎么回去呢,这个时候我们就得借助模型,让模型告诉我们怎样找到浮标。
去噪过程实际上就是求,考虑能不能利用上加噪过程中的信息,这个时候我们自然而然想到贝叶斯定理。
-------------------------------------------------------------------------------(12)
式(12)中还是有很多变量不知道(当然不知道,如果知道直接数学方法就推过去了,还要模型干嘛),于是我们引入加噪过程中的信息,引入。
---------------------------------------------------------------------(13)
这是真实的预测公式,这个时候已经可以看出来等式的右边均已知且符合高斯分布。
我们可以进一步推导,看看到底有没有解:
首先,我们将式(1)展开,得到:
-------------------------------------------------------------(14)
接着使用式(14)展开化简式(13):
(将记为)
-----------------------------------------(15)
--------------------------------------------------(16)
---(17 )
注:是因为 基于马尔科夫假设。
对比式(14)和式(17),比较和的系数,可以求得:
-------------------------------------------------------------------------------------(18)
----------------------------------------------------------------------(19)
于是,我们得到了真实分布:
-----------------------(20)
通过观察真实分布我们可以发现,方差是一个通过计算就能够得出来的量,如果我们想要让模型去尽可能地接近真实分布,只需要利用模型去对齐均值,即:
-----------------------------------------------------------------------(21)
观察式(19)可以发现,要想对齐均值,只需要利用模型给定来预测即可(这正是加噪过程中创造的大量数据对),这样做确实可以,但是需要很大的算力,DDPM的作者采用了一种方法,让预测原图转换成预测噪声,神经网络模型似乎天生比较擅长预测残差。
联想到、、之间存在着联系,将式(11)变形:
------------------------------------------------------------------------------(22)
将式(22)带入式(19):
---------------------------------------------------------------------------------(23)
模型就由预测转换成预测,利用模型给定来预测。
表示由公式推导从加噪到的标准高斯分布,也就是说模型预测的是噪声。
3. 设计模型
3.1 训练过程
由式(21)可知,我们需要的是在所有时刻,真实加入的噪音与模型预测的噪音差距越小越好,可以使用MSE来作为loss。
既然是所有时刻,只需要在训练时随机一个时间,再随机一个噪音,利用公式(11)得到。这样我们获得了模型所需要的所有输入和标签。这里附一张原论文的算法图。
注意:torch.randn函数生成的随机数使用了标准正态分布作为随机数生成的基础分布。
3.2 采样过程
现在,我们已经可以通过模型得到预测的噪音,就可以通过式(23)得到的的均值,它非常贴近真实分布。
在下文的实例中,直接将均值加上一个随机噪音作为一个简单的采样点,毕竟在高斯分布中均值的概率是最高的。
4. 从代码理解
代码来自此处,我做了注解,帮助读者理解模型。
4.1 准备工作
早期的Diffusion采用Unet为架构,现在有DiT这样的使用Transformer代替Unet的模型。模型的意义只是为了预测噪音,所以为了方便演示,这里使用Unet。
4.2 Unet主体
注意:DDPM不包含条件生成,这里的context是用于介绍后面的条件生成部分。
4.3 构建模型
sprite shape: (89400, 16, 16, 3) labels shape: (89400, 5)
4.4 训练
epoch 0 100%|██████████| 894/894 [00:25<00:00, 34.65it/s] saved model at ./weights/model_0.pth epoch 1 100%|██████████| 894/894 [00:21<00:00, 41.18it/s] epoch 2 100%|██████████| 894/894 [00:21<00:00, 41.48it/s] epoch 3 100%|██████████| 894/894 [00:21<00:00, 41.44it/s] epoch 4 100%|██████████| 894/894 [00:21<00:00, 41.46it/s] saved model at ./weights/model_4.pth epoch 5 100%|██████████| 894/894 [00:21<00:00, 41.55it/s] epoch 6 100%|██████████| 894/894 [00:22<00:00, 40.06it/s] epoch 7 100%|██████████| 894/894 [00:21<00:00, 41.59it/s] epoch 8 100%|██████████| 894/894 [00:21<00:00, 41.43it/s] saved model at ./weights/model_8.pth epoch 9 100%|██████████| 894/894 [00:22<00:00, 39.76it/s] epoch 10 100%|██████████| 894/894 [00:21<00:00, 41.59it/s] epoch 11 100%|██████████| 894/894 [00:21<00:00, 41.59it/s] epoch 12 100%|██████████| 894/894 [00:21<00:00, 41.45it/s] saved model at ./weights/model_12.pth epoch 13 100%|██████████| 894/894 [00:21<00:00, 41.33it/s] epoch 14 100%|██████████| 894/894 [00:21<00:00, 40.86it/s] epoch 15 100%|██████████| 894/894 [00:21<00:00, 41.48it/s] epoch 16 100%|██████████| 894/894 [00:21<00:00, 41.31it/s] saved model at ./weights/model_16.pth epoch 17 100%|██████████| 894/894 [00:21<00:00, 41.30it/s] epoch 18 100%|██████████| 894/894 [00:21<00:00, 41.55it/s] epoch 19 100%|██████████| 894/894 [00:21<00:00, 41.58it/s] epoch 20 100%|██████████| 894/894 [00:21<00:00, 41.38it/s] saved model at ./weights/model_20.pth epoch 21 100%|██████████| 894/894 [00:21<00:00, 41.25it/s] epoch 22 100%|██████████| 894/894 [00:21<00:00, 41.36it/s] epoch 23 100%|██████████| 894/894 [00:21<00:00, 41.11it/s] epoch 24 100%|██████████| 894/894 [00:21<00:00, 41.05it/s] saved model at ./weights/model_24.pth epoch 25 100%|██████████| 894/894 [00:21<00:00, 41.21it/s] epoch 26 100%|██████████| 894/894 [00:21<00:00, 41.74it/s] epoch 27 100%|██████████| 894/894 [00:21<00:00, 41.27it/s] epoch 28 100%|██████████| 894/894 [00:21<00:00, 41.39it/s] saved model at ./weights/model_28.pth epoch 29 100%|██████████| 894/894 [00:23<00:00, 38.47it/s] epoch 30 100%|██████████| 894/894 [00:21<00:00, 41.08it/s] epoch 31 100%|██████████| 894/894 [00:21<00:00, 41.50it/s]saved model at ./weights/model_31.pth
4.5 采样
观察 Epoch 0 采样
Loaded in Model
gif animating frame 31 of 32
<Figure size 640x480 with 0 Axes>
观察 Epoch 4 采样
Loaded in Model
gif animating frame 31 of 32
<Figure size 640x480 with 0 Axes>
观察 Epoch 8 采样
Loaded in Model
gif animating frame 31 of 32
<Figure size 640x480 with 0 Axes>
观察 Epoch 16 采样
Loaded in Model
gif animating frame 31 of 32
<Figure size 640x480 with 0 Axes>
观察 Epoch 32 采样
Loaded in Model
gif animating frame 31 of 32
<Figure size 640x480 with 0 Axes>







