【筆記】擴散模型(一一):Stable Diffusion XL 理論與實現

論文鏈接:SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis

官方實現:Stability-AI/generative-models

非官方實現:huggingface/diffusers

Stable Diffusion XL (SDXL) 是 Stablility AI 對 Stable Diffusion 進行改進的工作,主要通過一些工程化的手段提高了 SD 模型的生成能力。相比于 Stable Diffusion,SDXL 對模型架構、條件注入、訓練策略等都進行了優化,并且還引入了一個額外的 refiner,用于對生成圖像進行超分,得到高分辨率圖像。

Stable Diffusion XL

模型架構改進

SDXL 對模型的 VAE、UNet 和 text encoder 都進行了改進,下面依次介紹一下。

VAE

相比于 Stable Diffusion,SDXL 對 VAE 模型進行了重新訓練,訓練時使用了更大的 batchsize(256,Stable Diffusion 使用的則是 9),并且使用了指數移動平均,得到了性能更強的 VAE。

需要注意的是,SD 2.x 的 VAE 相對 SD 1.x 只對 decoder 進行了微調,而 encoder 不變。因此兩者的 latent space 是相同的,VAE 可以互相交換使用。但 SDXL 對 encoder 也進行了微調,因此 latent space 發生了變化,SDXL 不能用 SD 的 VAE,SD 也不能用 SDXL 的 VAE。另外,由于 SDXL 的 VAE 在 fp16 下會發生溢出,所以其必須在 fp32 類型下進行推理。

UNet

SDXL 使用了更大的 UNet 模塊,具體來說做了以下幾個變化:

  1. 為了提高效率,使用了更少的 3 個 stage,而不是 SD 使用的 4 個 stage;
  2. 將 transformer block 移動到更深的 stage,第一個 stage 沒有 transformer block;
  3. 使用更多的 transformer block。

詳情可以看下邊的表格,可以看到第二個 stage 和第三個 stage 分別使用了 2 個和 10 個 transformer block,最后 UNet 的整體參數量變成了大約 3 倍。

模型SDXLSD1.4/1.5SD 2.0/2.1
UNet 參數量2.6 B860 M865 M
Transformer block 數量[0, 2, 10][1, 1, 1, 1][1, 1, 1, 1]
通道倍增系數[1, 2, 4][1, 2, 4, 4][1, 2, 4, 4]

Text Encoder

SDXL 還使用了更強的 text encoder,其同時使用了 OpenCLIP ViT-bigG 和 OpenAI CLIP ViT-L,使用時同時用兩個 encoder 處理文本,并將倒數第二層特征拼接起來,得到一個 1280+768=2048 通道的文本特征作為最終使用的文本嵌入。

除此之外,SDXL 還使用 OpenCLIP ViT-bigG 的 pooled text embedding 映射到 time embedding 維度并與之相加,作為輔助的文本條件注入。

其與 SD 1.x 和 2.x 的比較如下表所示:

模型SDXLSD 1.4/1.5SD 2.0/2.1
文本編碼器OpenCLIP ViT-bigG & CLIP ViT-LCLIP ViT-LOpenCLIP ViT-H
特征通道數20487681024
Pooled text emb.OpenCLIP ViT-bigGN/AN/A

Refine Model

除了上述結構變化之外,SDXL 還級聯了一個 refine model 用來細化模型的生成結果。這個 refine model 相當于一個 img2img 模型,在模型中的位置如下所示:

SDXL 整體架構,refine model 級聯在基礎模型的后方

這個 refine model 的主要目的是進一步提高圖像的生成質量。其是單獨訓練的,專注于對高質量高分辨率數據的學習,并且只在比較低的 noise level 上(即前 200 個時間步)進行訓練。

在推理階段,首先從 base model 完成正常的生成過程,然后再加一些噪音用 refine model 進一步去噪。這樣可以使圖像的細節得到一定的提升。

Refine model 的結構與 base model 有所不同,主要體現在以下幾個方面:

  1. Refine model 使用了 4 個 stage,特征維度采用了 384(base model 為 320);
  2. Transformer block 在各個 stage 的數量為 [0, 4, 4, 0],最終參數量為 2.3 B,略小于 base model;
  3. 條件注入方面:text encoder 只使用了 OpenCLIP ViT-bigG;并且同樣也使用了尺寸和裁剪的條件注入(這個下文會講);除此之外還使用了 aesthetic score 作為條件。

