DeepSeek基于注意力模型的可控圖像生成

DeepSeek大模型高性能核心技術與多模態融合開發 - 商品搜索 - 京東

圖像的加噪與模型訓練

在擴散模型的訓練過程中,首先需要對輸入的信號進行加噪處理,經典的加噪過程是在圖像進行向量化處理后在其中添加正態分布,而正態分布的值也是與時間步相關的。這樣逐步向圖像中添加噪聲,直到圖像變得完全噪聲化。

import torch   T = 1000  # Diffusion過程的總步數  # 前向diffusion計算參數
# (T,) 生成一個線性間隔的tensor,用于計算每一步的噪聲水平  
betas = torch.linspace(0.0001, 0.02, T)    
alphas = 1 - betas  # (T,) 計算每一步的保留率  
# alpha_t累乘 (T,),計算每一步累積的保留率 
alphas_cumprod = torch.cumprod(alphas, dim=-1)   
# alpha_t-1累乘(T,),為計算方差做準備
alphas_cumprod_prev = torch.cat((torch.tensor([1.0]), alphas_cumprod[:-1]), dim=-1)    
# denoise用的方差(T,),計算每一步的去噪方差 
variance = (1 - alphas) * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)   # 執行前向加噪  
def forward_add_noise(x, t):  # batch_x: (batch,channel,height,width), batch_t: (batch_size,)  noise = torch.randn_like(x)  # 為每幅圖片生成第t步的高斯噪聲   (batch,channel,height,width)  # 根據當前步數t,獲取對應的累積保留率,并調整其形狀以匹配輸入x的形狀    batch_alphas_cumprod = alphas_cumprod[t].view(x.size(0), 1, 1, 1)    # 基于公式直接生成第t步加噪后的圖片    x = torch.sqrt(batch_alphas_cumprod) * x + torch.sqrt(1 - batch_alphas_cumprod) * noise    return x, noise  # 返回加噪后的圖片和生成的噪聲

上面這段代碼首先定義了擴散模型的前向過程中需要的參數,包括每一步的噪聲水平betas、保留率alphas、累積保留率alphas_cumprod以及用于去噪的方差variance。然后定義了一個函數forward_add_noise,該函數接受一個圖像x和步數t作為輸入。根據擴散模型的前向過程,向圖像中添加噪聲,并返回加噪后的圖像和生成的噪聲。

讀者可以采用以下代碼嘗試完成為圖像添加噪聲的演示:

