rust-candle學習筆記11-實現一個簡單的自注意力

參考:about-pytorch

定義ScaledDotProductAttention結構體:

use candle_core::{Result, Device, Tensor};
use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};struct ScaledDotProductAttention {wq: Linear,wk: Linear,wv: Linear,d_model: Tensor,device: Device,
}

為ScaledDotProductAttention結構體實現new方法:

impl ScaledDotProductAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { wq: linear_no_bias(embedding_dim, out_dim, vb.pp("wq"))?, wk: linear_no_bias(embedding_dim, out_dim, vb.pp("wk"))?, wv: linear_no_bias(embedding_dim, out_dim, vb.pp("wv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

為結構體實現Module的forward trait:

impl Module for ScaledDotProductAttention {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let q = self.wq.forward(xs)?;let k = self.wk.forward(xs)?;let v = self.wv.forward(xs)?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

融合qkv實現:

定義ScaledDotProductAttentionFusedQKV結構體:

struct ScaledDotProductAttentionFusedQKV {w_qkv: Linear,d_model: Tensor,device: Device,
}

為結構體實現new方法:

impl ScaledDotProductAttentionFusedQKV {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

為結構體實現forward trait:

impl Module for ScaledDotProductAttentionFusedQKV {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let qkv = self.w_qkv.forward(xs)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;let q = qkv.get_on_dim(2, 0)?;let q = q.reshape((batch_size, seq_len, ()))?;let k = qkv.get_on_dim(2, 1)?;let k = k.reshape((batch_size, seq_len, ()))?;let v = qkv.get_on_dim(2, 2)?;let v = v.reshape((batch_size, seq_len, ()))?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

測試:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;// let model = ScaledDotProductAttention::new(vb.clone(), 3, 2, device.clone())?;let model = ScaledDotProductAttentionFusedQKV::new(vb.clone(), 3, 2, device.clone())?;let output = model.forward(&input)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/80724.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/80724.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/80724.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

spark MySQL數據庫配置

Spark 連接 MySQL 數據庫的配置 要讓 Spark 與 MySQL 數據庫實現連接&#xff0c;需要進行以下配置步驟。下面為你提供詳細的操作指南和示例代碼&#xff1a; 1. 添加 MySQL JDBC 驅動依賴 你得把 MySQL 的 JDBC 驅動添加到 Spark 的類路徑中。可以通過以下兩種方式來完成&a…

web 自動化之 KDT 關鍵字驅動詳解

一、什么是關鍵字驅動&#xff1f; 1、什么是關鍵字驅動&#xff1f;&#xff08;以關鍵字函數驅動測試&#xff09; 關鍵字驅動又叫動作字驅動&#xff0c;把項目業務封裝成關鍵字函數&#xff0c;再基于關鍵字函數實現自動化測試 2、關鍵字驅動測試原理 關鍵字驅動測試是一…

Java使用POI+反射靈活的控制字段導出Excel

前端傳入哪些字段&#xff0c;后端就導出哪些到Excel表格中&#xff0c;具體代碼實現如下 controller /*** 用戶導出* param dto*/PostMapping("/exportUser")public void exportCharterOrder(RequestBody UserExportDTO dto){userService.exportUser(dto);} serv…

Qt/C++面試【速通筆記八】—Qt的事件處理機制

在Qt中&#xff0c;事件處理機制是應用程序與用戶或系統交互的核心。通過事件處理&#xff0c;Qt能夠響應用戶的輸入、窗口的變化、定時器的觸發等各種情況。 1. 事件循環&#xff08;Event Loop&#xff09; 在Qt應用程序中&#xff0c;事件循環是事件處理機制的基礎。事件循…

TTL (Time-To-Live) 解析

文章目錄 TTL (Time-To-Live) 解析&#xff1a;網絡與Java中的應用一、TTL的定義二、TTL在網絡中的應用1. **路由和數據包的生命周期**2. **DNS中的TTL**3. **防止環路** 三、TTL在Java中的應用1. **緩存管理**2. **Java中的ThreadLocal**3. **網絡通信中的TTL** 四、TTL的注意…

HDFS的客戶端操作(2)文件上傳

我們向/maven下上傳一個文件。 要用到的api是put (或者copyFormLocalFile&#xff09;。核心代碼如下。 public void testCopyFromLocalFile() throws IOException, InterruptedException, URISyntaxException {// 1 獲取文件系統Configuration configuration new Configurati…

光譜相機的光電信號轉換

光譜相機的光電信號轉換是將分光后的光學信息轉化為可處理的數字信號的核心環節&#xff0c;具體分為以下關鍵步驟&#xff1a; 一、分光后光信號接收與光電轉換 ?分光元件作用? 光柵/棱鏡/濾光片等分光元件將入射光分解為不同波長單色光&#xff0c;投射至探測器陣列表面…

網絡協議分析 實驗二 IP分片與IPv6

文章目錄 索引及重要內容實驗2 IP 高級實驗實驗2.1 IPv4協議分片實驗實驗2.2 IPV6協議實驗2.3 ARP初級 索引及重要內容 實驗2 IP 高級實驗 實驗2.1 IPv4協議分片實驗 icmp的不可達報文 實驗2.2 IPV6協議 實驗2.3 ARP初級 arp –a 查看ARP緩存表內容 arp –s IP地址(格式&…

20、map和set、unordered_map、un_ordered_set的復現

一、map 1、了解 map的使用和常考面試題等等&#xff0c;看這篇文章 map的key是有序的 &#xff0c;值不可重復 。插入使用 insert的效率更高&#xff0c;而在"更新map的鍵值對時&#xff0c;使用 [ ]運算符效率更高 。" 注意 map 的lower和upper那2個函數&#x…

基于 Amazon Bedrock 和 Amazon Connect 打造智能客服自助服務 – 設計篇

隨著 GenAI 技術不斷的發展和演進&#xff0c;人工智能技術廣泛地被應用在呼叫中心服務領域&#xff0c;主要包括虛擬坐席&#xff08;即自助服務&#xff09;、坐席助手和呼叫中心運營的數據洞察和智能分析。本博客主要針對自助服務應用場景的實現。 1. 傳統自助服務系統瓶頸 …

java高效實現爬蟲

一、前言 在Web爬蟲技術中&#xff0c;Selenium作為一款強大的瀏覽器自動化工具&#xff0c;能夠模擬真實用戶操作&#xff0c;有效應對JavaScript渲染、Ajax加載等復雜場景。而集成代理服務則能夠解決IP限制、地域訪問限制等問題。本文將詳細介紹如何利用JavaSelenium快代理實…

【計算機視覺】OpenCV實戰項目:基于OpenCV的車牌識別系統深度解析

基于OpenCV的車牌識別系統深度解析 1. 項目概述2. 技術原理與算法設計2.1 圖像預處理1) 自適應光照補償2) 邊緣增強 2.2 車牌定位1) 顏色空間篩選2) 形態學操作3) 輪廓分析 2.3 字符分割1) 投影分析2) 連通域篩選 2.4 字符識別 3. 實戰部署指南3.1 環境配置3.2 項目代碼解析 4.…

Python核心數據類型全解析:字符串、列表、元組、字典與集合

導讀&#xff1a; Python 是一門功能強大且靈活的編程語言&#xff0c;而其核心數據類型是構建高效程序的基礎。本文深入剖析了 Python 的五大核心數據類型——字符串、列表、元組、字典和集合&#xff0c;結合實際應用場景與最佳實踐&#xff0c;幫助讀者全面掌握這些數據類型…

GPT-4.1和GPT-4.1-mini系列模型支持微調功能,助力企業級智能應用深度契合業務需求

微軟繼不久前發布GPT-4.1系列模型后&#xff0c;Azure OpenAI服務&#xff08;國際版&#xff09;現已正式開放對GPT-4.1和GPT-4.1-mini的微調功能&#xff0c;并通過Azure AI Foundry&#xff08;國際版&#xff09;提供完整的部署和管理解決方案。這一重大升級標志著企業級AI…

構造+簡單樹狀

昨日的牛客周賽算是比較簡單的&#xff0c;其中最后一道構造題目屬實眼前一亮。 倒數第二個題目也是一個很好的模擬題目&#xff08;考驗對二叉樹的理解和代碼的細節&#xff09; 給定每一層的節點個數&#xff0c;自己擬定一個父親節點&#xff0c;構造一個滿足條件的二叉樹。…

apache2的默認html修改

使用127.0.0.1的時候&#xff0c;默認打開的是index.html&#xff0c;可以通過配置文件修改成我們想要的html vi /etc/apache2/mods-enabled/dir.conf <IfModule mod_dir.c>DirectoryIndex WS.html index.html index.cgi index.pl index.php index.xhtml index.htm <…

mysql性能提升方法大匯總

前言 最近在開發自己的小程序的時候&#xff0c;由于業務功能對系統性能的要求很高&#xff0c;系統性能損耗又主要在mysql上&#xff0c;而業務功能的數據表很多&#xff0c;單表數據量也很大&#xff0c;又涉及到很多場景的數據查詢&#xff0c;所以我針對mysql調用做了優化…

多模態RAG與LlamaIndex——1.deepresearch調研

摘要 關鍵點&#xff1a; 多模態RAG技術通過結合文本、圖像、表格和視頻等多種數據類型&#xff0c;擴展了傳統RAG&#xff08;檢索增強生成&#xff09;的功能。LlamaIndex是一個開源框架&#xff0c;支持多模態RAG&#xff0c;提供處理文本和圖像的模型、嵌入和索引功能。研…

LabVIEW中算法開發的系統化解決方案與優化

在 LabVIEW 開發環境中&#xff0c;算法實現是連接硬件數據采集與上層應用的核心環節。由于圖形化編程范式與傳統文本語言存在差異&#xff0c;LabVIEW 中的算法開發需要特別關注執行效率、代碼可維護性以及與硬件資源的適配性。本文從算法架構設計、性能優化到工程實現&#x…

OpenCV中的光流估計方法詳解

文章目錄 一、引言二、核心算法原理1. 光流法基本概念2. 算法實現步驟 三、代碼實現詳解1. 初始化設置2. 特征點檢測3. 光流計算與軌跡繪制 四、實際應用效果五、優化方向六、結語 一、引言 在計算機視覺領域&#xff0c;運動目標跟蹤是一個重要的研究方向&#xff0c;廣泛應用…