[深度學習] Transformer

Transformer是一種深度學習模型,最早由Vaswani等人在2017年的論文《Attention is All You Need》中提出。它最初用于自然語言處理(NLP)任務,但其架構的靈活性使其在許多其他領域也表現出色,如計算機視覺、時間序列分析等。以下是對Transformer模型的詳細介紹。

一、基本結構

Transformer模型主要由兩個部分組成:編碼器(Encoder)和解碼器(Decoder)。

編碼器(Encoder)
  • 輸入嵌入(Input Embedding):將輸入的詞匯轉換為高維向量表示。
  • 位置編碼(Positional Encoding):由于Transformer沒有循環結構或卷積結構,因此需要顯式地加入位置信息。位置編碼可以幫助模型了解序列中各個詞匯的位置。
  • 多頭自注意力機制(Multi-Head Self-Attention):自注意力機制可以捕捉序列中不同位置之間的依賴關系。多頭機制允許模型關注不同的子空間。
  • 前饋神經網絡(Feed-Forward Neural Network):兩個線性變換和一個ReLU激活函數,獨立地應用于每個位置。
  • 層歸一化(Layer Normalization)殘差連接(Residual Connection):每個子層的輸出都進行層歸一化,并通過殘差連接加入子層輸入。

編碼器包含多個(通常是6個)這樣的子層堆疊。

解碼器(Decoder)

解碼器的結構與編碼器類似,但增加了一個用于接收編碼器輸出的注意力層。

  • 輸入嵌入、位置編碼、多頭自注意力機制、前饋神經網絡、層歸一化和殘差連接:與編碼器相同。
  • 掩碼多頭自注意力機制(Masked Multi-Head Self-Attention):防止解碼器當前位置注意到未來位置的信息。
  • 編碼器-解碼器注意力機制(Encoder-Decoder Attention):使解碼器能關注編碼器的輸出,從而將編碼器捕捉到的上下文信息用于生成目標序列。

解碼器也包含多個(通常是6個)這樣的子層堆疊。

在這里插入圖片描述

二、詳細機制

注意力機制(Attention Mechanism)

自注意力機制是Transformer的核心。它的計算過程如下:

  1. 計算查詢(Query)、鍵(Key)、值(Value)矩陣
    在這里插入圖片描述
    其中,X 是輸入序列,WQ、WK、WV是可訓練的權重矩陣。

  2. 計算注意力分數
    在這里插入圖片描述
    其中:

    • dk是鍵向量的維度。
    • KT 是鍵矩陣的轉置
  3. 多頭機制
    多頭注意力機制將輸入映射到多個子空間,通過多個注意力頭來捕捉不同的特征。然后將這些頭的輸出連接起來:
    在這里插入圖片描述
    其中,每個頭是獨立的注意力機制,WO 是可訓練的線性變換矩陣。

三、Transformer的優點

  1. 并行計算:不同于RNN的序列處理方式,Transformer允許并行計算,提高了訓練速度。
  2. 長程依賴:通過自注意力機制,Transformer能夠直接捕捉序列中任意位置之間的依賴關系。
  3. 靈活性:Transformer架構可以輕松擴展到不同任務,如語言翻譯、文本生成、圖像處理等。

四、變種和改進

自從Transformer被提出以來,已經出現了許多改進和變種,例如:

  • BERT(Bidirectional Encoder Representations from Transformers):雙向編碼器,適用于多種NLP任務。
  • GPT(Generative Pre-trained Transformer):生成模型,專注于文本生成任務。
  • T5(Text-to-Text Transfer Transformer):將所有NLP任務統一為文本到文本的形式。
  • Vision Transformer(ViT):將Transformer應用于圖像分類任務。

五、應用領域

Transformer模型在以下領域表現出色:

  1. 自然語言處理(NLP):如機器翻譯、文本生成、問答系統等。
  2. 計算機視覺:如圖像分類、目標檢測等。
  3. 時間序列分析:如股票預測、天氣預報等。
  4. 推薦系統:通過捕捉用戶與物品之間的復雜關系來提供個性化推薦。

