ALOHA ACT算法與源碼筆記

算法

一文通透動作分塊算法ACT:斯坦福ALOHA團隊推出的動作序列預測算法(Action Chunking with Transformers)
比較簡單,算法題目里就寫了:Action Chunking with Transformers,比較有特色的地方就是Action Chunking,核心就是不浪費之前做過的推理預測,統統拿過來加權一下,得到最終的答案。
在這里插入圖片描述

源碼

逐行解讀ALOHA ACT的實現:機器人動作分塊算法ACT的代碼剖析、訓練部署(含真機上的智能分揀復現)
代碼寫得很優雅,讀起來很流暢

1.1.1 模仿學習及其挑戰:Action Chunking with Transformers(ACT)

預測動作中的小誤差會引起狀態的大差異,加劇模仿學習的“復合誤差”問題。為了解決這個問題,他們從動作分塊(action chunking)中獲得靈感,這是心理學中的一個概念,描述了如何將一系列動作組合在一起作為一個塊,最終作為一個單元執行

他們使用Transformers實現動作分塊策略,并將其訓練為條件VAE (CVAE),以捕獲人類數據中的可變性。他們將該方法命名為Action Chunking with Transformers(ACT),并發現它在一系列模擬和現實世界的精細操作任務上顯著優于以前的模仿學習算法

2.2.2 第二步 推斷z,以獲得CVAE解碼器輸入中的風格變量z

一文通透動作分塊算法ACT:斯坦福ALOHA團隊推出的動作序列預測算法(Action Chunking with Transformers)的這句話啥意思?

最后
只取第一個輸出,它對應于**[CLS]標記**,并使用另一個線性網絡來預測z分布均值方差,將其參數化為對角高斯分布
且使用重新參數化獲得z的樣本,這是一種允許在采樣過程中反向傳播的標準方法,以便編碼器和解碼器可以聯合優化[33]

看detr_vae.py的代碼就知道了:
在DETRVAE的if is_training頭上有個注釋:Obtain latent z from action sequence,
意思是風格變量z就是latent_input
[CLS]標記:encoder_output = encoder_output[0] # take cls output only
均值:mu = latent_info[:, :self.latent_dim]
方差:logvar = latent_info[:, self.latent_dim:]
使用重新參數化獲得z的樣本:latent_sample = reparametrize(mu, logvar)
最后:latent_input = self.latent_out_proj(latent_sample)

2.3 優勢特征:ACT與其他模仿學習方法的比較

一方面,transformer解碼器的“query”是第一層固定的正弦位置嵌入,即如上圖右下角所示的position embeddings(fixed),其維度為k ×512
二方面,transformer解碼器的交叉注意力(cross-attention)層中的“keys”和“values”來自上述transformer編碼器的輸出

eval_bc(評估一個行為克隆(behavior cloning)模型)和train_bc(訓練行為克隆BC模型)的區別

我看到train_bc里頭有個eval的,但這個eval應該和eval_bc不一樣,雖然兩者都要用到policy.eval()
注:policy里頭就會調用

model, optimizer = build_ACT_model_and_optimizer(args_override)
self.model = model

1.8.3.2 根據觀察結果查詢策略、獲取動作

這里的train_bc的policy調用參數是(qpos_data, image_data, action_data, is_pad)
eval_bc的policy調用參數是(qpos, curr_image)
根據參數來判斷是訓練還是推理
在這里插入圖片描述
在訓練模式下,會計算出一系列的損失并返回一個包含這些損失的字典
在推理模式下,會從模型中獲取預測的動作并返回

aloha act代碼里頭的qpos和action有什么區別?

https://metaso.cn/s/IOAGn1O

那mu, logvar是啥

https://metaso.cn/s/IOAGn1O
在變分自編碼器(VAE)中,mu 和 logvar 是兩個關鍵參數,它們分別代表潛在變量的均值和對數方差,用于生成潛在空間的樣本。
這段代碼是 變分自編碼器(VAE) 中的 重參數化技巧(Reparameterization Trick) 的實現,其作用是 從潛在變量的分布中采樣,同時保證 梯度可以連續傳播,從而實現端到端的訓練。

def reparametrize(mu, logvar):std = logvar.div(2).exp()eps = Variable(std.data.new(std.size()).normal_())return mu + std * eps

編碼器和編碼器的輸入與輸出

backbone + encoder 等等輸入到 self.transformer,其實self.transformer就是decoder部分
核心代碼是detr_vae.pyclass DETRVAE(nn.Module):def forwardif is_training:部分
前提:detr_vae.pyclass DETRVAE(nn.Module):def forward的參數:qpos, image, env_state, actions, is_pad,都來自于imitate_episodes.pydef forward_pass(data, policy)data

編碼器的輸入與輸出

編碼器的核心調用語句:self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
參數的來源:

# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos)  # (bs, hidden_dim) # qpos來自于forward_pass(data, policy):的image_data, qpos_data, action_data, is_pad = data
qpos_embed = torch.unsqueeze(qpos_embed, axis=1)  # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token 輸出形狀為(bs, 2)的二維張量,里面元素全部填充為False
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1)  # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2)  # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only

編碼器的輸入與輸出

編碼器的的核心調用語句為
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
其中:

  1. src
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):features, pos = self.backbones[0](image[:, cam_id]) # HARDCODEDfeatures = features[0] # take the last layer featurepos = pos[0]all_cam_features.append(self.input_proj(features))all_cam_pos.append(pos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
  1. pos
for cam_id, cam_name in enumerate(self.camera_names):features, pos = self.backbones[0](image[:, cam_id]) # HARDCODEDfeatures = features[0] # take the last layer featurepos = pos[0]all_cam_features.append(self.input_proj(features))all_cam_pos.append(pos)
pos = torch.cat(all_cam_pos, axis=3)
  1. latent_input 【Obtain latent z from action sequence】里的latent z
self.latent_dim = 32
latent_info = self.latent_proj(encoder_output) # 來自于編碼器的輸出
mu = latent_info[:, :self.latent_dim] # 潛在變量的均值
logvar = latent_info[:, self.latent_dim:] # 潛在變量的對數方差
latent_sample = reparametrize(mu, logvar) 
latent_input = self.latent_out_proj(latent_sample)
  1. proprio_input = self.input_proj_robot_state(qpos) # qpos來自于forward_pass(data, policy):的image_data, qpos_data, action_data, is_pad = data

為什么env_max_reward 設成0 ?

可能真機不需要看模擬出來的精度?

# load environment
if real_robot:from aloha_scripts.robot_utils import move_grippers # requires alohafrom aloha_scripts.real_env import make_real_env # requires alohaenv = make_real_env(init_node=True)env_max_reward = 0 # 為什么設成0 ?
success_rate = np.mean(np.array(highest_rewards) == env_max_reward)
avg_return = np.mean(episode_returns)
summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n'
for r in range(env_max_reward+1):more_or_equal_r = (np.array(highest_rewards) >= r).sum()more_or_equal_r_rate = more_or_equal_r / num_rolloutssummary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n'print(summary_str)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/84247.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/84247.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/84247.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

數字ic后端設計從入門到精通6(含fusion compiler, tcl教學)repeater詳解

Repeaters RC延遲與導線長度的關系: 導線的電阻(R)和電容(C)都會隨著導線長度(l)的增加而增大。RC延遲是電阻和電容共同作用導致的信號延遲。由于RC延遲與R和C的乘積有關,因此它會隨…

Data Warebase 成功押注 PostgreSQL 生態,或成 AI 時代數據底座

本文內容整理自 ProtonBase CEO 王紹翾在 AICon 的主題演講《Data Warebase: Instant Ingest-Transform-Explore-Retrieve for AI Applications》。作者的職業經歷貫穿了 AI 1.0、2.0 和 3.0 的時代,從搜索推薦,到視覺 / 語音 / NLP 智能,再到…

【電力電子】基于STM32F103C8T6單片機雙極性SPWM逆變(硬件篇)

本項目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脈寬調制)電源模塊,能夠生成可調頻率和幅值的正弦波交流電源輸出。該項目適用于逆變器、UPS電源、變頻器等應用場景。 供電電源 輸入電壓采集 上圖為本設計的電源電路,圖中 D1 為二極管, 其目的是防止正負極電源反接, …

Kubernetes (k8s)版本發布情況

Kubernetes (k8s)版本發布情況 代碼放在 GitHub - kubernetes/kubernetes: Production-Grade Container Scheduling and Management https://github.com/kubernetes/kubernetes/releases 文檔放在 kubernetes.io各個版本變更等: https://github.com/kubernetes/kubernet…

Python 接口:從協議到抽象基 類(Python使用register的方式)

Python使用register的方式 示例 11-14 把 Tombola.register 當作類裝飾器使用。在 Python 3.3 之 前的版本中不能這樣使用 register,必須在定義類之后像普通函數那 樣調用,如示例 11-14 中最后那行注釋所述。 雖然現在可以把 register 當作裝飾器使用了…

GRU 參數梯度推導與梯度消失分析

GRU 參數梯度推導與梯度消失分析 1. GRU 前向計算回顧 GRU 單元的核心計算步驟(忽略偏置項): 更新門: z_t σ(W_z [h_{t-1}, x_t]) 重置門: r_t σ(W_r [h_{t-1}, x_t]) 候選狀態: ?h_t tanh(W_h [r_t ⊙ h_{t-1}, x_t]) 新…

【字節擁抱開源】字節團隊開源視頻模型 ContentV: 有限算力下的視頻生成模型高效訓練

本項目提出了ContentV框架,通過三項關鍵創新高效加速基于DiT的視頻生成模型訓練: 極簡架構設計,最大化復用預訓練圖像生成模型進行視頻合成系統化的多階段訓練策略,利用流匹配技術提升效率經濟高效的人類反饋強化學習框架&#x…

分布式增量爬蟲實現方案

之前我們在討論的是分布式爬蟲如何實現增量爬取。增量爬蟲的目標是只爬取新產生或發生變化的頁面,避免重復抓取,以節省資源和時間。 在分布式環境下,增量爬蟲的實現需要考慮多個爬蟲節點之間的協調和去重。 另一種思路:將增量判…

單片機0-10V電壓輸出電路分享

一、原理圖 二、芯片介紹 GP8101是一個PWM信號轉模擬信號轉換器,相當于一個PWM信號輸入,模擬信號輸出的DAC。此 芯片可以將占空比為0%到100%的PWM信號線性轉換成0-5V或者0-10V的模擬電壓,并且輸出電壓 精度小于1%。GP8101M可以處理高頻調制的…

Spring AMQP

在現代分布式系統中,消息隊列是一種非常重要的通信機制,它能夠實現服務之間的異步通信、負載均衡以及解耦。Spring AMQP 是 Spring 框架對 AMQP(高級消息隊列協議)的支持,而 RabbitMQ 是 AMQP 協議的最流行實現之一。通…

第6章:Neo4j數據導入與導出

在實際應用中,數據的導入與導出是使用Neo4j的重要環節。無論是初始數據加載、系統遷移還是數據備份,都需要高效可靠的數據傳輸機制。本章將詳細介紹Neo4j中的各種數據導入與導出方法,幫助讀者掌握不同場景下的最佳實踐。 6.1 數據導入策略 …

RKNN開發環境搭建1-基于Ubuntu 18.04系統使用Docker安裝rknn-toolkit2

目錄 寫在最前面Docker 方式安裝rknn-toolkit2寫在最前面 瑞芯微在RKNN的環境搭建方面的資料很多,但是在搭建過程中發現很多問題教程中并未提及,對初學者不友好。所以博主做了這個系列的文章,從開始搭建環境到對于RKNN Model Zoo的示例進行實踐,希望能對初學者有幫助。堅持…

【實施指南】Android客戶端HTTPS雙向認證實施指南

🔐 一、所需準備材料 證書文件(6類核心文件) 類型 格式 作用 Android端要求 CA根證書 .crt/.pem 驗證服務器/客戶端證書合法性 需預置到Android信任庫 服務器證書 .crt 服務器身份證明 客戶端需持有以驗證服務器 客戶端證書 .crt 客戶端身份…

FPGA管腳類型,及選擇

fpga的IO Type選擇,如下: 具體的定義:

SELinux是什么以及如何編寫SELinux策略

目錄 一、SELinux 是什么? 二、SELinux 的兩種模式 如何查看當前 SELinux 狀態? 三、SELinux 在 Android 中的作用 四、為什么Root之后很多設備是 Permissive? 五、開發與調試場景 總結 🧩 一、什么是 SELinux 策略&#x…

MQTT示例體驗(C)

1、通用依賴準備 安裝編譯工具? Linux/macOS 需安裝: sudo apt update && sudo apt install build-essential cmake git # Ubuntu/Debian:ml-citation{ref"6" data"citationList"} brew install cmake # macOSWindows 需安裝 CMake…

MySQL中的系統庫(簡介、performance_schema)

文章目錄 性能監控performance_schema1、performance schema入門2、performance_schema表的分類3、performance_schema的簡單配置與使用4、常用配置項的參數說明5、重要配置表的相關說明6、performance_schema實踐操作 Show processlist 性能監控 每次你提交完一個 sql 語句之…

【Ftrace 專欄】Ftrace 參考博文

ftrace、perf、bcc、bpftrace、ply、simple_perf的使用Ftrace 基本用法Linux 利用 ftrace 分析內核調用如何利用ftrace精確跟蹤特定進程調度信息使用 ftrace 進行追蹤延遲Linux-培訓筆記-ftracehttps://www.kernel.org/doc/html/v4.18/trace/events.htmlhttps://blog.csdn.net/…

bug 記錄 - 使用 el-dialog 的 before-close 的坑

需求說明 彈窗中內嵌一個 form 表單 原始代碼 <script setup lang"ts"> import { reactive, ref } from "vue" import type { FormRules } from element-plus const ruleFormRef ref() interface RuleForm {name: stringregion: number | null } …

關鍵領域軟件測試的突圍之路:如何破解安全與效率的平衡難題

在數字化浪潮席卷全球的今天&#xff0c;軟件系統已成為國家關鍵領域的核心戰斗力。不同于普通商業軟件&#xff0c;這些承載著國家安全使命的軟件系統面臨著前所未有的質量挑戰——如何在確保絕對安全的前提下&#xff0c;實現高效測試與快速迭代&#xff1f;這一命題正考驗著…