具身系列——Diffusion Policy算法實現CartPole游戲

代碼原理分析

1. 核心思想

該代碼實現了一個基于擴散模型(Diffusion Model)的強化學習策略網絡。擴散模型通過逐步去噪過程生成動作,核心思想是:
? 前向過程:通過T步逐漸將專家動作添加高斯噪聲,最終變成純噪聲
? 逆向過程:訓練神經網絡預測噪聲,通過T步逐步去噪生成動作
? 數學基礎:基于DDPM(Denoising Diffusion Probabilistic Models)框架

算法步驟
1.1 前向加噪:在動作空間逐步添加高斯噪聲,將真實動作分布轉化為高斯分布
q ( a t ∣ a t ? 1 ) = N ( a t ; 1 ? β t a t ? 1 , β t I ) q(\mathbf{a}_t|\mathbf{a}_{t-1}) = \mathcal{N}(\mathbf{a}_t; \sqrt{1-\beta_t}\mathbf{a}_{t-1}, \beta_t\mathbf{I}) q(at?at?1?)=N(at?;1?βt? ?at?1?,βt?I)
其中 β t \beta_t βt? 為噪聲調度參數(網頁4][網頁5][網頁8])。

1.2 逆向去噪:基于觀測 o t \mathbf{o}_t ot? 條件去噪生成動作
p θ ( a t ? 1 ∣ a t , o t ) = N ( a t ? 1 ; μ θ ( a t , o t , t ) , Σ t ) p_\theta(\mathbf{a}_{t-1}|\mathbf{a}_t, \mathbf{o}_t) = \mathcal{N}(\mathbf{a}_{t-1}; \mu_\theta(\mathbf{a}_t, \mathbf{o}_t, t), \Sigma_t) pθ?(at?1?at?,ot?)=N(at?1?;μθ?(at?,ot?,t),Σt?)
去噪網絡 μ θ \mu_\theta μθ? 預測噪聲殘差(網頁5][網頁6][網頁8])。

