使用 BERT 進行文本分類 (02/3)

?

一、說明

????????在使用BERT(1)進行文本分類中,我向您展示了一個BERT如何標記文本的示例。在下面的文章中,讓我們更深入地研究是否可以使用 BERT 來預測文本是使用 PyTorch 傳達積極還是消極的情緒。首先,我們需要準備數據,以便使用 PyTorch 框架進行分析。

二、什么是 PyTorch

????????PyTorch 是用于構建深度學習模型的框架,深度學習模型是一種機器學習,通常用于圖像識別和語言處理等應用程序。它由Facebook的人工智能研究小組于2016年開發,由于其靈活性,易用性和動態計算圖構建而廣受歡迎。

????????PyTorch 提供了一個基于 Python 的科學計算包,它使用圖形處理單元 (GPU)?的強大功能來加速張量運算的計算。它具有簡單直觀的API,允許開發人員快速構建和訓練深度學習模型。PyTorch 還支持自動微分,使用戶能夠計算任意函數的梯度。

三、準備我們的數據集

????????首先,讓我們從Github下載我們的數據。這里有一個關于如何從Github下載CSV文件的小提醒。只需繼續并單擊以下鏈接:

github.com

????????然后,右鍵單擊“原始”,然后左鍵單擊“將鏈接文件下載為...”。您將看到“垃圾郵件.csv”并下載它。下載后,將其保存到您的首選文件夾中以供以后使用。

????????現在,讓我們導入數據。我們看到一條錯誤消息,告訴我們部分數據未采用 UTF-8 編碼。

import pandas as pd
df = pd.read_csv("spam.csv")ERROR: 
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 606-607: invalid continuation byte

我們可以通過了解數據包含的字符編碼并在讀取數據時調用該編碼來修復此錯誤。

# Use chardet to know the character encoding 
import chardet
with open("spam.csv", 'rb') as rawdata:result = chardet.detect(rawdata.read(100000))
resultOutput: 
{'encoding': 'Windows-1252', 'confidence': 0.7270322499829184, 'language': ''}

似乎我們的數據是在“Windows-1252”中編碼的。那讓我們再讀一遍。它奏效了!

df = pd.read_csv("spam.csv", encoding = 'Windows-1252')
df.head()

?

????????如我們所見,我們實際上并不需要“v1”和“v2”以外的列。此外,如果我們將“v1”和“v2”重命名為“類別”和“消息”,則更容易理解。

df = df.loc[:, ['v1', 'v2']]
df = df.rename(columns={'v1': 'Category', 'v2': 'Message'})
df.head()

?

????????現在,我們應該看看我們的數據集,看看每個類別中有多少條消息。

df['Category'].value_counts()Output: 
ham     4825
spam     747
Name: Category, dtype: int64

四、創建平衡數據集

????????事實證明,正常郵件比垃圾郵件多。構建機器學習模型時,如果數據集不平衡,其中一個類中的數據數量明顯多于另一個類,則可能會對模型的性能產生各種影響。一些潛在的后果。例如:

-1 有偏差模型:如果數據集不平衡,模型可能會偏向多數類,而對少數類表現不佳。這是因為模型更有可能預測多數類,這將導致少數類的準確性較差。

-2 泛化不良:不平衡的數據集可能導致模型泛化不良。這是因為該模型將在不代表數據真實世界分布的數據集上進行訓練,因此它可能無法很好地概括看不見的數據。

-3?評估不準確:如果使用準確性作為指標評估模型,則可能會產生誤導性結果。例如,始終預測不平衡數據集中多數類的模型可能具有很高的準確性,但對少數類沒有用。

-4 過擬合:由于數據點數量較多,模型可能會過度擬合多數類,從而導致測試數據的性能不佳。

為了解決這些問題,可以使用各種技術來平衡數據集,例如對少數類進行過采樣,對多數類進行欠采樣,或同時使用兩者的組合。在這篇文章中,我將使用欠采樣方法。

df_spam = df[df['Category']=='spam']
df_ham = df[df['Category']=='ham']
df_ham_downsampled = df_ham.sample(df_spam.shape[0])
df_balanced = pd.concat([df_ham_downsampled, df_spam])
df_balanced['Category'].value_counts()Output: 
ham     747
spam    747
Name: Category, dtype: int64

五、標記數據

????????當數據表示為數字而不是分類為用于訓練和測試的模型時,機器學習算法在準確性和其他性能指標方面表現更好。我們需要用數值對分類值進行標簽編碼。在這里,我們創建了一個新列“標簽”,如果郵件是垃圾郵件,我們將其標記為 1,否則為 0。

df_balanced['Label']=df_balanced['Category'].apply(lambda x: 1 if x=='spam' else 0)
df_balanced = df_balanced.reset_index(drop=True)display(df_balanced)

?

由作者創建

六、訓練、驗證和測試數據集:誰是誰