條件注入的改進

SDXL 引入了額外的條件注入來改善訓練過程中的數據處理問題,主要包括圖像尺寸圖像裁剪問題。

圖像尺寸條件

Stable Diffusion 的訓練通常分為多個階段,先在 256 分辨率的數據上進行訓練,再在 512 分辨率的數據上進行訓練,每次訓練時需要過濾掉尺寸小于訓練尺寸的圖像。根據統計,如果直接丟棄所有分辨率不足 256 的圖像,會浪費大約 40% 的數據。如果不希望丟棄圖像,可以使用超分辨率模型先將圖像超分到所需的分辨率,但這樣會導致數據質量的降低,影響訓練效果。

為了解決這個問題,作者加入了一種額外的圖像尺寸條件注入。作者將原圖的寬高進行傅立葉編碼,然后將特征拼接起來加到 time embedding 上作為額外條件。

在訓練時直接使用原圖的寬高作為條件,推理的時候可以自定義寬高,生成不同質量的圖像,下面的圖是一個例子,可以看到當以較小的尺寸為條件時,生成的圖比較模糊,反之則清晰且細節更豐富:

SDXL 中的 size conditioning

圖像裁剪條件

在 SD 訓練時使用的是固定尺寸(例如 512x512),使用時需要對原圖進行處理。一般的處理流程是先 resize 到短邊為目標尺寸,然后沿著長邊進行裁剪。這種裁剪會導致圖像部分缺失的問題,例如生成的圖像部分會出現部分缺失,就是因為裁剪后的數據是不完整的。

為了解決這個問題,SDXL 在訓練時把裁剪位置的坐標也當作條件注入進來。具體做法是把左上角的像素坐標值也進行傅立葉編碼+拼接,再加到 time embedding 上,這樣模型就能得知使用的數據是在什么位置進行的裁剪。

在推理階段,只需要將這個條件設置為 (0, 0) 即可得到正常圖像。如果設置成其他的值則能得到裁剪圖像,例如下邊圖里的效果。(感覺還是很神奇的,竟然這種條件能 work,而且沒有和圖像尺寸的條件混淆)

SDXL 中的 crop conditioning

訓練策略的改進

多比例微調

在訓練階段使用的都是正方形圖像,但是現實圖像很多都是有一定的長寬比的圖像。因此在訓練后的微調階段,還使用了一種多比例微調的策略。

具體來說,這種方法預先將訓練集按照圖像長寬比不同分成多個 bucket,在微調時每次隨機選取一個 bucket,并從中采樣出一個 batch 的數據進行訓練。在原論文中給出了一個表格,從表中可以看到選取的長寬比從 0.25(對應 512×2048512\times2048512×2048 分辨率) 到 4(對應 2048×5122048\times5122048×512 分辨率)不等,并且總像素數基本都維持在 102421024^210242 左右。

在這個微調階段,除了使用尺寸和裁剪條件注入,還使用了 bucket size(也就是生成的目標大小) 作為一個條件,用相同的方式進行了注入。經過這樣的條件注入和微調,模型就能生成多種長寬比的圖像。

Noise Offset

在多比例微調的階段,SDXL 還使用了一種 noise offset 的方法,來解決 SD 只能生成中等亮度圖像、而無法生成純黑或者純白圖像的問題。這個問題出現的原因是在訓練和采樣階段之間存在一定的 bias,訓練時在最后一個時間步的時候實際上是沒有加噪的,所以會出現一些問題,解決方案也比較簡單,就是在訓練的時候給噪聲添加一個比較小的 offset 即可。

SDXL 代碼解析

這里依然以 diffusers 提供的訓練代碼為主進行分析,模型架構的改變主要體現在加載的預訓練模型中(之后應該會出一期怎么看 huggingface 里的那些文件以及 config.json 的教程),這里主要分析一下各種條件注入和訓練策略是怎么實現的。

各種條件注入

首先是尺寸和裁剪的條件注入,在圖像進行預處理的階段就記錄下了每張圖的原始尺寸以及裁剪位置:

