TensorFlow深度學習實戰——使用Hugging Face構建Transformer模型

TensorFlow深度學習實戰——使用Hugging Face構建Transformer模型

    • 0. 前言
    • 1. 安裝 Hugging Face
    • 2. 文本生成
    • 3. 自動模型選擇和自動分詞
    • 4. 命名實體識別
    • 5. 摘要生成
    • 6. 模型微調
    • 相關鏈接

0. 前言

除了需要實現特定的自定義結構,或者想要了解 Transformer 工作原理外,從零開始實現 Transformer 并不是最佳選擇,和其它編程實踐一樣,通常并不需要從頭開始造輪子。只有想要理解 Transformer 架構的內部細節,或者修改 Transformer 架構以得到新的變體時才需要從零開始構建。有很多優秀的庫提供高質量的 Transformer 解決方案,Hugging Face 是其中的代表之一,它提供了一些構建 Transformer 的高效工具:

  • Hugging Face 提供了一個通用的 API 來處理多種 Transformer 架構
  • Hugging Face 不僅提供了基礎模型,還提供了帶有不同類型“頭”的模型來處理特定任務(例如,對于 BERT 架構,提供了 TFBertModel,用于情感分析的 TFBertForSequenceClassification,用于命名實體識別的 TFBertForTokenClassification,以及用于問答的 TFBertForQuestionAnswering 等)
  • 可以通過使用 Hugging Face 提供的預訓練權重來輕松創建自定義的網絡,例如,使用 TFBertForPreTraining
  • 除了 pipeline() 方法,還可以以常規方式定義模型,使用 fit() 進行訓練,使用 predict() 進行推理,就像普通的 TensorFlow 模型一樣

1. 安裝 Hugging Face

和其它第三方庫一樣,可以使用 pip 命令安裝 Hugging Face 庫:

$ pip install transformers[tf]

然后,通過下載一個用于情感分析的預訓練模型來驗證 Hugging Face 庫是否安裝成功:

$ python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('we love you'))"

如果成功安裝,將顯示如下輸出結果:

[{'label': 'POSITIVE', 'score': 0.9998704791069031}]

接下來,介紹如何使用 Hugging Face 解決具體任務。

2. 文本生成

在本節中,我們將使用 GPT-2 進行自然語言生成,這是一個生成自然語言輸出的過程。

(1) 使用 GPT-2 生成文本:

from transformers import pipeline
generator = pipeline(task="text-generation")

(2) 模型下載完成后,將文本傳遞給生成器,觀察結果:

generator("Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone")generator ("The original theory of relativity is based upon the premise that all coordinate systems in relative uniform translatory motion to each other are equally valid and equivalent ")generator ("It takes a great deal of bravery to stand up to our enemies")

生成結果

3. 自動模型選擇和自動分詞

Hugging Face 能夠盡可能幫助自動化多個步驟。

(1) 可以非常簡單的從數十個可用的預訓練模型中導入可用模型:

from transformers import TFAutoModelForSequenceClassification
model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

可以在下游任務上訓練模型,以便用于預測和推理。

(2) 可以使用 AutoTokenizer 將單詞轉換為模型使用的詞元:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
sequence = "The original theory of relativity is based upon the premise that all coordinate systems"
print(tokenizer(sequence))

輸出結果

4. 命名實體識別

命名實體識別 (Named Entity Recognition, NER) 是經典的自然語言處理任務。命名實體識別也稱實體識別 (entity identification)、實體分塊 (entity chunking) 或實體提取 (entity extraction),是信息提取的一個子任務,旨在定位和分類在非結構化文本中提到的命名實體,將其劃分為預定義的類別,例如人名、組織、地點、時間表達、數量、貨幣值和百分比等。接下來,我們使用 Hugging Face 完成命名實體識別任務。

(1) 創建一個 NER 管道:

from transformers import pipeline
ner_pipe = pipeline("ner")
sequence = """Mr. and Mrs. Dursley, of number four, Privet Drive, were
proud to say that they were perfectly normal, thank you very much."""
for entity in ner_pipe(sequence):print(entity)

