第15周:注意力匯聚:Nadaraya-Watson 核回歸

注意力匯聚:Nadaraya-Watson 核回歸

Nadaraya-Watson 核回歸是一個經典的注意力機制模型,它展示了如何通過注意力權重來對輸入數據進行加權平均。以下是該內容的核心總結:

關鍵概念

  1. 注意力機制框架:由查詢(自主提示)、鍵(非自主提示)和值(感官輸入)組成,通過查詢和鍵的交互形成注意力權重,然后加權聚合值。
  2. Nadaraya-Watson核回歸
    • 非參數形式: f ( x ) = ∑ ( s o f t m a x ( ? ( x ? x i ) 2 / 2 ) ? y i ) \color{red}f(x) = ∑(softmax(-(x - x_i)2/2) * y_i) f(x)=(softmax(?(x?xi?)2/2)?yi?)
    • 參數形式:引入可學習參數 w w w f ( x ) = ∑ ( s o f t m a x ( ? ( ( x ? x i ) w ) 2 / 2 ) ? y i ) \color{red}f(x) = ∑(softmax(-((x - x_i)w)2/2) * y_i) f(x)=(softmax(?((x?xi?)w)2/2)?yi?)
  3. 核函數:使用高斯核來衡量查詢和鍵之間的相似度。

主要特點

  1. 非參數模型
    • 直接基于訓練數據進行預測
    • 具有一致性(隨著數據量增加會收斂到最優解)
    • 預測結果平滑
  2. 參數模型
    • 引入可學習參數w
    • 可以調整注意力權重的分布
    • 預測結果可能不如非參數模型平滑
  3. 注意力權重可視化:展示了查詢與鍵之間的關系,距離越近權重越高。

實現要點

  1. 使用批量矩陣乘法高效計算小批量數據的注意力權重
  2. 通過softmax計算歸一化的注意力權重
  3. 訓練時使用平方損失和隨機梯度下降

應用意義

Nadaraya-Watson核回歸提供了一個簡單但完整的例子,展示了注意力機制如何通過加權平均的方式選擇性地聚焦于相關的輸入數據。這種注意力匯聚的思想是現代注意力機制的基礎,后續發展出了更復雜的注意力評分函數和模型結構。

這個模型清楚地演示了注意力機制的核心思想:根據查詢與鍵的相似度來決定對相應值的關注程度,從而實現對輸入數據的有選擇性的聚合。

Nadaraya-Watson 核回歸示例

以下為完整的代碼示例Nadaraya-Watson核回歸的實現和應用,包括非參數和帶參數兩種形式。

1. 生成數據集

首先我們生成一個非線性數據集,加入一些噪聲:

import numpy as np
import matplotlib.pyplot as plt# 生成訓練數據
n_train = 50
x_train = np.sort(np.random.rand(n_train) * 5)
def f(x):return 2 * np.sin(x) + x**0.8y_train = f(x_train) + np.random.normal(0.0, 0.5, n_train)  # 添加噪聲# 生成測試數據
x_test = np.arange(0, 5, 0.1)
y_true = f(x_test)  # 真實函數值# 繪制數據
plt.figure(figsize=(10, 5))
plt.scatter(x_train, y_train, label='Training data', color='blue', alpha=0.5)
plt.plot(x_test, y_true, label='True function', color='green', linewidth=2)
plt.legend()
plt.title('Generated Dataset')
plt.show()

在這里插入圖片描述

2. 非參數Nadaraya-Watson核回歸實現

