DiT: Transformer上的擴散模型

論文(ICCV 2023):Scalable Diffusion Models with Transformers

代碼和工程網頁:https://www.wpeebles.com/DiT.html

DiTs(Diffusion Transformers)是首個基于Transformer架構的擴散模型!它在圖像的潛在空間上進行訓練,用transformer替換了常用的U-Net主干網絡;DiTs具備?可擴展性?,當增加Transformer的深度/寬度,或輸入token的數量時,計算復雜度變高(Gflops↑),生成質量往往更好(FID↓);在經典的ImageNet 256×256 和 512×512 的類別條件生成任務中,作者推出的DiT-XL/2模型在質量上打敗了當時其它所有的擴散模型(FID分別為3.04和2.27)。

當年Transformer架構的推出改變了NLP等領域的發展趨勢,在自回歸模型中被廣泛應用,但在其它生成建模框架中卻應用不多——比如擴散模型;擴散模型在圖像生成領域一直位于前列,當時普遍采用的是U-Net這類卷積網絡架構。

U-Net

2020年DDPM的這篇工作首次將U-Net作為擴散模型的核心架構,該網絡有以下幾點特性

1. 多特征提取:通過下采樣和上采樣結構,能捕捉圖像的全局和局部信息;

2. 跳躍連接:幫助保留低級特征(如邊緣、紋理),提升生成質量;

3. 計算高效:相比于純Transformer,卷積U-Net在圖像生成任務中更輕量。

其設計源于PixelCNN++和Conditional GANs,早期在生成模型中的應用包括逐像素生成圖像、用于圖像到圖像的轉換任務。在DDPM中用到的是其結合了ResNet和注意力機制的變體

1. 替換原始U-Net的普通卷積塊,改用殘差連接,緩解深層網絡梯度消失的問題;

2. 在低分辨率層(多次下采樣后)插入空間自注意力模塊,以捕捉長程依賴關系,算是受Transformer啟發在卷積架構中的局部使用;

3. 其他有用Group Normalization代替BatchNorm,對小批量更魯棒;采用自適應歸一化注入條件信息(如時間步、類別標簽)(代碼示例如下)

class adaLN_Zero_Block(nn.Module):def __init__(self, dim):super().__init__()self.norm = nn.LayerNorm(dim, elementwise_affine=False) # create a LayerNorm layer without learnable parameters, just for normalization (minus mean, divide by variance)self.adaLN = nn.Linear(dim, 6 * dim) # input condition c of shape [B, dim] -> 6 * dim parameters for AdaLNnn.init.zeros_(self.adaLN.weight) # zero initialization for the linear layer, so that it behaves like a standard LayerNorm at the beginningdef forward(self, x, c):gamma1, beta1, gamma2, beta2, alpha1, alpha2 = self.adaLN(c).chunk(6, dim=-1)x = alpha1 * x + gamma1 * self.norm(x) + beta1 # normalization + scaling and shifting (like residual connection)x = alpha2 * x + gamma2 * self.norm(x) + beta2return x

因為DDPM的成功,故后續很多工作也一直沿用這種設計,而DiT的出現正反映了生成模型從卷積主導到注意力主導的趨勢。

動機

U-Net有一些歸納偏置(inductive bias),包括

1. 局部性:通過卷積核的局部感受野處理圖像,假設相鄰像素間強相關;

2. 平移等變性:卷積操作對于圖像平移具有不變性;

3. 層次化多尺度建模:通過上/下采樣結構可以捕捉不同尺度的特征。

傳統觀點認為這些是擴散模型高性能的關鍵,但作者認為這些偏置于擴散模型而言并不是必要的,完全可以用Tranformer替代。作者通過DiT證明,擴散模型的性能取決于計算資源和通用架構能力,而非特定的結構偏好。近年來深度學習不同任務和領域逐漸收斂到少數幾種通用架構,DiT的提出為擴散模型提供了更靈活的架構選擇,使其能受益于Transformer的擴展性和跨領域(NLP、CV……)技術積累。

ViT

20年提出的ViT(Vision Transformer)首次將純Transformer架構成功應用于圖像分類任務,核心是將圖像視為序列,用NLP中Transformer的方式處理圖像,打破傳統CNN在CV領域的主導地位。關鍵步驟為

1. 圖像分塊(patch embedding):將圖像分割為固定大小的path,每個patch展平為一個向量

2. 添加位置編碼(position embedding):由于Transformer本身不感知空間位置,需為每個patch添加可學習的位置編碼

3. 分類標記(class token):在序列開頭插入一個可學習的 [CLS] token,然后使用標準的Transformer編碼器(多頭注意力+MLP)處理patch序列,最后用 [CLS] token的輸出做分類

ViT無卷積操作,在中小規模數據集上表現可能不如CNN,但在超大數據集上顯著優于ResNet,且可擴展性強。

先驗知識

擴散方程

在高斯擴散模型中設定對真實數據x_0的前向加噪過程為q(x_t|x_0)=\mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I),其中\bar{\alpha}_t是超參數,應用重參數化技巧,可通過下式直接對x_t進行采樣

x_t = \sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t\epsilon_t \sim \mathcal{N}(0, I)

擴散模型會學著如何去噪,p_{\theta}(x_{t-1}|x_t)=\mathcal{N}(\mu_{\theta}(x_t), \Sigma_{\theta}(x_t))的各式統計量p_{\theta}正是神經網絡要預測的東西,訓練損失是x_0的對數似然的變分下界,即