六、代碼示例

以下是一個使用TensorFlow實現簡單Transformer的代碼示例:

import tensorflow as tf
import numpy as np# 注意力機制
def scaled_dot_product_attention(q, k, v, mask):matmul_qk = tf.matmul(q, k, transpose_b=True)dk = tf.cast(tf.shape(k)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)if mask is not None:scaled_attention_logits += (mask * -1e9)attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)output = tf.matmul(attention_weights, v)return output, attention_weights# 多頭注意力
class MultiHeadAttention(tf.keras.layers.Layer):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelassert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = tf.keras.layers.Dense(d_model)self.wk = tf.keras.layers.Dense(d_model)self.wv = tf.keras.layers.Dense(d_model)self.dense = tf.keras.layers.Dense(d_model)def split_heads(self, x, batch_size):x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))return tf.transpose(x, perm=[0, 2, 1, 3])def call(self, v, k, q, mask):batch_size = tf.shape(q)[0]q = self.wq(q)k = self.wk(k)v = self.wv(v)q = self.split_heads(q, batch_size)k = self.split_heads(k, batch_size)v = self.split_heads(v, batch_size)scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))output = self.dense(concat_attention)return output, attention_weights# 前饋神經網絡
def point_wise_feed_forward_network(d_model, dff):return tf.keras.Sequential([tf.keras.layers.Dense(dff, activation='relu'),tf.keras.layers.Dense(d_model)])# 編碼器層
class EncoderLayer(tf.keras.layers.Layer):def __init__(self, d_model, num_heads, dff, rate=0.1):super(EncoderLayer, self).__init__()self.mha = MultiHeadAttention(d_model, num_heads)self.ffn = point_wise_feed_forward_network(d_model, dff)self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)self.dropout1 = tf.keras.layers.Dropout(rate)self.dropout2 = tf.keras.layers.Dropout(rate)def call(self, x, training, mask):attn_output, _ = self.mha(x, x, x, mask)attn_output = self.dropout1(attn_output, training=training)out1 = self.layernorm1(x + attn_output)ffn_output = self.ffn(out1)ffn_output = self.dropout2(ffn_output, training=training)out2 = self.layernorm2(out1 + ffn_output)return out2# 解碼器層
class DecoderLayer(tf.keras.layers.Layer):def __init__(self, d_model, num_heads, dff, rate=0.1):super(DecoderLayer, self).__init__()self.mha1 = MultiHeadAttention(d_model, num_heads)self.mha2 = MultiHeadAttention(d_model, num_heads)self.ffn = point_wise_feed_forward_network(d_model, dff)self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)self.dropout1 = tf.keras.layers.Dropout(rate)self.dropout2 = tf.keras.layers.Dropout(rate)self.dropout3 = tf.keras.layers.Dropout(rate)def call(self, x, enc_output, training, look_ahead_mask, padding_mask):attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)attn1 = self.dropout1(attn1, training=training)out1 = self.layernorm1(x + attn1)attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask)attn2 = self.dropout2(attn2, training=training)out2 = self.layernorm2(out1 + attn2)ffn_output = self.ffn(out2)ffn_output = self.dropout3(ffn_output, training=training)out3 = self.layernorm3(out2 + ffn_output)return out3, attn_weights_block1, attn_weights_block2# 編碼器
class Encoder(tf.keras.layers.Layer):def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, rate=0.1):super(Encoder, self).__init__()self.d_model = d_modelself.num_layers = num_layersself.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)self.pos_encoding = positional_encoding(1000, self.d_model)self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]self.dropout = tf.keras.layers.Dropout(rate)def call(self, x, training, mask):seq_len = tf.shape(x)[1]x = self.embedding(x)x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))x += self.pos_encoding[:, :seq_len, :]x = self.dropout(x, training=training)for i in range(self.num_layers):x = self.enc_layers[i](x, training, mask)return x# 解碼器
class Decoder(tf.keras.layers.Layer):def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, rate=0.1):super(Decoder, self).__init__()self.d_model = d_modelself.num_layers = num_layersself.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)self.pos_encoding = positional_encoding(1000, self.d_model)self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]self.dropout = tf.keras.layers.Dropout(rate)def call(self, x, enc_output, training, look_ahead_mask, padding_mask):seq_len = tf.shape(x)[1]attention_weights = {}x = self.embedding(x)x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))x += self.pos_encoding[:, :seq_len, :]x = self.dropout(x, training=training)for i in range(self.num_layers):x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)attention_weights[f'decoder_layer{i+1}_block1'] = block1attention_weights[f'decoder_layer{i+1}_block2'] = block2return x, attention_weights# Transformer模型
class Transformer(tf.keras.Model):def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, rate=0.1):super(Transformer, self).__init__()self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, rate)self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, rate)self.final_layer = tf.keras.layers.Dense(target_vocab_size)def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):enc_output = self.encoder(inp, training, enc_padding_mask)dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)final_output = self.final_layer(dec_output)return final_output, attention_weights# 位置編碼
def positional_encoding(position, d_model):angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model)angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])pos_encoding = angle_rads[np.newaxis, ...]return tf.cast(pos_encoding, dtype=tf.float32)def get_angles(pos, i, d_model):angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))return pos * angle_rates# 掩碼
def create_padding_mask(seq):seq = tf.cast(tf.math.equal(seq, 0), tf.float32)return seq[:, tf.newaxis, tf.newaxis, :]def create_look_ahead_mask(size):mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)return mask# 超參數
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
input_vocab_size = 8500
target_vocab_size = 8000
dropout_rate = 0.1# 創建Transformer模型
transformer = Transformer(num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, dropout_rate)# 損失函數和優化器
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def loss_function(real, pred):mask = tf.math.logical_not(tf.math.equal(real, 0))loss_ = loss_object(real, pred)mask = tf.cast(mask, dtype=loss_.dtype)loss_ *= maskreturn tf.reduce_sum(loss_) / tf.reduce_sum(mask)learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(1e-4, decay_steps=100000, decay_rate=0.9, staircase=True)
optimizer = tf.keras.optimizers.Adam(learning_rate)# 編譯模型
transformer.compile(optimizer=optimizer, loss=loss_function)# 示例輸入
sample_input = tf.constant([[1, 2, 3, 4, 0, 0]])
sample_target = tf.constant([[1, 2, 3, 4, 0, 0]])# 訓練模型
transformer.fit([sample_input, sample_target], epochs=10)
解釋
  1. 注意力機制:定義了計算注意力權重的函數和多頭注意力機制。
  2. 前饋神經網絡:實現了前饋神經網絡的部分。
  3. 編碼器和解碼器層:定義了編碼器和解碼器的基本層。
  4. 編碼器和解碼器:實現了編碼器和解碼器的堆疊。
  5. Transformer模型:集成了編碼器和解碼器,定義了完整的Transformer模型。
  6. 位置編碼:為輸入序列添加位置信息。
  7. 掩碼:定義了填充掩碼和前瞻掩碼,用于處理輸入和目標序列中的填充和防止信息泄露。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/37870.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/37870.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/37870.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MySQL高級-SQL優化- limit優化(覆蓋索引加子查詢)