1.3 訓練目標:最小化噪聲預測誤差
L = E t , a 0 , ? [ ∥ ? ? ? θ ( α t a 0 + 1 ? α t ? , o t , t ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{t,\mathbf{a}_0,\epsilon}\left[ \|\epsilon - \epsilon_\theta(\sqrt{\alpha_t}\mathbf{a}_0 + \sqrt{1-\alpha_t}\epsilon, \mathbf{o}_t, t)\|^2 \right] L=Et,a0?,??[???θ?(αt? ?a0?+1?αt? ??,ot?,t)2]
其中 α t = ∏ s = 1 t ( 1 ? β s ) \alpha_t = \prod_{s=1}^t (1-\beta_s) αt?=s=1t?(1?βs?)(網頁4][網頁8][網頁11])。

2. 關鍵數學公式

? 前向過程(擴散過程):

q(a_t|a_{t-1}) = N(a_t; √(α_t)a_{t-1}, (1-α_t)I)
α_t = 1 - β_t,α?_t = ∏_{i=1}^t α_i
a_t = √α?_t a_0 + √(1-α?_t)ε,其中ε ~ N(0,I)

? 訓練目標(噪聲預測):

L = ||ε - ε_θ(a_t, s, t)||^2

? 逆向過程(采樣過程):

p_θ(a_{t-1}|a_t) = N(a_{t-1}; μ_θ(a_t, s, t), Σ_t)
μ_θ = 1/√α_t (a_t - β_t/√(1-α?_t) ε_θ)

逐行代碼注釋

import torch
import gymnasium as gym
import numpy as npclass DiffusionPolicy(torch.nn.Module):def __init__(self, state_dim=4, action_dim=2, T=20):super().__init__()self.T = T  # 擴散過程總步數self.betas = torch.linspace(1e-4, 0.02, T)  # 噪聲方差調度self.alphas = 1 - self.betas  # 前向過程參數self.alpha_bars = torch.cumprod(self.alphas, dim=0)  # 累積乘積α?# 去噪網絡(輸入維度:state(4) + action(2) + timestep(1) = 7)self.denoiser = torch.nn.Sequential(torch.nn.Linear(7, 64),  # 輸入層torch.nn.ReLU(),          # 激活函數torch.nn.Linear(64, 2)    # 輸出預測的噪聲)self.optimizer = torch.optim.Adam(self.denoiser.parameters(), lr=1e-3)def train_step(self, states, expert_actions):batch_size = states.size(0)t = torch.randint(0, self.T, (batch_size,))  # 隨機采樣時間步alpha_bar_t = self.alpha_bars[t].unsqueeze(1)  # 獲取對應α?_t# 前向加噪(公式實現)noise = torch.randn_like(expert_actions)  # 生成高斯噪聲noisy_actions = torch.sqrt(alpha_bar_t) * expert_actions + \torch.sqrt(1 - alpha_bar_t) * noise  # 公式(2)# 輸入拼接(狀態、加噪動作、歸一化時間步)inputs = torch.cat([states, noisy_actions,(t.float() / self.T).unsqueeze(1)  # 時間步歸一化到[0,1]], dim=1)  # 最終維度:batch_size x 7pred_noise = self.denoiser(inputs)  # 預測噪聲loss = torch.mean((noise - pred_noise)**2)  # MSE損失return lossdef sample_action(self, state):state_tensor = torch.FloatTensor(state).unsqueeze(0)a_t = torch.randn(1, 2)  # 初始化為隨機噪聲(動作維度2)# 逆向去噪過程(需要補全)for t in reversed(range(self.T)):# 應實現的步驟:# 1. 獲取當前時間步參數# 2. 拼接輸入(狀態,當前動作,時間步)# 3. 預測噪聲ε_θ# 4. 根據公式計算均值μ# 5. 采樣新動作(最后一步不添加噪聲)passreturn a_t.detach().numpy()[0]  # 返回最終動作

執行過程詳解

訓練流程
  1. 隨機采樣時間步:為每個樣本隨機選擇擴散步t ∈ [0, T-1]
  2. 前向加噪:根據公式將專家動作添加對應程度的噪聲
  3. 輸入構造:拼接狀態、加噪動作和歸一化時間步
  4. 噪聲預測:神經網絡預測添加的噪聲
  5. 損失計算:最小化預測噪聲與真實噪聲的MSE
采樣流程(需補全)
  1. 初始化:從高斯噪聲開始
  2. 迭代去噪:從t=T到t=1逐步去噪
    ? 根據當前動作和狀態預測噪聲
    ? 計算前一步的均值
    ? 添加隨機噪聲(最后一步除外)
  3. 輸出:得到最終去噪后的動作

關鍵改進建議

  1. 實現逆向過程:需要補充時間步循環和去噪公式
  2. 添加方差調度:在采樣時使用更復雜的方差計算
  3. 時間步嵌入:可以使用正弦位置編碼代替簡單歸一化
  4. 網絡結構優化:考慮使用Transformer或條件批歸一化

該實現展示了擴散策略的核心思想,但完整的擴散策略還需要實現完整的逆向采樣過程,并可能需要調整噪聲調度參數以獲得更好的性能。

最終可執行代碼:

import torch
import gymnasium as gym
import numpy as npclass DiffusionPolicy(torch.nn.Module):def __init__(self, state_dim=4, action_dim=2, T=20):super().__init__()self.T = Tself.betas = torch.linspace(1e-4, 0.02, T)self.alphas = 1 - self.betasself.alpha_bars = torch.cumprod(self.alphas, dim=0)# 去噪網絡(輸入維度:4+2+1=7)self.denoiser = torch.nn.Sequential(torch.nn.Linear(7, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2))self.optimizer = torch.optim.Adam(self.denoiser.parameters(), lr=1e-3)def train_step(self, states, expert_actions):batch_size = states.size(0)t = torch.randint(0, self.T, (batch_size,))alpha_bar_t = self.alpha_bars[t].unsqueeze(1)# 前向加噪公式[2](@ref)noise = torch.randn_like(expert_actions)noisy_actions = torch.sqrt(alpha_bar_t) * expert_actions + torch.sqrt(1 - alpha_bar_t) * noise# 輸入拼接(維度對齊)[1](@ref)inputs = torch.cat([states, noisy_actions,(t.float() / self.T).unsqueeze(1)], dim=1)  # 最終維度:batch_size x 7pred_noise = self.denoiser(inputs)loss = torch.mean((noise - pred_noise)**2)return lossdef sample_action(self, state):state_tensor = torch.FloatTensor(state).unsqueeze(0)a_t = torch.randn(1, 2)  # 二維動作空間[2](@ref)# 逆向去噪過程[2](@ref)for t in reversed(range(self.T)):alpha_t = self.alphas[t]alpha_bar_t = self.alpha_bars[t]inputs = torch.cat([state_tensor,a_t,torch.tensor([[t / self.T]], dtype=torch.float32)], dim=1)pred_noise = self.denoiser(inputs)a_t = (a_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)if t > 0:a_t += torch.sqrt(self.betas[t]) * torch.randn_like(a_t)return torch.argmax(a_t).item()  # 離散動作選擇[1](@ref)if __name__ == "__main__":env = gym.make('CartPole-v1')policy = DiffusionPolicy()# 關鍵修復:確保狀態數據維度統一[1,2](@ref)states, actions = [], []state, _ = env.reset()for _ in range(1000):action = env.action_space.sample()next_state, _, terminated, truncated, _ = env.step(action)done = terminated or truncated# 強制轉換狀態為numpy數組并檢查維度[2](@ref)state = np.array(state, dtype=np.float32).flatten()if len(state) != 4:raise ValueError(f"Invalid state shape: {state.shape}")states.append(state)  # 確保每個狀態是(4,)的數組actions.append(action)if done:state, _ = env.reset()else:state = next_state# 維度驗證與轉換[1](@ref)states_array = np.stack(states)  # 強制轉換為(1000,4)if states_array.shape != (1000,4):raise ValueError(f"States shape error: {states_array.shape}")actions_onehot = np.eye(2)[np.array(actions)]  # 轉換為one-hot編碼[2](@ref)states_tensor = torch.FloatTensor(states_array)actions_tensor = torch.FloatTensor(actions_onehot)# 訓練循環for epoch in range(100):loss = policy.train_step(states_tensor, actions_tensor)policy.optimizer.zero_grad()loss.backward()policy.optimizer.step()print(f"Epoch {epoch}, Loss: {loss.item():.4f}")# 測試state, _ = env.reset()for _ in range(200):action = policy.sample_action(state)state, _, done, _, _ = env.step(action)if done: break

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/899375.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/899375.shtml
英文地址,請注明出處:http://en.pswp.cn/news/899375.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

DeepSeek 本地化部署教程

1 概述 1.1 配置參考圖 科普: B,Billion(十億),是 “參數量” 的單位。 模型量超過 一億,可稱之為 “大模型”。 2 軟件安裝 2.1 下載 Ollama 官方主頁:https://ollama.com/download主頁截圖…

matlab打開兩個工程

1、問題描述 寫代碼時,需要實時參考別人的代碼,需要同時打開2個模型,當模型在同一個工程內時,這是可以直接打開的,如圖所示 2、解決方案 再打開一個MATLAB主窗口 這個時候就可以同時打開多個模型了 3、正確的打開方…

mac 下配置flutter 總是失敗,請參考文章重新配置flutter 環境MacOS Flutter環境配置和安裝

一、安裝和運行Flutter的系統環境要求 想要安裝并運行 Flutter,你的開發環境需要最低滿足以下要求: 操作系統:macOS磁盤空間:2.8 GB(不包括IDE/tools的磁盤空間)。工具:Flutter使用git進行安裝和升級。我們建議安裝Xcode,其中包括git&#x…

第4.1節:使用正則表達式

1 第4.1節:使用正則表達式 將正則表達式用斜杠括起來,就能用作模式。隨后,該正則表達式會與每條輸入記錄的完整文本進行比對。(通常情況下,它只需匹配文本的部分內容就能視作匹配成功。)例如,以…

Java 代理(一) 靜態代理

學習代理的設計模式的時候,經常碰到的一個經典場景就是想統計某個方法的執行時間。 1 靜態代理模式的產生 需求1. 統計方法執行時間 統計方法執行時間,在很多API/性能監控中都有這個需求。 下面以簡單的計算器為例子,計算加法耗時。代碼如下…

每日總結3.28

藍橋刷題 3227 找到最多的數 方法一&#xff1a;摩爾投票法 #include <bits/stdc.h> using namespace std; #define int long long signed main() { int n,m; cin>>n>>m; int a[m*n]; for(int i0;i<n*m;i) { cin>>a[i]; } int cand…

Flutter快速搭建聊天

之前項目中使用的環信聊天&#xff0c;我們的App使用的Flutter開發的 。 所以&#xff0c;就使用的 em_chat_uikit &#xff0c;這個是環信開發的Flutter版本的聊天。 一開始&#xff0c;我們也用的環信的聊天&#xff0c;是收費的&#xff0c;但是&#xff0c;后面就發現&…

Sa-Token

簡介 Sa-Token 是一個輕量級 Java 權限認證框架&#xff0c;主要解決&#xff1a;登錄認證、權限認證、單點登錄、OAuth2.0、分布式Session會話、微服務網關鑒權 等一系列權限相關問題。 官方文檔 常見功能 登錄認證 本框架 用戶提交 name password 參數&#xff0c;調用登…

基于javaweb的SSM航班機票預訂平臺系統設計與實現(源碼+文檔+部署講解)

技術范圍&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬蟲、數據可視化、小程序、安卓app、大數據、物聯網、機器學習等設計與開發。 主要內容&#xff1a;免費功能設計、開題報告、任務書、中期檢查PPT、系統功能實現、代碼編寫、論文編寫和輔導、論…

格雷碼、漢明碼,CRC校驗的區別

格雷碼、漢明碼和CRC校驗都是用于數據傳輸和存儲中的編碼技術。 它們在原理、功能和應用場景上存在顯著區別。 1.格雷碼&#xff08;Gray Code&#xff09; ? 定義&#xff1a;格雷碼是一種特殊的二進制編碼&#xff0c;任意兩個相鄰的碼字之間僅有一位不同。 ? 功能&#x…

【報錯】 /root/anaconda3/conda.exe: cannot execute binary file: Exec format error

背景: 安裝Anaconda3 bash Anaconda3-****-Linux-x86_64.sh 報錯: /root/anaconda3/conda.exe: cannot execute binary file: Exec format error 原因分析: 安裝包(如

JAVA實現動態IP黑名單過濾

一些惡意用戶(可能是黑客、爬蟲、DDoS 攻擊者)可能頻繁請求服務器資源&#xff0c;導致資源占用過高。因此需要一定的手段實時阻止可疑或惡意的用戶&#xff0c;減少攻擊風險。 通過 IP 封禁&#xff0c;可以有效拉黑攻擊者&#xff0c;防止資源被濫用&#xff0c;保障合法用戶…

開源的CMS建站系統可以隨便用嗎?有什么需要注意的?

開源CMS建站系統雖然具有許多優點&#xff0c;但并非完全“隨便用”。無論選哪個CMS系統&#xff0c;大家在使用的時候&#xff0c;可以盡可能地多注意以下幾點&#xff1a; 1、版權問題 了解開源許可證&#xff1a;不同的開源CMS系統采用不同的開源許可證&#xff0c;如GPL、…

故障識別 | 基于改進螂優化算法(MSADBO)優化變分模態提取(VME)結合稀疏最大諧波噪聲比解卷積(SMHD)進行故障診斷識別,matlab代碼

基于改進螂優化算法&#xff08;MSADBO&#xff09;優化變分模態提取&#xff08;VME&#xff09;結合稀疏最大諧波噪聲比解卷積&#xff08;SMHD&#xff09;進行故障診斷識別 一、引言 1.1 機械故障診斷的背景和意義 在工業生產的宏大畫卷中&#xff0c;機械設備的穩定運行…

探究 CSS 如何在HTML中工作

2025/3/28 向全棧工程師邁進&#xff01; 一、CSS的作用 簡單一句話——美化網頁 <p>Lets use:<span>Cascading</span><span>Style</span><span>Sheets</span> </p> 對于如上代碼來說&#xff0c;其顯示效果如下&#xff1…

硬件老化測試方案的設計誤區

硬件老化測試方案設計中的常見誤區主要包括測試周期不足、測試條件過于單一、樣品選擇不當等方面。其中&#xff0c;測試周期不足尤為突出&#xff0c;容易導致潛在缺陷未被完全暴露。老化測試本質上是通過加速產品老化來模擬長期使用狀況&#xff0c;因此測試周期不足會嚴重削…

無錫零碳園區“三年突圍”安科瑞源網荷儲充系統如何破解“綠電難、儲能貴、調度亂”困局?

零碳園區建設如火如荼&#xff0c;為何企業“不敢投、不會用”&#xff1f; 無錫市政府3月27日發布《零碳園區建設三年行動方案》&#xff0c;目標到2027年建成10家以上零碳園區、20家零碳工廠、10個源網荷儲一體化項目。但企業仍存疑慮&#xff1a; 綠電消納難&#xff1a;光…

docker torcherve打包mar包并部署模型

使用Docker打包深度網絡模型mar包到服務端 參考鏈接&#xff1a;Docker torchserve 部署模型流程——以WSL部署YOLO-FaceV2為例_class myhandler(basehandler): def initialize(self,-CSDN博客 1、docker拉取環境鏡像命令 docker images出現此提示為沒有權限取執行命令&…

Redis 分布式鎖實現深度解析

Redis 分布式鎖是分布式系統中協調多進程/服務對共享資源訪問的核心機制。以下從基礎概念到高級實現進行全面剖析。 一、基礎實現原理 1. 最簡實現&#xff08;SETNX 命令&#xff09; # 加鎖 SET resource_name my_random_value NX PX 30000# 解鎖&#xff08;Lua腳本保證原…

kubernetes》》k8s》》 kubeadm、kubectl、kubelet

kubeadm 、kubectl 、kubelet kubeadm、kubectl和kubelet是Kubernetes中不可或缺的三個組件。kubeadm負責集群的快速構建和初始化&#xff0c;為后續的容器部署和管理提供基礎&#xff1b;kubectl作為命令行工具&#xff0c;提供了與Kubernetes集群交互的便捷方式&#xff1b;而…