wan2.1代碼筆記

GPU內存不夠,可以先運行umt5,然后再運行wanpipeline,參考FLUX.1代碼筆記,或者使用ComfyUI。
下面使用隨機數代替umt5 embedding。

import torch
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2___1-T2V-1___3B-Diffusers"vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, tokenizer=None,text_encoder=None,vae=vae, torch_dtype=torch.bfloat16)
flow_shift = 3.0  # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe.to("cuda")
prompt_embeds = torch.randn(1,226,4096).to('cuda')  #隨機數
negative_prompt_embeds = torch.randn(1,226,4096).to('cuda') #隨機數output = pipe(prompt=None,negative_prompt=None,prompt_embeds = prompt_embeds,negative_prompt_embeds = negative_prompt_embeds,num_inference_steps = 1,height=480,width=832,num_frames=81,guidance_scale=6.0,).frames[0]export_to_video(output, "output.mp4", fps=16)

在這里插入圖片描述

WanPipeline的步驟和文生圖的步驟基本一致。
1.檢查輸入;
2.定義參數;
3.encode prompt;
4.準備timesteps;
5.準備latent;
6.循環去噪,最后decode.
在這里插入圖片描述
圖5,Wan-VAE 在時間維度上壓縮了4倍,空間維度上長和寬分別壓縮了8倍。
channel數為16,latent的維度就是(1,16,21,60,104)

        num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 #(81-1)/4+1shape = (batch_size, #1num_channels_latents, #16num_latent_frames,int(height) // self.vae_scale_factor_spatial, #480/8int(width) // self.vae_scale_factor_spatial, #832/8)

WanTransformer3DModel

在patchify中,WanTransformer3DModel 使用(1,2,2)的3D卷積核,將輸入的序列轉換為(B,L,D)維度,其中B為batch size,L為(1+T/4)×H/16×W/16,D為latent的維度。

self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)rotary_emb = self.rope(hidden_states) #(1,1,32760,64)
hidden_states = self.patch_embedding(hidden_states) #(1,1536,21,30,52)
hidden_states = hidden_states.flatten(2).transpose(1, 2)#(1,32760,1536)temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) #(1,1536),(1,9216),(1,226,1536),None
timestep_proj = timestep_proj.unflatten(1, (6, -1))#(1,6,1536)if encoder_hidden_states_image is not None:encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)for block in self.blocks:hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)if USE_PEFT_BACKEND:# remove `lora_scale` from each PEFT layerunscale_lora_layers(self, lora_scale)if not return_dict:return (output,)return Transformer2DModelOutput(sample=output)

WanTransformerBlock

1.3B模型有30個WanTransformerBlock,DiT結構。30個WanTransformerBlock是共享temb參數的,在每個Block中學習一個偏差self.scale_shift_table,self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5),通過大量實驗證明,這種設計可使參數數量減少約 25%,并表明在相同參數規模下,該方法能顯著提升性能。

class WanTransformerBlock(nn.Module):def forward(self,hidden_states: torch.Tensor,encoder_hidden_states: torch.Tensor,temb: torch.Tensor,rotary_emb: torch.Tensor,) -> torch.Tensor:shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb.float()).chunk(6, dim=1)# 1. Self-attentionnorm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)# 2. Cross-attentionnorm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)hidden_states = hidden_states + attn_output# 3. Feed-forwardnorm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)ff_output = self.ffn(norm_hidden_states)hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)return hidden_states

vae.decode

遍歷全部的frame。

        x = self.post_quant_conv(z) #(1,16,21,60,104)for i in range(num_frame): #21self._conv_idx = [0]if i == 0:out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)else:out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)out = torch.cat([out, out_], 2)out = torch.clamp(out, min=-1.0, max=1.0)self.clear_cache()

1.3B模型中,feat_cache是一個長度為33的list,在完整的decode過程中,需要用到33個Conv3d。

        def _count_conv3d(model):count = 0for m in model.modules():if isinstance(m, WanCausalConv3d):count += 1return countself._conv_num = _count_conv3d(self.decoder)self._conv_idx = [0]self._feat_map = [None] * self._conv_num