文章目錄 0、limit 優化0.1、從表 tb_sku 中按照 id 列進行排序,然后跳過前 9000000 條記錄0.2、通過子查詢獲取按照 id 排序后的第 9000000 條開始的 10 條記錄的 id 值,然后在原表中根據這些 id 值獲取對應的完整記錄 1、上傳5個sql文件到 /root2、查看…

libctk shared library的設計及編碼實踐記錄

一、引言 1.1 <libctk>的由來 1.2 <libctk>的設計理論依據 1.3 <libctk>的設計理念 二、<libctk>的依賴庫 三、<libctk>的目錄說明 四、<libctk>的功能模塊及使用實例說明 4.1 日志模塊 4.2 mysql client模塊 4.3 ftp client模塊 4…

鴻蒙開發設備管理:【@ohos.geolocation (位置服務)】

位置服務 說明&#xff1a; 本模塊首批接口從API version 7開始支持。后續版本的新增接口&#xff0c;采用上角標單獨標記接口的起始版本。 導入模塊 import geolocation from ohos.geolocation;geolocation.on(‘locationChange’) on(type: ‘locationChange’, request: L…

安卓開發自定義時間日期顯示組件

安卓開發自定義時間日期顯示組件 問題背景 實現時間和日期顯示&#xff0c;左對齊和對齊兩種效果&#xff0c;如下圖所示&#xff1a; 問題分析 自定義view實現一般思路&#xff1a; &#xff08;1&#xff09;自定義一個View &#xff08;2&#xff09;編寫values/attrs.…

poi-tl 生成 word 文件(插入文字、圖片、表格、圖表)

文章說明 本篇文章主要通過代碼案例的方式&#xff0c;展示 poi-tl 生成 docx 文件的一些常用操作&#xff0c;主要涵蓋以下內容 &#xff1a; 插入文本字符&#xff08;含樣式、超鏈接&#xff09;插入圖片插入表格引入標簽&#xff08;通過可選文字的方式&#xff0c;這種方…

俄羅斯防空系統

俄羅斯的S系列防空系統是一系列先進的地對空導彈系統&#xff0c;旨在防御各類空中威脅&#xff0c;包括飛機、無人機、巡航導彈和彈道導彈。以下是幾種主要的S系列防空系統&#xff1a; 1. **S-300系統**&#xff1a; - **S-300P**&#xff1a;最早期的版本&#xff0c;用…

翻譯造句練習

翻譯練習 翻譯 1&#xff1a;經常做運動會提高人的自信 翻譯 2&#xff1a;教學的質量對學生成績有很大的影響。 翻譯 3&#xff1a;家長和老師應該努力去減少小孩看電視的時間。 翻譯 4&#xff1a;經濟的下滑&#xff08;economic slowdown&#xff09;導致失業率的上升 翻譯…

大模型和數據庫最新結合進展

寫在前面 本文主要內容是上次接受 infoQ 訪談&#xff0c;百度智能云朱潔老師介紹了大模型和 AI 結合相關話題&#xff0c;這次整體再刷新下&#xff0c;給到對這個領域感興趣的同學。 當前&#xff0c;百度智能云云數據庫特惠專場開始&#xff01;熱銷規格新用戶免費使用&am…

Android中ViewModel+LiveData+DataBinding的配合使用(kotlin)

Android 中 ViewModel、LiveData 和 Data Binding 的配合使用&#xff08;Kotlin&#xff09; 摘要 本文將介紹如何在 Android 開發中結合使用 ViewModel、LiveData 和 Data Binding 進行數據綁定和狀態更新。我們將詳細探討這三者之間的關系&#xff0c;并展示如何在 Kotlin…

最逼真的簡易交通燈設計

最逼真的簡易交通燈設計 需要資料的請在文章末尾獲取&#xff08;有問題可以私信我哦~~&#xff09; 01 資料內容 Proteus仿真文件程序源碼實物制作&#xff0c;代碼修改&#xff0c;功能定制&#xff08;需額外收費&#xff0c;價格實惠&#xff0c;歡迎咨詢&#xff09; …

實驗場:在幾分鐘內使用 Elasticsearch 進行 RAG 應用程序實驗

作者&#xff1a;來自 Elastic Joe McElroy, Serena Chou 什么是 Playground&#xff08;實驗場&#xff09;&#xff1f; 我們很高興發布我們的 Playground 體驗 —- 一個低代碼界面&#xff0c;開發人員可以在幾分鐘內使用自己的私人數據探索他們選擇的 LLM。 在對對話式搜…