\mathcal{L}(\theta)=-p(x_0|x_1) + \Sigma_t \mathcal{D}_{KL}(q^*(x_{t-1}|x_t, x_0) || p_{\theta}(x_{t-1}|x_t)),再除以一個與訓練無關的附加項。

因為q^*p_{\theta}均為高斯分布,故可用它們的均值和協方差來評估DKL。將\mu_{\theta}重參數化為噪聲預測網絡\epsilon_t時,訓練損失可簡化為預測和實際噪聲之間的均方誤差,即

\mathcal{L}_{simple}(\theta)=\left \| \epsilon_{\theta}(x_t) - \epsilon_t \right \|_2^2

為了同時優化協方差矩陣\Sigma_{\theta},會用\mathcal{L}_{simple}訓練\epsilon_{\theta}的同時用完整損失\mathcal{L}訓練\Sigma_{\theta}

def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):"""Compute training losses for a single timestep.:param model: the model to evaluate loss on.:param x_start: the [N x C x ...] tensor of inputs.:param t: a batch of timestep indices.:param model_kwargs: if not None, a dict of extra keyword arguments topass to the model. This can be used for conditioning.:param noise: if specified, the specific Gaussian noise to try to remove.:return: a dict with the key "loss" containing a tensor of shape [N].Some mean or variance settings may also have other keys."""# 前向計算if model_kwargs is None:model_kwargs = {}if noise is None:noise = th.randn_like(x_start)x_t = self.q_sample(x_start, t, noise=noise)terms = {}# 直接計算變分下界損失,適用于理論嚴格的訓練if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:# compute the variational lower-bound termterms["loss"] = self._vb_terms_bpd(model=model,x_start=x_start,x_t=x_t,t=t,clip_denoised=False,model_kwargs=model_kwargs,)["output"]if self.loss_type == LossType.RESCALED_KL:terms["loss"] *= self.num_timesteps# 簡化為均方誤差損失elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:model_output = model(x_t, t, **model_kwargs)# 方差可學習時,也要為其計算變分下界if self.model_var_type in [ModelVarType.LEARNED,ModelVarType.LEARNED_RANGE,]:B, C = x_t.shape[:2]assert model_output.shape == (B, C * 2, *x_t.shape[2:])# 將模型輸出拆分為均值和方差model_output, model_var_values = th.split(model_output, C, dim=1)# 凍結均值部分的梯度,確保KL損失僅優化方差frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)terms["vb"] = self._vb_terms_bpd(model=lambda *args, r=frozen_out: r,x_start=x_start,x_t=x_t,t=t,clip_denoised=False,)["output"]if self.loss_type == LossType.RESCALED_MSE:# Divide by 1000 for equivalence with initial implementation.# Without a factor of 1/1000, the VB term hurts the MSE term.(損失加權)terms["vb"] *= self.num_timesteps / 1000.0# 目標選擇 + 算MSEtarget = {ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],ModelMeanType.START_X: x_start,ModelMeanType.EPSILON: noise,}[self.model_mean_type]assert model_output.shape == target.shape == x_start.shapeterms["mse"] = mean_flat((target - model_output) ** 2)if "vb" in terms:terms["loss"] = terms["mse"] + terms["vb"]else:terms["loss"] = terms["mse"]else:raise NotImplementedError(self.loss_type)return terms

p_{\theta}訓好后,初始化x_{t_{max}} \sim \mathcal{N}(0,I),然后通過重參數化技巧采樣x_{t-1} \sim p_{\theta}(x_{t-1}|x_t)

免分類器指導CFG

條件擴散模型會拿額外信息作為輸入,比如類標簽c,如何讓生成的樣本既符合條件c又高質量呢?如分類器引導這類傳統方法會額外訓練一個分類器,而這里提出一種免分類器的方案,同樣可以幫助采樣過程找到最大化log p(c|x)x

根據貝葉斯公式,有logp(c|x)\propto log p(x|c) - log p(x),故取梯度有

\bigtriangledown_x log p(c|x) \propto \bigtriangledown_x log p(x|c) - \bigtriangledown_x log p(x)

將模型輸出視作分數函數,\bigtriangledown_x log p(x|c) \rightarrow \epsilon_{\theta}(x_t, c)\bigtriangledown_x log p(x) \rightarrow \epsilon_{\theta}(x_t, \phi),另外\bigtriangledown_x log p(c|x)不用分類器計算,而是視作條件修正項,通過條件預測\epsilon_{\theta} (x_t,c)和無條件預測\epsilon_{\theta} (x_t,\phi)的差進行估計

\bigtriangledown_x log p(c|x) \propto \epsilon_{\theta}(x_t,c) - \epsilon_{\theta}(x_t, \phi)

于是可通過下列式子引導采樣:

\hat{\epsilon}_{\theta}(x_t, c) \\ = \epsilon_{\theta}(x_t, \phi)+s \cdot \bigtriangledown _x logp(x|c) \propto \epsilon_{\theta}(x_t, \phi)+s \cdot (\epsilon_{\theta} (x_t,c) - \epsilon_{\theta}(x_t, \phi))

即對條件/無條件預測結果進行插值,其中s=1時\hat{\epsilon}_{\theta}(x_t, c) = \epsilon_{\theta} (x_t,c),退化為普通條件采樣,s>1時增強條件控制,生成的樣本更符合c,但多樣性可能降低。在采樣過程中,會用引導后的\hat{\epsilon}_{\theta}(x_t, c)代替原始模型輸出\epsilon_{\theta}(x_t, c)

在訓練時,會隨機將條件c替換為“null”,讓模型同時學習條件生成和無條件生成