def preprocess_train(examples):images = [image.convert("RGB") for image in examples[image_column]]original_sizes = []all_images = []crop_top_lefts = []for image in images:# 在這里記錄原始尺寸original_sizes.append((image.height, image.width))# 調整圖片大小image = train_resize(image)# 以 0.5 的概率進行隨機翻轉if args.random_flip and random.random() < 0.5:image = train_flip(image)# 進行裁剪if args.center_crop:y1 = max(0, int(round((image.height - args.resolution) / 2.0)))x1 = max(0, int(round((image.width - args.resolution) / 2.0)))image = train_crop(image)else:y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))image = crop(image, y1, x1, h, w)# 在這里記錄裁剪位置crop_top_left = (y1, x1)crop_top_lefts.append(crop_top_left)image = train_transforms(image)all_images.append(image)examples["original_sizes"] = original_sizesexamples["crop_top_lefts"] = crop_top_leftsexamples["pixel_values"] = all_imagesreturn examples

隨后原始尺寸和裁剪位置被進行編碼,可以看到下邊這部分包含了三部分的條件注入:

def compute_time_ids(original_size, crops_coords_top_left):target_size = (args.resolution, args.resolution)# 包括三部分的條件注入,分別為:# 1. 原始尺寸;2. 裁剪位置;3. 目標尺寸add_time_ids = list(original_size + crops_coords_top_left + target_size)add_time_ids = torch.tensor([add_time_ids])add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)return add_time_idsadd_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
)

最后把 pooled prompt embedding 也加入進來:

unet_added_conditions = {"time_ids": add_time_ids}
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})

這樣四種條件注入就都準備好了,在 forward 時直接傳到 UNet 的 added_cond_kwargs 參數即可參與計算。這些參數在 get_aug_embed 中被組合起來添加到 time embedding 上:

# pooled text embedding
text_embeds = added_cond_kwargs.get("text_embeds")
# 1. 原始尺寸;2. 裁剪位置;3. 目標尺寸
time_ids = added_cond_kwargs.get("time_ids")
# 處理得到最終加到 time embedding 上的條件嵌入
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

在上邊的代碼里用到了兩個對象 self.add_time_projself.add_embedding,定義為:

self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

這兩個對象中,Timesteps 應該是負責傅立葉編碼,TimestepEmbedding 則負責對編碼后的結果進行嵌入。

Noise Offset

這個實現很簡單,就在加噪前對 noise 隨機偏移一下即可:

if args.noise_offset:# https://www.crosslabs.org//blog/diffusion-with-offset-noisenoise += args.noise_offset * torch.randn((model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device)

多尺度微調

根據我的觀察,diffusers 里沒有直接提供多尺度微調相關的代碼,應該是默認在訓練之前已經自行處理好了各個 bucket 的圖像。印象中前段時間某個組織開源了一份分 bucket 的代碼,不過因為當時沒保存所以現在找不到了,能找到的主要是 kohya-ss/sd-scripts 的一個實現。

大體的原理是先創建一系列桶,然后對于每張圖片,選擇長寬比最接近的一個桶,然后進行裁剪,裁剪到和這個桶對應的分辨率相同。由于相鄰兩個桶之間的分辨率之差為 64,所以最多裁剪 32 像素,對訓練的影響并不大。在將圖片分桶之后,則可以按照每個桶的數據比例作為概率進行采樣。如果某些桶中的數據量不足一個 batch,則把這個桶中的數據都放入一個公共桶中,并以標準的 1024×10241024\times10241024×1024 分辨率進行訓練。

如果讀者有興趣自己閱讀代碼,可以先看 library.model_util 模塊中的 make_bucket_resolutions,這個方法創建了一系列分辨率的 bucket,并在 library.train_util.BucketManager 中調用,用來創建 bucket。這個 BucketManager 提供了一個方法 select_bucket,用來為某個特定分辨率的圖像選擇 bucket。最后在 library.train_util.BaseDataset 中,會對每張圖片調用 select_bucket 選擇 bucket,再將對應的圖片加入到選擇的 bucket 中。

總結

感覺 SDXL 是一個比較工程的工作,尤其是對模型架構的修改,比較大力出奇跡。除此之外感覺對數據的理解還是很重要的,除了修改模型架構之外的其他工作都是圍繞著數據展開的,這也是比較值得學習的思路。

參考資料:

  1. 深入淺出完整解析Stable Diffusion XL(SDXL)核心基礎知識
  2. 擴散模型(七)| SDXL

本文原文以 CC BY-NC-SA 4.0 許可協議發布于 筆記|擴散模型(一一):Stable Diffusion XL 理論與實現,轉載請注明出處。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/93637.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/93637.shtml
英文地址,請注明出處:http://en.pswp.cn/web/93637.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

