前言?
ChatGPT出來后的兩年多,也是我瘋狂寫博的兩年多(年初deepseek更引爆了下),比如從創業起步時的15年到后來22年之間 每年2-6篇的,干到了23年30篇、24年65篇、25年前兩月18篇,成了我在大模型和具身的原始技術積累
如今一轉眼已到25年3月初,時光走得太快,近期和團隊接了好幾個大客戶訂單,使得3月起 不得不全力加速落地,自己也得每天摳paper、搞代碼
雖然今年可能沒法像去年24年那樣干65篇,不過,我還是爭取保持月月更新
- 一方面,有些文章是之前既定計劃中的,比如如此文《π0開源了且推出自回歸版π0-FAST——打造機器人動作專用的高效Tokenizer:比擴散π0的訓練速度快5倍但效果相當》最后所說的,對π0源碼的解讀
「至于什么是π0,詳見此文《π0——用于通用機器人控制的VLA模型:一套框架控制7種機械臂(基于PaliGemma和流匹配的3B模型)》」 - 二方面,我司「七月在線」在做一系列工廠落地場景的過程中,我們也希望團結到可以和我們一塊做的朋友,而若想團結,便需要對外分享我們每個季度在重點做的業務場景
比如過去一周,我把lerobot、reflect vlm、π0的仿真環境都在我自己本地電腦上跑了下(過程中,GitHub copilot這種AI編程工具在環境的安裝上幫了我很大的忙——各種環境 只要幾句命令,直接幫我裝好,真心不錯)
如此硬著頭皮冥思苦想、摸索了好幾天,隨后使得我自己知道怎么帶隊完成『太多工廠希望實現的一個生產線任務』了,3月初先仿真訓練,2-3個月內部署到真機
當然了,也不單純只是「這幾天的想」就能想出來的,?這幾天之前
- 有把過去一年當三年用的具身技術積累
- 有一年多來,和同事們 如姚博士,以及朋友們許多的討論
- 有去年十幾個工廠對我們的支持與信任
我們正在不斷壯大隊伍
- 有我司內部同事,亦有我帶的北理、中南等985的具身研究生,及一塊合作開發的朋友,很快會把多個生產線任務并行開發起來
- 且無論哪個項目,都是不斷長期迭代的,故過程中少不了科研層面的突破,歡迎更多伙伴加入我們(全、兼、實習皆可,有意者,敬請私我),和我們一塊開發
話休絮煩,本文便按照如下圖所示的源碼結構,重點解讀一下π的整個源碼
- π0的源碼結構非常清晰、可讀性高,不愧是成熟的商業化公司,是我司七月的學習榜樣之一
- 我身邊的很多朋友目前都在做π0的微調及二次開發,相信本文無論對我身邊的朋友,還是對更多人的學習與工作,都會起到比較大的提升
目錄
前言?
第一部分?examples、packages、scripts等結構的分析
1.1 examples :各種機器人平臺的示例實現
1.2 packages
1.3?scripts:包含數據處理、模型訓練/推理的多個腳本
1.3.1 __init__.py
1.3.2?compute_norm_stats.py:計算數據的歸一化統計信息
1.3.3?serve_policy.py:啟動策略服務,用于模型推理
1.3.4?train_test.py:訓練和測試模型
1.3.5?train.py:訓練模型
1.3.6 scripts/docker
第二部分 核心模塊src下models的全面分析與解讀
2.1 models/pi0.py的實現
2.1.1 make_attn_mask:注意力掩碼生成函數
2.1.2 posemb_sincos:位置編碼函數
2.1.3 class Pi0Config:含inputs_spec、get_freeze_filter
2.1.3.1?模型配置參數的定義
2.1.3.2?inputs_spec:定義了π0模型本身接收的輸入數據格式?編輯
2.1.3.3?get_freeze_filter:針對是否LoRA的處理
2.1.4 class Pi0:初始化、特征嵌入、損失函數、推理(去噪生成動作)
2.1.4.1 初始化方法 `__init__`
2.1.4.2 特征嵌入方法:embed_prefix(圖像和文本輸入)、embed_suffix(狀態和動作信息)?編輯
2.1.4.3 損失函數 `compute_loss`
2.1.4.4 推理函數 `sample_actions`:基于擴散模型逆向采樣,生成機器人動作序列
第一部分?examples、packages、scripts等結構的分析
1.1 examples :各種機器人平臺的示例實現
根據π0對應examples模塊的結構
其涉及以下模塊
- aloha_real/:真實機器人ALOHA的示例
- aloha_sim/:ALOHA模擬器的示例
- droid/:DROID機器人的示例
- libero/:LIBERO基準測試的示例
- simple_client/:簡單客戶端的示例
- ur5/:UR5機器人的示例
- inference.ipynb:推理示例的Jupyter Notebook
- policy_records.ipynb:策略記錄示例的Jupyter Notebook
1.2 packages
該模塊的目錄結構如下
1.3?scripts:包含數據處理、模型訓練/推理的多個腳本
根據下圖
可知,scripts 目錄包含多個 Python 腳本,這些腳本用于數據處理、模型訓練和服務部署等任務,每個腳本通常對應一個特定的功能或任務
- __init__.py
- compute_norm_stats.py: 計算數據的歸一化統計信息
- serve_policy.py: 啟動策略服務,提供模型推理接口
- train_test.py: 訓練和測試模型
- train.py: 訓練模型
1.3.1 __init__.py
1.3.2?compute_norm_stats.py:計算數據的歸一化統計信息
1.3.3?serve_policy.py:啟動策略服務,用于模型推理
- 在這個代碼片段中,首先導入了一些必要的模塊和庫,包括 `policy`、`policy_config`、`websocket_policy_server` 和 `config`,這些模塊來自 `openpi` 項目
接下來定義了一個枚舉類 `EnvMode`,它表示支持的環境類型,包括 `ALOHA`、`ALOHA_SIM`、`DROID` 和 `LIBERO`from openpi.policies import policy as _policy # 導入 openpi.policies.policy 模塊并重命名為 _policy from openpi.policies import policy_config as _policy_config # 導入 openpi.policies.policy_config 模塊并重命名為 _policy_config from openpi.serving import websocket_policy_server # 導入 openpi.serving.websocket_policy_server 模塊 from openpi.training import config as _config # 導入 openpi.training.config 模塊并重命名為 _config
class EnvMode(enum.Enum):"""支持的環境。"""ALOHA = "aloha" # ALOHA 環境ALOHA_SIM = "aloha_sim" # ALOHA 模擬環境DROID = "droid" # DROID 環境LIBERO = "libero" # LIBERO 環境
- 然后定義了幾個數據類
`Checkpoint` 類用于從訓練好的檢查點加載策略,包含兩個字段:`config`(訓練配置名稱)和 `dir`(檢查點目錄)
`Default` 類表示使用默認策略
`Args` 類定義了腳本的參數,包括環境類型、默認提示、端口、是否記錄策略行為以及如何加載策略 - 接下來定義了一個字典 `DEFAULT_CHECKPOINT`,它為每個環境類型指定了默認的檢查點配置
`create_default_policy` 函數根據環境類型創建默認策略,如果環境類型不支持,則拋出異常# 每個環境應使用的默認檢查點 DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {EnvMode.ALOHA: Checkpoint(config="pi0_aloha",dir="s3://openpi-assets/checkpoints/pi0_base",),EnvMode.ALOHA_SIM: Checkpoint(config="pi0_aloha_sim",dir="s3://openpi-assets/checkpoints/pi0_aloha_sim",),EnvMode.DROID: Checkpoint(config="pi0_fast_droid",dir="s3://openpi-assets/checkpoints/pi0_fast_droid",),EnvMode.LIBERO: Checkpoint(config="pi0_fast_libero",dir="s3://openpi-assets/checkpoints/pi0_fast_libero",), }
`create_policy` 函數根據傳入的參數創建策略,如果參數中指定了檢查點,則從檢查點加載策略,否則使用默認策略def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:"""為給定環境創建默認策略 """if checkpoint := DEFAULT_CHECKPOINT.get(env): # 獲取環境對應的默認檢查點return _policy_config.create_trained_policy(_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt) # 創建訓練好的策略raise ValueError(f"Unsupported environment mode: {env}") # 如果環境不支持,拋出異常
def create_policy(args: Args) -> _policy.Policy:"""根據給定的參數創建策略 """match args.policy: # 匹配策略類型case Checkpoint(): # 如果是 Checkpoint 類型return _policy_config.create_trained_policy(_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt) # 創建訓練好的策略case Default(): # 如果是 Default 類型return create_default_policy(args.env, default_prompt=args.default_prompt) # 創建默認策略
- `main` 函數是腳本的入口點,它首先調用 `create_policy` 函數創建策略,然后記錄策略的元數據
如果參數中指定了記錄策略行為,則使用 `PolicyRecorder` 包裝策略def main(args: Args) -> None:policy = create_policy(args) # 創建策略policy_metadata = policy.metadata # 獲取策略的元數據
接著獲取主機名和本地 IP 地址# 記錄策略的行為if args.record:# 使用 PolicyRecorder 記錄策略行為policy = _policy.PolicyRecorder(policy, "policy_records")
并創建一個 WebSocket 服務器來提供策略服務,最后調用 `serve_forever` 方法啟動服務器hostname = socket.gethostname() # 獲取主機名local_ip = socket.gethostbyname(hostname) # 獲取本地 IP 地址logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) # 記錄服務器創建信息
server = websocket_policy_server.WebsocketPolicyServer(policy=policy,host="0.0.0.0",port=args.port,metadata=policy_metadata,) # 創建 WebSocket 策略服務器server.serve_forever() # 啟動服務器,永遠運行
- 在腳本的最后,使用 `logging` 模塊配置日志記錄,并調用 `main` 函數啟動腳本,參數通過 `tyro.cli` 解析
1.3.4?train_test.py:訓練和測試模型
1.3.5?train.py:訓練模型
1.3.6 scripts/docker
好的,下面是對 `openpi-main/scripts/docker` 目錄的詳細分析。這個目錄通包含與 Docker 相關的腳本和配置文件,用于構建和管理 Docker 容器,具體而言,包含以下文件和子目錄:
主要文件和功能如下所示
- docker/compose.yml
- docker/install_docker_ubuntu22.sh
- docker/install_nvidia_container_toolkit.sh
- docker/serve_policy.Dockerfile
// 待更
第二部分 核心模塊src下models的全面分析與解讀
接下來,我們來看核心src下的各個模塊
首先是其中的src/openpi/models
2.1 models/pi0.py的實現
它結合了多模態輸入(圖像和文本)來生成機器人動作序列。下面是對代碼的詳細解析:
2.1.1 make_attn_mask:注意力掩碼生成函數
這個函數生成transformer中使用的注意力掩碼,控制 token 之間的注意力流動方式
def make_attn_mask(input_mask, mask_ar):"""從big_vision項目改編的注意力掩碼生成函數Token可以關注那些累積mask_ar小于等于自己的有效輸入token。這樣`mask_ar` bool[?B, N]可用于設置幾種類型的注意力,例如:[[1 1 1 1 1 1]]: 純因果注意力。[[0 0 0 1 1 1]]: 前綴語言模型注意力。前3個token之間可以互相關注,后3個token有因果注意力。第一個條目也可以是1,不改變行為。[[1 0 1 0 1 0 0 1 0 0]]: 4個塊之間的因果注意力。一個塊的token可以關注所有之前的塊和同一塊內的所有token。參數:input_mask: bool[B, N] 如果是輸入的一部分則為true,如果是填充則為falsemask_ar: bool[?B, N] 如果前面的token不能依賴于它則為true,如果它共享與前一個token相同的注意力掩碼則為false"""# 將mask_ar廣播到與input_mask相同的形狀mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape) # 計算mask_ar在序列維度上的累積和cumsum = jnp.cumsum(mask_ar, axis=1) # 創建注意力掩碼:當目標位置的累積值<=查詢位置的累積值時,允許注意力流動attn_mask = cumsum[:, None, :] <= cumsum[:, :, None] # 創建有效掩碼:只有有效的輸入位置之間才能有注意力valid_mask = input_mask[:, None, :] * input_mask[:, :, None] # 結合注意力掩碼和有效掩碼return jnp.logical_and(attn_mask, valid_mask)
它支持多種注意力模式:
- 純因果注意力(每個 token 只能關注自己和之前的 token)
- 前綴語言模型注意力(允許前綴內部自由注意,后綴部分使用因果注意力)
- 塊狀因果注意力(在塊內自由注意,塊之間是因果的)
2.1.2 posemb_sincos:位置編碼函數
使用正弦余弦函數實現位置編碼
def posemb_sincos(pos: at.Real[at.Array, Any], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, f"b {embedding_dim}"]:"""計算標量位置的正弦余弦位置嵌入向量"""if embedding_dim % 2 != 0: # 檢查嵌入維度是否為偶數raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2) # 創建均勻分布的分數值period = min_period * (max_period / min_period) ** fraction # 計算周期值,對數空間中均勻分布sinusoid_input = jnp.einsum("i,j->ij",pos,1.0 / period * 2 * jnp.pi, # 計算角頻率precision=jax.lax.Precision.HIGHEST, # 使用最高精度進行計算)# 連接sin和cos值,形成完整的位置編碼return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
2.1.3 class Pi0Config:含inputs_spec、get_freeze_filter
2.1.3.1?模型配置參數的定義
首先,這個類定義了模型的配置參數,比如PaLI-Gemma 變體:`gemma_2b
class Pi0Config(_model.BaseModelConfig):dtype: str = "bfloat16" # 設置數據類型為bfloat16paligemma_variant: _gemma.Variant = "gemma_2b" # 設置PaLI-Gemma變體為2B參數版本action_expert_variant: _gemma.Variant = "gemma_300m" # 設置動作專家變體為300M參數版本# 設置模型特定的默認值action_dim: int = 32 # 設置動作維度為32action_horizon: int = 50 # 設置動作序列長度為50步max_token_len: int = 48 # 設置最大token長度為48
2.1.3.2?inputs_spec:定義了π0模型本身接收的輸入數據格式
其次,通過inputs_spec函數定義了π0模型本身接收的輸入數據格式,函數采用關鍵字參數 `batch_size`(默認為1),返回一個包含觀察規格和動作規格的元組
def inputs_spec(self, *, batch_size: int = 1) -> Tuple[Type[_model.Observation], Type[_model.Actions]]
- 其支持多種輸入,比如
視覺輸入(三個不同視角的RGB圖像)、語言輸入(分詞后的文本prompt)、狀態輸入(當前機器人狀態) - 輸出上
則是一個時序動作序列(包含50個連續的動作向量,每個動作向量有32個維度,可能對應關節角度或其他控制信號)
具體而言該函數先
創建圖像規格
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
其中的
- `[batch_size, *_model.IMAGE_RESOLUTION, 3]` 定義了圖像張量的形狀:比如
? 批次大小
? 圖像分辨率(從 `_model.IMAGE_RESOLUTION` 獲取,可能是如 [224, 224] 這樣的值)
? 3 個顏色通道 (RGB)
- `jnp.float32` 指定了數據類型為 32 位浮點數
創建圖像掩碼規格
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
其定義了圖像掩碼規格,每個批次中的每個圖像都有一個布爾值,這個掩碼用于指示哪些圖像是有效的(`True`)或無效的(`False`)
創建觀察規格:包含視覺輸入、機器人狀態、指令輸入
`at.disable_typechecking()` 臨時禁用類型檢查,可能是因為這里創建的是類型規格而不是實際的數據,且觀察規格包含多個組件:
- 多視角圖像
base_0_rgb: 機器人底座/身體視角的RGB圖像
left_wrist_0_rgb: 左手腕視角的RGB圖像
right_wrist_0_rgb: 右手腕視角的RGB圖像with at.disable_typechecking():observation_spec = _model.Observation(images={"base_0_rgb": image_spec,"left_wrist_0_rgb": image_spec,"right_wrist_0_rgb": image_spec,},
- 圖像掩碼
對應每個視角圖像的有效性掩碼 - 機器人狀態:
形狀為 `[batch_size, self.action_dim]` 的浮點數張量
`self.action_dim` 默認為32,表示狀態向量的維度state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
- 分詞后的文本prompt
形狀為 `[batch_size, self.max_token_len]` 的整數張量
`self.max_token_len` 默認為48,表示最大token數量
數據類型為 `jnp.int32`,表示token ID - 提示掩碼
與分詞提示相同形狀的布爾張量,用于指示哪些位置有有效的tokenstate=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),)
創建動作規格
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
其定義了動作數據的形狀和類型:
- `batch_size`: 批次大小
- `self.action_horizon`: 動作序列長度,默認為50
- ?`self.action_dim`: 每個動作的維度,默認為32
- `jnp.float32` 指定了數據類型為32位浮點數
然后返回
return observation_spec, action_spec
2.1.3.3?get_freeze_filter:針對是否LoRA的處理
此外,該配置類還實現了get_freeze_filter這個函數,作用是如果選擇LoRA微調(凍結原始預訓練模型的參數,只更新新添加的低秩適應層參數),則需要對模型中的某些參數做凍結
三種可能的情況:
- 只對 PaLI-Gemma 使用 LoRA:凍結 Gemma 參數(但排除動作專家參數)
- 只對動作專家使用 LoRA:凍結動作專家參數
- 對兩者都使用 LoRA:凍結兩者的基礎參數
如此,可以選擇性地微調模型的特定部分(語言部分或動作預測部分)
具體而言
- 首先,定義函數
def get_freeze_filter(self) -> nnx.filterlib.Filter:"""返回基于模型配置的凍結過濾器"""
- 其次,初始化變量
filters = [] # 初始化過濾器列表has_lora = False # 初始化LoRA標志
- 接著,創建參數過濾器
# 匹配所有LLM參數的正則表達式,用于選擇 Gemma 語言模型的參數gemma_params_filter = nnx_utils.PathRegex(".*llm.*") # 匹配動作專家參數的正則表達式action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
- 接下來是對PaLI-Gemma變體的處理
# 如果PaLI-Gemma使用LoRAif "lora" in self.paligemma_variant:filters.append(gemma_params_filter, # 添加Gemma參數過濾器)if "lora" not in self.action_expert_variant:# 如果只凍結Gemma參數,排除動作專家參數filters.append(nnx.Not(action_expert_params_filter),)has_lora = True
- 再下來是對動作專家變體的處理
elif "lora" in self.action_expert_variant:# 如果動作專家使用LoRAfilters.append(action_expert_params_filter,)has_lora = True
2.1.4 class Pi0:初始化、特征嵌入、損失函數、推理(去噪生成動作)
核心模型類,繼承自 `_model.BaseModel`,實現了:
- 多模態輸入處理
處理多視角圖像(基礎視角、左手腕視角、右手腕視角)
處理文本提示(如指令)
處理機器人當前狀態 - 擴散過程
訓練時:將干凈動作添加噪聲,讓模型學習去噪
推理時:從純噪聲開始,逐步降噪生成動作序列 - 注意力機制
使用精心設計的注意力掩碼控制信息流動
前綴(圖像和文本)內部使用全注意力
后綴(狀態和動作)使用特殊的注意力模式
2.1.4.1 初始化方法 `__init__`
class Pi0(_model.BaseModel):def __init__(self, config: Pi0Config, rngs: nnx.Rngs):# 初始化基類super().__init__(config.action_dim, config.action_horizon, config.max_token_len)# 獲取PaLI-Gemma和動作專家配置paligemma_config = _gemma.get_config(config.paligemma_variant)action_expert_config = _gemma.get_config(config.action_expert_variant)
其組合了多個核心組件:
一個是PaLI-Gemma 模型:結合了 Gemma 語言模型和 SigLIP 視覺模型
- 先是對語言模型的初始化
# 創建并初始化語言模型# TODO: 用NNX重寫Gemma,目前使用橋接llm = nnx_bridge.ToNNX(_gemma.Module(configs=[paligemma_config, action_expert_config], # 配置兩個Gemma模型embed_dtype=config.dtype, # 設置嵌入數據類型))llm.lazy_init(rngs=rngs, method="init") # 延遲初始化LLM
- 然后是對視覺模型的初始化
# 創建并初始化圖像模型img = nnx_bridge.ToNNX(_siglip.Module(num_classes=paligemma_config.width, # 設置圖像特征維度與語言模型寬度相匹配variant="So400m/14", # 使用400M參數SigLIP模型pool_type="none", # 不使用池化,保留所有圖像標記scan=True, # 啟用掃描優化dtype_mm=config.dtype, # 設置矩陣乘法數據類型))# 使用假觀察中的圖像初始化圖像模型img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
- 最后,把語言模型和視覺模型組合成PaLI-Gemma多模態模型
# 組合LLM和圖像模型為PaLI-Gemma多模態模型self.PaliGemma = nnx.Dict(llm=llm, img=img)
另一個是線性投影層:用于
- 狀態投影
# 狀態投影層:將機器人狀態投影到模型維度self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
- 動作投影
# 動作輸入投影層:將動作投影到模型維度self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
- 時間-動作混合等
# 動作-時間MLP輸入層:將連接的動作和時間特征投影到模型維度self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)# 動作-時間MLP輸出層self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)# 動作輸出投影層:將模型輸出投影回動作維度self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
2.1.4.2 特征嵌入方法:embed_prefix(圖像和文本輸入)、embed_suffix(狀態和動作信息)
- `embed_prefix`:處理圖像和文本輸入(圖像通過SigLip模型編碼,文本通過Gemma LLM編碼),創建前綴 token,皆為雙向注意力,用ar_mask = false表示
- `embed_suffix`:處理機器人狀態信息
、噪聲化的動作信息
(狀態和噪聲動作經過線性投影和MLP處理),創建后綴 token
其中
狀態為單個token,和第一個動作token均設置為單向注意力,用ar_mask = true表示
其余動作tokens之間設置為雙向注意力,用ar_mask = false表示
對于前者embed_prefix
def embed_prefix(self, obs: _model.Observation) -> Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:"""嵌入前綴部分(圖像和文本)"""input_mask = [] # 初始化輸入掩碼列表ar_mask = [] # 初始化自回歸掩碼列表tokens = [] # 初始化token列表
其工作流程為
- 圖像處理:說白了,就是把圖像token化
使用SigLip視覺模型處理每個圖像,生成圖像tokens序列# 嵌入圖像for name in obs.images:# 通過圖像模型獲取圖像tokenimage_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)tokens.append(image_tokens) # 添加圖像token
- 圖像掩碼擴展
將圖像掩碼擴展到與圖像tokens相同的序列長度,使用einops.repeat進行形狀變換,這些掩碼會指示哪些圖像是有效的,而哪些是填充的# 重復圖像掩碼以匹配token維度input_mask.append(einops.repeat(obs.image_masks[name],"b -> b s", # 調整形狀:批次維度保持不變,添加序列維度s=image_tokens.shape[1], # 序列長度等于圖像token數))
- 自回歸掩碼設置
設置圖像tokens之間的注意力為雙向(False表示雙向注意力),原因在于圖像內容通常是非時序性的數據# 圖像token之間互相關注(非自回歸)ar_mask += [False] * image_tokens.shape[1]
- 文本處理
使用LLM模型對文本輸入tokenized_inputs進行嵌入
且同樣設置為雙向注意力,相當于語言token可以關注圖像token,圖像token反過來亦可關注語言token,最終實現多模態融合# 添加語言(即分詞后的輸入)if obs.tokenized_prompt is not None:# 通過語言模型嵌入分詞后的提示tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")tokens.append(tokenized_inputs) # 添加文本tokeninput_mask.append(obs.tokenized_prompt_mask) # 添加提示掩碼
# 圖像和語言輸入之間完全關注(非自回歸)ar_mask += [False] * tokenized_inputs.shape[1]
- 最后,連接所有token和掩碼,其中包含了
??多模態信息的融合表示tokens——圖像token和語言token
? 以及指示哪些token是有效信息的input_mask
? 和如何在這些token之間進行注意力計算規則的ar_mask——相當于控制信息流動的方向
# 連接所有token和掩碼tokens = jnp.concatenate(tokens, axis=1) # 在序列維度上連接tokeninput_mask = jnp.concatenate(input_mask, axis=1) # 在序列維度上連接輸入掩碼ar_mask = jnp.array(ar_mask) # 轉換自回歸掩碼為數組return tokens, input_mask, ar_mask # 返回token、輸入掩碼和自回歸掩碼
順便,再回顧下此圖
對于后者embed_suffix
def embed_suffix(self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, Any]) -> Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:"""嵌入后綴部分(狀態和動作)"""input_mask = [] # 初始化輸入掩碼列表ar_mask = [] # 初始化自回歸掩碼列表tokens = [] # 初始化token列表
其工作流程為
- 狀態處理
將狀態信息投影到embedding空間
并設置為單向注意力(True),表明圖像和語言輸入不能關注狀態信息,因為image/language do not attend to state or actions# 添加單個狀態tokenstate_token = self.state_proj(obs.state)[:, None, :] # 投影狀態并添加序列維度tokens.append(state_token) # 添加狀態token# 添加狀態掩碼(全為1),表示這個狀態token是有效的input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
# 圖像/語言輸入不關注狀態或動作(自回歸)ar_mask += [True]
- 時間步嵌入,使用正弦-余弦位置編碼生成時間步嵌入
# 使用正弦余弦位置編碼嵌入時間步,敏感度范圍為[0, 1]time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
- 動作和時間信息融合
# 混合時間步+動作信息,使用MLPaction_tokens = self.action_in_proj(noisy_actions) # 投影帶噪聲的動作# 重復時間嵌入以匹配動作序列長度time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)# 連接動作和時間tokenaction_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
- MLP處理
使用兩層MLP和swish激活函數對「動作和時間的組合表示」進行非線性變換,以進一步融合:動作和時間信息# 通過MLP處理action_time_tokens = self.action_time_mlp_in(action_time_tokens) # 輸入層action_time_tokens = nnx.swish(action_time_tokens) # Swish激活函數action_time_tokens = self.action_time_mlp_out(action_time_tokens) # 輸出層
- 注意力掩碼設置
第一個動作token設置為單向注意力「上面說過了的,單向注意力,用ar_mask = true表示」,其余動作tokens之間設置為雙向注意力# 添加動作時間tokentokens.append(action_time_tokens)# 添加掩碼(全為1),表示所有動作token都是有效的input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_)) # 圖像/語言/狀態輸入不關注動作token(動作第一個是自回歸的——單向,其余不是——雙向)ar_mask += [True] + ([False] * (self.action_horizon - 1))
- 最后連接所有token和掩碼
# 連接所有token和掩碼tokens = jnp.concatenate(tokens, axis=1) # 在序列維度上連接tokeninput_mask = jnp.concatenate(input_mask, axis=1) # 在序列維度上連接輸入掩碼ar_mask = jnp.array(ar_mask) # 轉換自回歸掩碼為數組return tokens, input_mask, ar_mask # 返回token、輸入掩碼和自回歸掩碼
2.1.4.3 損失函數 `compute_loss`
實現了擴散模型的訓練損失計算
- 對輸入觀察進行預處理,其中
preprocess_rng用于觀察預處理(比如圖像增強等)
noise_rng用于生成噪聲
time_rng用于從beta分布采樣時間步def compute_loss(self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False) -> at.Float[at.Array, Any]:"""計算擴散模型的損失函數"""# 分割隨機數生成器為三部分,用于不同的隨機操作preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
- 生成隨機噪聲并采樣時間點 t
# 獲取動作的批次形狀batch_shape = actions.shape[:-2]# 生成與動作相同形狀的高斯噪聲noise = jax.random.normal(noise_rng, actions.shape)# 從Beta分布采樣時間點,范圍為[0.001, 1],Beta(1.5, 1)偏向較低的值time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001# 擴展時間維度以匹配動作形狀time_expanded = time[..., None, None]
- 創建帶噪動作序列 x_t,相當于x_t是噪聲化的動作,隨著時間從0到1,原始動作逐漸加噪,變為純噪聲
而u_t代表所加的真實噪聲,而咱們就是要預測所添加的噪聲(而所添加的噪聲即等于加滿噪聲的動作 - 原始動作)
擴散策略diffusion policy的靈感來源于圖像生成中的擴散模型DDPM,通過逐步去除噪聲來生成目標數據(比如機器人的動作序列),如果對DDPM原理不太明白的,詳見此文《圖像生成發展起源:從VAE、擴散模型DDPM、DDIM到DETR、ViT、Swin transformer》# 創建帶噪聲的動作:t*noise + (1-t)*actionsx_t = time_expanded * noise + (1 - time_expanded) * actions# 計算真實噪聲減去動作的差異,這是模型需要預測的目標u_t = noise - actions
- 嵌入前綴和后綴
# 一次性前向傳遞前綴+后綴# 嵌入前綴(圖像和文本)prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)# 嵌入后綴(狀態和帶噪聲的動作)suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, time)
- 構建注意力掩碼和位置編碼
根據下圖
可得# 連接掩碼:通過鏈接前綴和后綴的掩碼,從而創建完整的輸入掩碼input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)# 創建注意力掩碼make_attn_mask,從而控制不同token之間的可見性attn_mask = make_attn_mask(input_mask, ar_mask)# 計算位置編碼positions = jnp.cumsum(input_mask, axis=1) - 1
- 模型前向傳播,即使用PaliGemma進行推理,處理前綴和后綴token
當然了,輸出中我們只關注與后綴相關的部分,因為其中包含了我們想要的動作預測的部分# 通過PaLI-Gemma模型處理token_, suffix_out = self.PaliGemma.llm([prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions)
- 預測噪聲v_t
# 將模型輸出投影回動作空間v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
- 計算預測噪聲與實際噪聲間的均方誤差
# 返回預測噪聲和真實噪聲之間的均方誤差return jnp.mean(jnp.square(v_t - u_t), axis=-1)
2.1.4.4 推理函數 `sample_actions`:基于擴散模型逆向采樣,生成機器人動作序列
sample_actions函數是Pi0模型的核心推理方法,實現了基于擴散模型的逆向采樣過程——說白了 就是去噪,它從純噪聲開始,通過多步驟逐漸"去噪",最終生成符合條件分布的機器人動作序列
函數的核心是一個基于while循環的迭代過程,每一步都使用訓練好的神經網絡預測從當前噪聲化動作到目標動作的方向——從噪聲到目標的方向 代表速度場,畢竟咱們去噪的方向得對 不然就去歪了
總之,這個函數將觀察數據(圖像和可選的文本提示)轉換為具體的動作軌跡,是模型部署時的主要接口,簡言之,其包含以下流程
- 首先從純噪聲開始 (t=1)
- 通過重復迭代降噪步驟,逐步將噪聲轉化為有意義的動作序列
- 使用KV緩存優化推理速度
- 實現了一個迭代降噪過程:
- 最終返回完全降噪后的動作序列 x_0
具體而言,包含如下步驟
第一,初始化
首先,函數對輸入觀察數據進行預處理,包括標準化圖像大小等操作
def sample_actions(self,rng: at.KeyArrayLike, # 隨機數生成器observation: _model.Observation, # 觀察輸入,包含圖像和文本等*,num_steps: int = 10, # 擴散過程的步數,默認為10步
) -> _model.Actions: # 返回生成的動作序列# 對觀察數據進行預處理,不進行訓練時的數據增強observation = _model.preprocess_observation(None, observation, train=False)
然后設置時間步長`dt`為負值(因為我們是從t=1向t=0方向演化),生成初始隨機噪聲作為起點,且時間上約定:"t=1是噪聲,t=0是目標分布",這是擴散文獻中常見的約定,不過與Pi0論文相反
# 注意:這里使用擴散模型文獻中更常見的約定,t=1是噪聲,t=0是目標分布# 這與pi0論文相反dt = -1.0 / num_steps # 計算時間步長,從1到0batch_size = observation.state.shape[0] # 獲取批次大小# 生成初始噪聲,形狀為[批次大小, 動作序列長度, 動作維度]noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
第二,Key-Value緩存初始化(預計算并存儲前綴表示,減少冗余計算)
處理觀察數據,得到前綴表示和相關掩碼
# 首先通過前綴的前向傳遞填充KV緩存# 獲取前綴的token表示和掩碼prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)# 創建前綴的注意力掩碼prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)# 計算位置編碼positions = jnp.cumsum(prefix_mask, axis=1) - 1
然后使用PaliGemma語言模型進行一次前向傳遞,生成Key-Value緩存(`kv_cache`)——這是一個性能優化:因為前綴部分在整個采樣過程中保持不變,預先計算并緩存它們的表示可以避免重復計算
# 進行前向傳遞,獲取KV緩存_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
第三,通過step函數構建注意力掩碼系統并讓PaliGemma做推理
核心迭代通過 `jax.lax.while_loop` 實現
根據源碼
可知,該class Pi0(_model.BaseModel)類的最后兩行是
# 使用while循環進行迭代采樣,從t=1(噪聲)開始x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))# 返回最終的去噪結果(生成的動作序列)return x_0
具體而言,包含 `step` 函數和 `cond` 函數,其中,`step` 函數是每次迭代的核心
首先,step函數通過 `embed_suffix` 處理當前狀態,包括狀態信息嵌入、噪聲化動作、時間步編碼
def step(carry):"""定義單步去噪函數"""x_t, time = carry # carry數組包含當前狀態和時間# 將時間廣播到批次維度,并嵌入后綴(狀態和動作)suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, jnp.broadcast_to(time, batch_size))
其次,構建復雜的注意力掩碼系統,處理前綴-后綴之間的注意力關系——這個復雜的掩碼系統允許后綴token(包括狀態和動作)有選擇地關注前綴token(圖像和文本),實現了條件生成,具體而言,其構建了三層注意力掩碼:
- 后綴內部注意力掩碼,控制后綴token(狀態和動作)之間的注意力關系
# 創建后綴內部的注意力掩碼,形狀為(批次, 后綴長度, 后綴長度)suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
- 前綴-后綴注意力掩碼,控制后綴token如何關注前綴token(圖像和文本輸入)
# 創建后綴對前綴的注意力掩碼,形狀為(批次, 后綴長度, 前綴長度)prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
- 完整注意力掩碼,將前兩個掩碼組合,形成完整的注意力控制機制
# 組合掩碼,形狀為(批次, 后綴長度, 前綴長度+后綴長度)# 控制后綴token(生成查詢)如何關注完整序列(生成鍵和值)full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
當然了,過程中還做了形狀檢查,確保張量維度正確
# 驗證掩碼形狀正確assert full_attn_mask.shape == (batch_size,suffix_tokens.shape[1],prefix_tokens.shape[1] + suffix_tokens.shape[1],)
接著,計算位置編碼,為后綴token計算其在完整序列中的位置,這對于Transformer模型理解序列順序很重要
# 計算后綴token的位置編碼positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
之后,模型推理,使用PaliGemma語言模型進行推理,利用緩存的前綴信息(`kv_cache`)提高效率
# 使用KV緩存進行高效的前向傳遞(prefix_out, suffix_out), _ = self.PaliGemma.llm([None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache)# 且確保前綴輸出為None(因為使用了KV緩存)assert prefix_out is None
第四,step函數中做最后的速度預測與動作更新(去噪)
在每一步中,模型預測速度場 `v_t`(從噪聲到目標的方向),并通過類歐拉法更新動作表示——使用簡單而有效的歐拉方法進行軌跡采樣
具體而言
- 一方面,提取模型輸出并預測速度場`v_t`——相當于本質是通過PaliGemma模型預測去噪方向 `v_t`
# 預測噪聲v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
- 二方面,使用歐拉法更新動作狀態和時間步
# 使用歐拉方法更新狀態和時間return x_t + dt * v_t, time + dt
至于cond函數確定何時停止迭代,通過檢查時間是否接近零(當然,要考慮浮點精讀可能存在的誤差)
def cond(carry):"""定義循環終止條件"""x_t, time = carry# 考慮浮點誤差,當時間接近0時停止return time >= -dt / 2
// 待更