[machine learning] Transformer - Attention (一)

Attention是Transformer的核心,本系列先通過介紹Attention來學習Transformer。本文先介紹簡單版的Attention。

在Attention出現之前,通常使用recurrent neural networds (RNNs)來處理長序列數據。模型架構上,又通常使用encoder-decoder的結構。
以機器翻譯為例,當輸入文本序列一個一個進入encoder時,encoder也在一步一步地更新它的hidden state(即隱藏層的值)。通過這種方式,encoder在最后一次更新完hidden state后,盡可能多地把整個輸入文本序列的含義捕捉存儲到最終的hidden state中。decoder以encoder的最終hidden state作為輸入,一次一個字地開始翻譯。同樣,decoder也是一步一步地更新它的hidden state,每一次更新后的hidden state都包含了預測下一個字的必要上下文信息。下面的圖就是整個流程:
在這里插入圖片描述
整個流程的關鍵點是encoder將整個文本序列處理成最終的hidden state(memory cell)。decoder以encoder最終的hidden state為輸入,生成輸出。encoder-decoder RNN架構最大的問題和局限性是在decoding階段,RNN無法直接訪問encoder中比較靠前的hidden state。結果decoder只能基于包含了所有相關信息的最終hidden state。當遇到復雜的前后依賴距離跨度大的長句時,這個問題會導致上下文信息的丟失。

Transformer的提出解決了RNN的缺陷并逐漸取代了RNN在NLP領域的位置。而Transformer的核心正是self-attention。self-attention是這樣一套機制,當計算一個序列的表示時,它允許序列中每個位置上的元素去關注序列中其他位置(包括它本身)的元素。self-attention中的"self"是指這套機制能夠通過比較當前位置元素與其他所有位置元素的相關性來計算出關注度權重(attention weights)。它評估學習了輸入自己本身不同部分之間的聯系和依賴關系。

有了上面的基本概念之后,下面先看一個沒有訓練權重的simple self-attention。

假設現在有一個6個單詞的序列,每個單詞的embedding維度是3。

import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

現在我們想要計算第二個單詞跟其他所有單詞的attention weight。首先我們先計算attention score:

query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
# 輸出:
# tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

這里我們使用的點乘(dot product)。點乘是將兩個向量element-wise地相乘然后相加。點乘結果越大表明兩個向量越相關。這部分過程如下圖所示:
在這里插入圖片描述
接下來,我們對attention score做歸一化求attention weight。

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())
# 輸出:
# Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
# Sum: tensor(1.0000)

實際中,主要是用PyTorch自帶的softmax函數來做這個歸一化:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())
# 輸出:
# Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
# Sum: tensor(1.)

在這里插入圖片描述

有了attention weight之后,就是最后一步,求第二個單詞的上下文向量。

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)
# 輸出:
# tensor([0.4419, 0.6515, 0.5683])

在這里插入圖片描述

上面是計算第二個單詞的上下文向量,下面我們一次性地求所有單詞的上下文向量:

# @ 是矩陣乘法,inputs.T 是轉置操作
attn_scores = inputs @ inputs.T
print(attn_scores)
# tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
# [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
# [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
# [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
# [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
# [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])# 注意在第一維上做softmax
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)
# 可以看到每一行加起來都是1
# tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
# [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
# [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
# [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
# [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
# [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])all_context_vecs = attn_weights @ inputs
print(all_context_vecs)
# tensor([[0.4421, 0.5931, 0.5790],
# [0.4419, 0.6515, 0.5683],
# [0.4431, 0.6496, 0.5671],
# [0.4304, 0.6298, 0.5510],
# [0.4671, 0.5910, 0.5266],
# [0.4177, 0.6503, 0.5645]])

本文介紹了一種簡單版的self-attention,即通過兩個詞向量點乘求出兩個詞向量的相關性。下一篇文章,我們將介紹帶可訓練參數的self-attention。

參考資料:
《Build a Large Language Model from scratch》

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/81862.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/81862.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/81862.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Android 輸入控件事件使用示例