def forward_with_cfg(self, x, t, y, cfg_scale):"""實現帶免分類器指導(CFG)的模型前向傳播,同時處理條件/無條件預測y為條件標簽,形狀[2B],前B個為真實條件,后B個為空條件參考來自OpenAI的GLIDE項目"""half = x[: len(x) // 2]combined = torch.cat([half, half], dim=0) # 保證條件/無條件預測的噪聲完全一致model_out = self.forward(combined, t, y)# 將模型輸出分為兩部分,默認只對前3個通道(RGB)應用CFG# 當模型輸出通道數>3,或顯式配置了方差學習,或架構設計包含額外預測任務,rest都不為空eps, rest = model_out[:, :3], model_out[:, 3:]# 將噪聲預測eps分為條件/無條件兩部分,形狀均為[B, 3, H, W]cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)# 指導后輸出half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)eps = torch.cat([half_eps, half_eps], dim=0) # 保持形狀return torch.cat([eps, rest], dim=1)
...
# 在sample.py中的調用邏輯class_labels = [207, 360, 387, 974, 88, 979, 417, 279] # 條件化的類別索引列表
n = len(class_labels) # 要同時生成的圖像類別數n=8
z = torch.randn(n, 4, latent_size, latent_size, device=device) # 在淺空間標準正態分布采樣所得初始噪聲
y = torch.tensor(class_labels, device=device) # 打包成張量的條件信息# 設置免分類器引導
z = torch.cat([z, z], 0) # 用于條件/無條件生成的兩份數據
y_null = torch.tensor([1000] * n, device=device) # 將1000當作“null”
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) # 超參cfg_scale即s# 采樣
samples = diffusion.p_sample_loop(model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)
samples = vae.decode(samples / 0.18215).sample # 0.18215是常見的尺度因子,潛在向量進入模型前乘以該值以匹配統計特性,故解碼時要移除save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))

作者將現成的VAEs混入基于Transformer的DDPMs

from diffusers.models import AutoencoderKL # 從Hugging Face上下載VAE
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) # 拉取參數,兩類優化策略可選:均方誤差mse,指數移動平均ema
...
samples = vae.decode(samples / 0.18215).sample
...
# HF的庫中diffusers/models/autoencoders/vae.py,定義編碼器代碼如下
class Encoder(nn.Module):r"""The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.Args:in_channels (`int`, *optional*, defaults to 3):The number of input channels.out_channels (`int`, *optional*, defaults to 3):The number of output channels.down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
卷積類(無注意力):DownBlock2D ResnetDownsampleBlock2D SkipDownBlock2D DownEncoderBlock2D KDownBlock2D
帶自注意力的(self-attention):AttnDownBlock2D AttnDownEncoderBlock2D AttnSkipDownBlock2D
帶跨注意力(cross-attention):CrossAttnDownBlock2D SimpleCrossAttnDownBlock2D KCrossAttnDownBlock2Doptions.block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):The number of output channels for each block.layers_per_block (`int`, *optional*, defaults to 2):The number of layers per block.norm_num_groups (`int`, *optional*, defaults to 32):The number of groups for normalization.act_fn (`str`, *optional*, defaults to `"silu"`):The activation function to use. See `~diffusers.models.activations.get_activation` for available options.double_z (`bool`, *optional*, defaults to `True`):Whether to double the number of output channels for the last block."""def __init__(self,in_channels: int = 3,out_channels: int = 3,down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),block_out_channels: Tuple[int, ...] = (64,),layers_per_block: int = 2,norm_num_groups: int = 32,act_fn: str = "silu",double_z: bool = True,mid_block_add_attention=True,):super().__init__()self.layers_per_block = layers_per_blockself.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1,) # 將輸入通道數映射為第一個下采樣塊的通道數self.down_blocks = nn.ModuleList([])# down(逐步降低空間分辨率、增加通道數,提取高層語義特征)output_channel = block_out_channels[0]for i, down_block_type in enumerate(down_block_types):input_channel = output_channeloutput_channel = block_out_channels[i]is_final_block = i == len(block_out_channels) - 1down_block = get_down_block(down_block_type,num_layers=self.layers_per_block,in_channels=input_channel,out_channels=output_channel,add_downsample=not is_final_block, # 最后一層down_block不做下采樣resnet_eps=1e-6,downsample_padding=0,resnet_act_fn=act_fn,resnet_groups=norm_num_groups,attention_head_dim=output_channel,temb_channels=None,)self.down_blocks.append(down_block)# mid(保持空間分辨率不變,但可以增加感受野和特征交互)self.mid_block = UNetMidBlock2D(in_channels=block_out_channels[-1],resnet_eps=1e-6,resnet_act_fn=act_fn,output_scale_factor=1,resnet_time_scale_shift="default",attention_head_dim=block_out_channels[-1],resnet_groups=norm_num_groups,temb_channels=None,add_attention=mid_block_add_attention, # 加一個注意力層?)# outself.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)self.conv_act = nn.SiLU()conv_out_channels = 2 * out_channels if double_z else out_channels # 輸出通道數翻倍,分別表示均值和方差self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)self.gradient_checkpointing = Falsedef forward(self, sample: torch.Tensor) -> torch.Tensor:r"""The forward method of the `Encoder` class."""sample = self.conv_in(sample)if torch.is_grad_enabled() and self.gradient_checkpointing: # 若啟用梯度計算,且開啟了梯度檢查點for down_block in self.down_blocks:sample = self._gradient_checkpointing_func(down_block, sample)sample = self._gradient_checkpointing_func(self.mid_block, sample)else:for down_block in self.down_blocks:sample = down_block(sample)sample = self.mid_block(sample)# post-processsample = self.conv_norm_out(sample)sample = self.conv_act(sample)sample = self.conv_out(sample)return sample