學習安卓APP開發,10年磨一劍,b4a/Android Studio

學習安卓APP開發 記得上次學APP都是在2016年前了&#xff0c;一晃就過去了10年。 當時用ANDROID studio打開一個空項目和編繹分別用了300秒&#xff0c;一下就用了10分鐘。 后來買了一臺一萬多的電腦&#xff0c;CPU換成了I5 8600K 4.2GHZ*6核&#xff0c;再加上M2固態硬盤。 編…

調試技巧(vs2022 C語言)

調試之前首先要保證我們的腦袋是清晰的&#xff0c;我們調試的過程主要是看代碼有沒有按照我們的想法去運行調試最常使用的幾個快捷鍵F5啟動調試&#xff0c;經常用來直接跳到下一個斷點處&#xff08;F5通常和F9配合使用&#xff0c;打了斷點按F5程序可以直接運行到斷點處&…

MySQL深度理解-Innodb底層原理

1.MySQL的內部組件結構大體來說&#xff0c;MySQL可以分為Server層和存儲引擎層兩部分。2.Server層Server層主要包括連接器、查詢緩存、分析器、優化器和執行器等&#xff0c;涵蓋MySQL的大多數核心服務功能&#xff0c;以及所有的內置函數&#xff08;如日期、時間、數據和加密…

QFtp在切換目錄、上傳文件、下載文件、刪除文件等一系列操作時,如何按照預期操作指令順序執行

FTP服務初始化時&#xff0c;考慮到重連、以及commandFinished信號信號執行完成置m_bCmdFinished 為true; void ICore::connectFtpServer() {if(g_pFile nullptr){g_pFile new QFile;}if(g_pFtp){g_pFtp->state();g_pFtp->abort();g_pFtp->deleteLater();g_pFtp n…

JavaSE高級-02

文章目錄1. 多線程1.1 創建線程的三種方式多線程的創建方式一&#xff1a;繼承Thread類多線程的創建方式二&#xff1a;實現Runnable接口多線程的創建方式三&#xff1a;實現Callable接口三種線程的創建方式對比Thread的常用方法1.2 線程安全線程同步方式一&#xff1a;同步代碼…

從舒適度提升到能耗降低再到安全保障,樓宇自控作用關鍵

在現代建筑的發展歷程中&#xff0c;樓宇自動化控制系統&#xff08;BAS&#xff09;已從單純的設備管理工具演變為集舒適度優化、能耗控制與安全保障于一體的核心技術。隨著物聯網和人工智能的深度應用&#xff0c;樓宇自控系統正以數據為紐帶&#xff0c;重構人與建筑的關系。…

圖像分類精度評價的方法——誤差矩陣、總體精度、用戶精度、生產者精度、Kappa 系數

本文詳細介紹 “圖像分類精度評價的方法”。 圖像分類后&#xff0c;需要評估分類結果的準確性&#xff0c;以判斷分類器的性能和結果的可靠性。 常涉及到下面幾個概念&#xff08;指標&#xff09; 誤差矩陣、總體精度、用戶精度、生產者精度和 Kappa 系數。1. 誤差矩陣&#…

【科普向-第一篇】數字鑰匙生態全景:手機廠商、車廠與協議之爭

目錄 一、協議標準之爭&#xff1a;誰制定規則&#xff0c;誰掌控入口 1.1 ICCE&#xff1a;中國車企主導的自主防線 1.2 ICCOA&#xff1a;手機廠商的生態突圍 1.3 CCC&#xff1a;國際巨頭的高端壁壘 1.4 協議對比 二、底層技術路線&#xff1a;成本與安全的博弈 2.1B…

dockerfile及docker常用操作

1: docker 編寫 Dockerfile 是用于構建 Docker 鏡像的文本文件&#xff0c;包含一系列指令和參數&#xff0c;用于定義鏡像的構建過程 以下是關鍵要點&#xff1a; 一、基本結構 ?FROM?&#xff1a;必須作為第一條指令&#xff0c;指定基礎鏡像&#xff08;如 FROM python:3.…

[vibe coding-lovable]lovable是不是ai界的復制忍者卡卡西?

在火影忍者的世界里&#xff0c;卡卡西也被稱為復制忍者&#xff0c;因為大部分忍術都可以被其Copy! 截圖提示:實現這個效果 -> 發給Lovalbe -> 生成的的效果如下&#xff0c;雖然不是1比1還原&#xff0c;但是這個效果也很驚艷。 這個交互設計&#xff0c;這個UI效果&am…

