我自己的原文哦~? ? ?https://blog.51cto.com/whaosoft/13905076
#Executor-Workers架構
圖解Vllm V1系列2
本文詳細介紹了vllm v1的Executor-Workers架構,包括Executor的四種類型(mp、ray、uni、external_launcher)及其適用場景,以及通過圖解展示了Executor與Workers之間的數據傳輸機制和整體架構,幫助讀者理解vllm v1在分布式推理中的核心設計和運作方式。
前文:圖解Vllm V1系列1:整體流程
在前文中,我們討論了 vllm v1 在 offline batching / online serving 這兩種場景下的整體運作流程,以offline batching為例:
整體上來看:
vllm v1將請求的pre-process和輸出結果的post-process與實際的推理過程拆分在2個不同的進程中(process0, process1)。
Client負責請求的pre-process和輸出結果的post-process,EngineCore負責實際的推理過程,不同進程間使用ZMQ來通信數據。
對于offline batching和online serving來說,它們會選取不同類型的Client進行運作,但是它們的EngineCore部分運作基本是一致的,如上圖所示。
通過這樣的進程拆分,在更好實現cpu和gpu運作的overlap的同時,也將各種模型復雜的前置和后置處理模塊化,統一交給processor和output_processor進行管理。
本文我們來關注上圖中的Executor部分,也就是管控模型分布式推理的核心部分,我們關注的是它的整體架構和初始化過程,而它實際執行推理的細節,我們留到后續文章細說。
一、Executor的類型
在vllm中,Executor一共有4種類型,由配置參數--distributed-executor-backend決定,相關的代碼和文檔參見:
代碼:
- 決定executor的類型:https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/config.py#L1465
- 根據executor的類型,import具體的executor:https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/v1/executor/abstract.py#L25
文檔:
- ??https://docs.vllm.ai/en/stable/serving/engine_args.html??
- ??https://docs.vllm.ai/en/stable/serving/distributed_serving.html#distributed-inference-and-serving??
對于--distributed-executor-backend,默認情況下為None,你當然也可以手動指定。在默認情況下,vllm會根據你的分布式配置(world_size)和你所使用的平臺特征(cuda、neuron、是否安裝/初始化/配置了ray等)來自動決定--distributed-executor-backend的取值。
我們來簡單介紹下這4種類型的Executor。
(1)mp:MultiprocExecutor
- 適用場景:單機多卡。當單機的卡數滿足分布式配置,且沒有正在運行的ray pg時,默認使用mp
- 在mp的情況下,Executor成為一個主進程,其下的若干個workers構成了它的子進程們
(2)ray:RayDistributedExecutor
- 適用場景:多機多卡
- 在ray的情況下,Executor成為一個ray driver process,其下管控著若干worker process
(3)uni:UniProcExecutor
- 適用場景:單卡或 Neuron環境
(4)external_launcher:ExecutorWithExternalLauncher
- 適用場景:想要用自定義的外部工具(例如Slurm)來做分布式管理
注意:以上的“適用場景”描述的是一般情況,更具體的細節,請參見“決定executor類型”的相關代碼。
在本文中,我們將以mp: MultiProcExecutor進行講解,并假設這里的分布式配置僅用了tp,其余的細節留給用戶自行閱讀。
二、Executor -> Workers
2.1 整體架構-官網版
我們先來看下官方給出的Executor-Workers架構圖。
??https://blog.vllm.ai/2025/01/27/v1-alpha-release.html??
上圖右側刻畫了V1的架構:
- Scheduler和Executor都位于EngineCoreProc所在的進程上。如本文第一章offline batching的流程圖所示,Scheduler決定單次調度步驟中要送去推理的請求,并將這些請求發送給Executor。
- 一個Executor下管理著若干workers,每個workers位于獨立的進程上,可以理解成一個workers占據著一張卡
- Executor負責把請求broadcast到各個workers上
- 各個workers接收到請求,負責執行實際的推理過程,并將推理結果返回給Executor。
相比于V0,V1這種架構設計的優勢在于:在V0中,worker0既要負責調度和數據分發、又要負責實際的推理計算。如此一來,各個workers間存在負載不均的問題,而worker0將成為性能的瓶頸。而V1通過拆分【調度】和【計算】過程解決了這個問題。
2.2 整體架構-細節版
現在我們已經通過vllm官方的示例圖,初步了解了V1下Executor-Workers的架構,現在我們擴展這張架構圖,來看更多的細節,為了畫圖簡明,這里我們只展示了其中1個worker,沒有畫出全部workers:
上圖展示的是使用MultiprocExecutor下的架構,如前文所說,該類型的Executor常被用于單機多卡的推理場景,我們按照從上到下,從左到右的順序來解讀上圖。
1、MultiprocExecutor和Scheduler都位于EngineCoreProc所在的進程中。Scheduler負責決定單次調度步驟中,要送去推理的reqs,并將這些reqs傳遞給MultiprocExecutor。
2、在MultiprocExecutor上,我們將創建一個rpc_broadcast_mq隊列:
- 該隊列存儲著Executor要broadcast給各個workers的【小數據(<=10MB)】,而【大數據(>10MB)】則不會進入此隊列,而是通過zmq socket進行傳輸
- 每條數據可以被粗糙理解成是(method, data)的形式,data = 數據本身,method=你期望worker上調用什么樣的方法來處理這條數據。
- 針對這個隊列、以及大小數據的傳輸細節,我們將在本文第三部分詳細介紹。
3、在MultiProcExecutor上,通過make_worker_process創建子進程:
- 每個進程占據一張卡,每個進程上運行著一個worker實例
- 在創建每個子進程時,我們會將rpc_broadcast_mq_handler,也就是輸入隊列的句柄也傳遞給子進程,這里你可以粗糙將“handler(句柄)”理解成是一個“地址”,有了這個地址,每個子進程才知道要去哪里找到并【連接】這個隊列,以此讀到隊列中的數據。相關細節我們同樣在后文給出。
4、每個Worker實例又由以下幾部分組成:
- WorkerWrapper,每個Worker實例都有屬于自己的WorkerWrapper。你可以將它形象理解成是一個worker的manager,它負責管控一個worker的生命周期(創建->銷毀)、所占資源、擴展功能(例如在rlhf場景下的一些功能)等等。
- Worker,真正的Worker實例,它上面維護著兩個重要隊列:
- rpc_broadcast_mq:正如3中所說,單個worker通過rpc_broadcast_mq_handler這個句柄,連接上了Executor的rpc_broadcast_mq隊列,這樣它就能從這個隊列中讀取(method, data)數據。注意,這里說的是【連接】而不是創建,為了強調這一點,圖中單worker上的該隊列用虛線表示。
- worker_response_mq:單個worker【創建】的、用于存放這個worker上推理輸出結果的隊列。同樣也會產出worker_response_mq_handler這個句柄。后續這個句柄將通過zmq socket傳送給Executor,讓Executor可以連接上這個隊列,這樣Executor就可以獲得每個worker的輸出結果。
- ModelRunner,一個worker實例下維護著一個ModelRunner實例,這個實例上維護著模型權重分片(model weights sharding)、這塊卡上的kv_caches、attn_backend等一系列的具體模型信息,它將最終負責模型權重的加載(load_model),并執行實際的推理過程。
5、連接Executor和Worker的ZMQ sockets:
- Executor和Worker分屬不同的進程,這里依然采用ZMQ sockets做進程間的通信。
- 這里其實創建了多個不同socket(為了表達簡便,我統一畫成ZMQ sockets),每個socket會用于不同內容的通信,例如:
- ready_socket:worker進程向Executor發送ready信號 + worker_broadcast_mq_handler
- local_socket:如前文所說,**除了使用上述的2個隊列做Executor->Worker間的輸入輸出通信外,我們還會直接使用local_socket做輸入輸出通信。前者用于單機內快速通信較小的數據(<=10MB),后者用于通信大數據(>10MB)**。我們會在后文細說這一點。
- 等等
**6、worker_busy_loop()**:
在worker上啟動busy loop,持續監聽Executor發送的數據、做推理、并將推理結果持續返回給Executor。這樣一來,這個worker就無限運轉起來了,除非收到用戶信號,顯式終止掉這個worker,否則這個busy loop不會停止。
到此為止,我們簡單總結一下在Executor->Workers的初始化環節都做了什么事:
- 首先,按照上圖所示,創建了Executor->Workers架構,特別注意上述2個輸入輸出隊列的初始化和連接。
- 對于每個Worker,我們通過init_device(),將它綁到指定的卡上,并對它做分布式環境初始化(即制定它的分布式通信group)
- 對于每個worker,我們通過load_model(),當ModelRunner實際去加載這個worker所要的模型分片
- 在每個worker上啟動run_busy_loop(),讓worker持續不斷地運轉起來。 更多的細節,請大家自行閱讀源碼。
接下來,我們著重來討論這rpc_broadcast_mq和worker_response_mq這兩個輸入輸出隊列。
三、Executor與Worker間的數據傳輸機制
我們先快速回顧一下上文的內容:
(1)在我們的例子中,Executor的具體類型是MultiprocExecutor,它一般適用于單機多卡推理。
(2)Executor和Worker分屬不同的進程,Executor需要把輸入數據broadcast到Worker上,Worker需要把推理的輸出結果返回給Executor。
(3)對于小數據(<=10MB),vllm使用rpc_broadcast_mq和worker_response_mq來做數據傳輸,這兩個隊列的本質是ShmRingBuffer(環形共享緩存),其中Shm即我們熟知的shared_memory,而ring是使用環形的方式往shm中讀寫數據(看不懂也沒關系,我們馬上來說細節)。
(4)對于大數據(>10MB),vllm使用zmq socket來做數據傳輸。
為什么要設計2種不同的進程間通信機制,來分別處理【小數據】和【大數據】呢?這里簡單說幾個我能想到的原因:
(1)首先,通過shm的方式讀寫數據時,不同的進程都從同一塊共享內存(shm)上直接讀取,這樣數據不需要從一個進程的地址空間復制到另一個進程的的地址空間,也就是可以實現數據的“零拷貝訪問”
(2)其次,通過shm的方式讀寫數據時,可以避免網絡協議棧和數據重復寫入的開銷,可以實現更高效、更快的數據訪問。
(3)那么,既然shm這么好,為什么只讓【小數據】使用它,而讓【大數據】走zmq socket呢?這是因為shm是一塊固定的內存大小,一旦預分配好,就不能被改變了。在實際使用場景中,可能需要傳輸的數據量本身就不大,只是會偶發出現一些【大數據】傳輸的情況,因此我們沒必要預留更大的shm空間,來應對這些只是偶發情況,這樣會造成內存的浪費。所以我們額外使用zmq socket來處理這些偶發情況。
3.1 ShmRingBuffer(共享環形緩存)
我們先來看小數據傳輸的實現機制,相關代碼參見:
??https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/distributed/device_communicators/shm_broadcast.py#L44??
假設我們現在只使用tp,即一個Executor下有若干tp workers,那么:
- Scheduler生成單次調度的結果,將這些結果傳遞給Executor,我們稱單次調度的結果為一個chunk
- Executor拿到單次調度的結果,寫入rpc_broadcast_mq(本質是ShmRingBuffer)中
- 這些tp workers都需要從rpc_broadcast_mq讀取這個chunk(每個tp worker的輸入是相同的)
- 各個tp workers執行推理,并將推理結果寫入各自維護的worker_broadcast_mq(本質是ShmRingBuffer)中。
- Scheduler繼續生成單次調度結果(chunk),重復以上步驟。
- 不難發現,對于一個chunk,我們總有1個writer,和1個或若干個readers,例如:
- rpc_broadcast_mq的chunk中,writer = Executor,readers = tp workers
- worker_broadcast_mq的chunk中,writer = 某個tp worker,reader = Executor
現在,讓我們將ShmRingBuffer想象成是一個存儲站:
- 這個存儲站中有max_chunks個柜子,每個柜子用于存儲一塊數據(chunk),max_chunk默認值為10
- 每個柜子的最大數據存儲量為max_chunk_bytes,該值當前默認為10MB
- 每個柜子上有一個面板(metadata),這個面板上有1 + n_reader個指示燈。其中1這個指示燈代表written_flag,即用于指示writer是否把chunk塞進了柜子(寫入完畢),n_reader個指示燈代表reader_flags,分別表示這些readers是否已經將這個chunk讀取完畢。
- 由此可知,對于一個柜子,只有當writer寫入完畢后,readers才可以去讀。只有當所有readers都讀取完畢后,這個柜子里的chunk才可以被“廢棄”,也就是這個柜子才可以重新回到“可寫入”的狀態,讓writer寫入新數據。
- Scheduler在做一輪又一輪的調度,產出一個又一個的chunk,那么這些chunk就按照順序,依次裝入這些柜子中,當這10個柜子的數據都被輪番用過以后,下一次再來新chunk前,就從0號柜開始復用起(當然要按照上條所說的,檢查該柜子是否達到可復用狀態),這種環形使用的方式,稱之為“ring”。
有了以上這個形象的理解,現在我們再回過頭來看vllm代碼中的這部分注釋,就不難讀懂了,而關于代碼的更多細節,請大家自行閱讀源碼:
3.2 zmq socket
正如前文所說,在Executor和Worker間做大數據(>10MB)的傳輸時,可以使用zmq socket,這塊就是傳統的zmq socket構建流程了,沒有太多可說的。這里我們想結合worker.worker_busy_loop()(也就是一個worker持續讀取輸入、進行推理、寫入輸出)的過程,來具體看一下shm和zmq socket是如何配合運作的。
worker_busy_loop()入口:
- ??https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/v1/executor/multiproc_executor.py#L362??
從代碼中我們發現一件有趣的事:這里好像只從shm上讀取數據,并沒有通過zmq socket呀!不要緊,我們現在深入rpc_broadcast_mq.dequeue()中一探究竟。
rpc_broadcast_mq.dequeue():
- ??https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/distributed/device_communicators/shm_broadcast.py#L462??
整體上說:
- 當一個chunk過來時,我們會先檢查它的大小:
- 如果<=10MB,則裝入shm對應的柜子中
- 如果 > 10MB,則先在shm對應的柜子中記上一筆(buf[0]==1),然后再通過zmq sokcet去send這份數據
- 以上過程不在所提供的代碼截圖中,需要大家自行找相關代碼閱讀
- 接著,我們會根據這個柜子的標志buf[0]是否為1,來檢測對應的chunk是裝在柜子里,還是通過zmq socket發送了。如果是前者,那么直接從shm讀;如果是后者,那么就通過zmq socket做receive。
- 最后,如果當前chunk是大數據,雖然它不會裝在對應的柜子里,但我們也會認為這個柜子已經被使用過。這樣后一個chunk來的時候,它不會使用當前的柜子,而是使用下一個可用的柜子。
你可能會發現,在上述dequeue的代碼中,存在2種zmq socket:local_socket和remote_socket,前者用于單機內的通信(writer和readers在一臺node內),后者用于多機間的通信(writer和readers不在一臺內)。由于我們當前都是以MultiprocExecutor這種單機場景為例的,所以我們提到的zmq socket都是指local_socket。
好,關于Executor->Worker架構的介紹就到這里了,大家可以配合本文,自行閱讀源碼,獲取更多細節。在這個系列后面的文章中,我們還會看到更多Executor-> Worker配合運行的更多例子。
#T2I-R1
文生圖也能這樣玩?T2I-R1:把R1的推理范式應用到文生圖任務!
文生圖進入R1時刻:港中文MMLab發布T2I-R1。
論文:https://arxiv.org/pdf/2505.00703
代碼:https://github.com/CaraJ7/T2I-R1
最近的大語言模型(LLMs)如OpenAI o1和DeepSeek-R1,已經在數學和編程等領域展示了相當強的推理能力。通過強化學習(RL),這些模型在提供答案之前使用全面的思維鏈(CoT)逐步分析問題,顯著提高了輸出準確性。最近也有工作將這種形式拓展到圖片理解的多模態大模型中(LMMs)中。然而,這種CoT推理策略如何應用于自回歸的圖片生成領域仍然處于探索階段,我們之前的工作Image Generation with CoT(https://github.com/ZiyuGuo99/Image-Generation-CoT)對這一領域有過首次初步的嘗試。
與圖片理解不同,圖片生成任務需要跨模態的文本與圖片的對齊以及細粒度的視覺細節的生成。為此,我們提出了適用于圖片生成的兩個不同層次的CoT推理:
Semantic-CoT
Semantic-CoT 是對于要生成的圖像的文本推理,在圖像生成之前進行。
負責設計圖像的全局結構,例如每個對象的外觀和位置。
優化Semantic-CoT可以在圖片Token的生成之前顯式地對于Prompt進行規劃和推理,使生成更容易。
Token-CoT
- Token-CoT是圖片Token的逐塊的生成過程。這個過程可以被視為一種CoT形式,因為它同樣是在離散空間中基于所有先前的Token輸出后續的Token,與文本CoT類似。
- Token-CoT更專注于底層的細節,比如像素的生成和維持相鄰Patch之間的視覺連貫性。
- 優化Token-CoT可以提高生成圖片的質量以及Prompt與生成圖片之間的對齊。
然而,盡管認識到這兩個層次的CoT,一個關鍵問題仍然存在:我們怎么能協調與融合它們??當前主流的自回歸圖片生成模型如VAR完全基于生成目標進行訓練,缺乏Semantic-CoT推理所需的顯式文本理解。雖然引入一個專門用于提示解釋的獨立模型(例如LLM)在技術上是可行的,但這種方法會顯著增加計算成本、復雜性和部署的困難。最近,出現了一種將視覺理解和生成合并到單一模型中的趨勢。在LMMs的基礎上,這些統一LMMs(ULMs)不僅可以理解視覺輸入,還可以從文本提示生成圖像。然而,它們的兩種能力仍然是解耦的,通常在兩個獨立階段進行預訓練,沒有明確證據表明理解能力可以使生成受益。鑒于這些潛力和問題,我們從一個ULM(Janus-Pro)開始,增強它以將Semantic-CoT以及Token-CoT統一到一個框架中用于文本生成圖像:
我們提出了BiCoT-GRPO,一種使用強化學習的方法來聯合優化ULM的兩個層次的CoT:
我們首先指示ULM基于Image Prompt來想象和規劃圖像來獲得Semantic-CoT。然后,我們將Image Prompt和Semantic-CoT重新輸入ULM來生成圖片以獲得Token-CoT。我們對于一個Image Prompt生成多組Semantic-CoT和Token-CoT,對于得到的圖像計算組內的相對獎勵,從而使用GRPO的方法來在一個訓練迭代內,同時優化兩個層次的CoT。
與圖片的理解任務不同,理解任務有明確定義的獎勵規則,圖像生成中不存在這樣的標準化的規則。為此,我們提出使用多個不同的視覺專家模型的集成來作為獎勵模型。這種獎勵設計有兩個關鍵的目的:
- 它從多個維度評估生成的圖像以確保可靠的質量評估
- 作為一種正則化方法來防止ULM過擬合到某個單一的獎勵模型
根據我們提出的方法,我們獲得了T2I-R1,這是第一個基于強化學習的推理增強的文生圖模型。根據T2I-R1生成的圖片,我們發現我們的方法使模型能夠通過推理Image Prompt背后的真實意圖來生成更符合人類期望的結果,并在處理不尋常場景時展現出增強的魯棒性。
同時,定量的實驗結果也表明了我們方法的有效性。T2I-R1在T2I-CompBench和WISE的Benchmark上分別比baseline模型提高了13%和19%的性能,在多個子任務上甚至超越了之前最先進的模型FLUX.1。
#One Step Diffusion Via ShortCut Models論文解讀
AIGC新手,內容理解如有不對請多多指正。
原文:One Step Diffusion via Shortcut Models
github:GitHub - kvfrans/shortcut-models?
摘要
為了緩解目前diffusion架構+flow matching生成速度慢且訓練階段復雜的問題提出了一個叫shortcut model的模型,整個訓練過程采用單一網絡、單一訓練階段。condition包括當前噪聲強度,還取決于stepsize(為了在去噪過程中直接跳過),這個方法在作者的實驗中比蒸餾的方法好,并且在推理的時候可以改變step budgets。
之前SD3采直流匹配訓練,但是仍然需要28步,這篇論文在首頁放了一張效果圖,效果看起來很驚艷。
初步效果圖
整個網絡設置是端到端且只需要一次訓練就可以完成一個one-step模型,不像之前的關于蒸餾的工作(參考Progressive Distillation for Fast Sampling of Diffusion Models、http://arxiv.org/abs/2211.12039、Relational Diffusion Distillation for Efficient Image Generation,這三個工作都基于教師-學生來蒸餾,通過多階段的訓練來逐漸折半DDIM的采樣步數)。?
前置知識?
Flow-matching
流匹配的內容在網絡上的很多博客都有講解,這部分就簡單帶過一下。
流匹配實際上就是通過學習將噪聲轉化為數據的常微分方程(ODE)來解決生成模型問題,在直流匹配中,整個模型就把真實圖像的概率分布和噪聲的概率分布之間的路徑當作一條直線進行傳輸,在給定 x0 和 x1 的情況下,速度 vt 是完全確定的。但是,如果只給定 xt,就會有多個可信的配對(x0、x1),因此速度會有不同的值,這就使得 vt 成為一個隨機變量。Flow-matching模型就是用來估計預期值在xt條件下的vt是多少,然后vt是xt處所有可信速度的平均值。最后可以通過對隨機采樣的噪聲 x0 和數據 x1 對的經驗速度進行回歸來優化流量模型。
例子就是直流匹配
這個速度vt就是直接用xt對t求導,就得到了x1-x0,然后整個模型優化就靠下面這個損失函數。
直流匹配的損失函數
實際上就是用回歸的損失去盡量讓預測的速度能夠符合直流匹配定義的速度。
然后去噪過程就是從流量模型中采樣,首先從正態分布中采樣一個噪聲點 x0。然后根據學習到流模型從x0到x1迭代更新該點,整個過程可通過在較小的離散時間間隔內進行歐拉采樣來近似實現,因為是直線傳輸。?
為什么提出ShortCut models?
作者通過一個實驗去研究了完美訓練的ODE在步數減少之后的缺陷,具體來說就是步長有限的情況下,還是很難做到能夠將噪聲分布確定性地映射到我們需要的數據分布。
作者做的實驗,這個圖示還是很清晰的,僅給定 xt,vt雖然是根據直流的路線去學習的,但是學習得到的vt是存在固有的不確定性的,vt是指向數據點平均值的,直到縮減到一步,一步的話所有的vt幾乎是指向一個點,并不能對應原始數據分布,多樣性完全崩塌了
流量匹配學習預測從 xt 到數據的平均方向,因此跟隨預測的步長越大,將跳轉到多個數據點的平均值。在 t=0 時,模型接收純噪聲輸入,并且(x0,x1)在訓練過程中隨機配對,因此 t=0 時的預測速度指向數據集平均值。因此,即使在流量匹配目標的最優狀態下,對于任何多模式數據分布,一步生成都會失敗。這段是作者原話,感覺說得蠻清晰,就不加個人理解了。?
ShortCut Models
insight:可以訓練一個支持不同sampling budgets的一個模型,以時間步長t和步長d為作為條件。那么就順勢提出了下面這個公式。
shortcut models的核心公式
這個s就是輸入Xt,t,d之后的出來的捷徑,得到這個路徑之后就可以直接讓Xt從這個s出發跳步得到Xt+d,OK,那么整個model的訓練目標就很明確了,就是通過shortcut model去學習這個s,條件是Xt,t,d。其實整個公式就是直流匹配的跳步模式,當d≈0的時候,就是flow-matching的訓練模式,s就直接退化成了v。
那么要學的東西出來了,用什么去約束呢?第一種方法當然就是用小步長去接近flow-matching的forward過程,但是這樣做的話訓練成本也還是很高,尤其是對直接端到端訓練來說,并且小步長實際上對flow-matching的改進不是很大。第二種就是本文用的方法,直接用shortcut model自己的性質,就是一個s步等于兩個s/2步。也就是以下公式。
shortcut等價模型
初步看這個公式可能會疑惑為什么會除以2,請注意,上一個公式在s求出來之后還需要乘d,所以s其實不是最終路程,最終的路程是s*d,而整個式子左邊的步長為2d,路程相同的情況下,兩邊同時除以2d才得出來右邊等式的1/2系數。
d>0的時候就直接用這個公式,d=0就直接用流匹配去訓練。整體流程如下
shortcut對flow-matching的優化
其實就是將flow-matching當作連續的一條線,shortcut直接輸入了步長,然后網絡獲得步長之后直接去獲得應道到路徑上的哪個路徑點,就是上面圖左邊曲線的黃色部分。整個訓練過程把flow-matching綜合起來構成了下面的損失函數:
總體損失函數
上述目標學習的是從噪聲到數據的映射,在任何步長序列下查詢時都是一致的,包括直接在單步中查詢。目標中的流量匹配部分將捷徑模型建立在小步長的基礎上,以匹配經驗速度樣本。這就確保了捷徑模型在多步長查詢時具有基礎生成能力,這與等效的流量匹配模型完全相同。第二部分的話,通過串聯兩個較小shortcut的序列,為較大步長構建適當的目標。這樣,生成能力就從多步到少步再到一步。綜合目標可通過單一模型和單一端到端訓練運行進行聯合訓練。?
訓練細節
名詞定義:經驗目標就是對應損失函數第一項需要的目標,一致性目標就是對應損失函數第二項所需要的目標
當 d → 0 時,s等同于vt。因此可以使用flow-matching的損失來訓練d=0時的捷徑模型,即隨機抽樣 (x0, x1) 對并擬合vt的期望值。這個項可以看作是小步s的基礎,以匹配數據去噪ODE,然后對t ~ U (0, 1) 進行均勻采樣。為了限制復合誤差,并且限制引導路徑的總長度。因此,我們選擇了一種二元遞歸模型,即用兩條捷徑來構建一條兩倍大的捷徑。
然后確定一個步數 M 來表示逼近 ODE 的最小時間單位;在實驗中使用了 128 步。根據 d∈ (1/128, 1/64 ... 1/2, 1),這將產生 log2(128) + 1 = 8 種可能的捷徑長度。在每個訓練步驟中,我們對 xt、t 和隨機 d < 1 進行采樣,然后使用shortcut連續進行兩步。然后將這兩步的并集作為目標,并且在2d處訓練模型。
將 1-k 個經驗目標與k個一致性目標的比例結合起來,構建一個訓練批次。k=1/4是合理的。其實這部分也很好理解,因為這個端到端模型實際上就是需要先訓練一個flow-matching較好的模型,然后第二項只是在flow-matching的基礎上進行優化,如果flow-matching訓練得不好,后一項自然訓練不好,因為s_target是需要從flow-matching模型中采樣的,后一項只能在d=0訓練的基礎模型上去擬合這個模型,本質上shortcut還是一個教師-學生的思路,但是不同于之前教師和學生都是模型,shortcut將教師-學生拆分為兩個損失函數去訓練同一個模型,從而實現了端到端。
CFG設定:評估 d = 0 時的捷徑模型時使用 CFG,而在其他情況下則放棄 CFG。CFG 在捷徑模型中的一個局限性是,必須在訓練前指定 CFG 比例。
EMA:用EMA去從d=0的模型上生成d=1的一致性目標,本質上就是平滑一下誤差。
其他就是一些網絡設置,這里就不一一闡述了,有興趣可以自己查看一下原論文。?
實驗結果
FID-50K分數評估
FID-50K分數評估
可以看到在端到端的訓練框架中,shortcut models的FID-50k是SOTA,但是相對于PD的蒸餾方式來說,在一步蒸餾中效果還是有待提高。
對ShortCut提出需解決問題的驗證
FID下降趨勢
在文章開投我們就提到了這篇論文的insight,他是為了緩解flow-matching在步數極低的情況下的崩塌而提出了,這個實驗也證明了這一點,在1步模型中,Shortcut的表現完全暴打直接用flow-matching訓練的diffusion(但實際上這個對比沒有什么特別大意義,flow-matching確實就不適合一步訓練,這個問題SD3當時也提出來了)。
作者在后續甚至驗證了shortcut在其他領域的魯棒性,確實是一項非常完善的工作,有其他領域的讀者可以去看下原文。?
總結
shortcut models確實提供了一個直接在flow-matching上蒸餾的好辦法,但是訓練過程中的參數設定個人感覺還是靠多種嘗試,例如K的選取或許會較大程度影響shortcut models的發揮。反觀多階段的訓練方法,至少多階段確保了一個訓練得較為完善的教師模型能夠作為參考,而shortcut models如果參數設置不對,flow-mathcing的基礎模型可能會不夠完善,進而倒是損失第二項會出現較大程度的累計誤差。
其次,作者本人也提到了,雖然shortcut能夠抑制flow-matching直接在1步訓練上的崩潰,但是在步數太低的時候仍然和多步采樣存在較大的性能差距(不過1步能做到這個程度已經很好了。。。)。
總的來說,這篇論文的工作很完善,也是一個比較新穎的減少采樣步數的方案,但是本質上也是蒸餾的一種,并且端到端的訓練相比于多階段的訓練確實更依靠經驗,一不注意就會訓練失敗。