Diffusion擴散模型
本文基于Hugging Face:The Annotated Diffusion Model一文翻譯遷移而來,同時參考了由淺入深了解Diffusion Model一文。
關于擴散模型(Diffusion Models)有很多種理解,本文的介紹是基于denoising diffusion probabilistic model (DDPM),DDPM已經在(無)條件圖像/音頻/視頻生成領域取得了較多顯著的成果,現有的比較受歡迎的的例子包括由OpenAI主導的GLIDE和DALL-E 2、由海德堡大學主導的潛在擴散和由Google Brain主導的圖像生成。
實際上生成模型的擴散概念已經在(Sohl-Dickstein et al., 2015)中介紹過。然而,直到(Song et al., 2019)(斯坦福大學)和(Ho et al., 2020)(在Google Brain)才各自獨立地改進了這種方法。
本文是在Phil Wang基于PyTorch框架的復現的基礎上(而它本身又是基于TensorFlow實現),遷移到MindSpore AI框架上實現的。
實驗中我們采用離散時間(潛在變量模型)的觀點
模型簡介
什么是Diffusion Model?
如果將Diffusion與其他生成模型(如Normalizing Flows、GAN或VAE)進行比較,它并沒有那么復雜,它們都將噪聲從一些簡單分布轉換為數據樣本,Diffusion也是從純噪聲開始通過一個神經網絡學習逐步去噪,最終得到一個實際圖像。 Diffusion對于圖像的處理包括以下兩個過程:
-
我們選擇的固定(或預定義)正向擴散過程?𝑞 :它逐漸將高斯噪聲添加到圖像中,直到最終得到純噪聲
-
一個學習的反向去噪的擴散過程?𝑝𝜃 :通過訓練神經網絡從純噪聲開始逐漸對圖像去噪,直到最終得到一個實際的圖像
由?𝑡 索引的正向和反向過程都發生在某些有限時間步長?𝑇(DDPM作者使用?𝑇=1000)內。從𝑡=0開始,在數據分布中采樣真實圖像?𝐱0(本文使用一張來自ImageNet的貓圖像形象的展示了diffusion正向添加噪聲的過程),正向過程在每個時間步長?𝑡 都從高斯分布中采樣一些噪聲,再添加到上一個時刻的圖像中。假定給定一個足夠大的?𝑇 和一個在每個時間步長添加噪聲的良好時間表,您最終會在?𝑡=𝑇通過漸進的過程得到所謂的各向同性的高斯分布。
擴散模型實現原理
Diffusion 前向過程
所謂前向過程,即向圖片上加噪聲的過程。雖然這個步驟無法做到圖片生成,但這是理解diffusion model以及構建訓練樣本至關重要的一步。 首先我們需要一個可控的損失函數,并運用神經網絡對其進行優化。
設?𝑞(𝑥0)是真實數據分布,由于?𝑥0~𝑞(𝑥0),所以我們可以從這個分布中采樣以獲得圖像?𝑥0 。接下來我們定義前向擴散過程?𝑞(𝑥𝑡|𝑥𝑡?1) ,在前向過程中我們會根據已知的方差?0<𝛽1<𝛽2<...<𝛽𝑇<1在每個時間步長 t 添加高斯噪聲,由于前向過程的每個時刻 t 只與時刻 t-1 有關,所以也可以看做馬爾科夫過程:
回想一下,正態分布(也稱為高斯分布)由兩個參數定義:平均值?𝜇 和方差?𝜎2≥0 。基本上,在每個時間步長?𝑡 處的產生的每個新的(輕微噪聲)圖像都是從條件高斯分布中繪制的,其中
我們可以通過采樣然后設置
請注意,?𝛽𝑡在每個時間步長?𝑡(因此是下標)不是恒定的:事實上,我們定義了一個所謂的“動態方差”的方法,使得每個時間步長的?𝛽𝑡可以是線性的、二次的、余弦的等(有點像動態學習率方法)。
因此,如果我們適當設置時間表,從?𝐱0開始,我們最終得到?𝐱1,...,𝐱𝑡,...,𝐱𝑇,即隨著?𝑡 的增大?𝐱𝑡會越來越接近純噪聲,而?𝐱𝑇就是純高斯噪聲。
那么,如果我們知道條件概率分布?𝑝(𝐱𝑡?1|𝐱𝑡),我們就可以反向運行這個過程:通過采樣一些隨機高斯噪聲?𝐱𝑇,然后逐漸去噪它,最終得到真實分布?𝐱0中的樣本。但是,我們不知道條件概率分布?𝑝(𝐱𝑡?1|𝐱𝑡)。這很棘手,因為需要知道所有可能圖像的分布,才能計算這個條件概率。
Diffusion 逆向過程
為了解決上述問題,我們將利用神經網絡來近似(學習)這個條件概率分布?𝑝𝜃(𝐱𝑡?1|𝐱𝑡), 其中?𝜃是神經網絡的參數。如果說前向過程(forward)是加噪的過程,那么逆向過程(reverse)就是diffusion的去噪推斷過程,而通過神經網絡學習并表示?𝑝𝜃(𝐱𝑡?1|𝐱𝑡)的過程就是Diffusion 逆向去噪的核心。
現在,我們知道了需要一個神經網絡來學習逆向過程的(條件)概率分布。我們假設這個反向過程也是高斯的,任何高斯分布都由2個參數定義:
-
由?𝜇𝜃參數化的平均值
-
由?𝜇𝜃參數化的方差
綜上,我們可以將逆向過程公式化為
其中平均值和方差也取決于噪聲水平?𝑡,神經網絡需要通過學習來表示這些均值和方差。
-
注意,DDPM的作者決定保持方差固定,讓神經網絡只學習(表示)這個條件概率分布的平均值?𝜇𝜃。
-
本文我們同樣假設神經網絡只需要學習(表示)這個條件概率分布的平均值?𝜇𝜃。
為了導出一個目標函數來學習反向過程的平均值,作者觀察到?𝑞和?𝑝𝜃的組合可以被視為變分自動編碼器(VAE)。因此,變分下界(也稱為ELBO)可用于最小化真值數據樣本?𝐱0的似然負對數(有關ELBO的詳細信息,請參閱VAE論文(Kingma等人,2013年)),該過程的ELBO是每個時間步長的損失之和?𝐿=𝐿0+𝐿1+...+𝐿𝑇 ,其中,每項的損失?𝐿𝑡(除了?𝐿0)實際上是2個高斯分布之間的KL發散,可以明確地寫為相對于均值的L2-loss!
如Sohl-Dickstein等人所示,構建Diffusion正向過程的直接結果是我們可以在條件是?𝐱0(因為高斯和也是高斯)的情況下,在任意噪聲水平上采樣?𝐱𝑡,而不需要重復應用?𝑞 去采樣?𝐱𝑡,這非常方便。使用
我們就有
這意味著我們可以采樣高斯噪聲并適當地縮放它,然后將其添加到?𝐱0中,直接獲得?𝐱𝑡 。
請注意,𝛼ˉ𝑡已知?𝛽𝑡方差計劃的函數,因此也是已知的,可以預先計算。這允許我們在訓練期間優化損失函數?𝐿的隨機項。或者換句話說,在訓練期間隨機采樣?𝑡并優化?𝐿𝑡。
正如Ho等人所展示的那樣,這種性質的另一個優點是可以重新參數化平均值,使神經網絡學習(預測)構成損失的KL項中噪聲的附加噪聲。這意味著我們的神經網絡變成了噪聲預測器,而不是(直接)均值預測器。其中,平均值可以按如下方式計算:
最終的目標函數?𝐿𝑡 如下 (隨機步長 t 由?(𝜖~𝑁(0,𝐈))給定):
在這里,?𝐱0是初始(真實,未損壞)圖像,?𝜖是在時間步長?𝑡采樣的純噪聲,𝜖𝜃(𝐱𝑡,𝑡)是我們的神經網絡。神經網絡是基于真實噪聲和預測高斯噪聲之間的簡單均方誤差(MSE)進行優化的。
訓練算法現在如下所示:
換句話說:
-
我們從真實未知和可能復雜的數據分布中隨機抽取一個樣本?𝑞(𝐱0)
-
我們均勻地采樣11和𝑇之間的噪聲水平𝑡(即,隨機時間步長)
-
我們從高斯分布中采樣一些噪聲,并使用上面定義的屬性在?𝑡時間步上破壞輸入
-
神經網絡被訓練以基于損壞的圖像?𝐱𝑡來預測這種噪聲,即基于已知的時間表?𝐱𝑡上施加的噪聲
實際上,所有這些都是在批數據上使用隨機梯度下降來優化神經網絡完成的。
U-Net神經網絡預測噪聲
神經網絡需要在特定時間步長接收帶噪聲的圖像,并返回預測的噪聲。請注意,預測噪聲是與輸入圖像具有相同大小/分辨率的張量。因此,從技術上講,網絡接受并輸出相同形狀的張量。那么我們可以用什么類型的神經網絡來實現呢?
這里通常使用的是非常相似的自動編碼器,您可能還記得典型的"深度學習入門"教程。自動編碼器在編碼器和解碼器之間有一個所謂的"bottleneck"層。編碼器首先將圖像編碼為一個稱為"bottleneck"的較小的隱藏表示,然后解碼器將該隱藏表示解碼回實際圖像。這迫使網絡只保留bottleneck層中最重要的信息。
在模型結構方面,DDPM的作者選擇了U-Net,出自(Ronneberger et al.,2015)(當時,它在醫學圖像分割方面取得了最先進的結果)。這個網絡就像任何自動編碼器一樣,在中間由一個bottleneck組成,確保網絡只學習最重要的信息。重要的是,它在編碼器和解碼器之間引入了殘差連接,極大地改善了梯度流(靈感來自于(He et al., 2015))。
可以看出,U-Net模型首先對輸入進行下采樣(即,在空間分辨率方面使輸入更小),之后執行上采樣。
實踐環境準備
python版本:Python 3.9.19
安裝所需依賴
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install download dataset matplotlib tqdm
完整的依賴環境如下:
pip listPackage Version
------------------------------ --------------
absl-py 2.1.0
aiofiles 22.1.0
aiosqlite 0.20.0
altair 5.3.0
annotated-types 0.7.0
anyio 4.4.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
astroid 3.2.2
asttokens 2.0.5
astunparse 1.6.3
attrs 23.2.0
auto-tune 0.1.0
autopep8 1.5.5
Babel 2.15.0
backcall 0.2.0
beautifulsoup4 4.12.3
black 24.4.2
bleach 6.1.0
certifi 2024.6.2
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
colorama 0.4.6
comm 0.2.1
contextlib2 21.6.0
contourpy 1.2.1
cycler 0.12.1
dataflow 0.0.1
debugpy 1.6.7
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.8
dnspython 2.6.1
download 0.3.5
easydict 1.13
email_validator 2.2.0
entrypoints 0.4
exceptiongroup 1.2.0
executing 0.8.3
fastapi 0.111.0
fastapi-cli 0.0.4
fastjsonschema 2.20.0
ffmpy 0.3.2
filelock 3.15.3
flake8 3.8.4
fonttools 4.53.0
fqdn 1.5.1
fsspec 2024.6.0
gitdb 4.0.11
GitPython 3.1.43
gradio 4.26.0
gradio_client 0.15.1
h11 0.14.0
hccl 0.1.0
hccl-parser 0.1
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.4
idna 3.7
importlib-metadata 7.0.1
importlib_resources 6.4.0
iniconfig 2.0.0
ipykernel 6.28.0
ipympl 0.9.4
ipython 8.15.0
ipython-genutils 0.2.0
ipywidgets 8.1.3
isoduration 20.11.0
isort 5.13.2
jedi 0.17.2
Jinja2 3.1.4
joblib 1.4.2
json5 0.9.25
jsonpointer 3.0.0
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
jupyter_client 7.4.9
jupyter_core 5.7.2
jupyter-events 0.10.0
jupyter-lsp 2.2.5
jupyter-resource-usage 0.7.2
jupyter_server 2.14.1
jupyter_server_fileid 0.9.2
jupyter-server-mathjax 0.2.6
jupyter_server_terminals 0.5.3
jupyter_server_ydoc 0.8.0
jupyter-ydoc 0.2.5
jupyterlab 3.6.7
jupyterlab_code_formatter 2.2.1
jupyterlab_git 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp 4.3.0
jupyterlab_pygments 0.3.0
jupyterlab_server 2.27.2
jupyterlab-system-monitor 0.8.0
jupyterlab-topbar 0.6.1
jupyterlab_widgets 3.0.11
kiwisolver 1.4.5
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.0
matplotlib-inline 0.1.6
mccabe 0.6.1
mdurl 0.1.2
mindspore 2.2.14
mindvision 0.1.0
mistune 3.0.2
ml_collections 0.1.1
mpmath 1.3.0
msadvisor 1.0.0
mypy-extensions 1.0.0
nbclassic 1.1.0
nbclient 0.10.0
nbconvert 7.16.4
nbdime 4.0.1
nbformat 5.10.4
nest-asyncio 1.6.0
notebook 6.5.7
notebook_shim 0.2.4
numpy 1.26.4
op-compile-tool 0.1.0
op-gen 0.1
op-test-frame 0.1
opc-tool 0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python 4.10.0.84
opencv-python-headless 4.10.0.84
orjson 3.10.5
overrides 7.7.0
packaging 23.2
pandas 2.2.2
pandocfilters 1.5.1
parso 0.7.1
pathlib2 2.3.7.post1
pathspec 0.12.1
pexpect 4.8.0
pickleshare 0.7.5
pillow 10.3.0
pip 24.1
platformdirs 4.2.2
pluggy 1.5.0
prometheus_client 0.20.0
prompt-toolkit 3.0.43
protobuf 5.27.1
psutil 5.9.0
ptyprocess 0.7.0
pure-eval 0.2.2
pycodestyle 2.6.0
pycparser 2.22
pydantic 2.7.4
pydantic_core 2.18.4
pydocstyle 6.3.0
pydub 0.25.1
pyflakes 2.2.0
Pygments 2.15.1
pylint 3.2.3
pyparsing 3.1.2
pytest 8.0.0
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-json-logger 2.0.7
python-jsonrpc-server 0.4.0
python-language-server 0.36.2
python-multipart 0.0.9
pytoolconfig 1.3.1
pytz 2024.1
PyYAML 6.0.1
pyzmq 25.1.2
referencing 0.35.1
requests 2.32.3
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.7.1
rope 1.13.0
rpds-py 0.18.1
ruff 0.4.10
schedule-search 0.0.1
scikit-learn 1.5.0
scipy 1.13.1
semantic-version 2.10.0
Send2Trash 1.8.3
setuptools 69.5.1
shellingham 1.5.4
six 1.16.0
smmap 5.0.1
sniffio 1.3.1
snowballstemmer 2.2.0
soupsieve 2.5
stack-data 0.2.0
starlette 0.37.2
sympy 1.12.1
synr 0.5.0
te 0.4.0
terminado 0.18.1
threadpoolctl 3.5.0
tinycss2 1.3.0
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.0
toolz 0.12.1
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
typer 0.12.3
types-python-dateutil 2.9.0.20240316
typing_extensions 4.11.0
tzdata 2024.1
ujson 5.10.0
uri-template 1.3.0
urllib3 2.2.2
uvicorn 0.30.1
uvloop 0.19.0
watchfiles 0.22.0
wcwidth 0.2.5
webcolors 24.6.0
webencodings 0.5.1
websocket-client 1.8.0
websockets 11.0.3
wheel 0.43.0
widgetsnbextension 4.0.11
y-py 0.6.2
yapf 0.40.2
ypy-websocket 0.8.4
zipp 3.17.0
實踐運行所需最小內存:30GB
實踐代碼
構建Diffusion模型
下面,我們逐步構建Diffusion模型。
首先,我們定義了一些幫助函數和類,這些函數和類將在實現神經網絡時使用。
import math
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from multiprocessing import cpu_count
from download import downloadimport mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, ToPIL
from mindspore.common.initializer import initializer
from mindspore.amp import DynamicLossScalerms.set_seed(0)def rearrange(head, inputs):b, hc, x, y = inputs.shapec = hc // headreturn inputs.reshape((b, head, c, x * y))def rsqrt(x):res = ops.sqrt(x)return ops.inv(res)def randn_like(x, dtype=None):if dtype is None:dtype = x.dtyperes = ops.standard_normal(x.shape).astype(dtype)return resdef randn(shape, dtype=None):if dtype is None:dtype = ms.float32res = ops.standard_normal(shape).astype(dtype)return resdef randint(low, high, size, dtype=ms.int32):res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)return resdef exists(x):return x is not Nonedef default(val, d):if exists(val):return valreturn d() if callable(d) else ddef _check_dtype(d1, d2):if ms.float32 in (d1, d2):return ms.float32if d1 == d2:return d1raise ValueError('dtype is not supported.')class Residual(nn.Cell):def __init__(self, fn):super().__init__()self.fn = fndef construct(self, x, *args, **kwargs):return self.fn(x, *args, **kwargs) + x# 定義上下采樣
def Upsample(dim):return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)def Downsample(dim):return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)
位置向量
由于神經網絡的參數在時間(噪聲水平)上共享,作者使用正弦位置嵌入來編碼𝑡,靈感來自Transformer(Vaswani et al., 2017)。對于批處理中的每一張圖像,神經網絡"知道"它在哪個特定時間步長(噪聲水平)上運行。
SinusoidalPositionEmbeddings
模塊采用(batch_size, 1)
形狀的張量作為輸入(即批處理中幾個有噪聲圖像的噪聲水平),并將其轉換為(batch_size, dim)
形狀的張量,其中dim
是位置嵌入的尺寸。然后,我們將其添加到每個剩余塊中。
class SinusoidalPositionEmbeddings(nn.Cell):def __init__(self, dim):super().__init__()self.dim = dimhalf_dim = self.dim // 2emb = math.log(10000) / (half_dim - 1)emb = np.exp(np.arange(half_dim) * - emb)self.emb = Tensor(emb, ms.float32)def construct(self, x):emb = x[:, None] * self.emb[None, :]emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)return emb
ResNet/ConvNeXT塊
接下來,我們定義U-Net模型的核心構建塊。DDPM作者使用了一個Wide ResNet塊(Zagoruyko et al., 2016),但Phil Wang決定添加ConvNeXT(Liu et al., 2022)替換ResNet,因為后者在圖像領域取得了巨大成功。
在最終的U-Net架構中,可以選擇其中一個或另一個,本文選擇ConvNeXT塊構建U-Net模型。
class Block(nn.Cell):def __init__(self, dim, dim_out, groups=1):super().__init__()self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')self.norm = nn.GroupNorm(groups, dim_out)self.act = nn.SiLU()def construct(self, x, scale_shift=None):x = self.proj(x)x = self.norm(x)if exists(scale_shift):scale, shift = scale_shiftx = x * (scale + 1) + shiftx = self.act(x)return xclass ConvNextBlock(nn.Cell):def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):super().__init__()self.mlp = (nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))if exists(time_emb_dim)else None)self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")self.net = nn.SequentialCell(nn.GroupNorm(1, dim) if norm else nn.Identity(),nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),nn.GELU(),nn.GroupNorm(1, dim_out * mult),nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),)self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()def construct(self, x, time_emb=None):h = self.ds_conv(x)if exists(self.mlp) and exists(time_emb):assert exists(time_emb), "time embedding must be passed in"condition = self.mlp(time_emb)condition = condition.expand_dims(-1).expand_dims(-1)h = h + conditionh = self.net(h)return h + self.res_conv(x)
Attention模塊
接下來,我們定義Attention模塊,DDPM作者將其添加到卷積塊之間。Attention是著名的Transformer架構(Vaswani et al., 2017),在人工智能的各個領域都取得了巨大的成功,從NLP到蛋白質折疊。Phil Wang使用了兩種注意力變體:一種是常規的multi-head self-attention(如Transformer中使用的),另一種是LinearAttention(Shen et al., 2018),其時間和內存要求在序列長度上線性縮放,而不是在常規注意力中縮放。 要想對Attention機制進行深入的了解,請參照Jay Allamar的精彩的博文。
class Attention(nn.Cell):def __init__(self, dim, heads=4, dim_head=32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)self.map = ops.Map()self.partial = ops.Partial()def construct(self, x):b, _, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, 1)q, k, v = self.map(self.partial(rearrange, self.heads), qkv)q = q * self.scale# 'b h d i, b h d j -> b h i j'sim = ops.bmm(q.swapaxes(2, 3), k)attn = ops.softmax(sim, axis=-1)# 'b h i j, b h d j -> b h i d'out = ops.bmm(attn, v.swapaxes(2, 3))out = out.swapaxes(-1, -2).reshape((b, -1, h, w))return self.to_out(out)class LayerNorm(nn.Cell):def __init__(self, dim):super().__init__()self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')def construct(self, x):eps = 1e-5var = x.var(1, keepdims=True)mean = x.mean(1, keep_dims=True)return (x - mean) * rsqrt((var + eps)) * self.gclass LinearAttention(nn.Cell):def __init__(self, dim, heads=4, dim_head=32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)self.to_out = nn.SequentialCell(nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),LayerNorm(dim))self.map = ops.Map()self.partial = ops.Partial()def construct(self, x):b, _, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, 1)q, k, v = self.map(self.partial(rearrange, self.heads), qkv)q = ops.softmax(q, -2)k = ops.softmax(k, -1)q = q * self.scalev = v / (h * w)# 'b h d n, b h e n -> b h d e'context = ops.bmm(k, v.swapaxes(2, 3))# 'b h d e, b h d n -> b h e n'out = ops.bmm(context.swapaxes(2, 3), q)out = out.reshape((b, -1, h, w))return self.to_out(out)
組歸一化
DDPM作者將U-Net的卷積/注意層與群歸一化(Wu et al., 2018)。下面,我們定義一個PreNorm
類,將用于在注意層之前應用groupnorm。
class PreNorm(nn.Cell):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.GroupNorm(1, dim)def construct(self, x):x = self.norm(x)return self.fn(x)
條件U-Net?
我們已經定義了所有的構建塊(位置嵌入、ResNet/ConvNeXT塊、Attention和組歸一化),現在需要定義整個神經網絡了。請記住,網絡?𝜖𝜃(𝐱𝑡,𝑡) 的工作是接收一批噪聲圖像+噪聲水平,并輸出添加到輸入中的噪聲。
更具體的: 網絡獲取了一批(batch_size, num_channels, height, width)
形狀的噪聲圖像和一批(batch_size, 1)
形狀的噪音水平作為輸入,并返回(batch_size, num_channels, height, width)
形狀的張量。
網絡構建過程如下:
-
首先,將卷積層應用于噪聲圖像批上,并計算噪聲水平的位置
-
接下來,應用一系列下采樣級。每個下采樣階段由2個ResNet/ConvNeXT塊 + groupnorm + attention + 殘差連接 + 一個下采樣操作組成
-
在網絡的中間,再次應用ResNet或ConvNeXT塊,并與attention交織
-
接下來,應用一系列上采樣級。每個上采樣級由2個ResNet/ConvNeXT塊+ groupnorm + attention + 殘差連接 + 一個上采樣操作組成
-
最后,應用ResNet/ConvNeXT塊,然后應用卷積層
最終,神經網絡將層堆疊起來,就像它們是樂高積木一樣(但重要的是了解它們是如何工作的)。
class Unet(nn.Cell):def __init__(self,dim,init_dim=None,out_dim=None,dim_mults=(1, 2, 4, 8),channels=3,with_time_emb=True,convnext_mult=2,):super().__init__()self.channels = channelsinit_dim = default(init_dim, dim // 3 * 2)self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)dims = [init_dim, *map(lambda m: dim * m, dim_mults)]in_out = list(zip(dims[:-1], dims[1:]))block_klass = partial(ConvNextBlock, mult=convnext_mult)if with_time_emb:time_dim = dim * 4self.time_mlp = nn.SequentialCell(SinusoidalPositionEmbeddings(dim),nn.Dense(dim, time_dim),nn.GELU(),nn.Dense(time_dim, time_dim),)else:time_dim = Noneself.time_mlp = Noneself.downs = nn.CellList([])self.ups = nn.CellList([])num_resolutions = len(in_out)for ind, (dim_in, dim_out) in enumerate(in_out):is_last = ind >= (num_resolutions - 1)self.downs.append(nn.CellList([block_klass(dim_in, dim_out, time_emb_dim=time_dim),block_klass(dim_out, dim_out, time_emb_dim=time_dim),Residual(PreNorm(dim_out, LinearAttention(dim_out))),Downsample(dim_out) if not is_last else nn.Identity(),]))mid_dim = dims[-1]self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):is_last = ind >= (num_resolutions - 1)self.ups.append(nn.CellList([block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),block_klass(dim_in, dim_in, time_emb_dim=time_dim),Residual(PreNorm(dim_in, LinearAttention(dim_in))),Upsample(dim_in) if not is_last else nn.Identity(),]))out_dim = default(out_dim, channels)self.final_conv = nn.SequentialCell(block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1))def construct(self, x, time):x = self.init_conv(x)t = self.time_mlp(time) if exists(self.time_mlp) else Noneh = []for block1, block2, attn, downsample in self.downs:x = block1(x, t)x = block2(x, t)x = attn(x)h.append(x)x = downsample(x)x = self.mid_block1(x, t)x = self.mid_attn(x)x = self.mid_block2(x, t)len_h = len(h) - 1for block1, block2, attn, upsample in self.ups:x = ops.concat((x, h[len_h]), 1)len_h -= 1x = block1(x, t)x = block2(x, t)x = attn(x)x = upsample(x)return self.final_conv(x)
正向擴散
我們已經知道正向擴散過程在多個時間步長𝑇中,從實際分布逐漸向圖像添加噪聲,根據差異計劃進行正向擴散。最初的DDPM作者采用了線性時間表:
-
我們將正向過程方差設置為常數,從𝛽1=10?4線性增加到𝛽𝑇=0.02。
-
但是,它在(Nichol et al., 2021)中表明,當使用余弦調度時,可以獲得更好的結果。
下面,我們定義了𝑇時間步的時間表。
def linear_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)
首先,讓我們使用?𝑇=200時間步長的線性計劃,并定義我們需要的?β𝑡中的各種變量,例如方差?𝛼ˉ𝑡的累積乘積。下面的每個變量都只是一維張量,存儲從?𝑡到?𝑇的值。重要的是,我們還定義了extract
函數,它將允許我們提取一批適當的?𝑡索引。
# 擴散200步
timesteps = 200# 定義 beta schedule
betas = linear_beta_schedule(timesteps=timesteps)# 定義 alphas
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))# 計算 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)def extract(a, t, x_shape):b = t.shape[0]out = Tensor(a).gather(t, -1)return out.reshape(b, *((1,) * (len(x_shape) - 1)))
用貓圖像說明如何在擴散過程的每個時間步驟中添加噪音。
# 下載貓貓圖像
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
path = download(url, './', kind="zip", replace=True)from PIL import Imageimage = Image.open('./image_cat/jpg/000000039769.jpg')
base_width = 160
image = image.resize((base_width, int(float(image.size[1]) * float(base_width / float(image.size[0])))))
image.show()
噪聲被添加到mindspore張量中,而不是Pillow圖像。我們將首先定義圖像轉換,允許我們從PIL圖像轉換到mindspore張量(我們可以在其上添加噪聲),反之亦然。
這些轉換相當簡單:我們首先通過除以255255來標準化圖像(使它們在?[0,1][0,1]?范圍內),然后確保它們在?[?1,1][?1,1]?范圍內。DPPM論文中有介紹到:
假設圖像數據由?{0,1,...,255}中的整數組成,線性縮放為?[?1,1]?, 這確保了神經網絡反向過程在從標準正常先驗?𝑝(𝐱𝑇)開始的一致縮放輸入上運行。
from mindspore.dataset import ImageFolderDatasetimage_size = 128
transforms = [Resize(image_size, Inter.BILINEAR),CenterCrop(image_size),ToTensor(),lambda t: (t * 2) - 1
]path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),extensions=['.jpg', '.jpeg', '.png', '.tiff'],num_shards=1, shard_id=0, shuffle=False, decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)
反向變換,它接收一個包含?[?1,1][?1,1]?中的張量,并將它們轉回 PIL 圖像:
import numpy as npreverse_transform = [lambda t: (t + 1) / 2,lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWClambda t: t * 255.,lambda t: t.asnumpy().astype(np.uint8),ToPIL()
]def compose(transform, x):for d in transform:x = d(x)return xreverse_image = compose(reverse_transform, x_start[0])
reverse_image.show()
定義前向擴散過程
def q_sample(x_start, t, noise=None):if noise is None:noise = randn_like(x_start)return (extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)def get_noisy_image(x_start, t):# 添加噪音x_noisy = q_sample(x_start, t=t)# 轉換為 PIL 圖像noisy_image = compose(reverse_transform, x_noisy[0])return noisy_image# 設置 time step
t = Tensor([40])
noisy_image = get_noisy_image(x_start, t)
print(noisy_image)
noisy_image.show()
不同的時間步驟可視化此情況:
import matplotlib.pyplot as pltdef plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):if not isinstance(imgs[0], list):imgs = [imgs]num_rows = len(imgs)num_cols = len(imgs[0]) + with_orig_, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)for row_idx, row in enumerate(imgs):row = [image] + row if with_orig else rowfor col_idx, img in enumerate(row):ax = axs[row_idx, col_idx]ax.imshow(np.asarray(img), **imshow_kwargs)ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])if with_orig:axs[0, 0].set(title='Original image')axs[0, 0].title.set_size(8)if row_title is not None:for row_idx in range(num_rows):axs[row_idx, 0].set(ylabel=row_title[row_idx])plt.tight_layout()plot([get_noisy_image(x_start, Tensor([t])) for t in [0, 50, 100, 150, 199]])
定義給定模型的損失函數
def p_losses(unet_model, x_start, t, noise=None):if noise is None:noise = randn_like(x_start)x_noisy = q_sample(x_start=x_start, t=t, noise=noise)predicted_noise = unet_model(x_noisy, t)loss = nn.SmoothL1Loss()(noise, predicted_noise)# todoloss = loss.reshape(loss.shape[0], -1)loss = loss * extract(p2_loss_weight, t, loss.shape)return loss.mean()
denoise_model
將是我們上面定義的U-Net。我們將在真實噪聲和預測噪聲之間使用Huber損失。
數據準備與處理
在這里我們定義一個正則數據集。數據集可以來自簡單的真實數據集的圖像組成,如Fashion-MNIST、CIFAR-10或ImageNet,其中線性縮放為?[?1,1]。
每個圖像的大小都會調整為相同的大小。有趣的是,圖像也是隨機水平翻轉的。根據論文內容:我們在CIFAR10的訓練中使用了隨機水平翻轉;我們嘗試了有翻轉和沒有翻轉的訓練,并發現翻轉可以稍微提高樣本質量。
本實驗我們選用Fashion_MNIST數據集,我們使用download下載并解壓Fashion_MNIST數據集到指定路徑。此數據集由已經具有相同分辨率的圖像組成,即28x28。
# 下載MNIST數據集
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
path = download(url, './', kind="zip", replace=True)from mindspore.dataset import FashionMnistDatasetimage_size = 28
channels = 1
batch_size = 16fashion_mnist_dataset_dir = "./dataset"
dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, usage="train", num_parallel_workers=cpu_count(), shuffle=True, num_shards=1, shard_id=0)transforms = [RandomHorizontalFlip(),ToTensor(),lambda t: (t * 2) - 1
]dataset = dataset.project('image')
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, 'image')
dataset = dataset.batch(16, drop_remainder=True)x = next(dataset.create_dict_iterator())
print(x.keys())
從擴散模型生成新圖像是通過反轉擴散過程來實現的:我們從𝑇開始,我們從高斯分布中采樣純噪聲,然后使用我們的神經網絡逐漸去噪(使用它所學習的條件概率),直到我們最終在時間步𝑡=0結束。如上圖所示,我們可以通過使用我們的噪聲預測器插入平均值的重新參數化,導出一個降噪程度較低的圖像?𝐱𝑡?1。請注意,方差是提前知道的。
理想情況下,我們最終會得到一個看起來像是來自真實數據分布的圖像。
下面的代碼實現了這一點。
def p_sample(model, x, t, t_index):betas_t = extract(betas, t, x.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)if t_index == 0:return model_meanposterior_variance_t = extract(posterior_variance, t, x.shape)noise = randn_like(x)return model_mean + ops.sqrt(posterior_variance_t) * noisedef p_sample_loop(model, shape):b = shape[0]# 從純噪聲開始img = randn(shape, dtype=None)imgs = []for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)imgs.append(img.asnumpy())return imgsdef sample(model, image_size, batch_size=16, channels=3):return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
請注意,上面的代碼是原始實現的簡化版本。
訓練過程
下面,我們開始訓練吧!
# 定義動態學習率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)# 定義 Unet模型
unet_model = Unet(dim=image_size,channels=channels,dim_mults=(1, 2, 4,)
)name_list = []
for (name, par) in list(unet_model.parameters_and_names()):name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):item.name = name_list[i]i += 1# 定義優化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)# 定義前向過程
def forward_fn(data, t, noise=None):loss = p_losses(unet_model, data, t, noise)return loss# 計算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)# 梯度更新
def train_step(data, t, noise):loss, grads = grad_fn(data, t, noise)optimizer(grads)return lossimport time# 由于時間原因,epochs設置為1,可根據需求進行調整
epochs = 1for epoch in range(epochs):begin_time = time.time()for step, batch in enumerate(dataset.create_tuple_iterator()):unet_model.set_train()batch_size = batch[0].shape[0]t = randint(0, timesteps, (batch_size,), dtype=ms.int32)noise = randn_like(batch[0])loss = train_step(batch[0], t, noise)if step % 500 == 0:print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)end_time = time.time()times = end_time - begin_timeprint("training time:", times, "s")# 展示隨機采樣效果unet_model.set_train(False)samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")
推理過程(從模型中采樣)
要從模型中采樣,我們可以只使用上面定義的采樣函數:
采樣64個圖片
# 采樣64個圖片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
# 展示一個隨機效果
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")
import matplotlib.animation as animationrandom_index = 53fig = plt.figure()
ims = []
for i in range(timesteps):im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)ims.append([im])animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()
總結
請注意,DDPM論文表明擴散模型是(非)條件圖像有希望生成的方向。自那以后,diffusion得到了(極大的)改進,最明顯的是文本條件圖像生成。下面,我們列出了一些重要的(但遠非詳盡無遺的)后續工作:
-
改進的去噪擴散概率模型(Nichol et al., 2021):發現學習條件分布的方差(除平均值外)有助于提高性能
-
用于高保真圖像生成的級聯擴散模型([Ho et al., 2021):引入級聯擴散,它包括多個擴散模型的流水線,這些模型生成分辨率提高的圖像,用于高保真圖像合成
-
擴散模型在圖像合成上擊敗了GANs(Dhariwal et al., 2021):表明擴散模型通過改進U-Net體系結構以及引入分類器指導,可以獲得優于當前最先進的生成模型的圖像樣本質量
-
無分類器擴散指南([Ho et al., 2021):表明通過使用單個神經網絡聯合訓練條件和無條件擴散模型,不需要分類器來指導擴散模型
-
具有CLIP Latents (DALL-E 2) 的分層文本條件圖像生成 (Ramesh et al., 2022):在將文本標題轉換為CLIP圖像嵌入之前使用,然后擴散模型將其解碼為圖像
-
具有深度語言理解的真實文本到圖像擴散模型(ImageGen)(Saharia et al., 2022):表明將大型預訓練語言模型(例如T5)與級聯擴散結合起來,對于文本到圖像的合成很有效
請注意,此列表僅包括在撰寫本文,即2022年6月7日之前的重要作品。
目前,擴散模型的主要(也許唯一)缺點是它們需要多次正向傳遞來生成圖像(對于像GAN這樣的生成模型來說,情況并非如此)。然而,有正在進行中的研究表明只需要10個去噪步驟就能實現高保真生成。
參考
-
The Annotated Diffusion Model
-
由淺入深了解Diffusion Model