DiT的設計

整體架構如下圖

處理256×256×3的圖片,映射到潛空間,變成32×32×4的z,作為DiT的輸入。

DiT的第一層為“patchify”,對z進行分塊,線性嵌入每塊,將空間表征變成含T個維度為d的token的序列,然后將標準的正余弦版本的位置嵌入應用于所有token

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):"""輸入一個 grid_size×grid_size 的網格輸出每個網格點的位置編碼向量,形為 (grid_size×grid_size, embed_dim)從后續兩個函數的“assert”語句來看,embed_dim需為4的倍數前/后一半列為x/y信息,再1/4列為正/余弦,即[x_sin_block, x_cos_block, y_sin_block, y_cos_block]    若存在cls_token,在編碼前補零MAE(帶掩碼的自編碼器)官方實現"""grid_h = np.arange(grid_size, dtype=np.float32)grid_w = np.arange(grid_size, dtype=np.float32)grid = np.meshgrid(grid_w, grid_h)  # here w goes 列表grid含兩個形為(grid_size, grid_size)的數組[W, H],分別記錄每個網格點的w/h坐標grid = np.stack(grid, axis=0) # 堆疊成3維張量,形為(2, grid_size, grid_size)grid = grid.reshape([2, 1, grid_size, grid_size])pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)if cls_token and extra_tokens > 0:pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)return pos_embeddef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):assert embed_dim % 2 == 0  # 確保embed_dim是偶數# 將embed_dim分成兩半,分別用于編碼兩個維度的位置emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])emb = np.concatenate([emb_h, emb_w], axis=1)return embdef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):assert embed_dim % 2 == 0# 將embed_dim分成兩半,分別用于正/余弦編碼omega = np.arange(embed_dim // 2, dtype=np.float64)omega /= embed_dim / 2.omega = 1. / 10000**omega  # 頻率衰減因子 (D/2,)pos = pos.reshape(-1)  # 展平坐標 (M=1×grid_size×grid_size,)(-1為占位符)out = np.einsum('m,d->md', pos, omega)  # 外積計算位置*頻率 (M, D/2)emb_sin = np.sin(out)  # 正弦部分 (M, D/2)emb_cos = np.cos(out)  # 余弦部分 (M, D/2)emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)return emb

看到這塊位置編碼時,我想到旋轉位置編碼RoPE,都是用正余弦函數來編碼位置信息,頻率都是按指數遞減(\theta_i = 10000^{-2i/d}, i=0, 1,...\frac{d}{2}-1),但這里是加法

P_{\theta, (i,j))}^d(X) = \begin{pmatrix} x_0\\ ...\\ x_{d/4}\\ ...\\ x_{d/2}\\ ...\\ x_{3d/4}\\...\\ x_{d-1} \end{pmatrix} + \begin{pmatrix} sinj\theta_0\\ ...\\cosj\theta_0 \\ ...\\ sini\theta_0\\ ...\\ cosi\theta_0 \\...\\cosi\theta_{d/4}\end{pmatrix}

其中(i,j)是該patch在網格中的絕對位置,在RoPE中則是通過“旋轉”將相對位置m編碼入嵌入向量

R_{\theta,m}^d(X) =\begin{pmatrix} x_0\\ ...\\ x_{d/4}\\ ...\\ x_{d/2}\\ ...\\ x_{3d/4}\\...\\ x_{d-1} \end{pmatrix} \bigotimes \begin{pmatrix} cosm\theta_0\\ ...\\cosm\theta_{d/4} \\ ...\\ cosm\theta_0 \\ ...\\ cosm\theta_{d/4} \\...\\cosm\theta_{d/2-1}\end{pmatrix}+ \begin{pmatrix} -x_{d/2}\\ ...\\ -x_{3d/4}\\ ...\\ x_0\\ ...\\ x_{d/4}\\...\\ x_{d/2-1} \end{pmatrix} \bigotimes \begin{pmatrix} sinm\theta_0\\ ...\\sinm\theta_{d/4} \\ ...\\ sinm\theta_0 \\ ...\\ sinm\theta_{d/4} \\...\\sinm\theta_{d/2-1}\end{pmatrix}

將query和key向量的每一對維度當作平面坐標,按角度旋轉一定角度(位置和頻率決定),將相對位置信息融入自注意力的內積計算中:(x',y') = (xcos\omega-ysin\omega, xsin\omega+ycos\omega)


T的大小由超參分塊大小p決定,p減半,T將呈平方增長,讓Transformer Gflops至少變成原來的平方,但對下游參數量并無顯著影響,因為模型的主要參數(Transformer層的權重)僅由隱藏維度d層數N決定,與T無關。作者添加了p = 2, 4, 8多種設計。

而后輸入的tokens會被系列transformer塊處理。除了噪聲圖片,擴散模型有時還會處理額外的條件信息,比如時間步t、類標簽c、自然語言等。作者嘗試了4種用來處理的transformer塊變體,對標準ViT塊設計引入微小但重要的修改:

1. 上下文條件輸入(In-context conditioning)。將條件信息作為額外的token拼接到輸入序列中,與圖像token一起處理。類似于[CLS] token,無需修改標準的ViT塊,在最后一塊中將條件token移除即可。引入的Gflops微乎其微。

2. 交叉注意力塊(Cross-attention block)。將條件信息單獨做一個長為2的序列,在多頭自注意力塊后添加額外的多頭交叉注意力層,類似于原始Transformer中的編碼器-解碼器交叉注意力機制,或者淺空間擴散模型(LDMs)對類標簽進行條件化的方法。引入的Gflops最高,大致占總開銷的15%。