一 前端 <EditTextandroid:id="@+id/editTextText2"android:layout_width="match_parent"android:layout_height="wrap_content"android:ems="10"android:inputType="text"android:text="Name" />二 后臺代…

【向量數據庫】用披薩點餐解釋向量數據庫:一個美味的技術類比

文章目錄 前言場景設定&#xff1a;披薩特征向量化顧客到來&#xff1a;生成查詢向量相似度計算實戰1. 歐氏距離計算&#xff08;值越小越相似&#xff09;2. 余弦相似度計算&#xff08;值越大越相似&#xff09; 關鍵發現&#xff1a;度量選擇影響結果現實啟示結語 前言 想象…

人工智能和機器學習在包裝仿真中的應用與價值

引言 隨著包裝成為消費品關鍵的差異化因素&#xff0c;對智能設計、可持續性和高性能的要求比以往任何時候都更高 。為了滿足這些復雜的期望&#xff0c;公司越來越多地采用先進的仿真方法&#xff0c;而現在人工智能 (AI) 和機器學習 (ML) 又極大地增強了這些方法 。本文探討…

【人工智能】深入探索Python中的自然語言理解:實現實體識別系統

《Python OpenCV從菜鳥到高手》帶你進入圖像處理與計算機視覺的大門! 解鎖Python編程的無限可能:《奇妙的Python》帶你漫游代碼世界 自然語言理解(NLU)是人工智能(AI)領域中的重要研究方向之一,其目標是讓計算機理解和處理人類語言。在NLU的眾多應用中,實體識別(Nam…

個人健康中樞的多元化AI硬件革新與精準健康路徑探析

在醫療信息化領域,個人健康中樞正經歷著一場由硬件技術革新驅動的深刻變革。隨著可穿戴設備、傳感器技術和人工智能算法的快速發展,新一代健康監測硬件能夠采集前所未有的多維度生物數據,并通過智能分析提供精準的健康建議。本文將深入探討構成個人健康中樞的最新硬件技術,…

深入了解Linux系統—— 進程切換和調度

前言&#xff1a; 了解了進程的狀態和進程的優先級&#xff0c;我們現在來看進程是如何被CPU調度執行的。 在單CPU的系統在&#xff0c;程序是并發執行的&#xff1b;也就是說在一段時間呢&#xff0c;進程是輪番執行的&#xff1b; 這也是說一個進程在運行時不會一直占用CPU直…

阿里云服務遷移實戰: 06-切換DNS

概述 按前面的步驟&#xff0c;所有服務遷移完畢之后&#xff0c;最后就剩下 DNS 解析修改了。 修改解析 在域名解析處&#xff0c;修改域名的解析地址即可。 如果 IP 已經過戶到了新賬號&#xff0c;則不需要修改解析。 何確保業務穩定 域名解析更換時&#xff0c;由于 D…

uni-app 中封裝全局音頻播放器

在開發移動應用時&#xff0c;音頻播放功能是一個常見的需求。無論是背景音樂、音效還是語音消息&#xff0c;音頻播放都需要一個穩定且易于管理的解決方案。在 uni-app 中&#xff0c;雖然原生提供了 uni.createInnerAudioContext 方法用于音頻播放&#xff0c;但直接使用它可…

golang常用庫之-標準庫text/template

文章目錄 golang常用庫之-標準庫text/template背景什么是text/templatetext/template庫的使用 golang常用庫之-標準庫text/template 背景 在許多編程場景中&#xff0c;我們經常需要把數據按照某種格式進行輸出&#xff0c;比如生成HTML頁面&#xff0c;或者生成配置文件。這…

Linux btop 使用教程

簡介 btop 是一個基于終端的現代系統資源監控器&#xff0c;具有美觀的圖形界面、響應快、功能豐富等特點。它支持查看 CPU、內存、磁盤、網絡、進程&#xff0c;并可以方便地篩選和管理進程。 功能總覽 啟動命令&#xff1a; btop界面分為以下幾部分&#xff1a; CPU 區域…

Vue3調度器錯誤解析,完美解決Unhandled error during execution of scheduler flush.

目錄 Vue3調度器錯誤解析&#xff0c;完美解決Unhandled error during execution of scheduler flush. 一、問題現象與本質 二、七大高頻錯誤場景與解決方案 1、Setup初始化陷阱 2、模板中的"幽靈屬性" 3、異步操作的"定時炸彈" 4、組件嵌套黑洞 5…

使用DeepSeek定制Python小游戲——以“俄羅斯方塊”為例

前言 本來想再發幾個小游戲后在整理一下流程的&#xff0c;但是今天試了一下這個俄羅斯方塊的游戲結果發現本來修改的好好的的&#xff0c;結果后面越改越亂&#xff0c;前面的版本也沒保存&#xff0c;根據AI修改他是在幾個版本改來改去&#xff0c;想著要求還是不能這么高。…

Kotlin帶接收者的Lambda介紹和應用(封裝DialogFragment)

先來看一個具體應用&#xff1a;假設我們有一個App&#xff0c;App中有一個退出應用的按鈕&#xff0c;點擊該按鈕后并不是立即退出&#xff0c;而是先彈出一個對話框&#xff0c;詢問用戶是否確定要退出&#xff0c;用戶點了確定再退出&#xff0c;點取消則不退出&#xff0c;…

ES6/ES11知識點 續一

模板字符串 在 ECMAScript&#xff08;ES&#xff09;中&#xff0c;模板字符串&#xff08;Template Literals&#xff09;是一種非常強大的字符串表示方式&#xff0c;它為我們提供了比傳統字符串更靈活的功能&#xff0c;尤其是在處理動態內容時。模板字符串通過反引號&…

【C++】智能指針RALL實現shared_ptr

個人主頁 &#xff1a; zxctscl 專欄 【C】、 【C語言】、 【Linux】、 【數據結構】、 【算法】 如有轉載請先通知 文章目錄 1. 為什么需要智能指針&#xff1f;2. 內存泄漏2.1 什么是內存泄漏&#xff0c;內存泄漏的危害2.2 內存泄漏分類&#xff08;了解&#xff09;2.3 如何…

ROS2 開發踩坑記錄(持續更新...)

1. 從find_package(xxx REQUIRED)說起&#xff0c;如何引用其他package(包&#xff09; 查看包的安裝位置和include路徑詳細文件列表 例如&#xff0c;xxx包名為pluginlib # 查看 pluginlib 的安裝位置 dpkg -L ros-${ROS_DISTRO}-pluginlib | grep include 這條指令的目的是…

系統思考:困惑源于內心假設

不要懷疑&#xff0c;你的困惑來自你的假設。 你是否曾經陷入過無解的困境&#xff0c;覺得外部環境太復雜&#xff0c;自己的處境無法突破&#xff1f;很多時候&#xff0c;答案并不在于外部的局勢&#xff0c;而是來自我們內心深處的假設——那些我們理所當然、從未質疑過的…

GitHub修煉法則:第一次提交代碼教學(Liunx系統)

前言 github是廣大程序員們必須要掌握的一個技能&#xff0c;萬事開頭難&#xff0c;如果成功提交了第一次代碼&#xff0c;那么后來就會簡單很多。網上的相關資料往往都不是從第一次開始&#xff0c;導致很多新手們會在過程中遇到很多權限認證相關的問題&#xff0c;進而被卡…

瀝青路面裂縫的目標檢測與圖像分類任務

文章題目是《A grid‐based classification and box‐based detection fusion model for asphalt pavement crack》 于2023年發表在《Computer‐Aided Civil and Infrastructure Engineering》 論文采用了一種基于網格分類和基于框的檢測&#xff08;GCBD&#xff09;&#xff…

【Flask】ORM模型以及數據庫遷移的兩種方法(flask-migrate、Alembic)

ORM模型 在Flask中&#xff0c;ORM&#xff08;Object-Relational Mapping&#xff0c;對象關系映射&#xff09;模型是指使用面向對象的方式來操作數據庫的編程技術。它允許開發者使用Python類和對象來操作數據庫&#xff0c;而不需要直接編寫SQL語句。 核心概念 1. ORM模型…