nanodiffusion代碼逐行理解之diffusion

目錄

  • 一、diffusion創建
  • 二、GaussianDiffusion定義
  • 三、代碼理解
    • def __init__(self,model,img_size,img_channels,num_classes,betas, loss_type="l2", ema_decay=0.9999, ema_start=5000, ema_update_rate=1,):
    • def remove_noise(self, x, t, y, use_ema=True):
    • def sample(self, batch_size, device, y=None, use_ema=True):
    • def perturb_x(self, x, t, noise):
    • def get_losses(self, x, t, y):
    • def forward(self, x, y=None):
    • def generate_cosine_schedule(T, s=0.008)和def generate_linear_schedule(T, low, high):

一、diffusion創建

diffusion = GaussianDiffusion(model,args.img_size,args.img_channels,args.num_classes,betas,ema_decay=args.ema_decay,ema_update_rate=args.ema_update_rate,ema_start=2000,loss_type=args.loss_type,)

二、GaussianDiffusion定義

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom functools import partial
from copy import deepcopyfrom ddpm.ema import EMA
from ddpm.utils import extractclass GaussianDiffusion(nn.Module):def __init__(self,model,img_size,img_channels,num_classes,betas,loss_type="l2",ema_decay=0.9999,ema_start=5000,ema_update_rate=1,):super().__init__()self.model = modelself.ema_model = deepcopy(model)self.ema = EMA(ema_decay)self.ema_decay = ema_decayself.ema_start = ema_startself.ema_update_rate = ema_update_rateself.step = 0self.img_size = img_sizeself.img_channels = img_channelsself.num_classes = num_classesif loss_type not in ["l1", "l2"]:raise ValueError("__init__() got unknown loss type")self.loss_type = loss_typeself.num_timesteps = len(betas)alphas = 1.0 - betasalphas_cumprod = np.cumprod(alphas)to_torch = partial(torch.tensor, dtype=torch.float32)self.register_buffer("betas", to_torch(betas))self.register_buffer("alphas", to_torch(alphas))self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))self.register_buffer("sigma", to_torch(np.sqrt(betas)))def update_ema(self):self.step += 1if self.step % self.ema_update_rate == 0:if self.step < self.ema_start:self.ema_model.load_state_dict(self.model.state_dict())else:self.ema.update_model_average(self.ema_model, self.model)@torch.no_grad()def remove_noise(self, x, t, y, use_ema=True):if use_ema:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))else:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))@torch.no_grad()def sample(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)return x.cpu().detach()@torch.no_grad()def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)yield x.cpu().detach()def perturb_x(self, x, t, noise):return (extract(self.sqrt_alphas_cumprod, t, x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise)   def get_losses(self, x, t, y):noise = torch.randn_like(x)perturbed_x = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return lossdef forward(self, x, y=None):b, c, h, w = x.shapedevice = x.deviceif h != self.img_size[0]:raise ValueError("image height does not match diffusion parameters")if w != self.img_size[0]:raise ValueError("image width does not match diffusion parameters")t = torch.randint(0, self.num_timesteps, (b,), device=device)return self.get_losses(x, t, y)def generate_cosine_schedule(T, s=0.008):def f(t, T):return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2alphas = []f0 = f(0, T)for t in range(T + 1):alphas.append(f(t, T) / f0)betas = []for t in range(1, T + 1):betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))return np.array(betas)def generate_linear_schedule(T, low, high):return np.linspace(low, high, T)

三、代碼理解

Input:
x: (N, img_channels, *img_size)
y: (N)
Output:
scalar loss tensor
Args:
model (nn.Module):估計高斯噪聲的模型
img_size (tuple): (H, W)
img_channels (int): 圖像通道數
betas (np.ndarray): diffusion betas 數組
loss_type (string): loss type, “l1” or “l2” 類型
ema_decay (float): model weights exponential moving average decay
ema_start (int): number of steps before EMA
ema_update_rate (int): number of steps before each EMA update
“”"

def init(self,model,img_size,img_channels,num_classes,betas, loss_type=“l2”, ema_decay=0.9999, ema_start=5000, ema_update_rate=1,):

在這里插入圖片描述
np.cumprod返回數組沿指定軸的累計積。
a=[a1,a2,a3,a4,a5]
np.cumprod(a)=array([a1,a1a2,a1a2a3,a1a2a3a4,a1a2a3a4a5])。

def remove_noise(self, x, t, y, use_ema=True):

(x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape)

這個函數就是去除第t-1到第t步的噪聲
在這里插入圖片描述
在這個函數里面調用了extract函數。實現的功能:提取時間步t時對應的參數

def extract(a, t, x_shape):b, *_ = t.shapeout = a.gather(-1, t)return out.reshape(b, *((1,) * (len(x_shape) - 1)))