技術賦能安全:智慧工地構建城市建設新防線

城市建設的熱潮中&#xff0c;工地安全始終是關乎生命與發展的核心議題。江西新余火災等事故的沉痛教訓&#xff0c;暴露了傳統工地監管的諸多短板——流動焊機“行蹤難覓”&#xff0c;無證動火作業屢禁不止&#xff0c;每一次監管缺位都可能引發災難性后果。如今&#xff0c;…

Sublime Text 代碼編輯器(Mac中文)

原文地址&#xff1a;Sublime Text Mac 代碼編輯器 sublime text Mac一款輕量級的文本編輯器&#xff0c;擁有豐富的功能和插件。 它支持多種編程語言&#xff0c;包括C、Java、Python、Ruby等&#xff0c;可以幫助程序員快速編寫代碼。 Sublime Text的界面簡潔、美觀&#…

如何制定項目時間線,合理預計?

制定一份現實可行且行之有效的項目時間線&#xff0c;是一個系統性的分解、估算與排序過程&#xff0c;而非簡單的日期羅列。核心步驟包括&#xff1a;明確項目范圍與可交付成果、利用工作分解結構&#xff08;WBS&#xff09;進行任務拆解、科學估算各項任務的持續時間、識別并…

RSA詳解

一、RSA 簡介RSA 是一種公鑰密碼體制&#xff0c;由羅納德?李維斯特&#xff08;Ron Rivest&#xff09;、阿迪?薩莫爾&#xff08;Adi Shamir&#xff09;和倫納德?阿德曼&#xff08;Leonard Adleman&#xff09;于 1977 年提出&#xff0c;算法名稱由他們三人姓氏的首字母…

Linux獲取物理硬盤總容量

獲取物理硬盤總容量: 1.查看單個硬盤: 使用 lsblk 或 fdisk -l (需要 sudo) 命令。它們會直接列出物理硬盤 (sda, nvme0n1 等) 和它們的分區,并顯示硬盤的總物理容量。 abcd四塊物理盤,只掛載使用3塊,留一塊未使用 最常見的原因通常是配置了熱備盤(RAID 1/5/6/10 等冗余…

STM32學習筆記14-I2C硬件控制

I2C外設簡介STM32內部集成了硬件I2C收發電路&#xff08;硬件收發器&#xff1a;自動生產波形&#xff0c;自動翻轉電平等&#xff09;&#xff0c;可以由硬件自動執行時鐘生成、起始終止條件生成、應答位收發、數據收發等功能&#xff0c;減輕CPU的負擔——軟件只需要寫入控制…

電子電氣架構 --- 軟件開發數字化轉型

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 簡單,單純,喜歡獨處,獨來獨往,不易合同頻過著接地氣的生活,除了生存溫飽問題之外,沒有什么過多的欲望,表面看起來很高冷,內心熱情,如果你身…

我國空間站首次應用專業領域 AI大模型

據中國載人航天工程辦公室消息&#xff0c;北京時間2025年8月15日22時47分&#xff0c;經過約6.5小時的出艙活動&#xff0c;神舟二十號乘組航天員陳冬、陳中瑞、王杰密切協同&#xff0c;在空間站機械臂和地面科研人員的配合支持下&#xff0c;圓滿完成既定任務&#xff0c;出…

WPF真入門教程35--手搓WPF出真汁【蜀味正道CS版】

1、項目介紹 本項目采用多層架構設計&#xff0c;使用wpf&#xff0c;Panuon.UI.Silver控件庫&#xff0c;AduSkin皮膚&#xff0c;MVVM等技術開發具有復雜交互和視覺效果的CS應用程序。WPF適用于企業級桌面應用&#xff1a;如ERP、CRM系統&#xff0c;需復雜表單和報表。WPF適…

JMeter與大模型融合應用之構建AI智能體:評審性能測試腳本

JMeter與大模型融合應用之構建AI智能體&#xff1a;評審性能測試腳本 一、引言 隨著DevOps和持續測試的普及&#xff0c;性能測試已成為軟件開發生命周期中不可或缺的環節。Apache JMeter作為最流行的開源性能測試工具之一&#xff0c;被廣泛應用于各種性能測試場景。然而&…