(2) 結果如下所示,其中實體已經被識別出來:
識別結果

命名實體識別可以理解九個不同的類別:

  • O: 不屬于命名實體
  • B-MIS: 在另一個雜項實體后開始的雜項實體
  • I-MIS: 雜項實體
  • B-PER: 在另一個人名后面開始的人名
  • I-PER: 人名
  • B-ORG: 在另一個組織后面開始的組織
  • I-ORG: 組織
  • B-LOC: 在另一個地點后面開始的地點
  • I-LOC: 地點

這些實體在 CoNLL-2003 數據集中定義,并由 Hugging Face 自動選擇。

5. 摘要生成

摘要生成,是指用簡短而清晰的形式表達有關某事或某人的最重要事實或觀點。Hugging Face 使用 T5 模型作為完成此任務的默認模型。

(1) 首先,使用默認的 T5 small 模型創建一個摘要生成管道:

from transformers import pipeline
summarizer = pipeline("summarization")
ARTICLE = """Mr. and Mrs.Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr.Dursley was the director of a firm called Grunnings, which made drills.He was a big, beefy man with hardly any neck, although he did have a very large mustache.Mrs.Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors.The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere"""
print(summarizer(ARTICLE, max_length=130, min_length=30, do_sample=False))

輸出結果如下:

輸出結果

(2) 如果想要更換使用不同的模型,只需修改參數 model

summarizer = pipeline("summarization", model='t5-base')

輸出結果如下:

輸出結果

6. 模型微調

一種常見的 Transformer 使用模式是先使用預訓練的大語言模型 (Large Language Model, LLM),然后對模型進行微調以適應特定的下游任務。微調步驟將基于自定義數據集,而預訓練則是在非常大的數據集上進行的。這種策略的優點在于節省計算成本,此外,微調令我們使用最先進的模型,而不需要從頭開始訓練一個模型。接下來,我們介紹如何使用 TensorFlow 進行模型微調,使用的預訓練模型是 bert-base-cased,在 Yelp Reviews 數據集上進行微調。
本節使用 datasets 庫加載數據集,datasets 庫是由 Hugging Face 提供的一個非常強大的工具,專門用于加載、處理和分享數據集,使用 pip 命令安裝 datasets 庫:

$ pip install datasets

(1) 首先,加載并對 Yelp 數據集進行分詞:

from datasets import load_datasetdataset = load_dataset("yelp_review_full")
from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

(2) 然后,將數據集轉換為 TensorFlow 格式:

from transformers import DefaultDataCollator
data_collator = DefaultDataCollator(return_tensors="tf")# convert the tokenized datasets to TensorFlow datasetstf_train_dataset = small_train_dataset.to_tf_dataset(columns=["attention_mask", "input_ids", "token_type_ids"],label_cols=["labels"],shuffle=True,collate_fn=data_collator,batch_size=8,
)tf_validation_dataset = small_eval_dataset.to_tf_dataset(columns=["attention_mask", "input_ids", "token_type_ids"],label_cols=["labels"],shuffle=False,collate_fn=data_collator,batch_size=8,
)

(3) 使用 TFAutoModelForSequenceClassification,選擇 bert-base-cased

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassificationmodel = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

(4) 最后,微調模型的方法是使用 TensorFlow 中的標準訓練方式,通過編譯模型并使用 fit() 進行訓練:

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=tf.metrics.SparseCategoricalAccuracy(),
)model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3)

相關鏈接

