GWM: Towards Scalable Gaussian World Models for Robotic Manipulation
- 文章概括
- 摘要
- 1. 引言
- 2. 相關工作
- 3. 高斯世界模型(Gaussian World Model)
- 3.1. 世界狀態編碼(World State Encoding)
- 3.2. 基于擴散的動態建模(Diffusion-based Dynamics Modeling)
- 4. 實驗(Experiments)
- 4.1. 動作條件場景預測(Action-conditioned Scene Prediction)
- 4.2. 基于GWM的模仿學習(GWM-based Imitation Learning)
- 4.3. 基于GWM的強化學習(GWM-based Reinforcement Learning)
- 4.4. 真實世界部署(Real-world Deployment)
- 4.5. 消融分析(Ablation Analysis)
- 5. 結論(Conclusion)
- GWM: Towards Scalable Gaussian World Models for Robotic Manipulation (補充材料)
- A. 數據集與基準(Datasets and Benchmarks)
- B. 實現細節(Implementation Details)
- B.1. EDM 預處理(EDM Preconditioning)
- B.2. 架構設計(Architectural Design)
- B.3. 超參數(Hyper-parameters)
文章概括
引用:
@article{lu2025gwm,title={GWM: Towards Scalable Gaussian World Models for Robotic Manipulation},author={Lu, Guanxing and Jia, Baoxiong and Li, Puhao and Chen, Yixin and Wang, Ziwei and Tang, Yansong and Huang, Siyuan},journal={arXiv preprint arXiv:2508.17600},year={2025}
}
Lu, G., Jia, B., Li, P., Chen, Y., Wang, Z., Tang, Y. and Huang, S., 2025. GWM: Towards Scalable Gaussian World Models for Robotic Manipulation. arXiv preprint arXiv:2508.17600.
主頁: https://gaussian-world-model.github.io
原文: https://arxiv.org/abs/2508.17600
代碼、數據和視頻:
系列文章:
請在 《《《文章》》》 專欄中查找
宇宙聲明!
引用解析部分屬于自我理解補充,如有錯誤可以評論討論然后改正!
摘要
在學習得到的世界模型中訓練機器人策略正成為一種趨勢,這是由于真實世界交互的低效性。現有的基于圖像的世界模型和策略雖然已經展示了早期的成功,但缺乏健壯的幾何信息,而這種信息需要對三維世界保持一致的空間和物理理解,即便在互聯網規模的視頻源上進行過預訓練,也依然不足。為此,我們提出了一種新的世界模型分支,稱為高斯世界模型(Gaussian World Model, GWM),用于機器人操作,它通過推斷在機器人動作作用下高斯基元(Gaussian primitives)的傳播來重建未來狀態。其核心是一個與三維變分自編碼器(3D variational autoencoder)結合的潛在擴散Transformer(latent Diffusion Transformer, DiT),能夠利用高斯點繪(Gaussian Splatting)實現精細的場景級未來狀態重建。GWM不僅可以通過自監督的未來預測訓練來增強模仿學習智能體的視覺表征,還可以作為一種神經模擬器支持基于模型的強化學習。無論在仿真還是現實實驗中,結果都表明GWM能夠在多樣化的機器人動作條件下精確預測未來場景,并且可以進一步用于訓練出顯著優于當前最先進方法的策略,展現了三維世界模型在初始數據擴展方面的潛力。
圖1. 高斯世界模型(Gaussian World Model, GWM)是一種新穎的世界模型分支,它基于三維高斯點繪(3D Gaussian Splatting)表示來預測動態的未來狀態,并支持機器人操作。它促進了基于動作條件的三維視頻預測,提升了模仿學習中的視覺表示學習能力,并作為一種穩健的神經模擬器服務于基于模型的強化學習。
1. 引言
人類能夠從有限的感官輸入中構建預測性世界模型,使其能夠預見未來的結果并適應新的情境 [12, 18]。受到這一能力的啟發,世界模型學習推動了智能體的重大進展,使其在自動駕駛 [14, 26, 27, 80, 98] 和游戲 [1, 18–22, 66, 89] 等領域表現出色。隨著智能體日益與物理世界互動,推進面向機器人操作的世界模型學習成為一項重要的研究方向,因為它理想情況下能夠使機器人具備關于交互進行推理、預測物理動力學、并適應多樣化未知環境的能力。
這自然引出了以下問題:如何有效地表示、構建并利用世界模型來增強機器人操作?這樣的需求對現有的表示方法和模型提出了重大挑戰。
-
三維表示的必要性
高容量的架構 [25, 77] 和互聯網規模的預訓練,使基于視頻的生成模型成為捕捉世界動態信息的強大工具,這極大地提升了策略學習 [82, 87]。然而,它們依賴圖像輸入,使其容易受到未見過的視覺變化(例如光照、相機姿態、紋理等)[40] 的影響,因為它們缺乏三維幾何和空間理解。盡管RGB-D和多視角 [16, 17] 方案試圖緩解這一差距,但在一致的三維空間中隱式對齊圖像補丁特征仍然具有挑戰性 [62, 100],這使得穩健性問題依然沒有解決。這凸顯了需要一種能夠將精細視覺細節與三維空間信息相結合的表示方式,以提升面向機器人操作的世界建模。 -
效率與可擴展性
為了從二維圖像中識別出一種既能保留三維幾何結構又能保留精細視覺細節的三維表示,多視角三維重建方法(例如神經輻射場 NeRF [57] 和三維高斯點繪 3D Gaussian Splatting, 3D-GS [35])提供了自然的解決方案。其中,3D-GS 尤其具有吸引力,因為它對三維場景進行了顯式的逐高斯建模,將點云等高效的三維表示與高保真渲染相結合。 然而,由于這些方法主要依賴于離線的逐場景重建,它們的計算需求在應用到機器人操作時帶來了重大挑戰 [49, 91],尤其是在基于模型的強化學習(Model-based Reinforcement Learning, MBRL)中,從而限制了它們的可擴展性。
為此,我們提出了高斯世界模型(Gaussian World Model, GWM),這是一種新穎的三維世界模型,它將3D-GS與高容量生成模型結合,用于機器人操作。具體來說,我們的方法結合了前饋式3D-GS重建的最新進展與擴散Transformer(Diffusion Transformers, DiTs),使得在當前觀測和機器人動作條件下,通過高斯渲染實現精細的未來場景重建。 為了實現實時訓練和推理,我們設計了一種三維高斯變分自編碼器(3D Gaussian Variational Autoencoder, VAE),用于從三維高斯中提取潛在表示,使基于擴散的世界模型能夠在緊湊的潛在空間中高效運行。通過這種新穎的設計,我們證明了GWM能夠增強視覺表示學習,提升其作為模仿學習視覺編碼器的作用,同時還可以作為一種穩健的神經模擬器服務于基于模型的強化學習(RL)。
為了全面評估GWM,我們在動作條件視頻預測、模仿學習和基于模型的強化學習設置下進行了廣泛的實驗,涵蓋了跨越三個領域的31個多樣化機器人任務。針對現實場景的評估,我們引入了一個包含20種變體的Franka PnP任務套件,涵蓋了域內和域外的設置。在消融實驗中,我們同時評估了感知指標和成功率,以驗證每個組成模塊的有效性。GWM持續優于之前的基線方法,包括最先進的基于圖像的世界模型,展現了顯著的優勢,并突出了其數據擴展潛力。
總而言之,我們的主要貢獻有三點:
- 我們提出了GWM,這是一種新穎的三維世界模型,由高斯擴散Transformer和高斯VAE實現高效的動態建模。GWM能夠以可擴展的端到端方式學習預測準確的未來狀態和動力學,而無需人工干預。
- GWM可以輕松集成到離線模仿學習和在線強化學習中,并具備卓越的效率,展現出在基于學習的機器人操作中令人印象深刻的擴展潛力。
- 我們通過在兩個具有挑戰性的仿真環境中的大量實驗驗證了GWM的有效性,其性能較之前的最先進基線提高了16.25%的巨大幅度。此外,我們在現實場景中驗證了其實用性,在20次試驗中,GWM將典型的擴散策略提升了30%。
2. 相關工作
世界模型(World Models)
世界模型捕捉場景動態,并通過基于當前觀測和動作預測未來狀態,從而實現高效學習。它們已在自動駕駛 [14, 26, 27, 80, 98, 102]、游戲智能體 [1, 18–22, 66, 89] 和機器人操作 [23, 67, 83] 中得到了廣泛研究。早期的工作 [18–23, 56, 65–67, 89, 96] 學習了一種用于未來預測的潛在空間,并在仿真和真實環境中都取得了強有力的結果 [83]。然而,雖然潛在表示簡化了建模,但其難以捕捉世界的精細細節。近期在擴散模型 [24, 71, 72] 和Transformer [64, 77] 的進展推動了世界建模向直接像素空間建模 [1, 50, 51, 87] 轉變,從而能夠捕捉精細細節并從互聯網視頻中實現大規模學習。然而,基于圖像的模型往往缺乏物理常識 [4],因此限制了它們在機器人操作中的適用性。
高斯點繪(Gaussian Splatting)
3D-GS [35] 使用三維高斯來表示場景,并通過可微分的投射高效地映射到二維平面。與隱式表示(如NeRF [57])相比,它具有更高的效率,因此受益于一些應用,例如侵入性手術 [46]、SLAM [34] 和自動駕駛 [99]。這種優勢擴展到四維動態建模 [28, 52, 84],因為三維高斯與點云類似,具有空間意義。然而,這些方法所需的離線逐場景重建為實時應用(如機器人操作)帶來了計算挑戰。近期的研究 [6, 13, 74, 85, 86, 94, 97, 101] 通過使用大規模數據集學習從像素到高斯的生成映射來解決這一問題,但仍然依賴已知的相機位姿,從而限制了可擴展性。另一條并行研究路徑 [8, 37, 70, 79] 探索了從無位姿圖像進行前饋的新視角合成,利用預測的點圖(point map)作為顯式多視角對齊的代理。在這些進展的基礎上,本工作開發了一種從無位姿圖像構建的可擴展高斯世界模型,從而保證空間感知與可擴展性,以支持策略訓練。
視覺操作(Visual Manipulation)
構建具有人類般能力的視覺驅動機器人一直是一項長期挑戰。視覺模仿學習方法 [5, 36, 39, 45, 75] 通過使用各種視覺表示模仿專家演示,例如點云 [7, 15]、體素 [44, 69]、NeRFs [11, 32, 41, 43, 68, 91] 和 3D-GS [49]。盡管這些模型在已學習的任務中有效,但它們在未見過的真實場景中表現不佳 [53, 54]。強化學習(RL)通過試錯來優化策略,從而彌補了這一缺陷,但它需要昂貴的現實世界執行過程。因此,許多方法采用“仿真到現實遷移”(sim-to-real transfer),即在世界的數字孿生體中學習RL策略并將其部署到任務執行中。然而,由于這些方法依賴于預定義資產 [3, 58, 78] 或將現實世界物體轉化為仿真的勞動密集型過程 [10, 38, 42, 47, 48, 63],可擴展性仍然是一個挑戰。為了解決這些局限性,GWM專注于同時為模仿學習提供更強的視覺表示,并為視覺強化學習提供一種高效的神經模擬器,從而實現更有效且更具可擴展性的機器人操作。
3. 高斯世界模型(Gaussian World Model)
我們的方法的整體流程如圖2所示,其中我們構建了一個高斯世界模型,用來推斷由三維高斯基元(3D Gaussian primitives)表示的未來場景重建。具體來說,我們首先將真實世界的視覺輸入編碼為潛在的三維高斯表示(第3.1節),然后利用基于擴散的條件生成模型,在給定機器人狀態和動作的情況下,學習表示的動態變化(第3.2節)。我們展示了GWM可以靈活地集成到離線模仿學習和在線基于模型的強化學習中,以應對多樣化的機器人操作任務(第3.3節)。
圖2. GWM的整體流程,主要由一個三維變分編碼器和一個潛在擴散Transformer組成。三維變分編碼器將由基礎重建模型估計得到的高斯點(Gaussian Splats)嵌入到一個緊湊的潛在空間中,而擴散Transformer則在這些潛在補丁(latent patches)上進行操作,在給定機器人動作和去噪時間步的條件下,交互式地“想象”未來的高斯點。
1. 輸入圖像 → 高斯點繪 (Splatt3R)
輸入可以是單張或多張未配準圖像(Unposed Images)。
通過 Splatt3R [70] 將輸入圖像轉化為 3D高斯點云 (Gaussian Splats GtG_tGt?),這一步得到的是場景在當前時刻的三維結構和外觀表示。
2. 3D VAE 編碼器 → 潛在表示
高斯點云 GtG_tGt? 被送入 3D VAE,壓縮成一個緊湊的潛在表示(Compact Representation)。
在潛在空間中引入了隨機噪聲(Random Noise),用于擴散建模。
3. 位置嵌入 & 條件信息
對潛在特征加入 位置嵌入 (Positional Embedding, RoPE),讓模型感知空間關系。
條件信息包括:
時間步 τ\tauτ(擴散噪聲步數),通過 AdaLN(自適應層歸一化) 融合到特征中;
機器人動作 ata_tat?,作為 Cross Attention 的鍵值 (KV),讓預測與動作相關聯。
4. 潛在擴散Transformer
在潛在空間上運行擴散Transformer,包含:
Cross Attention:將機器人動作與潛在表示對齊;
Feed-Forward 層:進一步建模時序和空間特征;
AdaLN:對每一層的輸入做自適應調制,提高穩定性。
所有注意力機制采用 RMSNorm 歸一化,保證訓練穩定。
5. 解碼器 → 未來高斯點云 Gt+1G_{t+1}Gt+1?
預測得到的潛在表示通過 3D VAE 解碼器 轉換回三維高斯點云。
得到的是 未來時刻 t+1t+1t+1 的場景重建,即模型根據當前狀態和機器人動作想象出的未來畫面。
3.1. 世界狀態編碼(World State Encoding)
前饋式三維高斯點繪(Feed-forward 3D Gaussian Splatting)
給定一個世界狀態的單視角或雙視角圖像輸入 I={I}i={1,2}\mathcal{I}=\{I\}_{i=\{1,2\}}I={I}i={1,2}?,我們的目標是首先將場景編碼為三維高斯表示,以便進行動力學學習和預測。三維高斯點繪(3D-GS)使用多個非結構化的三維高斯核來表示一個三維場景:
G={xp,σp,Σp,Cp}p∈P,G=\{x_p, \sigma_p, \Sigma_p, \mathcal{C}_p\}_{p\in \mathcal{P}}, G={xp?,σp?,Σp?,Cp?}p∈P?,
其中,xp,σp,Σp,Cpx_p, \sigma_p, \Sigma_p, \mathcal{C}_pxp?,σp?,Σp?,Cp? 分別表示高斯核的中心、透明度、不相關矩陣以及球諧函數系數。
- 它指的是,這些三維高斯核(3D Gaussian kernels)在空間中的排列、數量和形狀是不規則的、沒有預設網格或拓撲結構的。
- 非結構化的表示:這更像是一團漂浮在空中的彩色點云。每個點都是獨立的個體,它們沒有預設的鄰居關系,也沒有固定的排列順序。三維高斯核就是這種“非結構化點”的升級版。每個“點”不僅僅是一個位置,它還是一個高斯球,擁有自己的位置、尺寸、形狀和顏色。
想象一下你正在用成千上萬個“彩色光球”來重建一個三維世界。這四個參數就是用來描述每一個光球的屬性。
xpx_pxp? ? : 高斯核的中心(Position)
直觀理解:這就像光球的位置坐標,比如 (1.5,2.3,0.8)。它告訴我們這個光球在三維空間中的具體位置。
作用:決定了高斯核在場景中的落腳點。在訓練過程中,模型會不斷地調整這些位置,讓它們最有效地“填滿”整個場景,特別是在物體的表面。
σp\sigma_pσp? ? : 透明度(Opacity)
直觀理解:這就像光球的透明度或不透明度。它的值通常在 0 到 1 之間。
作用:決定了該高斯核對最終圖像的貢獻程度。
如果透明度接近 1,它就是一個幾乎不透明的“實心球”,會強烈地影響它所在位置的顏色。
如果透明度接近 0,它就是一個“虛影”,幾乎不會對圖像產生影響。
這對于表示半透明物體(如玻璃、煙霧)或者在遮擋關系中至關重要。模型會學習讓被遮擋的高斯核具有較低的透明度,從而讓它們“消失”在背景中。
Σp\Sigma_pΣp? ? : 協方差矩陣(Covariance Matrix)
直觀理解:這是最復雜的一個,但我們可以用一個簡單的比喻:它決定了光球的尺寸、形狀和旋轉方向。
一個簡單的球形高斯核,其協方差矩陣決定了它的半徑。
一個橢球形高斯核,其協方差矩陣決定了它的長軸、短軸以及它在空間中的旋轉角度。
作用:讓高斯核擁有了彈性。
在平坦、大面積的表面(如墻壁、地面)上,模型可以生成一個又大又扁平的橢球形高斯核,用很少的核就覆蓋很大一片區域,這大大提高了效率。
在物體邊緣、角落或細節豐富的區域,模型則會生成更小、更接近球形的高斯核,以精確地捕捉這些細節。
為什么叫“不相關矩陣”?
- 這實際上是對協方差矩陣的簡化描述。在 3D-GS 的原始論文中,協方差矩陣可以被分解為兩個部分:一個縮放矩陣(scaling matrix)和一個旋轉矩陣(rotation matrix)。“不相關矩陣”這個詞可能指的是一種簡化的表示方式,但其核心作用始終是控制高斯核的形狀和方向。
Cp\mathcal{C}_pCp? ? : 球諧函數系數(Spherical Harmonics Coefficients)
直觀理解:這就像光球的 “彩色涂料”。但它不僅僅是單一的顏色,它是一種更復雜的“涂料”,可以根據觀察視角的不同而改變顏色。
作用:決定了高斯核的顏色和光照效果。
簡單來說,球諧函數(Spherical Harmonics)是一種數學工具,可以非常有效地表示三維空間中的復雜函數,比如光照。
通過這些系數,模型能夠學習到當從不同角度觀察一個物體時,它的顏色會有什么變化(比如高光、陰影)。這使得 3D-GS
渲染出來的圖像具有非常真實的光照效果,而不僅僅是簡單的貼圖。舉例:想象一個閃亮的紅色蘋果。
從正面看,它可能大部分是紅色。
但從某個角度看,你會看到一個白色的高光點。
這并不是因為蘋果表面有不同的顏色,而是因為它對光線的反射方式不同。
球諧函數系數正是用來捕捉這種視角依賴的顏色變化,讓渲染結果看起來更加逼真。
總結
這四個參數共同工作,就像一個藝術家的工具箱:
xpx_pxp? ? :決定了你的筆觸落在哪里。
σp\sigma_pσp? ? :決定了你的筆觸的透明度。
Σp\Sigma_pΣp? ? :決定了你的筆觸的大小和形狀。
Cp\mathcal{C}_pCp? ? :決定了你的筆觸的顏色和光影效果。
通過對成千上萬個這些“筆觸”(高斯核)進行精心的優化和組合,3D-GS 就能夠從幾張二維圖片中,神奇地重建出一個高質量的三維場景。
為了從給定視角獲得每個像素的顏色,3D-GS將三維高斯投影到圖像平面,并計算像素顏色如下:
C(G)=∑p∈PαpSH(dp;Cp)∏j=1p?1(1?αj),(1)C(G) = \sum_{p\in \mathcal{P}} \alpha_p \, \text{SH}(d_p; \mathcal{C}_p) \prod_{j=1}^{p-1}(1-\alpha_j), \tag{1} C(G)=p∈P∑?αp?SH(dp?;Cp?)j=1∏p?1?(1?αj?),(1)
其中:
-
αp\alpha_pαp? 表示按照深度順序排列的有效透明度,即由 Σp\Sigma_pΣp? 推導出的二維高斯權重與其整體透明度 σp\sigma_pσp? 的乘積;
-
dpd_pdp? 表示從相機到 xpx_pxp? 的視角方向;
-
SH(?)\text{SH}(\cdot)SH(?) 是球諧函數。
由于原始的3D-GS依賴于耗時的逐場景離線優化,我們采用可泛化的3D-GS來學習從圖像到三維高斯的前饋映射,以加速這一過程。具體來說,我們使用 Splatt3R [70] 獲取三維高斯世界狀態 GGG:該方法首先利用立體重建模型 Mast3R [37] 從輸入圖像生成三維點圖(3D point maps),然后使用一個額外的預測頭,在這些點圖的基礎上預測每個三維高斯的參數。
這部分的核心是理解公式 (1),它描述了如何計算圖像中一個特定像素的顏色。這個過程被稱為 “Splating”,就像把顏料“潑灑”到畫布上。 想象你面前有一個場景,由成千上萬個三維高斯(那些彩色光球)組成。現在,你想用一個相機去拍攝它,得到一張二維照片。這個公式就是相機捕捉顏色的數學描述。
C(G)=∑p∈PαpSH(dp;Cp)∏j=1p?1(1?αj),(1)C(G) = \sum_{p\in \mathcal{P}} \alpha_p \, \text{SH}(d_p; \mathcal{C}_p) \prod_{j=1}^{p-1}(1-\alpha_j), \tag{1} C(G)=p∈P∑?αp?SH(dp?;Cp?)j=1∏p?1?(1?αj?),(1)
- C(G)C(G)C(G): 這是最終得到的像素顏色。
- ∑p∈P\sum_{p\in \mathcal{P}}∑p∈P? : 這表示對所有影響這個像素的三維高斯進行累加求和。
- αp\alpha_pαp?: 這是高斯核的有效透明度。它不是單純的 σp\sigma_pσp? (整體透明度),而是由兩部分相乘得到的:
- 高斯二維投影權重: 當一個三維高斯被投影到二維圖像平面時,它的能量分布是呈高斯曲線的。離高斯中心越近的像素,得到的“能量”或權重就越大。
- σp\sigma_pσp? 這個高斯核自身的整體透明度。
- 例子:想象一個半透明的藍色氣球。它的 σp\sigma_pσp? 可能只有 0.5。當它投影到圖像上時,只有它中心位置的像素會獲得最大的權重,而邊緣的像素權重較小。最終的 αp\alpha_pαp? 結合了這兩者,告訴我們這個高斯對這個特定像素的貢獻有多大。
- SH(dp;Cp)\text{SH}(d_p; \mathcal{C}_p)SH(dp?;Cp?): 這部分是高斯核在特定視角下的顏色。
- Cp\mathcal{C}_pCp?: 前面我們提到的球諧函數系數,它包含了這個高斯的顏色信息和光照信息。
- dpd_pdp?: 觀察方向,即從相機到這個高斯中心 xpx_pxp? 的方向。
- 作用:球諧函數利用 dpd_pdp? 和 Cp\mathcal{C}_pCp? 計算出,在當前這個視角下,這個高斯應該呈現出什么顏色。這正是為什么 3D-GS 能夠渲染出帶有高光、陰影等真實光影效果的原因。
- ∏j=1p?1(1?αj)\prod_{j=1}^{p-1}(1-\alpha_j)∏j=1p?1?(1?αj?): 這部分是累積的透明度,它處理了遮擋關系。
- 直觀理解:當你從一個方向看物體時,離你近的物體會遮擋住后面的物體。
- 公式解釋:3D-GS 會將所有高斯按深度(從遠到近或從近到遠,這里是從遠到近)進行排序。這個乘積項計算的是在當前高斯 ppp 前面的所有高斯 jjj 的累積透明度。
- 1?αj1?α_j1?αj? 表示高斯 jjj 的“透明度”部分(即沒有被它遮擋的光線)。
- 將前面的所有 (1?αj)(1?α_j)(1?αj?) 相乘,就得到了光線在到達當前高斯 ppp 之前,還剩下多少能量。如果前面的高斯非常不透明(αjα_jαj? 接近 1),那么這個乘積就會接近 0,意味著光線基本都被前面的高斯擋住了,后面的高斯 ppp 對最終顏色的貢獻就會很小。
從耗時離線優化到前饋式學習
原始的 3D-GS 有一個很大的缺點:它需要對每一個新場景都從頭開始進行數小時的離線優化。這就像每次想渲染一個新物體,你都必須讓藝術家從零開始雕刻它。 為了解決這個問題,研究者提出了 “可泛化的 3D-GS” 。它的目標是:學習一個通用的模型,能夠直接從輸入圖像快速預測出三維高斯表示,而不需要逐場景的優化。
方法的核心:
- 立體重建 (Mast3R):這個模型首先從給定的單視角或雙視角圖像中,生成一個三維點圖。你可以把這看作一個“粗略”的三維點云,它已經捕捉了場景的基本幾何形狀。
- 額外的預測頭 (Prediction Head):這個神經網絡是真正的“魔術”所在。它接收前面生成的三維點圖作為輸入,然后預測出每個點對應的完整三維高斯參數(xp,σp,Σp,Cpx_p, \sigma_p, \Sigma_p, \mathcal{C}_pxp?,σp?,Σp?,Cp?)。
優勢:
一旦這個“可泛化”的模型訓練完成,它就可以像一個“速寫大師”一樣,瞬間從新的輸入圖像中生成三維高斯,省去了漫長的離線優化過程。
這使得 3D-GS 的應用場景大大擴展,例如實時三維重建、快速虛擬現實內容生成等。
總而言之,前饋式的 3D-GS 將原本耗時費力的 “優化問題” 轉換成了一個更高效的 “預測問題”。它通過學習一個通用的映射關系,實現了從二維圖像到三維高斯表示的快速、直接的轉換。
三維高斯 VAE
想象一下,你用 3D-GS 重建了兩個不同的場景:一個簡單的立方體和一個復雜的雕像。
重建立方體可能只需要 1000 個高斯。
重建雕像可能需要 100 萬個高斯。
這就帶來了問題:如果你想用一個神經網絡來處理這些三維數據,比如讓機器人根據這些數據做決策,該怎么辦?這個網絡的輸入層必須是固定大小的,但你的高斯數量卻是可變的。
三維高斯VAE(變分自編碼器) 就是為了解決這個“可變大小”的問題而設計的。它的作用是:
將一個可變數量的三維高斯 GGG 壓縮成一個固定長度的潛在編碼 x\text{x}x,然后再將這個編碼解壓回一個三維高斯表示 G^\hat{G}G^。
通過這個過程,我們就可以將三維數據作為固定大小的向量輸入到其他網絡中,比如一個策略網絡(policy network),讓機器人來學習如何操作。
由于每個世界狀態中學習到的三維高斯數量在不同場景和任務中可能存在顯著差異,我們采用一個三維高斯 VAE(Eθ,Dθ)\text{VAE}(E_\theta, D_\theta)VAE(Eθ?,Dθ?) 來將重建的三維高斯 GGG 編碼為一個固定長度的 NNN 個潛在嵌入 x∈RN×D\text{x} \in \mathbb{R}^{N\times D}x∈RN×D。具體來說,我們首先使用最遠點采樣(Farthest Point Sampling, FPS)將重建的三維高斯 GGG 下采樣為固定數量 NNN 個高斯 GNG_NGN?:
GN=FPS(G).G_N = \text{FPS}(G). GN?=FPS(G).
最遠點采樣 (FPS): GN=FPS(G).G_N = \text{FPS}(G). GN?=FPS(G).
直觀理解:這是一種智能的“抽樣”方法。想象你有一大堆三維點,你想從中選出 NNN 個點來代表整體。FPS 的做法是:
隨機選擇第一個點。
從剩下的點中,選擇離所有已選點最遠的點作為下一個。
重復這個過程,直到選出 N 個點。
作用:這種方法能確保選出的 NNN 個點均勻地分布在原始的高斯集合 GGG 中,從而保證這 NNN 個點具有代表性,不會集中在某個角落。這解決了數量不固定的問題,將我們感興趣的高斯從 GGG 變成了固定數量的 GNG_NGN? ? 。
接下來,我們使用這些采樣到的高斯 GNG_NGN? 作為查詢項(queries),通過一個基于交叉注意力的 LLL 層編碼器 EθE_\thetaEθ? 從所有高斯 GGG 中聚合信息,得到潛在嵌入 x\text{x}x [93]:
X=Eθ(GN,G)=Eθ(L)°?°Eθ(1)(GN,G),(2)X = E_\theta(G_N, G) = E_\theta^{(L)} \circ \cdots \circ E_\theta^{(1)}(G_N, G), \tag{2} X=Eθ?(GN?,G)=Eθ(L)?°?°Eθ(1)?(GN?,G),(2)
其中
Eθ(l)(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G))).E_\theta^{(l)}(Q, G) = \text{LayerNorm}(\text{CrossAttn}(Q, \text{PosEmbed}(G))). Eθ(l)?(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G))).
交叉注意力編碼器 (Cross-Attention Encoder): X=Eθ(GN,G)X = E_\theta(G_N, G)X=Eθ?(GN?,G)
直觀理解:現在我們有了 NNN 個具有代表性的“查詢”高斯 GNG_NGN? ? 。編碼器要做的是讓這 NNN 個查詢去“問”所有原始高斯 GGG:“你們各自是什么樣的?”,然后把得到的信息聚合起來,形成 NNN 個固定長度的向量。
工作方式:
查詢(Queries): 采樣得到的 NNN 個高斯 GNG_NGN? ? 作為查詢項。
鍵值(Keys & Values): 所有原始高斯 GGG 作為鍵值。
交叉注意力: 編碼器中的交叉注意力 (Cross-Attention) 機制會計算每個查詢高斯 GNG_NGN? 與所有原始高斯 GGG 之間的關系。這個機制能夠讓每個查詢高斯有效地 “看到”并聚合 整個場景中所有高斯的信息,而不僅僅是它自己的信息。
結果:經過多層(LLL 層)這樣的交叉注意力處理后,每個查詢高斯都獲得了整個場景的上下文信息。最終,我們得到一個固定大小的潛在嵌入 x∈RN×D\text{x}∈\mathbb{R}^{N×D}x∈RN×D 。
在獲得潛在編碼 x\text{x}x 之后,我們采用一個對稱的基于Transformer的解碼器 DθD_\thetaDθ?,在潛在編碼集合內部傳播并聚合信息,從而得到重建的高斯 G^\hat{G}G^:
G^=Dθ(x)=LayerNorm(CrossAttn(x,x)).(3)\hat{G} = D_\theta(\text{x}) = \text{LayerNorm}(\text{CrossAttn}(\text{x}, \text{x})). \tag{3} G^=Dθ?(x)=LayerNorm(CrossAttn(x,x)).(3)
為了訓練三維高斯 VAE(Eθ,Dθ)\text{VAE}(E_\theta, D_\theta)VAE(Eθ?,Dθ?) ,我們使用重建高斯 G^\hat{G}G^ 的中心與原始高斯 GGG 的中心之間的 Chamfer 損失作為監督。同時,我們還添加了一個基于渲染的損失,以確保我們的重建高斯 G^\hat{G}G^ 能夠實現高保真的圖像渲染,從而服務于基于圖像的策略學習:
LVAE=Chamfer(G^,G)+∥C(G^)?C(G)∥1.(4)\mathcal{L}_{\text{VAE}} = \text{Chamfer}(\hat{G}, G) + \| C(\hat{G}) - C(G)\|_1. \tag{4} LVAE?=Chamfer(G^,G)+∥C(G^)?C(G)∥1?.(4)
Chamfer\text{Chamfer}Chamfer:
- 原始點云中的每個點,到重建點云中最近點的距離。
- 重建點云中的每個點,到原始點云中最近點的距離。
三維高斯VAE 結構總結表
模塊 輸入 處理方式 輸出 公式 編碼器 EθE_\thetaEθ? - 原始三維高斯集合 GGG
- 通過FPS采樣得到的 GNG_NGN?基于交叉注意力(Cross-Attn)的 LLL 層Transformer編碼器,將 GNG_NGN? 作為查詢,從 GGG 中聚合信息 潛在嵌入 X∈RN×DX \in \mathbb{R}^{N\times D}X∈RN×D X=Eθ(GN,G)X = E_\theta(G_N, G)X=Eθ?(GN?,G)
Eθ(l)(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G)))E_\theta^{(l)}(Q,G) = LayerNorm(CrossAttn(Q, PosEmbed(G)))Eθ(l)?(Q,G)=LayerNorm(CrossAttn(Q,PosEmbed(G))) (公式2)解碼器 DθD_\thetaDθ? 潛在嵌入 XXX 基于自注意力(Self-Attn)的Transformer解碼器,在潛在集合內部傳播并聚合信息 重建高斯 G^\hat{G}G^ G^=Dθ(X)=LayerNorm(SelfAttn(X,X))\hat{G} = D_\theta(X) = LayerNorm(SelfAttn(X,X))G^=Dθ?(X)=LayerNorm(SelfAttn(X,X)) (公式3) 損失函數 LVAE\mathcal{L}_{VAE}LVAE? - 原始高斯 GGG
- 重建高斯 G^\hat{G}G^Chamfer距離(中心點級別) + 渲染損失(像素級別差異) 優化目標,用于訓練 Eθ,DθE_\theta, D_\thetaEθ?,Dθ? LVAE=Chamfer(G^,G)+∣C(G^)?C(G)∣1\mathcal{L}_{VAE} = Chamfer(\hat{G}, G) + |C(\hat{G}) - C(G)|_1LVAE?=Chamfer(G^,G)+∣C(G^)?C(G)∣1? (公式4)
3.2. 基于擴散的動態建模(Diffusion-based Dynamics Modeling)
在時刻 ttt 的編碼世界狀態嵌入 xt\text{x}_txt? 以及其未來狀態 xt+1\text{x}_{t+1}xt+1? 已知的情況下,我們的目標是學習世界動力學 p(xt+1∣x≤t,a≤t)p(\text{x}_{t+1}\mid \text{x}_{\leq t}, a_{\leq t})p(xt+1?∣x≤t?,a≤t?),其中 x≤t\text{x}_{\leq t}x≤t? 和 a≤ta_{\leq t}a≤t? 分別表示歷史狀態和歷史動作。具體來說,我們采用基于擴散的動力學模型,將動力學學習轉化為一個條件生成問題,即從噪聲中生成未來狀態 xt+1\text{x}_{t+1}xt+1?,條件為歷史狀態和動作 yt=(x≤t,a≤t)\text{y}_t = (\text{x}_{\leq t}, a_{\leq t})yt?=(x≤t?,a≤t?)。
擴散公式(Diffusion Formulation)
為了生成未來狀態,我們從擴散過程的表述開始。具體來說,我們首先向真實的未來狀態 xt+10=xt+1\text{x}^0_{t+1} = \text{x}_{t+1}xt+10?=xt+1? 添加噪聲,以通過高斯擾動核獲得帶噪的未來狀態樣本 xt+1τ\text{x}^\tau_{t+1}xt+1τ?:
p0→τ(xt+1τ∣xt+10)=N(xt+1τ;xt+10,σ2(τ)I),(5)p^{0 \to \tau}(\text{x}^\tau_{t+1}\mid \text{x}^0_{t+1}) = \mathcal{N}(\text{x}^\tau_{t+1}; \text{x}^0_{t+1}, \sigma^2(\tau) I), \tag{5} p0→τ(xt+1τ?∣xt+10?)=N(xt+1τ?;xt+10?,σ2(τ)I),(5)
其中,τ\tauτ 是噪聲步索引,σ(τ)\sigma(\tau)σ(τ) 是噪聲調度函數。該擴散過程可以通過以下隨機微分方程(SDE)的解來描述 [72]:
dx=f(x,τ)dτ+g(τ)dw,(6)d\mathbf{x} = \mathbf{f}(\mathbf{x}, \tau)\, d\tau + g(\tau)\, d\mathbf{w}, \tag{6} dx=f(x,τ)dτ+g(τ)dw,(6)
其中,w\mathbf{w}w 表示標準維納過程(Wiener process),f\mathbf{f}f 是漂移系數(drift coefficient),ggg 是擴散系數(diffusion coefficient)。在這種表述下,高斯擾動核的作用等價于設置 f(x,τ)=0\mathbf{f}(\mathbf{x},\tau)=0f(x,τ)=0,并令 g(τ)=2σ˙(τ)σ(τ)g(\tau)=\sqrt{2 \dot{\sigma}(\tau)\sigma(\tau)}g(τ)=2σ˙(τ)σ(τ)?。
直觀理解:這是一個描述 “隨機過程隨時間演變” 的方程。它將公式 (5) 的離散加噪過程,用一個連續的數學工具來描述。
具體含義:
dxd\mathbf{x}dx:表示狀態 x 的微小變化。
f(x,τ)dτ\mathbf{f}(\mathbf{x}, \tau)d\tauf(x,τ)dτ:漂移項。它描述了狀態 x\mathbf{x}x 的確定性變化,就像一個物體在水流中漂移。
g(τ)dwg(\tau)\, d\mathbf{w}g(τ)dw:擴散項。它描述了狀態 x\mathbf{x}x 的隨機性變化,就像布朗運動一樣。w\mathbf{w}w 是維納過程,代表著隨機噪聲。
簡化:在你的例子中,f(x,τ)=0\mathbf{f}(\mathbf{x}, \tau)=0f(x,τ)=0,這意味著沒有“漂移”,整個變化過程完全由隨機噪聲($g(τ)$0)驅動,這與公式 (5) 的純高斯加噪是等價的。
為了從噪聲中生成樣本,我們可以使用逆時間SDE [2] 對式(6)進行反演,從而得到采樣公式:
dx=[f(x,τ)?g(τ)2?xlog?pτ(x)]dτ+g(τ)dwˉ,(7)d\mathbf{x} = \big[\mathbf{f}(\mathbf{x}, \tau) - g(\tau)^2 \nabla_\mathbf{x} \log p^\tau(\mathbf{x})\big] d\tau + g(\tau)\, d\bar{\mathbf{w}}, \tag{7} dx=[f(x,τ)?g(τ)2?x?logpτ(x)]dτ+g(τ)dwˉ,(7)
其中,wˉ\bar{\mathbf{w}}wˉ 表示逆時間維納過程,?xlog?pτ(x)\nabla_\mathbf{x} \log p^\tau(\mathbf{x})?x?logpτ(x) 是得分函數(score function),即關于 x\mathbf{x}x 的對數邊際概率的梯度 [30]。
直觀理解:這是對公式 (6) 的反演。它告訴我們,如果想讓一個隨機過程“倒著走”,需要額外加上一個“導向力”來抵消隨機性。
關鍵項:?xlog?pτ(x)\nabla_\mathbf{x} \log p^\tau(\mathbf{x})?x?logpτ(x)
這被稱為得分函數 (score function)。
log?pτ(x)\log p^\tau(\mathbf{x})logpτ(x):表示帶噪數據 x\mathbf{x}x 的概率密度函數的對數。
?x\nabla_\mathbf{x}?x? :表示對 x\mathbf{x}x 求梯度。
作用:這個梯度項指向概率密度增加最快的方向。想象一下你在一座山(概率分布)上,得分函數就像一個指南針,永遠指向山頂的方向。它告訴我們,如果想從一個隨機點(帶噪數據)回到真實數據(概率山頂),應該往哪個方向走。
與神經網絡的聯系:由于我們無法直接計算這個得分函數,所以我們用一個神經網絡來近似估計它。這就是擴散模型訓練的核心。
由于得分函數可以通過神經網絡來估計,我們通過最小化采樣到的未來狀態 x^t+10=Dθ(xt+1τ,yt)\hat{\mathbf{x}}^0_{t+1} = \mathcal{D}_\theta(\mathbf{x}^\tau_{t+1}, \mathbf{y}_t)x^t+10?=Dθ?(xt+1τ?,yt?) 與真實未來狀態 xt+10\mathbf{x}^0_{t+1}xt+10? 之間的差異來學習條件去噪模型 Dθ\mathcal{D}_\thetaDθ?:
L(θ)=E[∥Dθ(xt+1τ,ytτ)?xt+10∥22].(8)\mathcal{L}(\theta) = \mathbb{E}\left[\left\| \mathcal{D}_\theta(x^\tau_{t+1}, y^\tau_t) - x^0_{t+1}\right\|_2^2\right]. \tag{8} L(θ)=E[?Dθ?(xt+1τ?,ytτ?)?xt+10??22?].(8)
yt=(x≤t,a≤t)\text{y}_t = (\text{x}_{\leq t}, a_{\leq t})yt?=(x≤t?,a≤t?)
基于EDM的學習(Learning with EDM)
正如 [33] 中指出的,直接學習去噪器 Dθ(xt+1τ,yt)\mathcal{D}_\theta(\text{x}^\tau_{t+1}, \text{y}_t)Dθ?(xt+1τ?,yt?) 可能會受到噪聲幅度變化等問題的影響。因此,我們遵循 [1],并采用EDM [33] 中的做法,改為學習一個帶有預條件的網絡 Fθ\mathcal{F}_\thetaFθ?。具體來說,我們將去噪器 Dθ(xt+1τ,yt+1τ)\mathcal{D}_\theta(\text{x}^\tau_{t+1}, \text{y}^\tau_{t+1})Dθ?(xt+1τ?,yt+1τ?) 參數化為:
Dθ(xt+1τ,ytτ)=cskipτxt+1τ+coutτFθ(cinτxt+1τ,ytτ;cnoiseτ),(9)\mathcal{D}_\theta(\text{x}^\tau_{t+1}, \text{y}^\tau_t) = c^\tau_{\text{skip}} \text{x}^\tau_{t+1} + c^\tau_{\text{out}} \, \mathcal{F}_\theta\big(c^\tau_{\text{in}} \text{x}^\tau_{t+1}, \text{y}^\tau_t; c^\tau_{\text{noise}}\big), \tag{9} Dθ?(xt+1τ?,ytτ?)=cskipτ?xt+1τ?+coutτ?Fθ?(cinτ?xt+1τ?,ytτ?;cnoiseτ?),(9)
其中:
-
預條件器 cinτc^\tau_{\text{in}}cinτ? 和 coutτc^\tau_{\text{out}}coutτ? 用于縮放輸入與輸出的幅度,
-
cskipτc^\tau_{\text{skip}}cskipτ? 調節跳躍連接(skip connection),
-
cnoiseτc^\tau_{\text{noise}}cnoiseτ? 將噪聲水平映射為額外的條件輸入,送入 Fθ\mathcal{F}_\thetaFθ?。
這些預條件器的細節在附錄 B.1 中給出。
基于 EDM 的學習:為什么要改變訓練方式?
首先,讓我們回到最初的問題:為什么要引入 EDM (Elucidating Diffusion Models) 這個框架? 原始的擴散模型訓練有一個潛在問題:如果直接學習去噪器 Dθ\mathcal{D}_\thetaDθ?,它需要處理各種不同噪聲水平的輸入。當噪聲非常多或非常少時,網絡的行為可能變得不穩定,難以收斂到最佳解。
高噪聲:輸入幾乎是純噪聲,網絡很難識別出其中的微弱信號,訓練效率低。
低噪聲:輸入和真實數據幾乎一模一樣,網絡很容易學會 “什么也不做”,直接輸出輸入(即平凡解),這導致它沒有真正學到去噪的能力。
EDM 的核心思想是:我們不直接訓練去噪器 Dθ\mathcal{D}_\thetaDθ? ? ,而是訓練一個 “預處理”過的網絡 Fθ\mathcal{F}_\thetaFθ? 。這個網絡經過精心設計,它的輸入和輸出都經過了縮放(scaling),從而使得無論噪聲水平如何,它面臨的訓練任務都更加穩定和一致。 這就像一個廚師,他不會直接處理各種形狀、大小不一的食材,而是先將所有食材切成標準化的塊狀,然后再進行烹飪。這樣,他的烹飪過程就變得更加穩定和可控。
網絡 Fθ\mathcal{F}_\thetaFθ? :這是真正需要學習的神經網絡。它接收預處理后的輸入,并執行核心的去噪操作。
cskipτc^\tau_{\text{skip}}cskipτ? (跳躍連接):這個參數控制原始帶噪數據 xt+1τ\text{x}^\tau_{t+1}xt+1τ? ? 在最終去噪結果中的占比。
當噪聲很小時,σ(τ)σ(τ)σ(τ) 接近 0,數據和信號的差異很小,我們希望去噪器主要依賴原始數據,這時 cskipτc^\tau_{\text{skip}}cskipτ? 接近 1。
當噪聲很大時,σ(τ)σ(τ)σ(τ) 很大,原始數據幾乎被噪聲淹沒,我們希望去噪器主要依賴網絡 Fθ\mathcal{F}_\thetaFθ? 的輸出,這時 cskipτc^\tau_{\text{skip}}cskipτ? 接近 0。
cinτc^\tau_{\text{in}}cinτ? (輸入預條件):這個參數用于縮放網絡 Fθ\mathcal{F}_\thetaFθ? 的輸入。它的設計目標是讓網絡的輸入在任何噪聲水平下都具有相似的幅度。這使得網絡不必處理幅度差異巨大的輸入,從而穩定了訓練。
coutτc^\tau_{\text{out}}coutτ? (輸出預條件):這個參數用于縮放網絡 Fθ\mathcal{F}_\thetaFθ? 的輸出,以確保它與跳躍連接的部分正確組合。
cnoiseτc^\tau_{\text{noise}}cnoiseτ? (噪聲條件):這是一個將噪聲水平 τ 轉換為一個可供網絡 Fθ\mathcal{F}_\thetaFθ? 理解的額外輸入。它通常是一個簡單的映射函數,比如將 log(σ(τ))\text{log}(σ(τ))log(σ(τ)) 轉換為一個嵌入向量,這個向量會通過 AdaLN 或其他方式注入到網絡中。
通過這種轉換,我們可以將公式 (8) 的目標改寫為:
L(θ)=E[∥Fθ(cinτxt+1τ,ytτ)?1coutτ(xt+10?cskipτxt+1τ)∥22].(10)\mathcal{L}(\theta) = \mathbb{E}\left[\Big\|\mathbf{F}_\theta(c^\tau_{\text{in}} \mathbf{x}^\tau_{t+1}, y^\tau_t) - \tfrac{1}{c^\tau_{\text{out}}}\big(\mathbf{x}^0_{t+1} - c^\tau_{\text{skip}} \mathbf{x}^\tau_{t+1}\big)\Big\|_2^2\right]. \tag{10} L(θ)=E[?Fθ?(cinτ?xt+1τ?,ytτ?)?coutτ?1?(xt+10??cskipτ?xt+1τ?)?22?].(10)
這一轉換的一個關鍵見解在于:為更好地訓練網絡 Fθ\mathcal{F}_\thetaFθ? 創建了一個新的訓練目標,它能夠根據噪聲調度 σ(τ)\sigma(\tau)σ(τ) 自適應地混合信號與噪聲。直觀來說:在高噪聲水平下(σ(τ)?σdata\sigma(\tau) \gg \sigma_{\text{data}}σ(τ)?σdata?),cskipτ→0c^\tau_{\text{skip}} \to 0cskipτ?→0,網絡主要學習預測干凈信號。相反,在低噪聲水平下(σ(τ)→0\sigma(\tau) \to 0σ(τ)→0),cskipτ→1c^\tau_{\text{skip}} \to 1cskipτ?→1,訓練目標變為噪聲部分,從而避免目標退化為平凡解。
這個公式是將公式 (8) 中的 Dθ\mathcal{D}_\thetaDθ? 替換為公式 (9) 之后,對整個損失函數進行的數學變換。雖然看起來復雜,但它背后的思想非常直觀:
訓練目標不再是直接讓網絡預測 xt+10\mathbf{x}^0_{t+1}xt+10? ? ,而是讓它預測一個經過精心設計的 “目標值” 。
這個“目標值”是 1coutτ(xt+10?cskipτxt+1τ)\tfrac{1}{c^\tau_{\text{out}}}\big(\mathbf{x}^0_{t+1} - c^\tau_{\text{skip}} \mathbf{x}^\tau_{t+1}\big)coutτ?1?(xt+10??cskipτ?xt+1τ?),它根據噪聲水平 (τττ) 動態變化。
高噪聲時 (cskipτ→0c^\tau_{\text{skip}}→0cskipτ?→0):訓練目標接近 1coutτxt+10\tfrac{1}{c^\tau_{\text{out}}} \mathbf{x}^0_{t+1}coutτ?1?xt+10? ,網絡主要學習預測干凈的信號。
低噪聲時 (cskipτ→1c^\tau_{\text{skip}}→1cskipτ?→1):訓練目標接近 1coutτ(xt+10?xt+1τ)\tfrac{1}{c^\tau_{\text{out}}}\big(\mathbf{x}^0_{t+1} - \mathbf{x}^\tau_{t+1}\big)coutτ?1?(xt+10??xt+1τ?),這正是噪聲本身!這時,網絡主要學習預測噪聲。
這種自適應的目標讓網絡在任何噪聲水平下都能夠學習到有用的信息,從而極大地提高了訓練的穩定性和效率。
實現(Implementation)
在技術實現上,我們使用 DiT [60] 來實現網絡 Fθ\mathcal{F}_\thetaFθ?。給定一系列實際世界狀態的潛在嵌入 {xt0=xt}t=1T\{\mathbf{x}^0_t = \mathbf{x}_t\}_{t=1}^T{xt0?=xt?}t=1T?,我們首先根據公式 (5) 中描述的高斯擾動生成帶噪潛在嵌入 {xtτ}t=1T\{\mathbf{x}^\tau_t\}_{t=1}^T{xtτ?}t=1T?。接下來,我們將這些帶噪潛在嵌入與旋轉位置編碼(Rotary Position Embedding, RoPE [73])拼接,并作為輸入傳遞給 DiT。關于條件 yt=(x≤t0,a≤t,cnoiseτ)\mathbf{y}_t = (\mathbf{x}^0_{\leq t}, a_{\leq t}, c^\tau_{\text{noise}})yt?=(x≤t0?,a≤t?,cnoiseτ?),時間嵌入通過自適應層歸一化(Adaptive Layer Normalization, AdaLN [61])進行調制,而當前的機器人動作則作為鍵(keys)和值(values)輸入到 DiT 內部的交叉注意力層中,用于條件生成。為了在所有注意力機制中保持穩定性和效率,我們采用具有可學習縮放因子的均方根歸一化(Root Mean Square Normalization, RMSNorm [92]),以在處理空間表示的同時,結合時間動作序列作為條件,從而穩定訓練。
1. 初始化
- 策略 π(at?∣st?)π(a_t ?∣s_t ? )π(at??∣st??): 這是機器人的大腦,它決定了在當前狀態 sts_tst? 下應該采取什么動作 ata_tat? 。
- 高斯世界模型 pθ?(st+1?,rt?∣st?,at?)p_θ ? (s_{t+1} ? ,r_t ? ∣s_t ? ,a_t ? )pθ??(st+1??,rt??∣st??,at??): 這是機器人的“虛擬世界”模型。它接收當前狀態 sts_tst? 和動作 ata_tat? 作為輸入,并預測下一步的狀態 st+1s_{t+1}st+1? 和獎勵 rtr_trt? 。這里的 pθp_θpθ? ? 就是我們之前討論的高斯世界模型(GWM)。
- 回放緩沖區 B\mathcal{B}B: 這是一個存儲機器人過去經驗的數據庫。
2. 循環 N 個周期 (for N epochs do)
- 算法會重復執行以下步驟,直到學習完成。
3. 收集數據(Collect data with π in real environment)
機器人使用當前版本的策略 π 在真實世界中與環境進行交互。
它觀察自己的狀態 sts_tst?,執行一個動作 ata_tat?,然后觀察環境如何變化到新狀態 st+1s_{t+1}st+1? ,并獲得獎勵rtr_trt? 。
這個經驗元組 (st?,at?,st+1?,rt?)(s_t ? ,a_t ? ,s_{t+1} ? ,r_t ? )(st??,at??,st+1??,rt??) 被收集起來,并添加到回放緩沖區 B\mathcal{B}B中。
這一步的目的是確保模型和策略能夠接觸到真實世界的數據,從而避免只在虛擬世界中“閉門造車”。
4. 訓練世界模型(Train Gaussian world model pθp_θpθ? ? on dataset B\mathcal{B}B via maximum likelihood)
現在,我們使用回放緩沖區 B 中收集到的真實數據,來訓練高斯世界模型 p θ ? 。
最大似然(maximum likelihood) 是訓練目標。這表示我們希望高斯世界模型能夠盡可能準確地預測出與真實數據相符的結果。
公式中的 arg?max?θ?EB[log?pθ?(st+1?,rt?∣st?,at?)]\argmax_θ ? \mathbb{E}_{\mathcal{B}} [\log p_θ ? (s_{t+1} ? ,r_t ? ∣s_t ? ,a_t ? )]argmaxθ??EB?[logpθ??(st+1??,rt??∣st??,at??)] 意味著,我們找到一組最優的參數 θ,使得在給定的真實數據上,模型預測出這個真實結果的概率最大。
這一步讓機器人的“虛擬世界”模型變得越來越逼真。
5. 優化策略(Optimize policy π inside predictive model)
這是最關鍵的一步。現在我們有了一個相當準確的“虛擬世界”模型 pθp_θpθ? 。我們可以利用這個模型,讓機器人進行大量的虛擬練習,而不需要再消耗寶貴的真實世界經驗。
公式中的 arg?max?π?Eπ?[∑t≥0?γtrt?]\argmax_π ? \mathbb{E}_π ? [∑_{t≥0} ? γ^t r_t ?]argmaxπ??Eπ??[∑t≥0??γtrt??] 是強化學習的標準目標,它的含義是:找到一個最優的策略 πππ,能夠最大化機器人未來獲得的累積獎勵。
機器人會不斷地在它的 “虛擬世界” 中模擬行動,嘗試各種策略,找到能夠獲得最高獎勵的行動序列。這使得它能在短時間內進行數百萬次模擬,從而快速提升自己的決策能力。
6. 循環 (end)
- 當這一輪的策略優化完成后,機器人會回到步驟 3,帶著更新后的、更強大的策略 πππ,再次進入真實世界收集數據,然后繼續訓練世界模型并優化策略。
4. 實驗(Experiments)
在實驗中,我們主要聚焦于以下問題:
- 在不同領域下,基于動作條件的視頻預測結果質量如何?
- 高斯世界模型(GWM)是否能對下游的模仿學習和強化學習帶來益處?它是否比基于圖像的世界模型表現出更強的魯棒性?
- 在真實世界的機器人操作任務中,高斯世界模型如何幫助典型的策略(例如擴散策略 [9])?
在以下小節中,我們將詳細描述模型在這些關鍵問題上的性能表現。具體來說,我們在實驗中利用了以下三個測試環境和四個任務:
環境(Environments)
為了對GWM的能力進行全面分析,我們在兩個合成環境和一個真實環境中評估了我們的方法:
- META-WORLD [90]:一個合成環境,用于學習機器人操作的強化學習策略;
- ROBOCASA [59]:一個大規模、多尺度的合成模仿學習基準,涵蓋廚房環境中的多樣化機器人操作任務;
- FRANKA-PNP:一個真實世界的抓取與放置環境,使用 Franka Emika FR3 機械臂。
任務(Tasks)
我們精心設計了四個任務,以系統性地在不同測試環境中評估GWM:
- 動作條件場景預測(Action-conditioned scene prediction):評估GWM在世界建模和未來預測中的有效性;
- 基于GWM的模仿學習(GWM-based imitation learning):考察其表示質量及其對基于模仿學習的機器人操作的益處;
- 基于GWM的強化學習(GWM-based RL):探索其在基于模型的強化學習中的潛力;
- 真實任務部署(Real-world task deployment):評估GWM在真實世界機器人操作中的魯棒性。
4.1. 動作條件場景預測(Action-conditioned Scene Prediction)
實驗設置(Experiment Setup) 一個世界模型生成高保真且與動作對齊的 rollout 的能力,對有效的策略優化至關重要。為了評估這一能力,我們在所有考慮的真實與合成環境中,使用人類演示來訓練GWM,并在評估時將模型條件化在驗證集中采樣得到的、從未見過的動作軌跡上,以進行未來預測質量測試。在定量評估方面,我們采用常見的生成質量指標,包括 FVD [76] 來衡量時間一致性,基于圖像的指標 PSNR [29] 用于像素級精度,同時還使用 SSIM [81] 和 LPIPS [95] 來評估感知質量。
結果與分析(Results and Analyses) 我們在表1中提供了本方法與 iVideoGPT 的定量比較。如表1所示,我們的方法在合成和真實環境中始終優于當前最先進的基于圖像的世界建模方法 iVideoGPT,這表明我們的基于擴散的高斯世界模型學習流程的有效性。值得注意的是,如圖3所示,像 iVideoGPT 這樣的基于圖像的模型容易在捕捉動態細節時出現失敗(例如機械手夾爪的動作)。盡管這些細節可能不會在視覺指標上造成大的差異,但它們會顯著影響策略學習,這一點我們將在第4.3節進一步討論。我們在圖4中提供了GWM在 ROBOCASA 和 FRANKA-PNP 上的預測結果的更多定性可視化。
4.2. 基于GWM的模仿學習(GWM-based Imitation Learning)
實驗設置(Experiment Setup) 如第3.3節所討論,GWM可以從圖像觀測中提取信息量豐富的表示,這有望為模仿學習帶來益處。我們通過在 ROBOCASA 上測試GWM在模仿學習中的有效性來驗證這一性質。ROBOCASA中的任務集包含24個廚房環境的原子任務,并配有相關的語言指令,包括諸如抓取與放置(pick-and-place)、打開(open)和關閉(close)等動作。每個任務都提供了一組有限的 50個真人演示,以及一組 3000個來自 MimicGen [55] 生成的演示。我們在這些演示上訓練GWM,并將其作為狀態編碼傳遞給最先進的 BC-transformer [59],以在成功率指標上進行定量比較。
結果與分析(Results and Analyses) 我們在 ROBOCASA 基準上的實驗結果展示于表2,結果證明了我們的方法在多任務模仿學習場景中的有效性。在24個廚房操作任務中,我們的方法始終優于 BC-Transformer 基線。
- 在有限真人演示(H-50)的情況下,我們的方法在成功率上平均提升了 10.5%;
- 在使用生成演示(G-3000)訓練時,我們的方法依然保持了可擴展的性能,平均增益為 7.6%。
值得注意的是,我們的方法在復雜操作任務(如抓取與放置)以及交互性任務(如打開/關閉電器)中表現出特別的優勢,這些場景中的性能提升最為顯著。這些結果確認了我們的方法能夠從視覺觀測中提取信息量豐富的表示,從而在實際的機器人操作場景中,有效增強模仿學習的能力。
4.3. 基于GWM的強化學習(GWM-based Reinforcement Learning)
實驗設置(Experiment Setup) 我們在 Meta-World [90] 中的六個機器人操作任務上評估了GWM對強化學習策略的支持能力,這些任務具有遞增的復雜性。我們實現了一種受 MBPO [31] 啟發的基于模型的強化學習方法,使用GWM生成的合成rollout來增強 DrQ-v2 [88] actor-critic 算法的回放緩沖區。 我們將最先進的基于圖像的世界模型 iVideoGPT [82] 作為強基線方法。為了公平比較,我們沒有對兩種方法使用預訓練初始化。同時,為了保證公平性,所有比較方法使用相同的上下文長度、預測范圍,并且最大訓練步數為 1 × 10^5。
結果與分析(Results and Analyses) 如圖5所示,GWM在所有六個Meta-World任務上都始終優于iVideoGPT。平均而言,GWM的收斂速度大約比iVideoGPT快 2倍,并且在復雜操作任務上達到了更高的漸近性能。 其優越性能的來源在于:GWM的 三維高斯表示 能夠相比純粹基于圖像的方法,更準確地預測操作中的接觸動力學與物體運動。這些結果證實了:顯式的三維表示在需要精確空間推理的機器人控制任務中,提供了顯著的優勢。
4.4. 真實世界部署(Real-world Deployment)
實驗設置(Experiment Setup) 我們在真實機器人實驗中部署了一臺 Franka Emika FR3 機械臂 和一個 Panda 夾爪。實驗聚焦于現實世界中的一個任務:抓取一個有顏色的杯子,并將其放置到桌子上的盤子上。我們使用 Mujoco AR 遠程操作接口 收集了30個演示數據。此外,我們還設置了一臺第三視角的 Realsense D435i 相機,用于提供未配準的僅RGB圖像作為觀測輸入。我們在圖6中給出了該真實世界任務設置的概覽。類似于第4.2節的實驗設置,我們將最先進的基于RGB的策略 Diffusion Policy [9] 與“是否使用GWM表示”進行比較,以任務成功率為指標進行定量分析。
結果與分析(Results and Analysis) 如表3所示,在20次試驗中,GWM在任務成功率上優于Diffusion Policy(65% 對 35%),這些試驗包含了不同的初始起始位置和物體位置(即干擾物)。當出現新的干擾物時,兩者的性能差距進一步擴大,這表明GWM具有更強的泛化能力。我們的方法在不同任務變體中始終保持一致的性能,這是因為其高效的世界模型能夠捕捉任務相關的動態特征,同時對視覺差異具有魯棒性。在補充文件中展示了真實世界的rollout,GWM的優勢主要源于 更精確的物體定位和更準確的放置操作。這些結果表明,GWM在真實世界機器人操作任務中具備穩健的時空理解能力。
4.5. 消融分析(Ablation Analysis)
我們在 ROBOCASA 上進行了額外實驗,以進一步驗證我們的設計選擇。
高斯點繪的選擇(Choice of Gaussian Splatting) 如表4所示,與直接使用擴散Transformer構建基于圖像的世界模型(類似于 [1])相比,引入高斯點繪(Gaussian Splatting)顯著提升了成功率(SR),從 4% 提高到 18%。雖然PSNR略有下降,但SSIM和LPIPS指標都有所提升,這表明高斯點繪在不同時間步之間提供了更好的三維一致性。這驗證了我們的假設:相比于純二維方法,顯式的三維表示能夠增強機器人學習的空間理解能力。
三維高斯VAE的選擇(Choice of 3D Gaussian VAE) 進一步引入三維VAE組件,使得所有指標(包括PSNR)都持續提升。成功率從 18% 提高到 24%。結果表明,我們的三維高斯VAE能夠高效捕捉場景的潛在結構,實現更緊湊的場景表示,同時保持空間理解。
5. 結論(Conclusion)
在本文中,我們提出了一種新穎的 高斯世界模型(Gaussian World Model, GWM),用于機器人操作。該模型通過引入穩健的幾何信息,解決了基于圖像的世界模型的局限性。我們的方法通過建模機器人動作下高斯基元的傳播來重建未來狀態。該方法將 擴散Transformer (DiT) 與 三維感知的變分自編碼器 相結合,并通過 高斯點繪(Gaussian Splatting) 實現了精確的場景級未來狀態重建。我們開發了一個可擴展的數據處理流程,以便在基于模型的強化學習框架中支持測試時更新,從未配準圖像中提取對齊的高斯點。在仿真和真實環境中的實驗均表明,GWM在未來場景預測和訓練更優策略方面具有有效性。
GWM: Towards Scalable Gaussian World Models for Robotic Manipulation (補充材料)
A. 數據集與基準(Datasets and Benchmarks)
Robocasa. 該數據集由機器人操作數據組成,這些數據來自 MuJoCo 仿真環境,使用 Franka Emika Panda 機械臂 收集,主要聚焦于廚房場景。
在我們的實驗中,我們使用了 Human-50 (H-50) 和 Generated-3000 (G-3000) 兩個數據集,它們由 RoboCasa 提供,并基于人類演示使用 MimicGen [55] 自動生成。該基準包含 24個原子任務,詳細信息見表2。
Metaworld. MetaWorld 是一個常用的基準,用于 元強化學習 和 多任務學習。它包含 50個不同的機器人操作任務,這些任務均在仿真環境中使用 Sawyer 機械臂 完成。觀測輸入是大小為 64 × 64 的RGB圖像,動作為一個 4維連續向量。
B. 實現細節(Implementation Details)
B.1. EDM 預處理(EDM Preconditioning)
如第3.2節所述,我們在此列出了為改進網絡訓練而設計的預處理器 [33]:
cinτ=1σ(τ)2+σdata2(A1)c^\tau_{in} = \frac{1}{\sqrt{\sigma(\tau)^2 + \sigma^2_{data}}} \tag{A1} cinτ?=σ(τ)2+σdata2??1?(A1)
coutτ=σ(τ)σdataσ(τ)2+σdata2(A2)c^\tau_{out} = \frac{\sigma(\tau)\sigma_{data}}{\sqrt{\sigma(\tau)^2 + \sigma^2_{data}}} \tag{A2} coutτ?=σ(τ)2+σdata2??σ(τ)σdata??(A2)
cnoiseτ=14log?(σ(τ))(A3)c^\tau_{noise} = \frac{1}{4} \log(\sigma(\tau)) \tag{A3} cnoiseτ?=41?log(σ(τ))(A3)
cskipτ=σdata2σdata2+σ(τ)2(A4)c^\tau_{skip} = \frac{\sigma^2_{data}}{\sigma^2_{data} + \sigma(\tau)^2} \tag{A4} cskipτ?=σdata2?+σ(τ)2σdata2??(A4)
其中,σdata=0.5\sigma_{data} = 0.5σdata?=0.5。
噪聲參數 σ(τ)\sigma(\tau)σ(τ) 的采樣方式如下,以最大化訓練的有效性:
log?(σ(τ))~N(Pmean,Pstd2),(A5)\log(\sigma(\tau)) \sim \mathcal{N}(P_{mean}, P^2_{std}), \tag{A5} log(σ(τ))~N(Pmean?,Pstd2?),(A5)
其中,Pmean=?0.4P_{mean} = -0.4Pmean?=?0.4,Pstd=1.2P_{std} = 1.2Pstd?=1.2。
B.2. 架構設計(Architectural Design)
變分自編碼器(VAE)采用 基于Transformer的架構,并使用 點嵌入(point embedding) 來編碼點云輸入。它通過 最遠點采樣(farthest point sampling) 將原始點云從 N=2048N = 2048N=2048 下采樣到可管理的潛在點數量 M=512M = 512M=512,隨后通過一系列 自注意力(self-attention) 和 交叉注意力(cross-attention) 模塊進行處理。
對于概率變體,編碼器輸出 均值(mean) 和 對數方差(logvar) 參數,通過 重參數化技巧(reparameterization trick) 來采樣潛在向量,同時可以選擇性地引入 KL散度正則項。
擴散模型 Dθ\mathcal{D}_\thetaDθ? 采用 視覺Transformer (Vision Transformer, DiT) 結構,通過多個Transformer模塊處理點圖(pointmap)補丁,并使用 自適應層歸一化(adaLN) 來對時間步和動作進行條件化。其輸入由以下部分組成:
- 當前觀測(current observation)、
- 加噪的下一個觀測(noisy next observation)、
- 時間嵌入(time embedding)、
- 當前動作嵌入(current action embedding)。
該模型根據 EDM(Elucidated Diffusion Models) 的公式來預測去噪后的下一個狀態。
獎勵模型 RψR_\psiRψ? 結合了 卷積編碼 和 序列建模,由帶有可選注意力層的 殘差塊(ResBlocks) 以及后續的 LSTM 組成。編碼器處理 一對觀測(當前狀態與下一個狀態),同時以嵌入的動作作為條件,而LSTM則捕獲時間依賴關系,最終通過一個 MLP頭 來預測獎勵。
在推理之前,LSTM的隱藏狀態通過一個 熱身過程(burn-in procedure) 使用條件幀進行初始化。
B.3. 超參數(Hyper-parameters)
RoboCasa 和 MetaWorld 實驗的超參數分別列于表 A2 和表 A1 中。