41割隊伍

上海市計算機學會競賽平臺 | YACSYACS 是由上海市計算機學會于2019年發起的活動,旨在激發青少年對學習人工智能與算法設計的熱情與興趣,提升青少年科學素養,引導青少年投身創新發現和科研實踐活動。https://www.iai.sh.cn/problem/387 題目描述 給定 ??n 個數字 ??1,?…

一周小計(1):實習初體驗

實習的第一周&#xff0c;從最開始的配環境做好準備工作&#xff0c;到拉項目熟悉項目&#xff0c;然后自己去寫需求&#xff0c;每一步都有很大收獲&#xff0c;得到很多人幫助真的好感謝&#xff0c;以下是個人這幾天的記錄與感想。 &#xff08;這個其實是我寫的周報&#x…

Hi3861 OpenHarmony嵌入式應用入門--LiteOS Semaphore做同步使用

信號量作為同步使用 創建一個Semaphore對象&#xff0c;并指定一個初始的計數值&#xff08;通常稱為“許可”或“令牌”的數量&#xff09;。這個計數值表示當前可用的資源數量或可以同時訪問共享資源的線程數。當一個線程需要訪問共享資源時&#xff0c;它會嘗試從Semaphore…

加油站可視化:打造智能化運營與管理新模式

智慧加油站可視化通過圖撲 HT 構建仿真的三維模型&#xff0c;將加油站的布局、設備狀態、人員活動等信息動態呈現。管理者可以通過直觀的可視化界面實時監控和分析運營狀況&#xff0c;快速做出決策&#xff0c;提高管理效率和安全水平&#xff0c;推動加油站向智能化管理轉型…

后端之路第三站(Mybatis)——結合案例講Mybatis怎么操作sql

先講一下準備工作整體流程要做什么 我們要基于一個員工管理系統作為案例&#xff0c;進行員工信息的【增、刪、改、查】 原理就是用Mybatis通過java語言來執行sql語句&#xff0c;來達到【增、刪、改、查】 一、準備工作 1、引入數據庫數據 首先我們把一個員工、部門表的數…

【51單片機入門】速通定時器

文章目錄 前言定時器是什么初始化定時器初始化的大概步驟TMOD寄存器C/T寄存器 觸發定時器中斷是什么中斷函數定時器點亮led 總結 前言 在嵌入式系統的開發中&#xff0c;定時器是一個非常重要的組成部分。它們可以用于產生精確的時間延遲&#xff0c;或者在特定的時間間隔內觸…

對外發布的PDF文檔進行數字證書簽名的重要性?

對外發布的PDF文檔進行數字證書簽名具有以下幾個重要性&#xff1a; 身份驗證&#xff1a;數字簽名可以證明文檔的來源&#xff0c;即確認文檔的簽署者身份。這如同在紙質文檔上手寫簽名或加蓋公章&#xff0c;但更安全可靠&#xff0c;因為數字簽名是基于加密技術&#xff0c;…

Java--常用類APl(復習總結)

前言: Java是一種強大而靈活的編程語言&#xff0c;具有廣泛的應用范圍&#xff0c;從桌面應用程序到企業級應用程序都能夠使用Java進行開發。在Java的編程過程中&#xff0c;使用標準類庫是非常重要的&#xff0c;因為標準類庫提供了豐富的類和API&#xff0c;可以簡化開發過…

【接口自動化測試】第三節.實現項目核心業務接口自動化

文章目錄 前言一、實現登錄接口對象封裝和調用 1.0 登錄接口的接口測試文檔 1.1 接口對象層&#xff08;封裝&#xff09; 1.2 測試腳本層&#xff08;調用&#xff09;二、課程新增接口對象封裝和調用 2.0 課程新增接口的接口測試文檔 2.1 接口對象層…