3. 自適應層歸一化塊(adaLN block)。adaLN在GANs和以UNet為主干的擴散模型中有廣泛應用,作者將其替換掉標準Transformer塊中的Layer Norm,從原本的直接學習逐維縮放和平移的參數\gamma\beta,變成從條件信息(t和c的嵌入向量之和)回歸出來。在這3種塊中,引入的Gflops最少,計算效率最高,也是唯一能將條件信息影響所有token的方法。

4. 自適應歸一化層-零初始化塊(adaLN-Zero Block)。先前在ResNets上的工作發現,將殘差塊初始化為一個恒等函數是有益的,殘差塊的數學形式為y=x+F(x)F(x)為殘差函數,恒等映射即在初始狀態下使y\approx x,讓模型先學到全局結構,再逐漸利用殘差分支進行細化,這樣訓練會更穩定,收斂更快。例如在U-Net中將每個殘差塊最后一層卷積權重初始化為0,作者對adaLN的DiT塊做類似修改,讓其再回歸一個逐維縮放參數\alpha用在殘差連接前。系數\gamma\beta\alpha都是由條件調制網絡——一個小型MLP(非線性+線性層)回歸出來的,就將該MLP最后一層的weight和bias初始化為0,所有輸出自然為0,如下

# initialization for AdaLN modulation
self.adaLN_modulation = nn.Sequential(nn.SiLU(),nn.Linear(dim, dim * 6)
)
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
...
# 從條件生成調制參數(多頭自注意力/多層感知機的平移和縮放參數、殘差分支門控系數)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)

和原始的adaLN塊一樣,adaLN-Zero引入的Gflops微乎其微。

依次堆疊N個DiT塊,每個塊的隱藏維度為d。參考ViT,采用標準Transformer配置,同時在N、d和注意力頭數上進行聯合縮放,具體為4種配置:DiT-S、DiT-B、DiT-L和DiT-XL,涵蓋從0.3到118.6Gflops的廣泛模型規模和計算量分配,以評估不同規模下的性能表現,詳細配置如下

在最后的DiT塊中,需將圖片token序列解碼成對噪聲和對角(不同像素/通道/特征之間的噪聲獨立)協方差的預測,這兩部分輸出需與原始空間輸入保持形狀一致,為此使用一個標準的線性解碼器:先對序列施加最終的(自適應)層歸一化,再將每個token線性映射成大小為p×p×2C的張量,其中C為DiT空間輸入的通道數,最后將得到的token重排列回其原始的空間布局,得到噪聲和協方差。

整個DiT設計空間包括了patch大小、Transformer塊結構和模型規模。

實驗

作者設計了積累模型,用以探索DiT設計空間,研究擴展性,模型根據其配置和潛在分片大小p進行命名,例如DiT-XL/2,就是XL版模型規模,同時p=2。

訓練設置

在ImageNet數據集(蠻有挑戰性的基準)的256×256和512×512的圖片上,訓練類條件潛在DiT模型。對最終的線性層要么進行零初始化,要么像ViT那樣初始化。訓練所有模型時均采用AdamW,固定學習率1 \times 10^{-4},無權重衰減,批量大小256,唯一用的數據增強是水平翻轉。

與先前ViTs的工作不同,作者沒發現學習率預熱和正則化的必要性,甚至不用時,所有模型配置下的訓練都高度穩定,也沒看到訓練Transformer時的任何常見損失尖峰。

按生成式建模文獻中的常用做法,在訓練DiT的過程中采用指數滑動平均EMA,衰減系數為0.9999


EMA權重的更新公式為

\theta_{EMA} \leftarrow decay \cdot \theta_{EMA} + (1-decay) \cdot \theta_t

其中decay=0.9999,表示每次更新只用到當前權重變化的0.01%,主要依賴于歷史權重。用這個“平滑版本”的參數\theta_{EMA}作為最終評估或采樣的模型權重/


在所有模型規模和分塊大小下采用相同的訓練參數,和ADM(U-Net作為骨干網絡,又比原始DDPM加入注意力層,即Attention-based Diffusion Model)幾乎完全一致。

擴散設置

對于擴散模型,作者采用現成的預訓練好的VAE作為圖像的編解碼器,VAE編碼器下采樣因子為8,假設給定RGB圖像x的形狀為256×256×3,那么z=E(x)的形狀就是32×32×4,實驗中所有擴散模型都在這個潛在空間中進行操作。在從擴散模型中采樣出一個新的潛在向量后,用VAE解碼器將其解碼成像素x=D(z)

所有實驗同樣沿用了ADM論文中擴散過程的超參數設置:1. 設置最大擴散步數t_{max} = 1000,將噪聲方差\beta_t1e-4線性增加到2e-2;2. 反向過程的高斯分布p_{\theta}(x_{t-1}|x_t) = \mathcal{N}(\mu_{\theta}, \Sigma_{\theta})中的協方差矩陣\Sigma_{\theta}也假設為對角形式,由模型預測(VS. 固定);3. 使用正余弦編碼+MLP得到時間步的嵌入向量,與標簽嵌入結合,用于調制網絡的中間層。

評價指標