class WanDecoder3d是執行Vae decode的類。根據圖5,這里要運行兩次時間維度的放大和三次空間維度的放大。
CACHE_T = 2,緩存后兩個frame的值。
緩存處理,除了前兩個frame的feat_cache要特殊處理,每個feat_cache元素都含有兩個frame,然后和當前的frame湊成3個frame進行下一步計算。同時取feat_cache元素的最后一個frame和當前的frame,更新feat_cache。如圖6

class WanDecoder3ddef forward(self, x, feat_cache=None, feat_idx=[0]):## conv1if feat_cache is not None:idx = feat_idx[0]cache_x = x[:, :, -CACHE_T:, :, :].clone()if cache_x.shape[2] < 2 and feat_cache[idx] is not None:# cache last frame of last two chunkcache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)x = self.conv_in(x, feat_cache[idx]) #(1,384,1,60,104)feat_cache[idx] = cache_xfeat_idx[0] += 1else:x = self.conv_in(x)## middlex = self.mid_block(x, feat_cache, feat_idx)#(1,384,1,60,104)## upsamplesfor up_block in self.up_blocks:x = up_block(x, feat_cache, feat_idx) #(1,192,2,120,208),(1,192,4,240,416),(1,96,4,480,832),(1,96,4,480,832)## headx = self.norm_out(x)x = self.nonlinearity(x)if feat_cache is not None:idx = feat_idx[0]cache_x = x[:, :, -CACHE_T:, :, :].clone()if cache_x.shape[2] < 2 and feat_cache[idx] is not None:# cache last frame of last two chunkcache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)x = self.conv_out(x, feat_cache[idx])feat_cache[idx] = cache_xfeat_idx[0] += 1else:x = self.conv_out(x)return x

在這里插入圖片描述

WanCausalConv3d

class WanCausalConv3d(nn.Conv3d):r"""A custom 3D causal convolution layer with feature caching support.This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling featurecaching for efficient inference.Args:in_channels (int): Number of channels in the input imageout_channels (int): Number of channels produced by the convolutionkernel_size (int or tuple): Size of the convolving kernelstride (int or tuple, optional): Stride of the convolution. Default: 1padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0"""def __init__(self,in_channels: int,out_channels: int,kernel_size: Union[int, Tuple[int, int, int]],stride: Union[int, Tuple[int, int, int]] = 1,padding: Union[int, Tuple[int, int, int]] = 0,) -> None:super().__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,)# Set up causal paddingself._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)self.padding = (0, 0, 0)def forward(self, x, cache_x=None):padding = list(self._padding)if cache_x is not None and self._padding[4] > 0:cache_x = cache_x.to(x.device)x = torch.cat([cache_x, x], dim=2)padding[4] -= cache_x.shape[2]x = F.pad(x, padding)return super().forward(x)

空間維度,self.resample,nn.Upsample上采樣擴大2倍,然后維度縮小1/2.
時間維度,nn.Conv3d,輸出維度擴大2倍。

        if mode == "upsample2d":self.resample = nn.Sequential(WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1))elif mode == "upsample3d":self.resample = nn.Sequential(WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1))self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))class WanUpsample(nn.Upsample):def forward(self, x):return super().forward(x.float()).type_as(x)

upsample blocks

ModuleList((0): WanUpBlock((resnets): ModuleList((0-2): 3 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity()))(upsamplers): ModuleList((0): WanResample((resample): Sequential((0): WanUpsample(scale_factor=(2.0, 2.0), mode='nearest-exact')(1): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(time_conv): WanCausalConv3d(384, 768, kernel_size=(3, 1, 1), stride=(1, 1, 1)))))(1): WanUpBlock((resnets): ModuleList((0): WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(192, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): WanCausalConv3d(192, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1)))(1-2): 2 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity()))(upsamplers): ModuleList((0): WanResample((resample): Sequential((0): WanUpsample(scale_factor=(2.0, 2.0), mode='nearest-exact')(1): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(time_conv): WanCausalConv3d(384, 768, kernel_size=(3, 1, 1), stride=(1, 1, 1)))))(2): WanUpBlock((resnets): ModuleList((0-2): 3 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(192, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(192, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity()))(upsamplers): ModuleList((0): WanResample((resample): Sequential((0): WanUpsample(scale_factor=(2.0, 2.0), mode='nearest-exact')(1): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))))(3): WanUpBlock((resnets): ModuleList((0-2): 3 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity())))
)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/84314.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/84314.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/84314.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

