深度學習處理文本(10)

保存自定義層

在編寫自定義層時,一定要實現get_config()方法:這樣我們可以利用config字典將該層重新實例化,這對保存和加載模型很有用。該方法返回一個Python字典,其中包含用于創建該層的構造函數的參數值。所有Keras層都可以被序列化(serialize)和反序列化(deserialize)?,如下所示。

config = layer.get_config()
new_layer = layer.__class__.from_config(config)---- config不包含權重值,因此該層的所有權重都是從頭初始化的

來看下面這個例子。

layer = PositionalEmbedding(sequence_length, input_dim, output_dim)
config = layer.get_config()
new_layer = PositionalEmbedding.from_config(config)

在保存包含自定義層的模型時,保存文件中會包含這些config字典。從文件中加載模型時,你應該在加載過程中提供自定義層的類,以便其理解config對象,如下所示。

model = keras.models.load_model(filename, custom_objects={"PositionalEmbedding": PositionalEmbedding})

你會注意到,這里使用的規范化層并不是之前在圖像模型中使用的BatchNormalization層。這是因為BatchNormalization層處理序列數據的效果并不好。相反,我們使用的是LayerNormalization層,它對每個序列分別進行規范化,與批量中的其他序列無關。它類似NumPy的偽代碼如下

def layer_normalization(batch_of_sequences):----輸入形狀:(batch_size, sequence_length, embedding_dim)mean = np.mean(batch_of_sequences, keepdims=True, axis=-1)---- (本行及以下1)計算均值和方差,僅在最后一個軸(?1軸)上匯聚數據variance = np.var(batch_of_sequences, keepdims=True, axis=-1)return (batch_of_sequences - mean) / variance

下面是訓練過程中的BatchNormalization的偽代碼,你可以將二者對比一下。

def batch_normalization(batch_of_images):----輸入形狀:(batch_size, height, width, channels)mean = np.mean(batch_of_images, keepdims=True, axis=(0, 1, 2))---- (本行及以下1)在批量軸(0軸)上匯聚數據,這會在一個批量的樣本之間形成相互作用variance = np.var(batch_of_images, keepdims=True, axis=(0, 1, 2))return (batch_of_images - mean) / variance

BatchNormalization層從多個樣本中收集信息,以獲得特征均值和方差的準確統計信息,而LayerNormalization層則分別匯聚每個序列中的數據,更適用于序列數據。我們已經實現了TransformerEncoder,下面可以用它來構建一個文本分類模型,如代碼清單11-22所示,它與前面的基于GRU的模型類似。代碼清單11-22 將Transformer編碼器用于文本分類