和先前工作對齊,比較250步DDPM采樣后的FID-50K。FID是衡量生成圖像質量的常用指標,依賴于Inception網絡提取的特征統計(均值&協方差),任何圖像預處理細節的差異(如圖像尺寸、歸一化方式、插值方法)都會讓數值產生顯著變化,故為了避免計算流程差異,作者直接將生成樣本保存,然后用AMD論文提供的TensorFlow版評估工具統一計算FID,即可保證和ADM的結果完全可比(一樣的代碼、相同的預處理)。默認不使用CFG,除非特別注明。補充報告IS(圖像清晰度和多樣性)、sFID(更關注圖像的空間細節)、Precision/Recall(衡量生成圖像和真實分布的覆蓋率recall和質量)作為參考指標。

計算

所有DiT模型均用JAX框架實現,在TPU v3 prod上訓練,JAX是Google提出的一個高性能數值計算庫,支持自動微分、XLA編譯、TPU加速,非常適合大規模模型訓練,TPU v3是Google第三代張量處理硬件單元,每個pod是由多個TPU芯片組成的集群,可大規模并行訓練。DiT-L-2作為計算量最大的模型,在256個TPU芯片上協同訓練,全局批量大小為256時,每秒約5.7次迭代(一個批次數據完成一次前向+反向+參數更新)

DiT塊設計

4種:in-context(119.4Gflops)、cross-attention(137.6Gflops)、adaLN(118.6Gflops)、adaLN-zero(118.6Gflops)。FID結果如下

adaLN-Zero塊的FID最低,計算最高效,在400K訓練迭代后,近in-context模型的一半,說明該調節機制對模型質量有重要影響,相較于原始adaLN的性能提升也說明初始化方式的重要性,后續就采用adaLN-Zero了。

模型規模和分塊大小

12個DiT模型,4種模型配置(S、B、L、XL),3種分塊大小(8、4、2),表現如下,圓圈面積表示擴散模型的flops,注意到DiT-L和DiT-XL在Gflops上相對較近。可以看出隨著計算量增大(增大模型尺寸、減少分塊大小),FID所反映出的性能還是穩步提升的

或者從下圖中可以更明顯地看出模型規模和分塊大小對性能的影響(迭代12次的結果),要想降低FID,可以讓transformer更深更寬,或分塊尺寸更小

下圖進一步說明Gflops與FID間的強相關,顯示400K步訓練后的結果,總Gflops相近的,FID值就差不多,比如DiT-S/2和DiT-B/4

這種強相關性在其他指標上也有體現(采用MSE重建損失下微調的VAE解碼器,ft-MSE VAE decoder,更傾向于準確還原像素,而非追求主觀感知效果)

但是模型更大,訓練會更高效。下圖繪制不同模型的FID隨總訓練計算量的變化,其中訓練計算量估摸著為Gflops × 批量大小?× 訓練步數 × 3(反向傳遞的計算量算前向的兩倍)。可以看到,小模型即使訓練更久,FID也不見得比大模型低

另外類似地,除分塊大小外其它條件均相同的情況下,模型在相同訓練Gflops下也會有不同表現,例如同樣是大致訓練10^{10}Gflops,XL/4就比XL/2表現更好。

將可擴展性效果可視化,還是400K步訓練后,使用相同的起始噪聲、采樣噪聲和類別標簽,由12個DiT模型分別采樣出一張圖片,如下(差距還是蠻明顯的)

效果SOTA

最高Gflops的模型DiT-XL/2,在訓練7M步后,個別采樣效果圖如下(ImageNet 512×512和256×256)

在256×256分辨率下,和當時類條件生成模型相比,打敗了SOTA(StyleGAN-XL),同時將此前LDM達到的最好的FID-50K從3.60降到了2.27,在所有測試的cfg設置下,都取得了比LDM-4和LDM-8更高的召回值(Recall)

其計算(118.6Gflops)相較于潛空間的U-Net模型,例如LDM-4(103.6Gflops)更為高效,比像素空間U-Net模型,例如ADM(1120Gflops)或ADM-U(742Fflops)則要高效得多

再思考一個問題,小模型能否憑借更多步采樣勝過大模型?從下圖可以看到,增大采樣步數并不能彌補模型Gflops的不足,都是訓練了400K步,采樣步數為[16, 32, 64, 128, 256, 1000],針對每個采樣步數的設置,繪制FID及采樣Gflops。DiT-L/2采樣1000步(80.7 Tflops),也比不過DiT-XL/2采樣128步(至少少了5倍,15.2 Tflops),fid-0K為25.9 VS. 23.7。

run_DiT.ipynb

也許有必要講一下代碼庫中這個用來演示的記事本文件,它用預訓練好的DiT模型進行采樣。也許你用Google Colab打開會好看一點,這個在線Jupyter Notebook環境網址為:https://colab.research.google.com/,因為文件中有些Colab的記事本參數化功能(#@param),例如下圖中顯示的交互面板

開頭簡要介紹了一下:DiTs是用ImageNet訓練出來的“class-conditional latent”擴散模型,將DDPM框架中的U-Nets替換為了transformer,也就是說DiTs將類別標簽作為條件輸入,由VAE編碼器將輸入映射到潛在空間后,再由Transformer在潛空間進行處理,輸出由VAE解碼器映射回圖像空間。DiT在ImageNet這個基準測試集上打敗了當時所有先進的擴散模型。

下面是工程頁、HuggingFace、論文、GitHub的網址。

1. Setup

建議用GPU來運行這個記事本,在Google Colab中的操作步驟就是依次點擊Runtime > Change runtime type > Hardware accelerator > GPU,這樣Colab就會分配一臺帶GPU的虛擬機,可能資源不是很充足或者要收費,可能比較慢,用自己的服務器可能快些,anyway,自己選擇