def nadaraya_watson(x_query, x_keys, y_values, bandwidth=1.0):"""非參數Nadaraya-Watson核回歸:param x_query: 查詢點:param x_keys: 訓練數據鍵:param y_values: 訓練數據值:param bandwidth: 核帶寬:return: 預測值"""predictions = []for x in x_query:# 計算高斯核權重weights = np.exp(-0.5 * ((x - x_keys) / bandwidth)**2)# 歸一化權重weights /= np.sum(weights)# 加權平均prediction = np.sum(weights * y_values)predictions.append(prediction)return np.array(predictions)# 使用不同帶寬進行預測
bandwidths = [0.1, 0.5, 1.0]
plt.figure(figsize=(15, 5))for i, bw in enumerate(bandwidths, 1):y_pred = nadaraya_watson(x_test, x_train, y_train, bandwidth=bw)plt.subplot(1, 3, i)plt.scatter(x_train, y_train, color='blue', alpha=0.3)plt.plot(x_test, y_true, label='True', color='green')plt.plot(x_test, y_pred, label=f'Pred (bw={bw})', color='red')plt.legend()plt.title(f'Bandwidth = {bw}')plt.tight_layout()
plt.show()

在這里插入圖片描述

3. 帶參數Nadaraya-Watson核回歸實現

class ParametricNWKernelRegression:def __init__(self, learning_rate=0.1, n_epochs=100):self.w = None  # 可學習參數self.lr = learning_rateself.epochs = n_epochsdef fit(self, x_train, y_train):# 初始化參數self.w = np.random.randn(1)# 訓練過程losses = []for epoch in range(self.epochs):# 前向傳播weights = np.exp(-0.5 * (self.w * (x_train[:, None] - x_train[None, :]))**2)weights /= np.sum(weights, axis=1, keepdims=True)y_pred = np.sum(weights * y_train[None, :], axis=1)# 計算損失loss = np.mean((y_pred - y_train)**2)losses.append(loss)# 反向傳播# (這里簡化了梯度計算,實際實現可能需要更精確的梯度)grad = np.random.randn(1) * 0.1  # 簡化的梯度self.w -= self.lr * gradif epoch % 10 == 0:print(f'Epoch {epoch}, Loss: {loss:.4f}')return lossesdef predict(self, x_query, x_keys, y_values):weights = np.exp(-0.5 * (self.w * (x_query[:, None] - x_keys[None, :]))**2)weights /= np.sum(weights, axis=1, keepdims=True)return np.sum(weights * y_values[None, :], axis=1)# 訓練帶參數模型
model = ParametricNWKernelRegression(learning_rate=0.1, n_epochs=100)
losses = model.fit(x_train, y_train)# 預測并繪制結果
y_pred_param = model.predict(x_test, x_train, y_train)plt.figure(figsize=(10, 5))
plt.scatter(x_train, y_train, color='blue', alpha=0.3, label='Training data')
plt.plot(x_test, y_true, label='True function', color='green')
plt.plot(x_test, y_pred_param, label='Parametric NW', color='red')
plt.legend()
plt.title('Parametric Nadaraya-Watson Regression')
plt.show()# 繪制訓練損失
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

在這里插入圖片描述
在這里插入圖片描述

4. 注意力權重可視化

# 計算注意力權重
def compute_attention(x_query, x_keys, w=1.0):weights = np.exp(-0.5 * (w * (x_query[:, None] - x_keys[None, :]))**2)weights /= np.sum(weights, axis=1, keepdims=True)return weights# 非參數模型注意力權重
attn_nonparam = compute_attention(x_test, x_train)# 帶參數模型注意力權重
attn_param = compute_attention(x_test, x_train, w=model.w)# 可視化
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)
plt.imshow(attn_nonparam, cmap='Reds', aspect='auto')
plt.colorbar()
plt.title('Non-parametric Attention Weights')
plt.xlabel('Training samples')
plt.ylabel('Test samples')plt.subplot(1, 2, 2)
plt.imshow(attn_param, cmap='Reds', aspect='auto')
plt.colorbar()
plt.title('Parametric Attention Weights')
plt.xlabel('Training samples')
plt.ylabel('Test samples')plt.tight_layout()
plt.show()

在這里插入圖片描述