TensorFlow深度學習實戰(1)——神經網絡與模型訓練過程詳解
TensorFlow深度學習實戰(2)——使用TensorFlow構建神經網絡
TensorFlow深度學習實戰(3)——深度學習中常用激活函數詳解
TensorFlow深度學習實戰(4)——正則化技術詳解
TensorFlow深度學習實戰(5)——神經網絡性能優化技術詳解
TensorFlow深度學習實戰(6)——回歸分析詳解
TensorFlow深度學習實戰(7)——分類任務詳解
TensorFlow深度學習實戰(8)——卷積神經網絡
TensorFlow深度學習實戰(12)——詞嵌入技術詳解
TensorFlow深度學習實戰(13)——神經嵌入詳解
TensorFlow深度學習實戰(14)——循環神經網絡詳解
TensorFlow深度學習實戰(15)——編碼器-解碼器架構
TensorFlow深度學習實戰(16)——注意力機制詳解
TensorFlow深度學習實戰(21)——Transformer架構詳解與實現
TensorFlow深度學習實戰(22)——從零開始實現Transformer機器翻譯
TensorFlow深度學習實戰——Transformer變體模型
TensorFlow深度學習實戰——Transformer模型評價指標

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/84884.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/84884.shtml
英文地址,請注明出處:http://en.pswp.cn/web/84884.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SAP-ABAP:SAP全模塊的架構化解析,涵蓋核心功能、行業方案及技術平臺

一、核心業務模塊(Logistics & Operations) 模塊代號核心功能典型流程關鍵事務碼物料管理MM采購/庫存/發票校驗采購到付款 (P2P)ME21N(采購訂單), MI31(庫存盤點)銷售與分銷SD訂單/定價/發貨/開票訂單…

實時預警!機場機坪井室無線智能液位監測系統助力安全降本

某沿海機場因地處多雨區域,每年雨季均面臨排水系統超負荷運行壓力。經勘測發現,5個井室因長期遭受地下水滲透侵蝕,井壁出現細微結構性裂縫,導致內部水位異常升高。作為機坪地下管網系統的核心節點,這些井室承擔著雨水導…

邊云協同 AI 視頻分析系統設計方案