運行下列代碼單元,克隆DiT的GitHub倉庫(下載的文件存儲在Google Drive中),安裝并配置PyTorch和所需依賴,在當前會話(session)中只需執行一次,但虛擬機是臨時的,一旦斷開或閑置超時,所有安裝的東西都會丟失

!git clone https://github.com/facebookresearch/DiT.git
# 開頭的 ! 表示在 Notebook 的系統 Shell 里執行命令,Colab默認當前工作目錄/content
import DiT, os
os.chdir('DiT')
# 將當前工作目錄切換到剛克隆的倉庫中,相當于 Shell 中的 cd
os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
# 設置環境變量PYTHONPATH,告訴后續子進程去哪些路徑找包
!pip install diffusers timm --upgrade
'''
安裝/升級兩大依賴
diffusers:HF的擴散模型庫
timm:PyTorch圖像模型庫,提供很多視覺backbone/工具
'''
# DiT imports:
import torch
from torchvision.utils import save_image # 將一批Tensor圖保存為網格圖
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model # 下載/定位預訓練權重
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False) # 推理模型,關閉梯度
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":print("GPU not found. Using CPU instead.")

然后下載模型DiT-XL/2,你可以選分辨率為512×512還是256×256,也可以LDM VAE換掉

image_size = 256 #@param [256, 512]
# 下拉選擇框 256還是512
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
# 選哪個Stable Diffusion的VAE,EMA版穩定更常用,用MSE目標微調過的版本像素更保真
latent_size = int(image_size) // 8
# SD的VAE編碼器會將分辨率下采樣8倍
# Load model:
model = DiT_XL_2(input_size=latent_size).to(device)
# 網絡規模最大號,分塊大小2
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
# 加載模型對應的預訓練權重
model.eval() # important!
vae = AutoencoderKL.from_pretrained(vae_model).to(device)
# 從HF拉取預訓練VAE(首次會下載到緩存,之后直接讀取)

2. Sample from Pre-trained DiT Models

采樣時有多個可供用戶自定義的選項

文件中也提供了ImageNet中類別編號相應鏈接(the full list of ImageNet classes),點進去可以查到類別標簽“207”是“golden retriever”,即“金毛獵犬”

# Set user inputs:
seed = 0 #@param {type:"number"}
# 顯示數字框,由用戶選擇隨機種子
torch.manual_seed(seed)
# 使初始噪聲和每步采樣用到的隨機數可復現(相同用戶設置下,每次運行得到的圖片都一樣)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 2 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
samples_per_row = 4 #@param {type:"number"}# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps)) # 將步數作字符串傳入# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
# 通道數為4來自Stable Diffusion的 VAE latent 約定
y = torch.tensor(class_labels, device=device)# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)# Sample images(一次前向里同時跑“有條件”和“無條件”兩路):
samples = diffusion.p_sample_loop(model.forward_with_cfg, z.shape, z, clip_denoised=False,model_kwargs=model_kwargs, progress=True, device=device
)
'''
model.forward_with_cfg:模型的前向函數,內部將前/后半(有/無條件)分開前向,再按cfg公式混合
z.shape:告知采樣器目標形狀
clip_denoised=False:不裁剪預測值
model_kwargs=model_kwargs:附加傳給model.forward_with_cfg的字典(里面有y和cfg_scale)
progress=True:打開進度條
'''
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples(后半無條件結果用于cfg輔助)
samples = vae.decode(samples / 0.18215).sample# Save and display images:
save_image(samples, "sample.png", nrow=int(samples_per_row),normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)

采樣出的圖片

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/92913.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/92913.shtml
英文地址,請注明出處:http://en.pswp.cn/web/92913.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MySQL 索引:索引為什么使用 B+樹?(詳解B樹、B+樹)

文章目錄一、二叉查找樹(BST):不平衡二、平衡二叉樹(AVL):旋轉耗時三、紅黑樹:樹太高由一個例子總結索引的特點基于哈希表實現的哈希索引高效的查找方式:二分查找基于二分查找思想的二叉查找樹升級版的BST樹:AVL 樹四、…

ESP32入門開發·VScode空白項目搭建·點亮一顆LED燈

目錄 1. 環境搭建 2. 創建項目 3. 調試相關介紹 4. 代碼編寫 4.1 包含頭文件 4.2 引腳配置 4.3 設置輸出電平 4.4 延時函數 4.5 調試 1. 環境搭建 默認已經搭建好環境,如果未搭建好可參考: ESP32入門開發Windows平臺下開發環境的搭建…

ONLYOFFICE AI 智能體上線!與編輯器、新的 AI 提供商等進行智能交互

ONLYOFFICE AI 插件?迎來重要更新,帶來了新功能和更智能的交互體驗。隨著 AI 智能體(現為測試版)的上線、帶來更多 AI 提供商支持以及其他新功能,AI 插件已經成為功能強大的文檔智能助理。 關于 ONLYOFFICE ONLYOFFICE 文檔是多…

【C++進階學習】第十一彈——C++11(上)——右值引用和移動語義

前言: 前面我們已經將C的重點語法講的大差不差了,但是在C11版本之后,又出來了很多新的語法,其中有一些作用還是非常大的,今天我們就先來學習其中一個很重要的點——右值引用以及它所擴展的移動定義 目錄 一、左值引用和…

【IoTDB】363萬點/秒寫入!IoTDB憑何領跑工業時序數據庫賽道?

【作者主頁】Francek Chen 【專欄介紹】???大數據與數據庫應用??? 大數據是規模龐大、類型多樣且增長迅速的數據集合,需特殊技術處理分析以挖掘價值。數據庫作為數據管理的關鍵工具,具備高效存儲、精準查詢與安全維護能力。二者緊密結合&#xff0…

