@浙大疏錦行?Python day51
復習日,DDPM
class DenoiseDiffusion():def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):super().__init__()self.eps_model = eps_modelself.n_steps = n_stepsself.device = deviceself.beta = torch.linspace(0.0001, 0.02, n_steps).to(device) # beta值self.alpha = 1. - self.beta # alpha值self.alpha_bar = torch.cumprod(self.alpha, dim=0) # alpha_bar值 self.sigma2 = self.beta # sampling中的sigma_tself.tools = Tools()# forward-diffusion process 獲得 xt 所服從的高斯分布的mean和vardef q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:mean = self.tools.gather(self.alpha_bar, t) ** 0.5 * x0var = 1 - self.tools.gather(self.alpha_bar, t)return mean, var# forward-diffusion process,生成xtdef q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):if eps is None:eps = torch.randn_like(x0)mean, var = self.q_xt_x0(x0, t)return mean + (var ** 0.5) * eps # return xt 第t時刻加完噪聲的圖片# 只有 sampling時才會用到的函數,執行Denoise Process# sampling,根據xt和t推出x_{t-1} 抽象出來的一步,可以用于循環n次def p_sample(self, xt: torch.Tensor, t: torch.Tensor):eps_theta = self.eps_model(xt, t)alpha_bar = self.tools.gather(self.alpha_bar, t)alpha = self.tools.gather(self.alpha, t)eps_coef = (1 - alpha) / (1 - alpha_bar) ** 0.5mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)var = self.tools.gather(self.sigma2, t)eps = torch.randn(xt.shape, device=xt.device)return mean + (var ** 0.5) * eps # sigma_t * eps + mean# 會更新哪些模型的參數呢?# loss functiondef loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):batch_size = x0.shape[0]t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)if noise is None:noise = torch.randn_like(x0)xt = self.q_sample(x0, t, eps=noise) # 傳入的值為隨機噪聲 -- 高斯分布eps_theta = self.eps_model(xt, t) # 模型預測值return F.mse_loss(noise, eps_theta) # mse loss
# 激活函數
class Swish(nn.Module):def forward(self, x):return x* torch.sigmoid(x)class ResidualBlock(nn.Module):"""每一個Residual block都有兩層CNN做特征提取"""def __init__(self, in_channels: int, out_channels: int, time_channels: int,n_groups: int = 32, dropout: float = 0.1):"""Params:in_channels: 輸入圖片的channel數量out_channels: 經過residual block后輸出特征圖的channel數量time_channels:time_embedding的向量維度,例如t原來是個整型,值為1,表示時刻1,現在要將其變成維度為(1, time_channels)的向量n_groups: Group Norm中的超參dropout: dropout rate"""super().__init__()# 第一層卷積 = Group Norm + CNNself.norm1 = nn.GroupNorm(n_groups, in_channels)self.act1 = Swish()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))# 第二層卷積 = Group Norm + CNNself.norm2 = nn.GroupNorm(n_groups, out_channels)self.act2 = Swish()self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))# 當in_c = out_c時,殘差連接直接將輸入輸出相加;# 當in_c != out_c時,對輸入數據做一次卷積,將其通道數變成和out_c一致,再和輸出相加if in_channels != out_channels:self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) # 使用 1x1卷積修改通道數else:self.shortcut = nn.Identity() # 占位 # t向量的維度time_channels可能不等于out_c,所以我們要對起做一次線性轉換self.time_emb = nn.Linear(time_channels, out_channels)self.time_act = Swish()self.dropout = nn.Dropout(dropout)def forward(self, x: torch.Tensor, t: torch.Tensor):"""Params:x: 輸入數據xt,尺寸大小為(batch_size, in_channels, height, width)t: 輸入數據t,尺寸大小為(batch_size, time_c)【配合圖例進行閱讀】"""# 1.輸入數據先過一層卷積h = self.conv1(self.act1(self.norm1(x)))# 2. 對time_embedding向量,通過線性層使time_c變為out_c,再和輸入數據的特征圖相加h += self.time_emb(self.time_act(t))[:, :, None, None]# 3、過第二層卷積h = self.conv2(self.dropout(self.act2(self.norm2(h))))# 4、返回殘差連接后的結果return h + self.shortcut(x)# Attention Block
# 通道注意力機制class AttentionBlock(nn.Module):"""Attention模塊和Transformer中的multi-head attention原理及實現方式一致"""def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):"""Params:n_channels:等待做attention操作的特征圖的channel數n_heads: attention頭數d_k: 每一個attention頭處理的向量維度n_groups: Group Norm超參數"""super().__init__()# 一般而言,d_k = n_channels // n_heads,需保證n_channels能被n_heads整除if d_k is None:d_k = n_channels# 定義Group Normself.norm = nn.GroupNorm(n_groups, n_channels)# Multi-head attention層: 定義輸入token分別和q,k,v矩陣相乘后的結果self.projection = nn.Linear(n_channels, n_heads * d_k * 3)# MLP層self.output = nn.Linear(n_heads * d_k, n_channels)self.scale = d_k ** -0.5self.n_heads = n_headsself.d_k = d_kdef forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):"""Params:x: 輸入數據xt,尺寸大小為(batch_size, in_channels, height, width)t: 輸入數據t,尺寸大小為(batch_size, time_c)【配合圖例進行閱讀】"""# t并沒有用到,但是為了和ResidualBlock定義方式一致,這里也引入了t_ = t# 獲取shapebatch_size, n_channels, height, width = x.shape# 將輸入數據的shape改為(batch_size, height*weight, n_channels)# 這三個維度分別等同于transformer輸入中的(batch_size, seq_length, token_embedding)x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)# 計算輸入過矩陣q,k,v的結果,self.projection通過矩陣計算,一次性把這三個結果出出來 也就是qkv矩陣是三個結果的拼接# 其shape為:(batch_size, height*weight, n_heads, 3 * d_k)qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)# 將拼接結果切開,每一個結果的shape為(batch_size, height*weight, n_heads, d_k)q, k, v = torch.chunk(qkv, 3, dim=-1)# 以下是正常計算attention score的過程,不再做說明attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scaleattn = attn.softmax(dim=2)res = torch.einsum('bijh,bjhd->bihd', attn, v)# 將結果reshape成(batch_size, height*weight,, n_heads * d_k)# 復習一下:n_heads * d_k = n_channelsres = res.view(batch_size, -1, self.n_heads * self.d_k)# MLP層,輸出結果shape為(batch_size, height*weight,, n_channels)res = self.output(res)# 殘差連接res += x# 將輸出結果從序列形式還原成圖像形式,# shape為(batch_size, n_channels, height, width)res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)return res
class DownBlock(nn.Module):def __init__(self, in_channels: int, out_channels: int, time_channels: int,use_attention: bool = False):super().__init__()self.res_block = ResidualBlock(in_channels, out_channels, time_channels)if use_attention:self.attn_block = AttentionBlock(out_channels)else:self.attn_block = nn.Identity()def forward(self, x: torch.Tensor, t: torch.Tensor):x = self.res_block(x, t)x = self.attn_block(x)return xclass UpBlock(nn.Module):def __init__(self, in_channels: int, out_channels: int, time_channels: int,use_attention: bool = False):super.__init__()self.res_block = ResidualBlock(in_channels + out_channels, out_channels, time_channels)if use_attention:self.attn = AttentionBlock(out_channels)else:self.attn = nn.Identity()def forward(self, x: torch.Tensor, t: torch.Tensor):x = self.res_block(x, t)x = self.attn(x)return x
class TimeEmbedding(nn.Module):def __init__(self, n_channels: int):"""Params:n_channels:即time_channel"""super().__init__()self.n_channels = n_channelsself.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)self.act = Swish()self.lin2 = nn.Linear(self.n_channels, self.n_channels)def forward(self, t: torch.Tensor):"""Params:t: 維度(batch_size),整型時刻t"""# 以下轉換方法和Transformer的位置編碼一致# 【強烈建議大家動手跑一遍,打印出每一個步驟的結果和尺寸,更方便理解】half_dim = self.n_channels // 8emb = math.log(10_000) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)emb = t[:, None] * emb[None, :]emb = torch.cat((emb.sin(), emb.cos()), dim=1)# Transform with the MLPemb = self.act(self.lin1(emb))emb = self.lin2(emb)# 輸出維度(batch_size, time_channels)return emb
class Upsample(nn.Module):"""上采樣"""def __init__(self, n_channels):super().__init__()self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))def forward(self, x: torch.Tensor, t: torch.Tensor):_ = treturn self.conv(x)class Downsample(nn.Module):"""下采樣"""def __init__(self, n_channels):super().__init__()self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))def forward(self, x: torch.Tensor, t: torch.Tensor):_ = treturn self.conv(x)class MiddleBlock(nn.Module):def __init__(self, n_channels: int, time_channels: int):super.__init__()self.res1 = ResidualBlock(n_channels, n_channels, time_channels)self.attn = AttentionBlock(n_channels)self.res2 = ResidualBlock(n_channels, n_channels, time_channels)def forward(self, x: torch.Tensor, t: torch.Tensor):x = self.res1(x, t)x = self.attn(x)x = self.res2(x, t)return x
class UNet(Module):"""DDPM UNet去噪模型主體架構"""def __init__(self, image_channels: int = 3, n_channels: int = 64,ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),n_blocks: int = 2):"""Params:image_channels:原始輸入圖片的channel數,對RGB圖像來說就是3n_channels: 在進UNet之前,會對原始圖片做一次初步卷積,該初步卷積對應的out_channel數,也就是圖中左上角的第一個墨綠色箭頭ch_mults: 在Encoder下采樣的每一層的out_channels倍數,例如ch_mults[i] = 2,表示第i層特征圖的out_channel數,是第i-1層的2倍。Decoder上采樣時也是同理,用的是反轉后的ch_multsis_attn: 在Encoder下采樣/Decoder上采樣的每一層,是否要在CNN做特征提取后再引入attention(會在下文對該結構進行詳細說明)n_blocks: 在Encoder下采樣/Decoder下采樣的每一層,需要用多少個DownBlock/UpBlock(見圖),Deocder層最終使用的UpBlock數=n_blocks + 1 """super().__init__()# 在Encoder下采樣/Decoder上采樣的過程中,圖像依次縮小/放大,# 每次變動都會產生一個新的圖像分辨率# 這里指的就是不同圖像分辨率的個數,也可以理解成是Encoder/Decoder的層數n_resolutions = len(ch_mults)# 對原始圖片做預處理,例如圖中,將32*32*3 -> 32*32*64self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))# time_embedding,TimeEmbedding是nn.Module子類,我們會在下文詳細講解它的屬性和forward方法self.time_emb = TimeEmbedding(n_channels * 4)# --------------------------# 定義Encoder部分# --------------------------# down列表中的每個元素表示Encoder的每一層down = []# 初始化out_channel和in_channelout_channels = in_channels = n_channels# 遍歷每一層for i in range(n_resolutions):# 根據設定好的規則,得到該層的out_channelout_channels = in_channels * ch_mults[i]# 根據設定好的規則,每一層有n_blocks個DownBlockfor _ in range(n_blocks):down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))in_channels = out_channels# 對Encoder來說,每一層結束后,我們都做一次下采樣,但Encoder的最后一層不做下采樣if i < n_resolutions - 1:down.append(Downsample(in_channels))# self.down即是完整的Encoder部分self.down = nn.ModuleList(down)# --------------------------# 定義Middle部分# --------------------------self.middle = MiddleBlock(out_channels, n_channels * 4, )# --------------------------# 定義Decoder部分# --------------------------# 和Encoder部分基本一致,可對照繪制的架構圖閱讀up = []in_channels = out_channelsfor i in reversed(range(n_resolutions)):# `n_blocks` at the same resolutionout_channels = in_channelsfor _ in range(n_blocks):up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))out_channels = in_channels // ch_mults[i]up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))in_channels = out_channelsif i > 0:up.append(Upsample(in_channels))# self.up即是完整的Decoder部分self.up = nn.ModuleList(up)# 定義group_norm, 激活函數,和最后一層的CNN(用于將Decoder最上一層的特征圖還原成原始尺寸)self.norm = nn.GroupNorm(8, n_channels)self.act = Swish()self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))def forward(self, x: torch.Tensor, t: torch.Tensor):"""Params:x: 輸入數據xt,尺寸大小為(batch_size, in_channels, height, width)t: 輸入數據t,尺寸大小為(batch_size)"""# 取得time_embeddingt = self.time_emb(t)# 對原始圖片做初步CNN處理x = self.image_proj(x)# -----------------------# Encoder# -----------------------h = [x]# First half of U-Netfor m in self.down:x = m(x, t)h.append(x)# -----------------------# Middle# -----------------------x = self.middle(x, t)# -----------------------# Decoder# -----------------------for m in self.up:if isinstance(m, Upsample):x = m(x, t)else:s = h.pop()# skip_connectionx = torch.cat((x, s), dim=1)x = m(x, t)return self.final(self.act(self.norm(x)))