vocab_size = 20000
embed_dim = 256
num_heads = 2
dense_dim = 32inputs = keras.Input(shape=(None,), dtype="int64")
x = layers.Embedding(vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
x = layers.GlobalMaxPooling1D()(x)---- TransformerEncoder返回的是完整序列,所以我們需要用全局匯聚層將每個序列轉換為單個向量,以便進行分類
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop",loss="binary_crossentropy",metrics=["accuracy"])
model.summary()

我們來訓練這個模型,如代碼清單11-23所示。模型的測試精度為87.5%,比GRU模型略低。代碼清單11-23 訓練并評估基于Transformer編碼器的模型

callbacks = [keras.callbacks.ModelCheckpoint("transformer_encoder.keras",save_best_only=True)
]
model.fit(int_train_ds, validation_data=int_val_ds, epochs=20,callbacks=callbacks)
model = keras.models.load_model("transformer_encoder.keras",custom_objects={"TransformerEncoder": TransformerEncoder})----在模型加載過程中提供自定義的TransformerEncoder類
print(f"Test acc: {model.evaluate(int_test_ds)[1]:.3f}")

現在你應該已經開始感到有些不對勁了。你能看出是哪里不對勁嗎?本節的主題是“序列模型”?。我一開始就強調了詞序的重要性。我說過,Transformer是一種序列處理架構,最初是為機器翻譯而開發的。然而……你剛剛見到的Transformer編碼器根本就不是一個序列模型。你注意到了嗎?它由密集層和注意力層組成,前者獨立處理序列中的詞元,后者則將詞元視為一個集合。你可以改變序列中的詞元順序,并得到完全相同的成對注意力分數和完全相同的上下文感知表示。如果將每篇影評中的單詞完全打亂,模型也不會注意到,得到的精度也完全相同。自注意力是一種集合處理機制,它關注的是序列元素對之間的關系,如圖11-10所示,它并不知道這些元素出現在序列的開頭、結尾還是中間。既然是這樣,為什么說Transformer是序列模型呢?如果它不查看詞序,又怎么能很好地進行機器翻譯呢?

在這里插入圖片描述

Transformer是一種混合方法,它在技術上是不考慮順序的,但將順序信息手動注入數據表示中。這就是缺失的那部分,它叫作位置編碼(positional encoding)?。我們來看一下。

使用位置編碼重新注入順序信息

位置編碼背后的想法非常簡單:為了讓模型獲取詞序信息,我們將每個單詞在句子中的位置添加到詞嵌入中。這樣一來,輸入詞嵌入將包含兩部分:普通的詞向量,它表示與上下文無關的單詞;位置向量,它表示該單詞在當前句子中的位置。我們希望模型能夠充分利用這一額外信息。你能想到的最簡單的方法就是將單詞位置與它的嵌入向量拼接在一起。你可以向這個向量添加一個“位置”軸。在該軸上,序列中的第一個單詞對應的元素為0,第二個單詞為1,以此類推。然而,這種做法可能并不理想,因為位置可能是非常大的整數,這會破壞嵌入向量的取值范圍。如你所知,神經網絡不喜歡非常大的輸入值或離散的輸入分布。

在“Attention Is All You Need”這篇原始論文中,作者使用了一個有趣的技巧來編碼單詞位置:將詞嵌入加上一個向量,這個向量的取值范圍是[-1, 1],取值根據位置的不同而周期性變化(利用余弦函數來實現)?。這個技巧提供了一種思路,通過一個小數值向量來唯一地描述較大范圍內的任意整數。這種做法很聰明,但并不是本例中要用的。我們的方法更加簡單,也更加有效:我們將學習位置嵌入向量,其學習方式與學習嵌入詞索引相同。然后,我們將位置嵌入與相應的詞嵌入相加,得到位置感知的詞嵌入。這種方法叫作位置嵌入(positional embedding)?。我們來實現這種方法,如代碼清單11-24所示。代碼清單11-24 將位置嵌入實現為Layer子類

class PositionalEmbedding(layers.Layer):def __init__(self, sequence_length, input_dim, output_dim, **kwargs):----位置嵌入的一個缺點是,需要事先知道序列長度super().__init__(**kwargs)self.token_embeddings = layers.Embedding(----準備一個Embedding層,用于保存詞元索引input_dim=input_dim, output_dim=output_dim)self.position_embeddings = layers.Embedding(input_dim=sequence_length, output_dim=output_dim)----另準備一個Embedding層,用于保存詞元位置self.sequence_length = sequence_lengthself.input_dim = input_dimself.output_dim = output_dimdef call(self, inputs):length = tf.shape(inputs)[-1]positions = tf.range(start=0, limit=length, delta=1)embedded_tokens = self.token_embeddings(inputs)embedded_positions = self.position_embeddings(positions)return embedded_tokens + embedded_positions  ←----將兩個嵌入向量相加def compute_mask(self, inputs, mask=None):---- (本行及以下1)與Embedding層一樣,該層應該能夠生成掩碼,從而可以忽略輸入中填充的0。框架會自動調用compute_mask方法,并將掩碼傳遞給下一層return tf.math.not_equal(inputs, 0)def get_config(self):----實現序列化,以便保存模型config = super().get_config()config.update({"output_dim": self.output_dim,"sequence_length": self.sequence_length,"input_dim": self.input_dim,})return config

你可以像使用普通Embedding層一樣使用這個PositionEmbedding層。我們來看一下它的實際效果。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/75800.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/75800.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/75800.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

機器視覺3D中激光偏鏡的優點

機器視覺的3D應用中,激光偏鏡(如偏振片、波片、偏振分束器等)通過其獨特的偏振控制能力,顯著提升了系統的測量精度、抗干擾能力和適應性。以下是其核心優點: 1. 提升3D成像精度 抑制環境光干擾:偏振片可濾除非偏振的環境雜光(如日光、室內照明),僅保留激光偏振信號,大…

線程同步的學習與應用

1.多線程并發 1).多線程并發引例 #include <stdio.h> #include <stdlib.h> #include <unistd.h> #include <assert.h> #include <pthread.h>int wg0; void *fun(void *arg) {for(int i0;i<1000;i){wg;printf("wg%d\n",wg);} } i…

