【DLI】Generative AI with Diffusion Models通關秘籍

Generative AI with Diffusion Models,加載時間在20分鐘左右,耐心等待。
在這里插入圖片描述
在這里插入圖片描述

6.2TODO

在這里插入圖片描述

這里是在設置擴散模型的參數,代碼里的FIXME部分需要根據上下文進行替換。以下是各個FIXME的替換說明:
1.a_bar 是 a 的累積乘積,在 PyTorch 里可以用 torch.cumprod 實現。
2.sqrt_a_bar、sqrt_one_minus_a_bar 和 sqrt_a_inv 都是對輸入張量求平方根,可使用 torch.sqrt 實現。
3.pred_noise_coeff 中的 FIXME(1 - a_bar) 同樣是求平方根,用 torch.sqrt 即可。
以下是替換后的代碼:

nrows = 10
ncols = 15T = nrows * ncols
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)a = 1.0 - B
a_bar = torch.cumprod(a, dim=0)
sqrt_a_bar = torch.sqrt(a_bar)  # Mean Coefficient
sqrt_one_minus_a_bar = torch.sqrt(1 - a_bar)  # St. Dev. Coefficient# Reverse diffusion variables
sqrt_a_inv = torch.sqrt(1 / a)
pred_noise_coeff = (1 - a) / torch.sqrt(1 - a_bar)  # Predicted Noise Coefficient

在擴散模型里,正向擴散過程 q 函數是按照如下公式把原始圖像 x_0 逐步添加噪聲變成 x_t 的
在這里插入圖片描述
FIXME 部分應該分別用 sqrt_a_bar_t 和 sqrt_one_minus_a_bar_t 來替換。
在這個 q 函數中,按照擴散模型的正向過程公式,把原始圖像 x_0 和隨機噪聲 noise 按一定比例組合,從而得到加噪后的圖像 x_t。

def q(x_0, t):t = t.int()noise = torch.randn_like(x_0)sqrt_a_bar_t = sqrt_a_bar[t, None, None, None]sqrt_one_minus_a_bar_t = sqrt_one_minus_a_bar[t, None, None, None]x_t = sqrt_a_bar_t * x_0 + sqrt_one_minus_a_bar_t * noisereturn x_t, noise

在反向擴散過程中,我們要根據當前的潛在圖像,當前時間步 , 以及預測的噪聲 來恢復上一個時間步的圖像。在這里插入圖片描述
在這個 reverse_q 函數中,我們根據反向擴散過程的公式,從當前的潛在圖像和預測的噪聲中恢復上一個時間步的圖像。如果當前時間步為 0,則表示反向擴散過程完成。否則,我們會添加一些噪聲以模擬擴散過程。下面是對代碼中 FIXME 部分的分析與替換:

@torch.no_grad()
def reverse_q(x_t, t, e_t):t = t.int()pred_noise_coeff_t = pred_noise_coeff[t]sqrt_a_inv_t = sqrt_a_inv[t]u_t = sqrt_a_inv_t * (x_t - pred_noise_coeff_t * e_t)if t[0] == 0:  # All t values should be the samereturn u_t  # Reverse diffusion complete!else:B_t = B[t - 1]  # Apply noise from the previous timestepnew_noise = torch.randn_like(x_t)return u_t + torch.sqrt(B_t) * new_noise

在這里插入圖片描述

6.3TODO

在這里插入圖片描述

每個類的功能來添加正確模塊名 依次改寫FIXME 即可:

DownBlock進行下采樣操作,包含卷積和池化相關的塊
EmbedBlock將輸入進行線性變換和激活
GELUConvBlock使用了卷積、組歸一化和 GELU 激活函數,通常是一個卷積塊
RearrangePoolBlock使用了 Rearrange 進行張量重排和卷積操作
ResidualConvBlock使用了兩個卷積塊并進行了殘差連接
SinusoidalPositionEmbedBlock實現了正弦位置嵌入的功能
UpBlock上采樣操作,包含轉置卷積和卷積塊

6.4TODO