????????要記住的一件事是,當我們使用 train_test_split 庫來訓練模型時,我們實際上是將數據集拆分為 TRAINING 數據集和 VALIDATION 數據集,而不是 TRAINING 數據集和 TESTING 數據集。下面提醒一下這些數據集的含義。

  1. 訓練集:用于構建我們的模型。我們將使用訓練集來找到具有反向傳播規則的“最佳”權重和偏差。在此階段,我們通常會創建多個算法,以便在交叉驗證階段比較它們的性能。
  2. 交叉驗證集:此數據集用于比較基于訓練集創建的預測算法的性能。我們選擇性能最佳的算法。
  3. 測試集:這是“未來”數據集。現在我們已經選擇了我們喜歡的預測算法,但我們還不知道它將如何在完全看不見的真實世界數據上執行。因此,我們將我們選擇的預測算法應用于我們的測試集,以查看它將如何執行,以便我們可以了解我們的算法在野外的性能。

????????因此,在測試集中,我們沒有數據的標簽,而是使用我們的模型來預測標簽。我們只能將手頭的數據集拆分為訓練集和驗證集,因為我們還沒有“未來”數據。

七、拆分為訓練數據集和驗證數據集

????????現在我們了解了這三種類型的數據的真正含義,我們可以使用scikit-learn的train_test_split來拆分數據。

from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df_balanced['Message'],df_balanced['Label'], stratify=df_balanced['Label'], test_size=.2)X_train.head()Output: 
708                      ;-) ok. I feel like john lennon.
1386    Cashbin.co.uk (Get lots of cash this weekend!)...
1492    REMINDER FROM O2: To get 2.50 pounds free call...
119     Back in brum! Thanks for putting us up and kee...
89                       Sorry, I can't help you on this.
Name: Message, dtype: object

八、總結

????????我們已經學會了如何下載和拆分數據。在下一篇文章中,我們將首先對其進行標記,并使用DistilBERT訓練分類器。達門·

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/39627.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/39627.shtml
英文地址,請注明出處:http://en.pswp.cn/news/39627.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

3.1 Qt樣式選擇器

本期內容 3.1 樣式選擇器 3.1.1 Universal Selector (通用選擇器) 3.1.2 Type Selector (類型選擇器) 3.1.3 Property Selector (屬性選擇器) 3.1.4 Class Selector (類選擇器) 3.1.5 ID Selector (ID選擇器) 3.1.6 Descendant Selector (后裔選擇器) 3.1.7 Chil…

前端跨域的原因以及解決方案(vue),一文讓你真正理解跨域

跨域這個問題,可以說是前端的必需了解的,但是多少人是知其然不知所以然呢? 下面我們來梳理一下vue解決跨域的思路。 什么情況會跨域? ? 跨域的本質就是瀏覽器基于同源策略的一種安全手段。所謂同源就是必須有以下三個相同點:協議相同、域名…

WinCC V7.5 中的C腳本對話框不可見,將編輯窗口移動到可見區域的具體方法

WinCC V7.5 中的C腳本對話框不可見,將編輯窗口移動到可見區域的具體方法 由于 Windows 系統更新或使用不同的顯示器,在配置C動作時,有可能會出現C腳本編輯窗口被移動到不可見區域的現象。 由于該窗口無法被關閉,故無法進行進一步…

KafkaStream:Springboot中集成

1、在kafka-demo中創建配置類 配置kafka參數 package com.heima.kafkademo.config;import lombok.Data; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.streams.StreamsConfig; import org.springframework.boot.context.properties.Configu…

8月11日上課內容 nginx的多實例和動靜分離

多實例部署 在一臺服務器上有多個tomcat的服務。 配置多實例之前,看單個實例是否訪問正常。 1.安裝好 jdk 2.安裝 tomcat cd /opt tar zxvf apache-tomcat-9.0.16.tar.gz mkdir /usr/local/tomcat mv apache-tomcat-9.0.16 /usr/local/tomcat/tomcat1 cp -a /u…

Linux系統管理:虛擬機ESXi安裝

目錄 一、理論 1.VMware Workstation 2.VMware vSphere Client 3.ESXi 二、實驗 1.ESXi 7安裝 一、理論 1.VMware Workstation 它是一款專業的虛擬機軟件,可以在一臺物理機上運行多個操作系統,支持Windows、Linux等操作系統,可以模擬…

使用selenium如何實現自動登錄

回顧使用requests如何實現自動登錄一文中,提到好多網站在我們登錄過后,在之后的某段時間內訪問該網頁時,不會給出請登錄的提示,時間到期后就會提示請登錄!這樣在使用爬蟲訪問網頁時還要登錄,打亂我們的節奏…

item_get_sales-獲取商品銷量詳情

一、接口參數說明: item_get_sales-獲取商品銷量詳情,點擊更多API調試,請移步注冊API賬號點擊獲取測試key和secret 公共參數 請求地址: https://api-gw.onebound.cn/taobao/item_get_sales 名稱類型必須描述keyString是調用key&#xff08…

ACM模式刷Leetcode題目