環境搭建與工具配置

3.1 本地環境搭建 3.1.1 WAMP環境搭建漏洞靶場&#xff08;一、二&#xff09; WAMP&#xff08;Windows Apache MySQL PHP&#xff09;是搭建本地Web漏洞靶場的基礎環境。 安裝步驟&#xff1a; Apache&#xff1a;下載并安裝最新版Apache HTTP Server&#xff0c;配置監…

STM32F446主時鐘失效時DAC輸出異常現象解析與解決方案

—### 現象概述 在STM32F446微控制器應用中&#xff0c;若主時鐘&#xff08;HSE&#xff09;的晶體信號對地短路&#xff0c;但DAC&#xff08;數模轉換器&#xff09;仍能輸出變化信號&#xff0c;這一現象看似矛盾&#xff0c;實則與系統時鐘切換機制密切相關。本文將從硬件…

React 如何封裝一個可復用的 Ant Design 組件

文章目錄 前言一、為什么需要封裝組件&#xff1f;二、 仿antd組件的Button按鈕三、封裝一個可復用的表格組件 (實戰)1. 明確需求2. 設計組件 API3. 實現組件代碼4. 使用組件 三、封裝組件的最佳實踐四、進階優化 總結 前言 作為一名前端開發工程師&#xff0c;在日常項目中&a…

STC89C52RC/LE52RC

STC89C52RC 芯片手冊原理圖擴展版原理圖 功能示例LED燈LED燈的常亮效果LED燈的閃爍LED燈的跑馬燈效果&#xff1a;從左到右&#xff0c;從右到左 數碼管靜態數碼管數碼管計數mian.cApp.cApp.hCom.cCom.hDir.cDir.hInt.cInt.hMid.cMid.h 模板mian.cApp.cApp.hCom.cCom.hDir.cDir…

踩坑記錄:RecyclerView 局部刷新notifyItemChanged多次調用只觸發一次 onBindViewHolder 的原因

1. 問題背景 在做項目的時候&#xff0c;RecyclerView需要使用局部刷新&#xff0c;使用 notifyItemChanged(position, payload) 實現局部刷新&#xff0c;但發現調用多次只執行了一次&#xff0c;第二個刷新不生效。 2. 錯誤示例&#xff08;只處理 payloads.get(0)&#xff…

OpenLayers 加載鷹眼控件

注&#xff1a;當前使用的是 ol 5.3.0 版本&#xff0c;天地圖使用的key請到天地圖官網申請&#xff0c;并替換為自己的key 地圖控件是一些用來與地圖進行簡單交互的工具&#xff0c;地圖庫預先封裝好&#xff0c;可以供開發者直接使用。OpenLayers具有大部分常用的控件&#x…

WPF···

設置啟動頁 默認最后一個窗口關閉,程序退出,可以設置 修改窗體的icon圖標 修改項目exe圖標 雙擊項目名會看到代碼 其他 在A窗體點擊按鈕打開B窗體,在B窗體設置WindowStartupLocation=“CenterOwner” 在A窗體的代碼設置 B.Owner = this; B.Show(); B窗體生成在A窗體中間…

github公開項目爬取