a: Tensor:(1000,)
t: Tensor:(128,)
x_shape: torch.Size([128, 1, 28, 28])
最終返回的是Tensor:(128,1,1,1)
模型定義在初始化函數中,模型調用定義在forward函數中。

def sample(self, batch_size, device, y=None, use_ema=True):

    def sample(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)return x.cpu().detach()

def perturb_x(self, x, t, noise):

在圖像中添加噪聲

        return (extract(self.sqrt_alphas_cumprod, t, x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise)   

在這里插入圖片描述

def get_losses(self, x, t, y):

計算添加噪聲和估計噪聲的損失

        noise = torch.randn_like(x)perturbed_x = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return loss

def forward(self, x, y=None):

前向函數很簡單,隨機b個t,然后計算對應的噪聲損失。

    def forward(self, x, y=None):b, c, h, w = x.shapedevice = x.deviceif h != self.img_size[0]:raise ValueError("image height does not match diffusion parameters")if w != self.img_size[0]:raise ValueError("image width does not match diffusion parameters")t = torch.randint(0, self.num_timesteps, (b,), device=device)return self.get_losses(x, t, y)

def generate_cosine_schedule(T, s=0.008)和def generate_linear_schedule(T, low, high):

這個函數就是兩種不同的生成betas的方法。betas數組是從小到大排列的。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/41076.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/41076.shtml
英文地址,請注明出處:http://en.pswp.cn/web/41076.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MySQL 集群

MySQL 集群有多種類型&#xff0c;每種類型都有其特定的用途和優勢。以下是一些常見的 MySQL 集群解決方案&#xff1a; 1. MySQL Replication 描述&#xff1a;MySQL 復制是一種異步復制機制&#xff0c;允許將一個 MySQL 數據庫的數據復制到一個或多個從服務器。 用途&…

bug——多重定義

bug——多重定義 你的問題是在C代碼中遇到了"reference to data is ambiguous"的錯誤。這個錯誤通常發生在你嘗試引用一個具有多重定義的變量時。 在你的代碼中&#xff0c;你定義了一個全局變量data&#xff0c;同時&#xff0c;C標準庫中也有一個名為data的函數模板…

【云原生】Kubernetes部署高可用平臺手冊

部署Kubernetes高可用平臺 文章目錄 部署Kubernetes高可用平臺基礎環境一、基礎環境配置1.1、關閉Swap1.2、添加hosts解析1.3、橋接IPv4流量傳遞到iptables的鏈 二、配置Kubernetes的VIP2.1、安裝Nginx2.2、修改Nginx配置文件2.3、啟動服務2.4、安裝Keepalived2.5、修改配置文件…

Linux 定時任務詳解:全面掌握 cron 和 at 命令

Linux 定時任務詳解&#xff1a;全面掌握 cron 和 at 命令 Linux 系統中定時任務的管理對于運維和開發人員來說都是至關重要的。通過定時任務&#xff0c;可以在特定時間自動執行腳本或命令&#xff0c;提高系統自動化程度。本文將詳細介紹 Linux 中常用的定時任務管理工具 cr…

一拖二快充線:生活充電新風尚,高效便捷解決雙設備充電難題

一拖二快充線在生活應用領域的優勢與雙接充電的便攜性問題 在現代快節奏的生活中&#xff0c;電子設備已成為我們不可或缺的日常伴侶。無論是智能手機、平板電腦還是筆記本電腦&#xff0c;它們在我們的工作、學習和娛樂中扮演著至關重要的角色。然而&#xff0c;隨著設備數量…

優化:遍歷List循環查找數據庫導致接口過慢問題

前提&#xff1a; 我們在寫查詢的時候&#xff0c;有時候會遇到多表聯查&#xff0c;一遇到多表聯查大家就會直接寫sql語句&#xff0c;不會使用較為方便的LambdaQueryWrapper去查詢了。作為一個2024新進入碼農世界的小白&#xff0c;我喜歡使用LambdaQueryWrapper&#xff0c;…

產品經理系列1—如何實現一個電商系統

具體筆記如下&#xff0c;主要按獲客—找貨—下單—售后四個部分進行模塊拆解

代碼隨想錄算法訓練Day58|LeetCode417-太平洋大西洋水流問題、LeetCode827-最大人工島

太平洋大西洋水流問題 力扣417-太平洋大西洋水流問題 有一個 m n 的矩形島嶼&#xff0c;與 太平洋 和 大西洋 相鄰。 “太平洋” 處于大陸的左邊界和上邊界&#xff0c;而 “大西洋” 處于大陸的右邊界和下邊界。 這個島被分割成一個由若干方形單元格組成的網格。給定一個…

用 Emacs 寫代碼有哪些值得推薦的插件

以下是一些用于 Emacs 寫代碼的值得推薦的插件&#xff1a; Ido-mode&#xff1a;交互式操作模式&#xff0c;它用列出當前目錄所有文件的列表來取代常規的打開文件提示符&#xff0c;能讓操作更可視化&#xff0c;快速遍歷文件。Smex&#xff1a;可替代普通的 M-x 提示符&…

【Unity】unity學習掃盲知識點

1、建議檢查下SystemInfo的引用。這個是什么 Unity的SystemInfo類提供了一種獲取關于當前硬件和操作系統的信息的方法。這包括設備類型&#xff0c;操作系統&#xff0c;處理器&#xff0c;內存&#xff0c;顯卡&#xff0c;支持的Unity特性等。使用SystemInfo類非常簡單。它的…

【python】生成完全數

定義 如果一個數恰好等于它的真因子之和&#xff0c;則稱該數為“完全數” [2]。各個小于它的約數&#xff08;真約數&#xff0c;列出某數的約數&#xff0c;去掉該數本身&#xff0c;剩下的就是它的真約數&#xff09;的和等于它本身的自然數叫做完全數&#xff08;Perfect …

Linux 查看磁盤是不是 ssd 的方法

lsblk 命令檢查 $ lsblk -d -o name,rota如果 ROTA 值為 1&#xff0c;則磁盤類型為 HDD&#xff0c;如果 ROTA 值為 0&#xff0c;則磁盤類型為 SSD。可以在上面的屏幕截圖中看到 sda 的 ROTA 值是 1&#xff0c;表示它是 HDD。 2. 檢查磁盤是否旋轉 $ cat /sys/block/sda/q…

php使用PHPExcel 導出數據表到Excel文件

直接上干貨&#xff1a;<?php$cards_list Cards::find($parameters);$objPHPExcel new \PHPExcel(); $objPHPExcel->getProperties()->setCreator("jiequan")->setLastModifiedBy("jiequan")->setTitle("card List")->setS…

Vuetify3: 根據滾動距離顯示/隱藏搜索組件

我們在使用vuetify3開發的時候&#xff0c;產品需要實現當搜索框因滾動條拉拽的時候&#xff0c;消失&#xff0c;搜索組件再次出現在頂部位置。這個我們需要獲取滾動高度&#xff0c;直接參考vuetify3 滾動指令???????&#xff0c;執行的時候發現一個問題需要設置 max-…

在什么情況下你會使用設計模式

設計模式是在軟件開發中解決常見問題的最佳實踐。它們提供了可復用的解決方案&#xff0c;使得代碼更加模塊化、易于理解和維護。以下是在什么情況下你可能會使用設計模式的一些常見情況&#xff1a; 代碼重復&#xff1a;當你發現項目中多處出現相同或相似的代碼結構時&#x…

機器學習之保存與加載

前言 模型的數據需要存儲和加載&#xff0c;這節介紹存儲和加載的方式方法。 存和加載模型權重 保存模型使用save_checkpoint接口&#xff0c;傳入網絡和指定的保存路徑&#xff0c;要加載模型權重&#xff0c;需要先創建相同模型的實例&#xff0c;然后使用load_checkpoint…

Autosar Dcm配置-0x85服務配置及使用-基于ETAS軟件

文章目錄 前言Dcm配置DcmDsdDcmDsp代碼實現總結前言 0x85服務用來控制DTC設置的開啟和關閉。某OEM3.0架構強制支持0x85服務,本文介紹ETAS工具中的配置 Dcm配置 DcmDsd 配置0x85服務 此處配置只在擴展會話下支持(具體需要根據需求決定),兩個子服務Disable為0x02,Enable…

馮諾依曼體系結構與操作系統(Linux)

文章目錄 前言馮諾依曼體系結構&#xff08;硬件&#xff09;操作系統&#xff08;軟件&#xff09;總結 前言 馮諾依曼體系結構&#xff08;硬件&#xff09; 上圖就是馮諾依曼體系結構圖&#xff0c;主要包括輸入設備&#xff0c;輸出設備&#xff0c;存儲器&#xff0c;運算…

Go高級庫存照片源碼v5.3

GoStock – 免費和付費庫存照片腳本這是一個免費和付費共享高質量庫存照片的平臺,用戶可以上傳照片與整個社區和訪客分享,并可以通過 PayPal 接收捐款。此外,用戶還可以點贊、評論、分享和收藏您最喜歡的照片。 下載 特征: 使用Laravel 10構建訂閱系統Stripe 連接漸進式網頁…

從零開始讀RocketMq源碼(一)生產者啟動

目錄 前言 獲取源碼 總概論 生產者實例 源碼 A-01:設置生產者組名稱 A-02:生產者服務啟動 B-01&#xff1a;初始化狀態 B-02&#xff1a;該方法再次對生產者組名稱進行校驗 B-03&#xff1a;判斷是否為默認生產者組名稱 B-04: 該方法是為了實例化MQClientInstance對…