在這個 get_context_mask 函數里,其目的是隨機丟棄上下文信息。要實現隨機丟棄,通常會使用 torch.bernoulli 函數。torch.bernoulli 函數會依據給定的概率來生成一個二進制掩碼張量,其中每個元素為 1 的概率就是傳入的概率值。
在這個函數中,我們希望以 drop_prob 的概率丟棄上下文,所以每個元素保留的概率是 1 - drop_prob。因此,FIXME 處應該填入 bernoulli。

def get_context_mask(c, drop_prob):c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device)c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device)return c_hot, c_mask

代碼解釋:
c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device):將輸入的 c 轉換為獨熱編碼向量,并且移動到指定的設備(如 GPU)上。
c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device):生成一個與 c_hot 形狀相同的二進制掩碼張量,每個元素以 1 - drop_prob 的概率為 1,以 drop_prob 的概率為 0。
return c_hot, c_mask:返回獨熱編碼向量和二進制掩碼張量。
這樣,你就可以使用這個函數來隨機丟棄上下文信息了。

在這里插入圖片描述

在擴散模型里,通常采用均方誤差損失(Mean Squared Error Loss,MSE)來衡量預測噪聲 noise_pred 和實際添加的噪聲 noise 之間的差異。因為均方誤差能夠很好地衡量兩個向量之間的平均平方誤差,這對于擴散模型中預測噪聲的準確性評估是很合適的。
在 PyTorch 中,nn.functional.mse_loss 函數可用于計算均方誤差損失。所以 FIXME 處應填入 mse_loss。

def get_loss(model, x_0, t, *model_args):x_noisy, noise = q(x_0, t)noise_pred = model(x_noisy, t/T, *model_args)return F.mse_loss(noise, noise_pred)

代碼解釋
x_noisy, noise = q(x_0, t):調用 q 函數給原始圖像 x_0 添加噪聲,得到加噪后的圖像 x_noisy 以及實際添加的噪聲 noise。
noise_pred = model(x_noisy, t/T, *model_args):把加噪后的圖像 x_noisy 和歸一化后的時間步 t/T 輸入到模型 model 中,得到模型預測的噪聲 noise_pred。
return F.mse_loss(noise, noise_pred):使用 F.mse_loss 函數計算實際噪聲 noise 和預測噪聲 noise_pred 之間的均方誤差損失并返回。
通過使用均方誤差損失,模型能夠學習到如何更準確地預測添加到圖像中的噪聲,從而在反向擴散過程中更好地恢復原始圖像。

下一個 TODO

  1. c_drop_prob 的設置
    c_drop_prob 是上下文丟棄概率,一般在訓練過程中會采用線性衰減策略,也就是在訓練初期以較高概率丟棄上下文,隨著訓練的推進逐漸降低丟棄概率。在代碼中,我們可以簡單地將其設置為一個隨著訓練輪數逐漸降低的值。
  2. get_context_mask 函數的輸入
    get_context_mask 函數需要一個上下文標簽作為輸入,在代碼里這個標簽應該從 batch 中獲取。通常假設 batch 的第二個元素為上下文標簽。

optimizer = Adam(model.parameters(), lr=0.001)
epochs = 5
preview_c = 0model.train()
for epoch in range(epochs):# 線性衰減上下文丟棄概率c_drop_prob = max(0.1, 1 - epoch / epochs)  #這里我調整了順序for step, batch in enumerate(dataloader):optimizer.zero_grad()t = torch.randint(0, T, (BATCH_SIZE,), device=device).float()x = batch[0].to(device)# 假設 batch 的第二個元素是上下文標簽c = batch[1].to(device)c_hot, c_mask = get_context_mask(c, c_drop_prob)loss = get_loss(model, x, t, c_hot, c_mask)loss.backward()optimizer.step()if epoch % 1 == 0 and step % 100 == 0:print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()} | C: {preview_c}")c_drop_prob = 0  # Do not drop context for previewc_hot, c_mask = get_context_mask(torch.Tensor([preview_c]).to(device), c_drop_prob)sample_images(model, IMG_CH, IMG_SIZE, ncols, c_hot, c_mask)preview_c = (preview_c + 1) % N_CLASSES