IEEE 2025 | 重磅開源!SLAM框架用“法向量+LRU緩存”,將三維重建效率飆升72%!

一、前言 當前研究領域在基于擴散模型的文本到圖像生成技術方面取得了顯著進展,尤其在視覺條件控制方面。然而,現有方法(如ControlNet)在組合多個視覺條件時存在明顯不足,主要表現為獨立控制分支在去噪過程中容易引入…

無人機遙控器教練模式技術要點

一、技術要點1.控制權仲裁機制:核心功能:清晰定義主控權歸屬邏輯(默認為學員,但教練隨時可接管)。切換方式:通常通過教練遙控器上的物理開關(瞬時或鎖定型)或軟件按鈕觸發。切換邏輯…

【跨服務器的數據自動化下載--安裝公鑰,免密下載】

跨服務器的數據自動化下載功能介紹:上代碼:發現好久沒寫csdn了,說多了都是淚~~ 以后會更新一些自動化工作的腳本or 小tricks,歡迎交流。分享一個最近在業務上寫的較為實用的自動化腳本,可以批量從遠端服務器下載指定數…

C++-->stl: list的使用

前言list的認識list是可以在固定時間(O(1))內在任意位置進行插入和刪除的序列式容器,并且該容器可以前后雙向迭代。 2. list的底層是雙向鏈表結構,雙向鏈表中每個元素存儲在互不相關的獨立節點中&#xff0…

本地WSL部署接入 whisper + ollama qwen3:14b 總結字幕

1. 實現功能 M4-1 接入 whisper ollama qwen3:14b 總結字幕 自動下載視頻元數據如果有字幕,只下載字幕使用 ollama 的 qwen3:14b 對字幕內容進行總結 2.運行效果 source /root/anaconda3/bin/activate ytdlp 🔍 正在提取視頻元數據… 📝 正在…

《Linux運維總結:Shell腳本高級特性之變量間接調用》

總結:整理不易,如果對你有幫助,可否點贊關注一下? 更多詳細內容請參考:Linux運維實戰總結 一、變量間接調用 在Shell腳本中,變量間接調用是一種高級特性,它允許你通過另一個變量的值來動態地訪問…

ABP VNext + Akka.NET:高并發處理與分布式計算

ABP VNext Akka.NET:高并發處理與分布式計算 🚀 用 Actor 模型把高并發寫入“分片→串行化”,把鎖與競態壓力轉回到代碼層面的可控順序處理;依托 Cluster.Sharding 橫向擴容,Persistence 宕機可恢復,Strea…

[激光原理與應用-250]:理論 - 幾何光學 - 透鏡成像的優缺點,以及如克服缺點

透鏡成像是光學系統中應用最廣泛的技術,其通過折射原理將物體信息轉換為圖像,但存在像差、環境敏感等固有缺陷。以下是透鏡成像的優缺點及針對性改進方案:一、透鏡成像的核心優點高效集光能力透鏡通過曲面設計將分散光線聚焦到一點&#xff0…

測試匠談 | AI語音合成之大模型性能優化實踐

「測試匠談」是優測云服務平臺傾心打造的內容專欄,匯集騰訊各大產品的頂尖技術大咖,為大家傾囊相授開發測試領域的知識技能與實踐,讓測試工作變得更加輕松高效。 本期嘉賓介紹 Soren,騰訊TEG技術事業群質量工程師,負責…

用天氣預測理解分類算法-從出門看天氣到邏輯回歸

一、生活中的決策難題:周末郊游的「天氣判斷」 周末計劃郊游時,你是不是總會打開天氣預報反復確認?看到 "25℃、微風、無雨" 就興奮收拾行李,看到 "35℃、暴雨" 就果斷取消計劃。這個判斷過程,其…

HTTPS服務

HTTPS服務 一、常見的端口 http ------ 80 明文 https ------ 443 數據加密 dns ------ 53 ssh ------ 22 telent ------ 23 HTTPS http ssl或者tls (安全模式) 二、原理: c(客戶端…

【Android筆記】Android 自定義 TextView 實現垂直漸變字體顏色(支持 XML 配置)

Android 自定義 TextView 實現垂直漸變字體顏色(支持 XML 配置) 在 Android UI 設計中,字體顏色的漸變效果能讓界面看起來更加精致與現代。常見的漸變有從左到右、從上到下等方向,但 Android 的 TextView 默認并不支持垂直漸變。…

CANopen Magic調試軟件使用

一、軟件安裝與硬件連接1.1 系統要求操作系統:Windows 7/10/11 (64位)硬件接口:支持Vector/PEAK/IXXAT等主流CAN卡推薦配置:4GB內存,2GHz以上CPU1.2 安裝步驟運行安裝包CANopen_Magic_Setup.exe選擇安裝組件(默認全選&…

前端css學習筆記3:偽類選擇器與偽元素選擇器

本文為個人學習總結,如有謬誤歡迎指正。前端知識眾多,后續將繼續記錄其他知識點! 目錄 前言 一、偽類選擇器 1.概念 2.動態選擇器(用戶交互) 3.結構偽類 :first-child:選擇所有兄弟元素的…

深入探索 PDF 數據提取:PyMuPDF 與 pdfplumber 的對比與實戰

在數據處理和分析領域,PDF 文件常常包含豐富的文本、表格和圖形信息。然而,從 PDF 中提取這些數據并非易事,尤其是當需要保留格式和顏色信息時。幸運的是,Python 社區提供了多個強大的庫來幫助我們完成這項任務,其中最…