import requestsdef search_github_repositories(keyword, tokenNone, languageNone, max_results1000):"""通過 GitHub API 搜索倉庫&#xff0c;支持分頁獲取所有結果&#xff08;最多 1000 條&#xff09;:param keyword: 搜索關鍵詞:param token: GitHub To…

防震基座在半導體晶圓制造設備拋光機詳細應用案例-江蘇泊蘇系統集成有限公司

在半導體制造領域&#xff0c;晶圓拋光作為關鍵工序&#xff0c;對設備穩定性要求近乎苛刻。哪怕極其細微的振動&#xff0c;都可能對晶圓表面質量產生嚴重影響&#xff0c;進而左右芯片制造的成敗。以下為您呈現一個防震基座在半導體晶圓制造設備拋光機上的經典應用案例。 企…

S32K開發環境搭建詳細教程(一、S32K IDE安裝注冊)

一、S32K IDE安裝注冊 1、進入恩智浦官網https://www.nxp.com.cn/&#xff08;需要在官網注冊一個賬號&#xff09; 2、直接搜索 “Standard Software”&#xff0c;找到S32K3 Standard Software&#xff0c;點擊進入 3、下載 (1)Automotive SW - S32K3 - S32 Design Studio…

Spring Cloud Gateway 微服務網關實戰指南

上篇文章簡單介紹了SpringCloud系列OpenFeign的基本用法以及Demo搭建&#xff08;Spring Cloud實戰&#xff1a;OpenFeign遠程調用與服務治理-CSDN博客&#xff09;&#xff0c;今天繼續講解下SpringCloud Gateway實戰指南&#xff01;在分享之前繼續回顧下本次SpringCloud的專…

MSP430G2553 USCI模塊串口通信

1.前言 最近需要利用msp430連接藍牙模塊傳遞數據&#xff0c;于是死磕了一段時間串口&#xff0c;在這里記錄一下 2.msp430串口模塊 msp430的串口模塊可以有USCI模塊提供 在異步模式中&#xff0c; USCI_Ax 模塊通過兩個外部引腳&#xff0c; UCAxRXD 和 UCAxTXD&#xff0…

【產品經理從0到1】用戶端產品設計與用戶畫像

思考 xx新聞的第一個版本應該做哪些事情呢&#xff1f; 用戶端核心功能 用戶端通用頁面設計 思考 回想一下&#xff0c;大家在第一次使用一個新下載的App的時候會看到一些什么樣的頁面?這樣的頁面一般都是展示了一些什么內容? 引導頁 概念 第一次安裝App或者更新App后第…

多場景游戲AI新突破!Divide-Fuse-Conquer如何激發大模型“頓悟時刻“?

多場景游戲AI新突破&#xff01;Divide-Fuse-Conquer如何激發大模型"頓悟時刻"&#xff1f; 大語言模型在強化學習中偶現的"頓悟時刻"引人關注&#xff0c;但多場景游戲中訓練不穩定、泛化能力差等問題亟待解決。Divide-Fuse-Conquer方法&#xff0c;通過…

佰力博科技與您探討壓電材料的原理與壓電效應的應用

壓電材料的原理基于正壓電效應和逆壓電效應&#xff0c;即機械能與電能之間的雙向轉換特性。 壓電材料的原理源于其獨特的晶體結構和電-機械耦合效應&#xff0c;具體可分為以下核心要點&#xff1a; 1. ?正壓電效應與逆壓電效應的定義? ?正壓電效應?&#xff1a;當壓電…

算法備案審核周期

&#xff08;一&#xff09;主體備案審核 主體備案審核周期通常為7-10個工作日&#xff0c;監管部門將對企業提交的資質信息進行嚴格審查&#xff0c;審核重點包括&#xff1a; 營業執照的真實性、有效性及與備案主體的一致性。法人及算法安全責任人身份信息的準確性與有效性…

管理系統的接口文檔

一、接口概述 本接口文檔用于描述圖書管理系統中的一系列 Restful 接口&#xff0c;涵蓋圖書的查詢、添加、更新與刪除操作&#xff0c;以及用戶的登錄注冊等功能&#xff0c;方便客戶端與服務器之間進行數據交互。 二、接口基礎信息 接口地址&#xff1a;https://book-manag…

杰發科技AC7801——PWM獲取固定脈沖個數

測試通道6 在初始化時候打開通道中斷 void PWM1_GenerateFrequency(void) {PWM_CombineChConfig combineChConfig[1]; //組合模式相關結構體PWM_IndependentChConfig independentChConfig[2];//獨立模式相關結構體PWM_ModulationConfigType pwmConfig; //PWM模式相關結構體PWM…

RL電路的響應

學完RC電路的響應&#xff0c;又過了一段時間了&#xff0c;想必很多人都忘了RC電路響應的一些內容。我們這次學習RL電路的響應&#xff0c;以此同時&#xff0c;其實也是帶大家一起回憶一些之前所學的RC電路的響應的一些知識點。所以&#xff0c;這次的學習&#xff0c;其實也…

鴻蒙Flutter實戰:21-混合開發詳解-1-概述

引言 在前面的系列文章中&#xff0c;我們從搭建開發環境開始&#xff0c;講到如何使用、集成第三方插件&#xff0c;如何將現有項目進行鴻蒙化改造&#xff0c;以及上架審核等內容&#xff1b;還以高德地圖的 HarmonyOS SDK 的使用為例&#xff0c; 講解了如何將高德地圖集成…