代碼解釋
c_drop_prob 的設置:運用線性衰減策略,在訓練初期 c_drop_prob 為 0.9,隨著訓練的推進逐漸降低到 0.1。
get_context_mask 函數的輸入:假設 batch 的第二個元素是上下文標簽,將其傳入 get_context_mask 函數。
訓練過程:在每個訓練步驟中,先將梯度清零,接著計算損失,再進行反向傳播和參數更新。每訓練 100 個步驟,就打印一次損失信息并進行一次樣本生成。
通過這些修改,代碼就能正常運行,從而開始訓練模型。
在這里插入圖片描述

6.5TODO

在擴散模型的采樣過程中,為了給擴散過程添加權重,一般會根據給定的權重 w 對保留上下文的預測噪聲 e_t_keep_c 和丟棄上下文的預測噪聲 e_t_drop_c 進行加權組合。在這里插入圖片描述
在代碼中,FIXME 處應該根據上述公式進行計算,將 e_t_keep_c 和 e_t_drop_c 按照權重 w 進行組合。具體的代碼如下:

def sample_w(model, c, w):input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)n_samples = len(c)w = torch.tensor([w]).float()w = w[:, None, None, None].to(device)  # Make w broadcastablex_t = torch.randn(n_samples, *input_size).to(device)# One c for each wc = c.repeat(len(w), 1)# Double the batchc = c.repeat(2, 1)# Don't drop context at test timec_mask = torch.ones_like(c).to(device)c_mask[n_samples:] = 0.0x_t_store = []for i in range(0, T)[::-1]:# Duplicate t for each samplet = torch.tensor([i]).to(device)t = t.repeat(n_samples, 1, 1, 1)# Double the batchx_t = x_t.repeat(2, 1, 1, 1)t = t.repeat(2, 1, 1, 1)# Find weighted noisee_t = model(x_t, t, c, c_mask)e_t_keep_c = e_t[:n_samples]e_t_drop_c = e_t[n_samples:]e_t = w * e_t_keep_c + (1 - w) * e_t_drop_c# Deduplicate batch for reverse diffusionx_t = x_t[:n_samples]t = t[:n_samples]x_t = reverse_q(x_t, t, e_t)return x_t

## TODO

在擴散模型里,權重 w 可用于控制上下文信息在生成過程中的影響程度。w 值越接近 1,生成結果就越依賴上下文信息;w 值越接近 0,生成結果受上下文信息的影響就越小。若要讓生成的數字能夠被持續識別,你可以試著增大 w 的值,以此增強上下文信息對生成過程的影響。
下面是修改后的代碼,你可以調整 w 的值來觀察生成結果:

model.eval()
w = 5.0  # 可以嘗試不同的值,通常大于 1 能增強上下文的影響
c = torch.arange(N_CLASSES).to(device)
c_drop_prob = 0 
c_hot, c_mask = get_context_mask(c, c_drop_prob)x_0 = sample_w(model, c_hot, w)
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES))

代碼解釋
w = 5.0:把 w 的值設為 5.0,你可以根據實際情況調整這個值。通常,當 w 大于 1 時,上下文信息的影響會得到增強,這樣生成的數字可能會更易于識別。
x_0 = sample_w(model, c_hot, w):調用 sample_w 函數生成圖像,將 w 作為參數傳入。
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES)):把生成的圖像轉換為可視化的形式。
你可以多次運行這段代碼,并且調整 w 的值,直到生成的數字能夠被穩定識別。

至此結束。
在這里插入圖片描述

完整代碼都在圖片里

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/75483.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/75483.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/75483.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

如何在本地部署魔搭上千問Qwen2.5-VL-32B-Instruct-AWQ模型在顯卡1上面運行推理,并開啟api服務