139題單詞拆分 鏈接: link #include<iostream> #include<sstream> #include<string> #include<vector> #include<algorithm> #include<unordered_set> using namespace std;int main() {// 實現輸入第一行為s字符串。// 第二行為wordDic…

【代碼隨想錄day22】爬樓梯

題目 假設你正在爬樓梯。需要 n 階你才能到達樓頂。 每次你可以爬 1 或 2 個臺階。你有多少種不同的方法可以爬到樓頂呢&#xff1f; 示例 1&#xff1a; 輸入&#xff1a;n 2 輸出&#xff1a;2 解釋&#xff1a;有兩種方法可以爬到樓頂。 1. 1 階 1 階 2. 2 階 示例 2…

Spring的三種異常處理方式

1.SpringMVC 異常的處理流程 異常分為編譯時異常和運行時異常&#xff0c;編譯時異常我們 try-cache 進行捕獲&#xff0c;捕獲后自行處理&#xff0c;而運行時異常是不 可預期的&#xff0c;就需要規范編碼來避免&#xff0c;在SpringMVC 中&#xff0c;不管是編譯異常還是運行…

java:JDBC

文章目錄 什么是JDBCJDBC使用步驟詳解各個對象DriverManagerConnectionStatementResultSetPreparedStatement JDBC控制事務操作步驟示例 什么是JDBC 我們知道&#xff0c;數據庫有很多種&#xff0c;比如 mysql&#xff0c;Oracle&#xff0c;DB2等等&#xff0c;如果每一種數…

C# WPF 中 外部圖標引入iconfont,無法正常顯示問題 【小白記錄】

wpf iconfont 外部圖標引入&#xff0c;無法正常顯示問題。 1. 檢查資源路徑和引入格式是否正確2. 檢查資源是否包含在程序集中 1. 檢查資源路徑和引入格式是否正確 正確的格式&#xff0c;注意字體文件 “xxxx.ttf” 應寫為 “#xxxx” <TextBlock Text"&#xe7ae;…

不重啟Docker能添加自簽SSL證書鏡像倉庫嗎?

應用背景 在企業應用Docker規劃初期配置非安全鏡像倉庫時&#xff0c;有時會遺漏一些倉庫沒配置&#xff0c;但此時應用程序已經在Docker平臺上部署起來了&#xff0c;體量越大就越不會讓人去直接重啟Docker。 那么&#xff0c;不重啟Docker能添加自簽SSL證書鏡像倉庫嗎&…

經典人體模型SMPL介紹(一)

SMPL是馬普所提出的經典人體模型&#xff0c;目前已成為姿態估計、人體重建等領域必不可少的基礎先驗。SMPL基于蒙皮和BlendShape實現&#xff0c;從數千個三維人體掃描結果得來&#xff0c;后通過PCA統計學習得來。 論文&#xff1a;SMPL: A Skinned Multi-Person Linear Mode…

Python讀取svn版本

本文將詳細介紹如何使用Python讀取svn版本。 一、安裝svn庫 首先&#xff0c;我們需要使用Python來連接svn服務器&#xff0c;并獲取版本號。這里我們使用pysvn庫來完成這個工作。 pip install pysvn需要注意的是&#xff0c;如果你需要安裝特定版本的pysvn&#xff0c;你可…

2023連鎖收銀系統該如何選?值得推薦的5款連鎖收銀系統

現在不管是連鎖店還是零售店&#xff0c;只要是開店做生意賺錢的&#xff0c;都少不了要和錢打交道&#xff0c;尤其是對連鎖店來說&#xff0c;收銀工作更是重中之重。 連鎖店涉及的門店較多&#xff0c;必須要有一套足夠優秀的連鎖收銀系統&#xff0c;才能做好每個門店的收銀…

【ARM 嵌入式 編譯系列 5 -- GCC 內建函數 __builtin 詳細介紹】

文章目錄 什么是GCC內建函數?GCC 常見內建函數GCC內建函數使用示例上篇文章:ARM 嵌入式 編譯系列 4.2 – GCC 鏈接規范 extern “C“ 介紹 下篇文章:ARM 嵌入式 編譯系列 6 – GCC objcopy, objdump, readelf, nm 介紹 什么是GCC內建函數? GCC提供了一些專門的功能,用于…

使用 `tailwindcss-patch@2` 來提取你的類名吧

使用 tailwindcss-patch2 來提取你的類名吧 使用 tailwindcss-patch2 來提取你的類名吧 安裝使用方式 命令行 Cli 開始提取吧 Nodejs API 的方式來使用 配置 初始化 What’s next? tailwindcss-patch 是一個 tailwindcss 生態的擴展項目。也是 tailwindcss-mangle 項目重要…

redis的Key的過期策略是如何實現的?

Key的過期策略 一個redis中可能同時存在很多很多key&#xff0c;這些key可能有很大一部分都有過期時間&#xff0c;此時&#xff0c;redis服務器咋知道哪些key已經過期要被刪除&#xff0c;哪些key還沒有過期&#xff1f; 如果直接遍歷所有的key&#xff0c;顯然是行不通的&am…