寫.NET可以指定運行SUB MAIN嗎?調用任意一個里面的類時,如何先執行某段初始化代碼?

VB.NET 寫.NET可以指定運行SUB MAIN嗎?調用任意一個里面的類時,如何先執行某段初始化代碼? 分享 1. 在 VB.NET 中指定運行 Sub Main 在 VB.NET 里&#xff0c;你能夠指定 Sub Main 作為程序的入口點。下面為你介紹兩種實現方式&#xff1a; 方式一&#xff1a;在項目屬性…

【AI插件開發】Notepad++ AI插件開發實踐(代碼篇):從Dock窗口集成到功能菜單實現

一、引言 上篇文章已經在Notepad的插件開發中集成了選中即問AI的功能&#xff0c;這一篇文章將在此基礎上進一步集成&#xff0c;支持AI對話窗口以及常見的代碼功能菜單&#xff1a; 顯示AI的Dock窗口&#xff0c;可以用自然語言向 AI 提問或要求執行任務選中代碼后使用&…

關聯容器-模板類pair數對

關聯容器 關聯容器和順序容器有著根本的不同:關聯容器中的元素是按關鍵字來保存和訪問的,而順序容器中的元素是按它們在容器中的位置來順序保存和訪問的。 關聯容器支持高效的關鍵字查找和訪問。 兩個主要的關聯容器(associative-container),set和map。 set 中每個元素只包…

京東運維面試題及參考答案

目錄 OSPF 實現原理是什么? 請描述 TCP 三次握手的過程。 LVS 的原理是什么? 闡述 Nginx 七層負載均衡的原理。 Nginx 與 Apache 有什么區別? 如何查看監聽在 8080 端口的是哪個進程(可舉例:netstat -tnlp | grep 8080)? OSI 七層模型是什么,請寫出各層的協議。 …

輸入框輸入數字且保持精度

在項目中如果涉及到金額等需要數字輸入且保持精度的情況下&#xff0c;由于輸入框是可以隨意輸入文本的&#xff0c;所以一般情況下可能需要監聽輸入框的change事件&#xff0c;然后通過正則表達式去替換掉不匹配的文本部分。 由于每次文本改變都會被監聽&#xff0c;包括替換…

使用 requests 和 BeautifulSoup 解析淘寶商品

以下將詳細解釋如何通過這兩個庫來實現按關鍵字搜索并解析淘寶商品信息。 一、準備工作 1. 安裝必要的庫 在開始之前&#xff0c;確保已經安裝了 requests 和 BeautifulSoup 庫。如果尚未安裝&#xff0c;可以通過以下命令進行安裝&#xff1a; bash pip install requests…

C#調用ACCESS數據庫,解決“Microsoft.ACE.OLEDB.12.0”未注冊問題

C#調用ACCESS數據庫&#xff0c;解決“Microsoft.ACE.OLEDB.12.0”未注冊問題 解決方法&#xff1a; 1.將C#采用的平臺從AnyCpu改成X64 2.將官網下載的“Microsoft Access 2010 數據庫引擎可再發行程序包AccessDatabaseEngine_X64”文件解壓 3.安裝解壓后的文件 點擊下載安…

【文獻閱讀】Vision-Language Models for Vision Tasks: A Survey

發表于2024年2月 TPAMI 摘要 大多數視覺識別研究在深度神經網絡&#xff08;DNN&#xff09;訓練中嚴重依賴標注數據&#xff0c;并且通常為每個單一視覺識別任務訓練一個DNN&#xff0c;這導致了一種費力且耗時的視覺識別范式。為應對這兩個挑戰&#xff0c;視覺語言模型&am…