環境: 云服務器Ubuntu NVIDIA H20 96GB Qwen2.5-VL-32B Qwen2.5-VL-72B 問題描述: 如何在本地部署魔搭上千問Qwen2.5-VL-32B-Instruct-AWQ模型在顯卡1上面運行推理,并開啟api服務 解決方案: 1.環境準備 硬件要求 顯卡1(顯存需≥48GB,推薦≥64GB)CUDA 11.7或更高…

基于方法分類的無監督圖像去霧論文

在之前的博客中,我從研究動機的角度對無監督圖像去霧論文進行了分類,而現在我打算根據論文中提出的方法進行新的分類。 1. 基于對比學習的方法 2022年 論文《UCL-Dehaze: Towards Real-world Image Dehazing via Unsupervised Contrastive Learning》&a…

4月3號.

JDK7前時間相關類: 時間的相關知識: Data時間類: //1.創建對象表示一個時間 Date d1 new Date(); //System.out.println(d1);//2.創建對象表示一個指定的時間 Date d2 new Date(0L); System.out.println(d2);//3.setTime修改時間 //1000毫秒1秒 d2.setTime(1000L); System.o…

數據結構與算法:子數組最大累加和問題及擴展

前言 子數組最大累加和問題看似簡單,但能延伸出的題目非常多,千題千面,而且會和其他算法結合出現。 一、最大子數組和 class Solution { public:int maxSubArray(vector<int>& nums) {int n=nums.size();vector<int>dp(n);//i位置往左能延伸出的最大累加…

MIT6.828 Lab3-2 Print a page table (easy)

實驗內容 實現一個函數來打印頁表的內容&#xff0c;幫助我們更好地理解 xv6 的三級頁表結構。 修改內容 kernel/defs.h中添加函數聲明&#xff0c;方便其它函數調用 void vmprint(pagetable_t);// lab3-2 Print a page tablekernel/vm.c中添加函數具體定義 采用…

2025高頻面試設計模型總結篇

文章目錄 設計模型概念單例模式工廠模式策略模式責任鏈模式 設計模型概念 設計模式是前人總結的軟件設計經驗和解決問題的最佳方案&#xff0c;它們為我們提供了一套可復用、易維護、可擴展的設計思路。 &#xff08;1&#xff09;定義&#xff1a; 設計模式是一套經過驗證的…

Java基礎:面向對象進階(二)

01-static static修飾成員方法 static注意事項&#xff08;3種&#xff09; static應用知識&#xff1a;代碼塊 static應用知識&#xff1a;單列模式 02-面向對象三大特征之二&#xff1a;繼承 什么是繼承&#xff1f; 使用繼承有啥好處? 權限修飾符 單繼承、Object類 方法重…

Spring框架如何做EhCache緩存?

在Spring框架中&#xff0c;緩存是一種常見的優化手段&#xff0c;用于減少對數據庫或其他資源的訪問次數&#xff0c;從而提高應用性能。Spring提供了強大的緩存抽象&#xff0c;支持多種緩存實現&#xff08;如EhCache、Redis、Caffeine等&#xff09;&#xff0c;并可以通過…

NVIDIA顯卡

NVIDIA顯卡作為全球GPU技術的標桿&#xff0c;其產品線覆蓋消費級、專業級、數據中心、移動計算等多個領域&#xff0c;技術迭代貫穿架構創新、AI加速、光線追蹤等核心方向。以下從技術演進、產品矩陣、核心技術、生態布局四個維度展開深度解析&#xff1a; 一、技術演進&…

【BUG】生產環境死鎖問題定位排查解決全過程

目錄 生產環境死鎖問題定位排查解決過程0. 表面現象1. 問題分析&#xff08;1&#xff09;數據庫連接池資源耗盡&#xff08;2&#xff09;數據庫鎖競爭(3) 代碼實現問題 2. 分析解決(0) 分析過程&#xff08;1&#xff09;優化數據庫連接池配置&#xff08;2&#xff09;優化數…

【計算機網絡應用層】

文章目錄 計算機網絡應用層詳解一、前言二、應用層的功能三、常見的應用層協議1. HTTP/HTTPS&#xff08;超文本傳輸協議&#xff09;2. DNS&#xff08;域名系統&#xff09;3. FTP&#xff08;文件傳輸協議&#xff09;4. SMTP/POP3/IMAP&#xff08;電子郵件協議&#xff09…

Linux 虛擬化方案

一、Linux 虛擬化技術分類 1. 全虛擬化 (Full Virtualization) 特點&#xff1a;Guest OS 無需修改&#xff0c;完全模擬硬件 代表技術&#xff1a; KVM (Kernel-based Virtual Machine)&#xff1a;主流方案&#xff0c;集成到 Linux 內核 QEMU&#xff1a;硬件模擬器&…

樹莓派 5 換清華源

首先備份原設置 cp /etc/apt/sources.list ~/sources.list.bak cp /etc/apt/sources.list.d/raspi.list ~/raspi.list.bak修改配置 /etc/apt/sources.list 文件替換內容如下&#xff08;原內容刪除&#xff09; deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm …

WGAN原理及實現(pytorch版)

WGAN原理及實現 一、WGAN原理1.1 原始GAN的缺陷1.2 Wasserstein距離的引入1.3 Kantorovich-Rubinstein對偶1.4 WGAN的優化目標1.4 數學推導步驟1.5 權重裁剪 vs 梯度懲罰1.6 優勢1.7 總結 二、WGAN實現2.1 導包2.2 數據加載和處理2.3 構建生成器2.4 構建判別器2.5 訓練和保存模…

Unity網絡開發基礎 (3) Socket入門 TCP同步連接 與 簡單封裝練習

本文章不作任何商業用途 僅作學習與交流 教程來自Unity唐老獅 關于練習題部分是我觀看教程之后自己實現 所以和老師寫法可能不太一樣 唐老師說掌握其基本思路即可,因為前端程序一般不需要去寫后端邏輯 1.認識Socket的重要API Socket是什么 Socket&#xff08;套接字&#xff0…

【linux】一文掌握 ssh和scp 指令的詳細用法(ssh和scp 備忘速查)

文章目錄 入門連接執行SCP配置位置SCP 選項配置示例ProxyJumpssh-copy-id SSH keygenssh-keygen產生鑰匙類型known_hosts密鑰格式 此快速參考備忘單提供了使用 SSH 的各種方法。 參考&#xff1a; OpenSSH 配置文件示例 (cyberciti.biz)ssh_config (linux.die.net) 入門 連…

真實筆試題

文章目錄 線程題樹的深度遍歷 線程題 實現一個類支持100個線程同時向一個銀行賬戶中存入一元錢.需通過同步機制消除競態條件,當所有線程執行完成后,賬戶余額必須精確等于100元 package com.itheima.thread;public class ShowMeBug {private double balance; // 賬戶余額priva…

2.2 路徑問題專題:LeetCode 63. 不同路徑 II

動態規劃解決LeetCode 63題&#xff1a;不同路徑 II&#xff08;含障礙物&#xff09; 1. 題目鏈接 LeetCode 63. 不同路徑 II 2. 題目描述 一個機器人位于 m x n 網格的左上角&#xff0c;每次只能向右或向下移動一步。網格中可能存在障礙物&#xff08;標記為 1&#xff…

2874. 有序三元組中的最大值 II

給你一個下標從 0 開始的整數數組 。nums 請你從所有滿足 的下標三元組 中&#xff0c;找出并返回下標三元組的最大值。 如果所有滿足條件的三元組的值都是負數&#xff0c;則返回 。i < j < k(i, j, k)0 下標三元組 的值等于 。(i, j, k)(nums[i] - nums[j]) * nums[k…

【論文筆記】Llama 3 技術報告

Llama 3中的頂級模型是一個擁有4050億參數的密集Transformer模型&#xff0c;并且它的上下文窗口長度可以達到128,000個tokens。這意味著它能夠處理非常長的文本&#xff0c;記住和理解更多的信息。Llama 3.1的論文長達92頁&#xff0c;詳細描述了模型的開發階段、優化策略、模…