注意

  1. 帶寬影響:在非參數模型中,帶寬參數控制著平滑程度:
    • 小帶寬(0.1)導致過擬合,預測曲線波動大
    • 大帶寬(1.0)導致欠擬合,預測曲線過于平滑
    • 中等帶寬(0.5)通常效果最好
  2. 參數模型:通過學習參數w,模型可以自動調整注意力權重的分布:
    • 通常比固定帶寬的非參數模型更靈活
    • 但需要足夠的訓練數據來學習合適的參數
  3. 注意力模式:從注意力權重圖中可以看到:
    • 查詢點附近的鍵會獲得更高的注意力權重
    • 參數模型通常會學習到更集中的注意力分布

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/900554.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/900554.shtml
英文地址,請注明出處:http://en.pswp.cn/news/900554.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

adb devices報錯 ADB server didn‘t ACK

ubuntu下連接手機首次使用adb devices 報錯ADB server didn’t ACK adb devices * daemon not running; starting now at tcp:5037 ADB server didnt ACK Full server startup log: /tmp/adb.1000.log Server had pid: 52986 --- adb starting (pid 52986) --- 04-03 17:23:23…

Mac下Homebrew的安裝與使用

Mac下Homebrew的安裝與使用 一蓑煙羽 關注 2017.10.19 11:59* 字數 515 閱讀 7684評論 0喜歡 3 Homebrew簡介,安裝與使用 簡介 Homebrew 官方網站 Homebrew是一個包管理器,用于安裝Apple沒有預裝但你需要的UNIX工具。(比如著名的wget&am…

非常適合做后臺項目的go腳手架

分享一個非常適合做后臺腳手架的go項目,該項目使用gin作為mvc框架搭建。她就是Gin-vue-admin。該一個基于 vue 和 gin 開發的全棧前后端分離的開發基礎平臺,集成jwt鑒權,動態路由,動態菜單,casbin鑒權,表單…

優化 Django 數據庫查詢

優化 Django 數據庫查詢 推薦超級課程: 本地離線DeepSeek AI方案部署實戰教程【完全版】Docker快速入門到精通Kubernetes入門到大師通關課AWS云服務快速入門實戰目錄 優化 Django 數據庫查詢**理解 N+1 查詢問題****`select_related`:外鍵的急加載**示例何時使用 `select_re…

大數據(5)Spark部署核彈級避坑指南:從高并發集群調優到源碼級安全加固(附萬億級日志分析實戰+智能運維巡檢系統)

目錄 背景一、Spark核心架構拆解1. 分布式計算五層模型 二、五步軍工級部署階段1:環境核彈級校驗階段2:集群拓撲構建階段3:黃金配置模板階段4:高可用啟停階段5:安全加固方案 三、萬億級日志分析實戰1. 案例背景&#x…

【學Rust寫CAD】36 顏色插值函數(alpha256.rs補充方法)

源碼 pub fn alpha_lerp(self,src: Argb, dst: Argb, clip: u32) -> Argb {self.alpha_mul_256(clip).lerp(src, dst)}這個函數 alpha_lerp 是一個顏色插值(線性插值,lerp)函數,它結合了透明度混合(alpha_mul_256&…

解決Ubuntu系統鼠標不流暢的問題

電腦是聯想的臺式組裝機,安裝ubuntu系統(不管是16、18、20、22)后,鼠標都不流暢。最近幾天想解決這個問題,于是懷疑到了顯卡驅動上。懷疑之前一直用的是集成顯卡,而不是獨立顯卡,畢竟2060的顯卡…

oracle asm 相關命令和查詢視圖

有關asm磁盤的命令 添加磁盤 alter diskgroup data1 add disk /devices/diska*;---runs with a rebalance power of 5 , and dose not return until the rebalance operation is completealter diskgroup data1 add disk /devices/diskd* rebalance power 5 wait;查詢 select …

C++基于rapidjson的Json與結構體互相轉換

簡介 使用rapidjson庫進行封裝,實現了使用C對結構體數據和json字符串進行互相轉換的功能。最短只需要使用兩行代碼即可無痛完成結構體數據轉換為Json字符串。 支持std::string、數組、POD數據(int,float,double等)、std::vector、嵌套結構體…

Python爬蟲HTTP代理使用教程:突破反爬的實戰指南

目錄 一、代理原理:給爬蟲穿上"隱身衣" 二、代理類型選擇指南 三、代碼實戰:三行代碼實現代理設置 四、代理池管理:打造智能IP倉庫 代理驗證機制 動態切換策略 自動重試裝飾器 五、反反爬對抗技巧 請求頭偽裝 訪問頻率控…

STM32江科大----IIC

聲明:本人跟隨b站江科大學習,本文章是觀看完視頻后的一些個人總結和經驗分享,也同時為了方便日后的復習,如果有錯誤請各位大佬指出,如果對你有幫助可以點個贊小小鼓勵一下,本文章建議配合原視頻使用?? 如…

使用 React 和 Konva 實現一個在線畫板組件

文章目錄 一、前言二、Konva.js 介紹三、創建 React 畫板項目3.1 安裝依賴3.2 創建 CanvasBoard 組件 四、增加畫布控制功能4.1 清空畫布4.2 撤銷 & 重做功能 五、增加顏色和畫筆大小選擇5.1 選擇顏色5.2 選擇畫筆大小 六、最終效果七、總結 一、前言 在線畫板是許多應用&…

服務器配置虛擬IP

服務器配置虛擬IP的核心步驟取決于具體場景,主要包括本地單機多IP配置和高可用集群下的虛擬IP管理兩種模式。? 一、本地虛擬IP配置(單服務器多IP) ?基于Linux系統?: ?確認網絡接口?:使用 ip addr 或 ifconfig 查…

C++ —— 文件操作(流式操作)

C —— 文件操作(流式操作) ofstream文件創建文件寫入 ofstream 文件打開模式std::ios::out 寫入模式std::ios::app 追加模式std::ios::trunc 截斷std::ios::binary 二進制std::ios::ate at the end模式 ifstreamstd::ios::in 讀取模式(默認&…

【Cursor】打開Vscode設置

在這里打開設置界面 打開設置json

智能指針和STL庫學習思維導圖和練習

思維導圖&#xff1a; #include <iostream> #include <vector> #include <string> using namespace std;// 用戶結構體 struct User {string username;string password; };vector<User> users; // 存儲所有注冊用戶// 使用迭代器查找用戶名是否存在 ve…

前端工具方法整理

文章目錄 1.在數組中找到匹配項&#xff0c;然后創建新對象2.對象轉JSON字符串3.JSON字符串轉JSON對象4.有個響應式對象&#xff0c;然后想清空所有屬性5.判斷參數不為空6.格式化字符串7.解析數組內容用逗號拼接8.刷新整個頁面 1.在數組中找到匹配項&#xff0c;然后創建新對象…

狀態空間建模與極點配置 —— 理論、案例與交互式 GUI 實現

目錄 狀態空間建模與極點配置 —— 理論、案例與交互式 GUI 實現一、引言二、狀態空間建模的基本理論2.1 狀態空間模型的優勢2.2 狀態空間模型的物理意義三、極點配置的理論與方法3.1 閉環系統的狀態反饋3.2 極點配置條件與方法3.3 設計流程四、狀態空間建模與極點配置的優缺點…

仿modou庫one thread one loop式并發服務器

源碼&#xff1a;田某super/moduo 目錄 SERVER模塊&#xff1a; Buffer模塊&#xff1a; Socket模塊&#xff1a; Channel模塊&#xff1a; Connection模塊&#xff1a; Acceptor模塊&#xff1a; TimerQueue模塊&#xff1a; Poller模塊&#xff1a; EventLoop模塊&a…

Oracle中的UNION原理

Oracle中的UNION操作用于合并多個SELECT語句的結果集&#xff0c;并自動去除重復行。其核心原理可分為以下幾個步驟&#xff1a; 1. 執行各個子查詢 每個SELECT語句獨立執行&#xff0c;生成各自的結果集。 如果子查詢包含過濾條件&#xff08;如WHERE&#xff09;、排序&…