目錄 一、項目背景與目標 二、系統架構概述 總體架構圖 三、ER 圖(核心數據庫設計) 實體關系圖簡述 數據表設計(簡要) 四、模型結構圖(邊緣云端AI推理架構) 邊緣模型(YOLOv5-tiny/PP-YO…

vue3整合element-plus

為項目命名 選擇vue 框架 選擇TS 啟動測試: npm run dev 開始整合 element-plus npm install element-plus --save npm install unplugin-vue-components unplugin vitejs/plugin-vue --save-dev 修改main.ts import { createApp } from vue import ./style.cs…

【AI 測試】測試用例設計:人工智能語言大模型性能測試用例設計

目錄 一、性能測試可視化架構圖 (1)測試整體架構圖 (2)測試體系架構圖 (3)測試流程時序圖 二、性能測試架構總覽 (1)性能測試功能點 (2)測試環境要…

Windsurf SWE-1模型評析:軟件工程的AI革命

引言 軟件開發領域正經歷著前所未有的變革,AI輔助編程工具層出不窮,但大多數僅專注于代碼生成這一環節。Windsurf公司近期推出的SWE-1系列模型打破了這一局限,首次將AI應用擴展至軟件工程的全流程。這一舉措不僅反映了行業對AI工具認知的深化…

Qt for OpenHarmony 編譯鴻蒙調用的動態庫

簡介 Qt for Harmony? 是跨平臺開發框架 ?Qt? 與華為 ?OpenHarmony? 操作系統的深度集成方案,由 Qt Group 與華為聯合推動。其核心目標是為開發者提供一套高效工具鏈,實現 ??“一次開發,多端部署”?,加速 OpenHarmony 生…

退休時,按最低基數補繳醫療保險15年大概需要多少錢

在南京退休時,如果醫保繳費年限不足(男需滿25年/女需滿20年),需補繳差額年限。若按最低基數一次性補繳15年醫保,費用估算如下(以2024年政策為例): 一、補繳金額計算公式 總補繳費用…

wireshark過濾顯示rtmp協議

wireshark中抓包顯示的數據報文中,明明可以看到有 rtmp 協議的報文,但是過濾的時候卻顯示一條都沒有 查看選項中的配置,已經沒有 RTMP 這個協議了,已經被 RTMPT 替換了,過濾框中輸入 rtmpt 過濾即可

《哈希表》K倍區間(解題報告)

文章目錄 零、題目描述一、算法概述二、算法思路三、代碼實現四、算法解釋五、復雜度分析 零、題目描述 題目鏈接:K倍區間 一、算法概述 計算子數組和能被k整除的子數組數量的算法。通過前綴和與哈希表的結合,高效地統計滿足條件的子數組。??需要注…

OpenShift 在 Kubernetes 多出的功能中,哪些開源?

OpenShift 在 Kubernetes 基礎上增加的功能中,部分組件是開源的(代碼可公開訪問),而另一些則是 Red Hat 專有(閉源)。以下是詳細分類: 1. 完全開源的功能(代碼可查) 這些…

【每天一個知識點】CITE-seq 技術

一、技術背景 單細胞RNA測序(scRNA-seq)自問世以來,極大推動了細胞異質性和組織復雜性的研究。但RNA水平并不能完全代表蛋白質水平,因為蛋白質的表達受轉錄后調控、翻譯效率及蛋白降解等多種因素影響。此外,許多細胞類…

中文Windows系統下程序輸出重定向亂碼問題解決方案

導言 最近我在用 Rust 開發時,遇到了一個讓人頭疼的問題:運行 cargo run -- version Cargo.toml > output.txt 將輸出重定向到文件后,打開 output.txt 卻發現里面全是亂碼!我的程序確實是UTF8但是輸出的文件卻是UTF16LE編碼的…

Python管理工具UV

常用 UV 命令 安裝 pip install uv 版本相關 uv python list 打印所有uv支持的python版本uv python install cpython-3.12 安裝指定的python版本uv run -p 3.12 test.py 用指定的python版本運行python代碼uv run -p 3.12 python 進入python執行環境。假如輸入的版本是一個本…

論文略讀:ASurvey on Intent-aware Recommender Systems

202406 arxiv 推薦系統在許多現代在線服務中發揮著關鍵作用,例如電子商務或媒體流服務,它們能夠為消費者和服務提供商創造巨大的價值。因此,過去幾十年來,研究人員提出了大量生成個性化推薦的技術方法。傳統算法——從早期的 Gro…

Neo4j 中存儲和查詢數組數據的完整指南

Neo4j 中存儲和查詢數組數據的完整指南 圖形數據庫 Neo4j 不僅擅長處理節點和關系,還提供了強大的數組(Array)存儲和操作能力。本文將全面介紹如何在 Neo4j 中高效地使用數組,包括存儲、查詢、優化以及實際應用場景。 數組在 Neo4j 中的基本使用 數組…

Android 編譯和打包image鏡像流程

1. 編譯命令 source build/envsetup.sh lunch aosp_car_arm64-userdebug make2. 編譯流程 source build/envsetup.sh 定義一些函數的環境變量,如 lunchvalidate_current_shell,確認 shell 環境set_global_paths,設置環境變量 ANDROID_GLOB…

MySQL:SQL 慢查詢優化的技術指南

1、簡述 在 Java 后端開發中,數據庫是系統性能瓶頸的高發地帶,而 慢 SQL 查詢 往往是系統響應遲緩的“罪魁禍首”。本文將全面梳理慢 SQL 的優化思路,并結合 Java 示例進行實戰演練。 2、慢查詢的常見表現 慢查詢通常表現為: 接…

leetcode543-二叉樹的直徑

leetcode 543 思路 路徑長度計算:任意兩個節點之間的路徑長度,等于它們的最低公共祖先到它們各自的深度之和遞歸遍歷:通過后序遍歷(左右根)計算每個節點的左右子樹深度,并更新全局最大直徑深度與直徑的關…

詳解main的參數并實現讀取文件

在 C 語言中,main函數的參數argc和argv用于接收命令行傳入的參數 main 函數的兩個參數 int main(int argc, char* argv[]) 假設顧客通過手機 APP 點餐,訂單信息會被傳遞給餐廳的處理系統(也就是你的程序)。 訂單信息結構 argc…