【Kubernetes】StorageClass 的作用是什么?如何實現動態存儲供應?

StorageClass 使得用戶能夠根據不同的存儲需求動態地申請和管理存儲資源。 StorageClass 定義了如何創建存儲資源&#xff0c;并指定了存儲供應的配置&#xff0c;例如存儲類型、質量、訪問模式等。為動態存儲供應提供了基礎&#xff0c;使得 Kubernetes 可以在用戶創建 PVC 時…

Muduo網絡庫介紹

1.Reactor介紹 1.回調函數 **回調&#xff08;Callback&#xff09;**是一種編程技術&#xff0c;允許將一個函數作為參數傳遞給另一個函數&#xff0c;并在適當的時候調用該函數 1.工作原理 定義回調函數 注冊回調函數 觸發回調 2.優點 異步編程 回調函數允許在事件發生時…

Debian編譯安裝mysql8.0.41源碼包 筆記250401

Debian編譯安裝mysql8.0.41源碼包 以下是在Debian系統上通過編譯源碼安裝MySQL 8.0.41的完整步驟&#xff0c;包含依賴管理、編譯參數優化和常見問題處理&#xff1a; 準備工作 1. 安裝編譯依賴 sudo apt update sudo apt install -y \cmake gcc g make libssl-dev …

Git常用問題收集

gitignore 忽略文件夾 不生效 有時候我們接手別人的項目時&#xff0c;發現有的忽略不對想要修改&#xff0c;但發現修改忽略.gitignore后無效。原因是如果某些文件已經被納入版本管理在.gitignore中忽略路徑是不起作用的&#xff0c;這時候需要先清除本地緩存&#xff0c;然后…

編程哲學——TCP可靠傳輸

TCP TCP可靠傳輸 TCP的可靠傳輸表現在 &#xff08;1&#xff09;建立連接時三次握手&#xff0c;四次揮手 有點像是這樣對話&#xff1a; ”我們開始對話吧“ ”收到“ ”好的&#xff0c;我收到你收到了“ &#xff08;2&#xff09;數據傳輸時ACK應答和超時重傳 ”我們去吃…

【MediaPlayer】基于libvlc+awtk的媒體播放器

基于libvlcawtk的媒體播放器 libvlc下載地址 awtk下載地址 代碼實現libvlc相關邏輯接口UI媒體接口實例化媒體播放器注意事項 libvlc 下載地址 可以到https://download.videolan.org/pub/videolan/vlc/去下載一個vlc版本&#xff0c;下載后其實是vlc的windows客戶端&#xff0…

pulsar中的延遲隊列使用詳解

Apache Pulsar的延遲隊列支持任意時間精度的延遲消息投遞&#xff0c;適用于金融交易、定時提醒等高時效性場景。其核心設計通過堆外內存索引隊列與持久化分片存儲實現&#xff0c;兼顧靈活性與可擴展性。以下從實現原理、使用方式、優化策略及挑戰展開解析&#xff1a; 一、核…

單鏈表的實現 | 附學生信息管理系統的實現

目錄 1.前言&#xff1a; 2.單鏈表的相關概念&#xff1a; 2.1定義&#xff1a; 2.2形式&#xff1a; 2.3特點&#xff1a; 3.常見功能及代碼 &#xff1a; 3.1創建節點&#xff1a; 3.2頭插&#xff1a; 3.3尾插&#xff1a; 3.4頭刪&#xff1a; 3.5尾刪&#xff1a; 3.6插入…

java實用工具類Localstorage

public class LocalStorageUtil {//提供ThreadLocal對象,private static ThreadLocal threadLocalnew ThreadLocal();public static Object get(){return threadLocal.get();}public static void set(Object o){threadLocal.set(o);}public static void remove(){threadLocal.r…

LLM-大語言模型淺談

目錄 核心定義 典型代表 核心原理 用途 優勢與局限 未來發展方向 LLM&#xff08;Large Language Model&#xff09;大語言模型&#xff0c;指通過海量文本數據訓練 能夠理解和生成人類語言的深度學習模型。 核心定義 一種基于深度神經網絡&#xff08;如Transformer架…