import matplotlib.pyplot as plt 
from dataset import MNISTdataset=MNIST()
# 兩幅圖片拼batch, (2,1,48,48)    
x=torch.stack((dataset[0][0],dataset[1][0]),dim=0) # 原圖
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(x[0].permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(x[1].permute(1,2,0))
plt.show()# 隨機時間步
t=torch.randint(0,T,size=(x.size(0),))
print('t:',t)# 加噪
x=x*2-1 # [0,1]像素值調整到[-1,1]之間,以便與高斯噪聲值范圍匹配
x,noise=forward_add_noise(x,t)
print('x:',x.size())
print('noise:',noise.size())# 加噪圖
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(((x[0]+1)/2).permute(1,2,0))   
plt.subplot(1,2,2)
plt.imshow(((x[0]+1)/2).permute(1,2,0))
plt.show()

運行結果如圖9-13所示。

在此基礎上,我們可以完成對Dit模型的訓練,代碼如下:

from torch.utils.data import DataLoader  # 導入PyTorch的數據加載工具  
from dataset import MNIST  # 從dataset模塊導入MNIST數據集類  
from diffusion import forward_add_noise  # 從diffusion模塊導入forward_add_noise函數,用于向圖像添加噪聲  
import torch  # 導入PyTorch庫  
from torch import nn  # 從PyTorch導入nn模塊,包含構建神經網絡所需的工具  
import os  # 導入os模塊,用于處理文件和目錄路徑  
from dit import DiT  # 從dit模塊導入DiT模型  
# 判斷是否有可用的CUDA設備,如果有則使用GPU,否則使用CPU  
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'    dataset=MNIST()  # 實例化MNIST數據集對象  T = 1000  # 設置擴散過程中的總時間步數  
model=DiT(img_size=28,patch_size=4,channel=1,emb_size=64,label_num=10,dit_num=3,head=4).to(DEVICE)  # 實例化DiT模型并移至指定設備  
#model.load_state_dict(torch.load('./saver/model.pth'))  # 可選:加載預訓練模型參數  # 使用Adam優化器,學習率設置為0.001
optimzer=torch.optim.Adam(model.parameters(),lr=1e-3)   
loss_fn=nn.L1Loss()  # 使用L1損失函數(即絕對值誤差均值)  '''訓練模型'''  
EPOCH=300  # 設置訓練的總輪次  
BATCH_SIZE=300  # 設置每個批次的大小  if __name__ == '__main__':  from tqdm import tqdm  # 導入tqdm庫,用于在訓練過程中顯示進度條  dataloader=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=10,persistent_workers=True)  # 創建數據加載器  iter_count=0  for epoch in range(EPOCH):  # 遍歷每個訓練輪次  pbar = tqdm(dataloader, total=len(dataloader))  # 初始化進度條  for imgs,labels in pbar:  # 遍歷每個批次的數據  x=imgs*2-1  # 將圖像的像素范圍從[0,1]轉換到[-1,1],與噪聲高斯分布的范圍對應  t=torch.randint(0,T,(imgs.size(0),))  # 為每幅圖片生成一個隨機的t時刻  y=labels  # 向圖像添加噪聲,返回加噪后的圖像和添加的噪聲x,noise=forward_add_noise(x,t)    # 模型預測添加的噪聲pred_noise=model(x.to(DEVICE),t.to(DEVICE),y.to(DEVICE))    # 計算預測噪聲和實際噪聲之間的L1損失loss=loss_fn(pred_noise,noise.to(DEVICE))    optimzer.zero_grad()  # 清除之前的梯度  loss.backward()  # 反向傳播,計算梯度  optimzer.step()  # 更新模型參數  # 更新進度條描述pbar.set_description(f"epoch:{epoch + 1}, train_loss:{loss.item():.5f}")    if epoch % 20 == 0:  # 每20輪保存一次模型  torch.save(model.state_dict(),'./saver/model.pth')  print("base diffusion saved")

?讀者自行查看代碼運行結果。

基于注意力模型的可控圖像生成

DiT模型的可控圖像生成是在我們訓練的基礎上,逐漸對正態分布的噪聲圖像進行按步驟的脫噪過程。這一過程不僅要求模型具備精準的噪聲預測能力,還需確保脫噪步驟的細膩與連貫,從而最終實現從純粹噪聲到目標圖像的華麗蛻變。

完整的可控圖像生成代碼如下:

import torch   from dit import DiT  
import matplotlib.pyplot as plt   
# 導入diffusion模塊中的所有內容,這通常包含一些與擴散模型相關的預定義變量和函數
from diffusion import *    # 設置設備為GPU或CPU  
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'   
DEVICE = "cpu"  # 強制使用CPU  T = 1000  # 擴散步驟的總數  def backward_denoise(model,x,y):  steps=[x.clone(),]  # 初始化步驟列表,包含初始噪聲圖像  global alphas,alphas_cumprod,variance  # 這些是從diffusion模塊導入的全局變量  x=x.to(DEVICE)  # 將輸入x移動到指定的設備  alphas=alphas.to(DEVICE)  alphas_cumprod=alphas_cumprod.to(DEVICE)  variance=variance.to(DEVICE)  y=y.to(DEVICE)  # 將標簽y移動到指定的設備  model.eval()  # 設置模型為評估模式  with torch.no_grad():  # 在不計算梯度的情況下運行,節省內存和計算資源  for time in range(T-1,-1,-1):  # 從T-1到0逆序迭代  t=torch.full((x.size(0),),time).to(DEVICE)  # 創建一個包含當前時間步的tensor  # 預測x_t時刻的噪聲  noise=model(x,t,y)      # 生成t-1時刻的圖像  shape=(x.size(0),1,1,1)  mean=1/torch.sqrt(alphas[t].view(*shape))*  \  (  x-   (1-alphas[t].view(*shape))/torch.sqrt(1-alphas_cumprod[t].view(*shape))*noise  )  if time!=0:  x=mean+ \  torch.randn_like(x)* \  torch.sqrt(variance[t].view(*shape))  else:  x=mean  x=torch.clamp(x, -1.0, 1.0).detach()  # 確保x的值在[-1,1]之間,并分離計算圖  steps.append(x)  return steps  # 初始化DiT模型  
model=DiT(img_size=28,patch_size=4,channel=1,emb_size=64,label_num=10,dit_num=3,head=4).to(DEVICE)  
model.load_state_dict(torch.load('./saver/model.pth'))  # 加載模型權重  # 生成噪聲圖  
batch_size=10  
x=torch.randn(size=(batch_size,1,28,28))  # 生成隨機噪聲圖像  
y=torch.arange(start=0,end=10,dtype=torch.long)   # 生成標簽  # 逐步去噪得到原圖  
steps=backward_denoise(model,x,y)  # 繪制數量  
num_imgs=20  # 繪制還原過程  
plt.figure(figsize=(15,15))  
for b in range(batch_size):  for i in range(0,num_imgs):  idx=int(T/num_imgs)*(i+1)  # 計算要繪制的步驟索引  # 像素值還原到[0,1]  final_img=(steps[idx][b].to('cpu')+1)/2  # tensor轉回PIL圖  final_img=final_img.permute(1,2,0)  # 調整通道順序以匹配圖像格式  plt.subplot(batch_size,num_imgs,b*num_imgs+i+1)plt.imshow(final_img)  
plt.show()  # 顯示圖像

上面的代碼展示了使用DiT進行圖像去噪的完整過程。首先,它導入了必要的庫和模塊,包括PyTorch、DiT模型、matplotlib繪圖模塊,以及從diffusion模塊導入的一些預定義變量和函數,這些通常與擴散模型相關。然后,代碼設置了計算設備為CPU(盡管提供了檢測GPU可用性的選項),并定義了擴散步驟的總數。

backward_denoise函數是實現圖像去噪的核心。它接受一個DiT模型、一批噪聲圖像以及對應的標簽作為輸入。在這個函數內部,它首先將輸入移動到指定的計算設備,然后將模型設置為評估模式,并開始一個不計算梯度的循環,從最后一個擴散步驟開始逆向迭代至第一步。在每一步中,它使用模型預測當前步驟的噪聲,然后根據擴散模型的公式計算上一步的圖像。這個過程一直持續到生成原始圖像。

接下來,代碼初始化了DiT模型,并加載了預訓練的權重。然后,它生成了一批隨機噪聲圖像和對應的標簽,并使用backward_denoise函數對這些噪聲圖像進行去噪,逐步還原出原始圖像。運行結果如圖9-14所示。

圖9-14? 基于DiT模型的可控圖像生成

可見,我們使用生成代碼繪制了去噪過程的圖像,展示了從完全噪聲的圖像逐步還原為清晰圖像的過程。通過調整通道順序和像素值范圍,它將tensor格式的圖像轉換為適合繪制的格式,并使用matplotlib庫的subplot函數在一個大圖中展示了所有步驟的圖像。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/82330.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/82330.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/82330.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

第十六屆藍橋杯B組第二題

當時在考場的時候這一道題目 無論我是使用JAVA的大數(BIGTHGER)還是賽后 使用PY 都是沒有運行出來 今天也是突發奇想在B站上面搜一搜 看了才知道這也是需要一定的數學思維 通過轉換 設X來把運算式精簡化 避免運行超時 下面則是代碼 public class lanba…

HT71663同步升壓2.7V-13V輸入10A聚能芯半導體禾潤一級代理

在便攜式設備飛速發展的今天,電源轉換效率與產品尺寸始終是行業難以平衡的難題。但現在,HT71663 高功率全集成升壓轉換器強勢登場,一舉打破僵局,為便攜式系統帶來顛覆性的高效小尺寸解決方案!? HT71663 的卓越性能&am…

Unity:輸入系統(Input System)與持續檢測鍵盤按鍵(Input.GetKey)

目錄 Unity 的兩套輸入系統: 🔍 Input.GetKey 詳解 🎯 對比:常用的輸入檢測方法 技術底層原理(簡化版) 示例:角色移動 為什么會被“新輸入系統”替代? Unity 的兩套輸入系統&…

港大今年開源了哪些SLAM算法?

過去的5個月,香港大學 MaRS 實驗室陸續開源了四套面向無人機的在線 SLAM 框架:**FAST-LIVO2 、Point-LIO(grid-map 分支) 、Voxel-SLAM 、Swarm-LIO2 **。這四套框架覆蓋了單機三傳感器融合、高帶寬高速機動、長時間多級地圖優化以…

【質量管理】TRIZ因果鏈分析:解碼質量問題的“多米諾效應“

為什么要使用因果鏈分析 沒有發現問題并不等于沒有問題。愛因斯坦曾說,如果我只有一個小時的時間來拯救世界,我將花45分鐘時間分析問題,10分鐘的時間來檢查問題,最后5分鐘的時間來解決問題。可見問題分析的重要性。 在質量管理實踐…

線程中常用的方法

知識點詳細說明 Java線程的核心方法集中在Thread類和Object類中,以下是新增整合后的常用方法分類解析: 1. 線程生命周期控制 方法作用注意事項start()啟動新線程,JVM調用run()方法多次調用會拋出IllegalThreadStateException(線程狀態不可逆)。run()線程的任務邏輯直接調…

c++:迭代器(Iterator)

目錄 🚪什么是迭代器? 🔧 迭代器的本質 為什么不用普通數組或下標? STL容器的迭代器并不是共用一個類型! 迭代器的類型(Iterator Categories) 📦 常見容器的迭代器類型 ? 迭…

【文件系統—散列結構文件】

文章目錄 一、實驗目的實驗內容設計思路 三、實驗代碼實現四、總結 一、實驗目的 理解linux文件系統的內部技術,掌握linux與文件有關的系統調用命令,并在此基礎上建立面向隨機檢索的散列結構文件;## 二、實驗內容與設計思想 實驗內容 1.設…

力扣26——刪除有序數組中的重復項

目錄 1.題目描述: 2.算法分析: 3.代碼展示: 1.題目描述: 給你一個 非嚴格遞增排列 的數組 nums ,請你 原地 刪除重復出現的元素,使每個元素 只出現一次 ,返回刪除后數組的新長度。元素的 相對…

ggplot2 | GO barplot with gene list

1. 效果圖 2. 代碼 數據是GO的輸出結果,本文使用的是 metascape 輸出的excel挑選的若干行。 # 1. 讀取數據 datread.csv("E:\\research\\scPolyA-seq2\\GO-APA-Timepoint\\test.csv", sep"\t") head(dat)# 2. 選擇所需要的列 dat.usedat[, c(…

學習搭子,秘塔AI搜索

什么是秘塔AI搜索 《秘塔AI搜索》的網址:https://metaso.cn/ 功能:AI搜索和知識學習,其中學習部分是亮點,也是主要推薦理由。對應的入口:https://metaso.cn/study 推薦理由 界面細節做工精良《今天學點啥》板塊的知…

【C語言】--指針超詳解(三)

目錄 一.數組名的理解 二.使用指針訪問數組 三.一維數組傳參的本質 四.冒泡排序 五.二級指針 六.指針數組 6.1--指針數組的定義 6.2--指針數組模擬二維數組 🔥個人主頁:草莓熊Lotso的個人主頁 🎬作者簡介:C方向學習者 &…

Linux防火墻

1.防火墻是一種位于內部網絡與外部網絡之間的網絡安全系統,它依照特定的規則,允許或限制傳輸的數據通過,以保護內部網絡的安全。以下從功能、分類、工作原理等方面為你詳細講解: 功能訪問控制:這是防火墻最主要的功能。…

嵌入式培訓之C語言學習完(十七)結構體、共用體、枚舉、typedef關鍵字與位運算

目錄 一、結構體(struct關鍵字) (一)聲明一個結構體數據類型 (二)結構體的成員初始化與賦值 a、結構體變量賦值 b、結構體成員初始化 c、結構體的定義形式 (三)考點&#xff…

Python字典:數據操作的核心容器

在Python編程生態中,字典(dict)是最常用且功能強大的內置數據結構之一。它以鍵值對(Key-Value Pair)的形式存儲數據,為快速查找、靈活映射關系提供了天然支持。無論是數據清洗、算法實現還是Web開發&#x…

按位寬提取十六進制值

需求:給出一個十六進制值,要求提取high和low位之間的值。比如16ha0f0,這是一個16bit寬的十六進制數0xa0f0,提取[15:12]范圍內的值。 def extract_bits(value, high, low):"""從 value 中提取 [high:low] 位的值:p…

LeRobot 項目部署運行邏輯(六)——visualize_dataset_html.py/visualize_dataset.py

可視化腳本包括了兩個方法:遠程下載 huggingface 上的數據集和使用本地數據集 腳本主要使用兩個: 目前來說,ACT 采集訓練用的是統一時間長度的數據集,此外,這兩個腳本最大的問題在于不能裁剪,這也是比較好…

SSTI模版注入

1、概念 SSTI是一種常見的Web安全漏洞,它允許攻擊者通過注入惡意模板代碼,使服務器在渲染模板時執行非預期的操作。 (1)渲染模版 至于什么是渲染模版:服務器端渲染模板是一種Web開發技術,它允許在服務器端…

關于點膠機的精度

一、精度: 1:X/y軸定位精度常通在5個絲左右,Z軸在3個絲左右, 如果采用伺服電機絲桿配置,可提升至于個2絲左右。 2:膠水控制精度:通過噴閥驅動器,氣壓等參數,實現膠量控制&#xf…

gitee推送更新失敗問題記錄:remote: error: hook declined to update refs/heads/master

問題描述: gitee推送更新時,提示: 解決方法: 登錄Gitee,進入【個人主頁】 點擊【個人設置】 更改郵箱的配置,如下: 更改“禁止命令行推送暴露個人郵箱”,將其關閉: