分布變化的模仿學習算法

與傳統監督學習不同,直接模仿學習在不同時刻所面臨的數據分布可能不同.試設計一個考慮不同時刻數據分布變化的模仿學習算法

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.neighbors import KernelDensity
import matplotlib.pyplot as pltclass TimeAwareImitationLearning:def __init__(self, state_dim, action_dim, hidden_dim=64, device='cpu'):"""初始化時間感知的模仿學習算法state_dim: 狀態維度action_dim: 動作維度hidden_dim: 隱藏層維度"""self.state_dim = state_dimself.action_dim = action_dimself.device = device# 策略網絡 - 模仿專家行為self.policy = nn.Sequential(nn.Linear(state_dim + 1, hidden_dim),  # +1 是為了包含時間信息nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, action_dim)).to(device)# 判別器網絡 - 區分專家和策略生成的軌跡self.discriminator = nn.Sequential(nn.Linear(state_dim + action_dim + 1, hidden_dim),  # +1 是為了包含時間信息nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1),nn.Sigmoid()).to(device)# 優化器self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=1e-3)self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=1e-3)# 記錄訓練過程self.train_losses = []def _compute_time_weights(self, expert_times, current_time, sigma=1.0):"""計算時間權重,距離當前時間越近的樣本權重越大"""time_diffs = np.abs(expert_times - current_time)weights = np.exp(-time_diffs / (2 * sigma**2))return weights / np.sum(weights)def _compute_mmd_loss(self, expert_states, policy_states, times, current_time):"""計算最大均值差異(MMD)損失,衡量分布差異"""# 計算時間權重weights = self._compute_time_weights(times, current_time)# 對專家狀態應用時間權重weighted_expert_states = expert_states * weights.reshape(-1, 1)# 計算MMDexpert_kernel = rbf_kernel(weighted_expert_states, weighted_expert_states)policy_kernel = rbf_kernel(policy_states, policy_states)cross_kernel = rbf_kernel(weighted_expert_states, policy_states)mmd = np.mean(expert_kernel) + np.mean(policy_kernel) - 2 * np.mean(cross_kernel)return mmddef train(self, expert_states, expert_actions, expert_times, epochs=100, batch_size=64):"""訓練時間感知的模仿學習模型expert_states: 專家狀態序列 [num_samples, state_dim]expert_actions: 專家動作序列 [num_samples, action_dim]expert_times: 專家時間戳 [num_samples]"""num_samples = expert_states.shape[0]expert_states_tensor = torch.FloatTensor(expert_states).to(self.device)expert_actions_tensor = torch.FloatTensor(expert_actions).to(self.device)expert_times_tensor = torch.FloatTensor(expert_times).reshape(-1, 1).to(self.device)for epoch in range(epochs):# 當前"時間" - 使用訓練輪次的比例作為時間表示current_time = epoch / epochs# 生成策略動作policy_actions = []for i in range(0, num_samples, batch_size):batch_states = expert_states_tensor[i:i+batch_size]batch_times = torch.full((batch_states.shape[0], 1), current_time).to(self.device)policy_action = self.policy(torch.cat([batch_states, batch_times], dim=1))policy_actions.append(policy_action.detach().cpu().numpy())policy_actions = np.vstack(policy_actions)# 計算MMD損失mmd_loss = self._compute_mmd_loss(expert_states, policy_actions, expert_times, current_time)# 訓練判別器for _ in range(5):  # 判別器訓練多次# 隨機采樣批次indices = np.random.randint(0, num_samples, batch_size)batch_expert_states = expert_states_tensor[indices]batch_expert_actions = expert_actions_tensor[indices]batch_expert_times = expert_times_tensor[indices]# 生成策略動作batch_times = torch.full((batch_size, 1), current_time).to(self.device)batch_policy_actions = self.policy(torch.cat([batch_expert_states, batch_times], dim=1))# 計算判別器損失expert_input = torch.cat([batch_expert_states, batch_expert_actions, batch_expert_times], dim=1)policy_input = torch.cat([batch_expert_states, batch_policy_actions, batch_times], dim=1)expert_output = self.discriminator(expert_input)policy_output = self.discriminator(policy_input)# 判別器損失 (最大化區分能力)d_loss = -torch.mean(torch.log(expert_output + 1e-8) + torch.log(1 - policy_output + 1e-8))self.discriminator_optimizer.zero_grad()d_loss.backward()self.discriminator_optimizer.step()# 訓練策略網絡for _ in range(1):  # 策略網絡訓練較少次數indices = np.random.randint(0, num_samples, batch_size)batch_states = expert_states_tensor[indices]batch_times = torch.full((batch_size, 1), current_time).to(self.device)# 生成策略動作actions = self.policy(torch.cat([batch_states, batch_times], dim=1))# 計算策略損失 (最小化判別器的區分能力)policy_input = torch.cat([batch_states, actions, batch_times], dim=1)policy_output = self.discriminator(policy_input)# 策略損失 + MMD正則化p_loss = -torch.mean(torch.log(policy_output + 1e-8)) + 0.1 * mmd_lossself.policy_optimizer.zero_grad()p_loss.backward()self.policy_optimizer.step()# 記錄損失self.train_losses.append(p_loss.item())if epoch % 100 == 0:print(f"Epoch {epoch}, Loss: {p_loss.item():.4f}, MMD: {mmd_loss:.4f}")def predict(self, state, time):"""根據當前狀態和時間預測動作"""state_tensor = torch.FloatTensor(state).reshape(1, -1).to(self.device)time_tensor = torch.FloatTensor([time]).reshape(1, 1).to(self.device)with torch.no_grad():action = self.policy(torch.cat([state_tensor, time_tensor], dim=1))return action.cpu().numpy()[0]def visualize_training(self):"""可視化訓練過程"""plt.figure(figsize=(10, 6))plt.plot(self.train_losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.grid(True)plt.show()# 示例:生成具有時間分布變化的專家數據
def generate_time_varying_expert_data(num_samples=1000, state_dim=2, time_period=1.0):"""生成隨時間變化的數據分布"""times = np.linspace(0, time_period, num_samples)states = []actions = []for t in times:# 狀態分布隨時間變化mean = np.array([np.sin(2 * np.pi * t), np.cos(2 * np.pi * t)])cov = np.diag([0.1 + 0.1 * np.abs(np.sin(np.pi * t)), 0.1 + 0.1 * np.abs(np.cos(np.pi * t))])state = np.random.multivariate_normal(mean, cov)# 動作是狀態的函數,也隨時間變化action = 2.0 * state * (1.0 + 0.5 * np.sin(2 * np.pi * t))states.append(state)actions.append(action)return np.array(states), np.array(actions), times# 測試算法
def test_time_aware_il():# 生成專家數據state_dim = 2action_dim = 2expert_states, expert_actions, expert_times = generate_time_varying_expert_data(num_samples=2000, state_dim=state_dim, time_period=1.0)# 創建并訓練模型model = TimeAwareImitationLearning(state_dim, action_dim)model.train(expert_states, expert_actions, expert_times, epochs=500)# 可視化訓練過程model.visualize_training()# 測試不同時間點的策略test_times = np.linspace(0, 1, 5)test_states = np.random.randn(len(test_times), state_dim)plt.figure(figsize=(12, 8))for i, t in enumerate(test_times):plt.subplot(2, 3, i+1)# 真實專家行為expert_mask = (expert_times >= t - 0.1) & (expert_times <= t + 0.1)plt.scatter(expert_states[expert_mask, 0], expert_states[expert_mask, 1], c='blue', alpha=0.5, label='Expert')# 模型預測行為pred_actions = np.array([model.predict(s, t) for s in expert_states[expert_mask]])plt.scatter(pred_actions[:, 0], pred_actions[:, 1], c='red', alpha=0.5, label='Policy')plt.title(f'Time = {t:.2f}')plt.xlabel('State 1')plt.ylabel('State 2')plt.legend()plt.tight_layout()plt.show()if __name__ == "__main__":test_time_aware_il()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/85997.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/85997.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/85997.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

arm-none-eabi-ld: cannot find -lm

arm-none-eabi-ld -Tuser/hc32l13x.lds -o grbl_hc32l13x.elf user/interrupts_hc32l13x.o user/system_hc32l13x.o user/main.o user/startup_hc32l13x.o -lm -Mapgrbl_hc32l13x.map arm-none-eabi-ld: cannot find -lm makefile:33: recipe for target link failed 改為在gcc…

【Python辦公】Excel文件批量樣式修改器

目錄 專欄導讀1. 背景介紹2. 項目概述3. 庫的安裝4. 核心架構設計① 類結構設計② 數據模型1) 文件管理2) 樣式配置5. 界面設計與實現① 布局結構② 動態組件生成6. 核心功能實現① 文件選擇與管理② 顏色選擇功能③ Excel文件處理核心邏輯完整代碼結尾專欄導讀 ?? 歡迎來到P…

QT的一些介紹

//雖然下面一行代碼進行widget和ui的窗口關聯&#xff0c;但是如果發生窗口大小變化的時候&#xff0c;里面的布局不會隨之變化ui->setupUi(this);//通過下面這行代碼進行顯示說明&#xff0c;讓窗口變化時&#xff0c;布局及其子控件隨之變化this->setLayout(ui->ver…

RISC-V向量擴展與GPU協處理:開源加速器設計新范式——對比NVDLA與香山架構的指令集融合方案

點擊 “AladdinEdu&#xff0c;同學們用得起的【H卡】算力平臺”&#xff0c;H卡級別算力&#xff0c;按量計費&#xff0c;靈活彈性&#xff0c;頂級配置&#xff0c;學生專屬優惠 當開源指令集遇上異構計算&#xff0c;RISC-V向量擴展&#xff08;RVV&#xff09;正重塑加速…

自動恢復網絡路由配置的安全腳本說明

背景 兩個文章 看了就明白 Ubuntu 多網卡路由配置筆記&#xff08;內網 外網同時通 可能有問題&#xff0c;以防萬一可以按照個來恢復 sudo ip route replace 192.168.1.0/24 dev eno8403 proto kernel scope link src <你的IP>或者恢復腳本! 如下 誤操作路由時&…

創建 Vue 3.0 項目的兩種方法對比:npm init vue@latest vs npm init vite@latest

創建 Vue 3.0 項目的兩種方法對比&#xff1a;npm init vuelatest vs npm init vitelatest Vue 3.0 作為當前主流的前端框架&#xff0c;官方提供了多種項目創建方式。本文將詳細介紹兩種最常用的創建方法&#xff1a;Vue CLI 方式 (npm init vuelatest) 和 Vite 方式 (npm in…

Java求職者面試指南:Spring, Spring Boot, Spring MVC, MyBatis技術點深度解析

Java求職者面試指南&#xff1a;Spring, Spring Boot, Spring MVC, MyBatis技術點深度解析 面試官與程序員JY的三輪提問 第一輪&#xff1a;基礎概念問題 1. 請解釋一下Spring框架的核心容器是什么&#xff1f;它有哪些主要功能&#xff1f; JY回答&#xff1a;Spring框架的…

【修復MySQL 主從Last_Errno:1051報錯的幾種解決方案】

當MySQL主從集群遇到Last_Errno:1051報錯后不要著急&#xff0c;主要有三種解決方案&#xff1a; 方案1: 使用GTID場景&#xff1a; mysql> STOP SLAVE;(2)設置事務號&#xff0c;事務號從Retrieved_Gtid_Set獲取 在session里設置gtid_next&#xff0c;即跳過這個GTID …

定位接口偶發超時的實戰分析:iOS抓包流程的完整復現

我們通常把“請求超時”歸結為網絡不穩定、服務器慢響應&#xff0c;但在一次產品灰度發布中&#xff0c;我們遇到的一個“偶發接口超時”問題完全打破了這些常規判斷。 這類Bug最大的問題不在于表現&#xff0c;而在于極難重現、不可預測、無法復盤。它不像邏輯Bug那樣能從代…

【網工】華為配置專題進階篇②

目錄 ■DHCP NAT BFD 策略路由 ▲掩碼與反掩碼總結 ▲綜合實驗 ■DHCP NAT BFD 策略路由 ▲掩碼與反掩碼總結 使用掩碼的場景&#xff1a;IP地址強相關 場景一&#xff1a;IP地址配置 ip address 192.168.1.1 255.255.255.0 或ip address 192.168.1.1 24 場景二&#x…

基于STM32電子密碼鎖

基于STM32電子密碼鎖 &#xff08;程序&#xff0b;原理圖&#xff0b;PCB&#xff0b;設計報告&#xff09; 功能介紹 具體功能&#xff1a; 1.正確輸入密碼前提下&#xff0c;開鎖并有正確提示&#xff1b; 2.錯誤輸入密碼情況下&#xff0c;蜂鳴器報警并短暫鎖定鍵盤&…

前端基礎知識CSS系列 - 14(CSS提高性能的方法)

一、前言 每一個網頁都離不開css&#xff0c;但是很多人又認為&#xff0c;css主要是用來完成頁面布局的&#xff0c;像一些細節或者優化&#xff0c;就不需要怎么考慮&#xff0c;實際上這種想法是不正確的 作為頁面渲染和內容展現的重要環節&#xff0c;css影響著用戶對整個…

判斷 NI Package Manager (NIPM) 版本與 LabVIEW 2019 兼容性

?判斷依據 1. 查閱 LabVIEW 2019 自述文件 LabVIEW 2019 自述文件中包含系統要求&#xff0c;可通過 NI 官網訪問。文件提到使用 NIPM 安裝&#xff0c;但未明確最低版本要求&#xff0c;需結合其他信息判斷。 2. 參考 NI 官方兼容性文檔 NI 官方文檔指出 LabVIEW 運行引擎與…

Django 安裝指南

Django 安裝指南 引言 Django 是一個高級的 Python Web 框架,用于快速開發安全且實用的網站。本文將詳細介紹如何在您的計算機上安裝 Django,以便您能夠開始使用這個強大的工具。 安裝前的準備 在開始安裝 Django 之前,請確保您的計算機滿足以下條件: 操作系統:Django…

Spring MVC參數綁定終極手冊:單多參對象集合JSON文件上傳精講

我們通過瀏覽器訪問不同的路徑&#xff0c;就是在發送不同的請求&#xff0c;在發送請求時&#xff0c;可能會帶一些參數&#xff0c;本文將介紹了Spring MVC中處理不同請求參數的多種方式 一、傳遞單個參數 接收單個參數&#xff0c;在Spring MVC中直接用方法中的參數就可以…

synchronized 做了哪些優化?

Java 中的 synchronized 關鍵字是保證線程安全的基本機制&#xff0c;隨著 JVM 的發展&#xff0c;它經歷了多次優化以提高性能。 1. 鎖升級機制&#xff08;鎖膨脹&#xff09; JDK 1.6 引入了偏向鎖→輕量級鎖→重量級鎖的升級機制&#xff0c;避免了一開始就使用重量級鎖&…

三甲醫院AI醫療樣本數據集分類與收集全流程節點分析(下)

3.3 典型案例分析 —— 以某三甲醫院為例 為了更深入地了解三甲醫院 AI 醫療樣本數據收集的實際情況,本研究選取了具有代表性的某三甲醫院作為案例進行詳細分析。該醫院作為區域醫療中心,在醫療技術、設備和人才方面具有顯著優勢,同時在醫療信息化建設和 AI 應用方面也進行…

設置程序開機自動啟動

在Windows系統中&#xff0c;有幾種方法可以將程序設置為開機自動啟動。下面我將介紹最常用的三種方法&#xff0c;并提供一個C#實現示例。 方法一&#xff1a;使用啟動文件夾&#xff08;最簡單&#xff09; 按下 Win R 鍵打開運行對話框 輸入 shell:startup 并回車 將你的…

多源異構數據接入與實時分析:衡石科技的技術突破

在數字化轉型的浪潮中&#xff0c;企業每天產生的數據量呈指數級增長。這些數據來自CRM系統、IoT設備、日志文件、社交媒體、交易平臺等眾多源頭&#xff0c;格式各異、結構混亂、流速不一。傳統的數據處理方式如同在無數孤立的島嶼間劃著小船傳遞信息&#xff0c;效率低下且無…

JVM——Synchronized:同步鎖的原理及應用

引入 在多線程編程的世界里&#xff0c;共享資源的訪問控制就像一場精心設計的交通管制&#xff0c;而Synchronized作為Java并發編程的基礎同步機制&#xff0c;扮演著"交通警察"的關鍵角色。 并發編程的核心矛盾 當多個線程同時訪問共享資源時&#xff0c;"…