CNN文本分類(tensorflow實現)

前言
  • 實現步驟
    • 1.安裝tensorflow
    • 2.導入所需要的tensorflow庫和其它相關模塊
    • 3.設置隨機種子
    • 4.定義模型相關超參數
    • 5.加載需要的數據集
    • 6.對加載的文本內容進行填充和截斷
    • 7.構建自己模型
    • 8.訓練構建的模型
    • 9.評估完成的模型
  • CNN(卷積神經網絡)在文本分類任務中具有良好的特征提取能力、位置不變性、參數共享和處理大規模數據的優勢,能夠有效地學習文本的局部和全局特征,提高模型性能和泛化能力,所以本文將以CNN實現文本分類。
    CNN對文本分類的支持主要提現在:

    特征提取:CNN能夠有效地提取文本中的局部特征。卷積層通過應用多個卷積核來捕獲不同大小的n-gram特征,從而能夠識別關鍵詞、短語和句子結構等重要信息。

    位置不變性:對于文本分類任務,特征的位置通常是不重要的。CNN中的池化層(如全局最大池化)能夠保留特征的最顯著信息,同時忽略其具體位置,這對于處理可變長度的文本輸入非常有幫助。

    參數共享:CNN中的卷積核在整個輸入上共享參數,這意味著相同的特征可以在不同位置進行識別。這種參數共享能夠極大地減少模型的參數量,降低過擬合的風險,并加快模型的訓練速度。

    處理大規模數據:CNN可以高效地處理大規模的文本數據。由于卷積和池化操作的局部性質,CNN在處理文本序列時具有較小的計算復雜度和內存消耗,使得它能夠適應大規模的文本分類任務。

    上下文建模:通過使用多個卷積核和不同的大小,CNN可以捕捉不同尺度的上下文信息。這有助于提高模型對文本的理解能力,并能夠捕捉更長范圍的依賴關系。

    實現步驟之前首先安裝完成tensorflow
  • 使用這個代碼安裝的前提是你的深度學習已經環境存在
  • 例如:conda、pytorch、cuda、cudnn等環境
  • conda create -n tf python=3.8
    conda activate tf
    #tensorflow的安裝
    pip install tensorflow-gpu -i https://pypi.douban.com/simple
    

    一. 測試tensorflow是否安裝成功

  • 有三種方法

  • 方法一:

import tensorflow as tf 
print(tf.__version__)
#輸出'2.0.0-alpha0'
print(tf.test.is_gpu_available())
#會輸出True,則證明安裝成功
#新版本的tf把tf.test.is_gpu_available()換成如下命令
import tensorflow as tf 
tf.config.list_physical_devices('GPU')
  • 方法二:
  • import tensorflow as tf 
    with tf.device('/GPU:0'):a = tf.constant(3)
    

    方法三:

  • #輸入python,進入python環境
    import tensorflow as tf
    #查看tensorflow版本
    print(tf.__version__)
    #輸出'2.0.0-alpha0'
    #測試GPU能否調用,先查看顯卡使用情況
    import os 
    os.system("nvidia-smi")
    #調用顯卡
    @tf.function
    def f():pass
    f()
    #這時會打印好多日志
    #再次查詢顯卡
    os.system("nvidia-smi")
    可以對比兩次使用情況
    

    二、打開pycharm倒入你創建的tf環境,新建py文件開始構建代碼

  • 1.導入所需的庫和模塊:

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Embedding, Conv1D, GlobalMaxPooling1D

其中提前安裝TensorFlow來用于構建和訓練模型,以及Keras中的各種層和模型類

2.設置隨機種子:

np.random.seed(42)

在CNN(卷積神經網絡)中設置隨機種子主要是為了保證實驗的可重復性。由于深度學習模型中涉及大量的隨機性,如權重的初始化、數據的打亂(shuffle)等,設置隨機種子可以使得每次實驗的隨機過程都保持一致,從而使得實驗結果可以復現

3.定義模型超參數:

max_features = 5000  # 詞匯表大小
max_length = 100  # 文本最大長度
embedding_dims = 50  # 詞嵌入維度
filters = 250  # 卷積核數量
kernel_size = 3  # 卷積核大小
hidden_dims = 250  # 全連接層神經元數量
batch_size = 32  # 批處理大小
epochs = 5  # 訓練迭代次數

超參數影響模型的結構和訓練過程,可自行調整。

4.加載數據集:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)

示例中,使用的IMDB電影評論數據集,其中包含以數字表示的評論文本和相應的情感標簽(正面或負面),使用tf.keras.datasets.imdb.load_data函數可以方便地加載數據集,并指定num_words參數來限制詞匯表的大小。

5.對文本進行填充和截斷:

x_train = sequence.pad_sequences(x_train, maxlen=max_length)
x_test = sequence.pad_sequences(x_test, maxlen=max_length)

由于每條評論的長度可能不同,需要將它們統一到相同的長度。sequence.pad_sequences函數用于在文本序列前后進行填充或截斷,使它們具有相同的長度。

6.構建模型:

model = Sequential()
model.add(Embedding(max_features, embedding_dims, input_length=max_length))
model.add(Dropout(0.2))
model.add(Conv1D(filters, kernel_size, padding='valid', activation='relu', strides=1))
model.add(GlobalMaxPooling1D())
model.add(Dense(hidden_dims, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))

這個模型使用Sequential模型類構建,依次添加了嵌入層(Embedding)、卷積層(Conv1D)、全局最大池化層(GlobalMaxPooling1D)和兩個全連接層(Dense)。嵌入層將輸入的整數序列轉換為固定維度的詞嵌入表示,卷積層通過應用多個卷積核來提取特征,全局最大池化層獲取每個特征通道的最大值,而兩個全連接層用于分類任務。

7.編譯模型:

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

在編譯模型之前,需要指定損失函數、優化器和評估指標。使用二元交叉熵作為損失函數,Adam優化器進行參數優化,并使用準確率作為評估指標。

8.訓練模型:

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))

使用fit函數對模型進行訓練。需要傳入訓練數據、標簽,批處理大小、訓練迭代次數,并可以指定驗證集進行模型性能評估。

9.評估模型:

scores = model.evaluate(x_test, y_test, verbose=0)
print("Test accuracy:", scores[1])

使用evaluate函數評估模型在測試集上的性能,計算并打印出測試準確率。

完整代碼

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Embedding, Conv1D, GlobalMaxPooling1D# 設置隨機種子
np.random.seed(42)# 定義模型超參數
max_features = 5000  # 詞匯表大小
max_length = 100  # 文本最大長度
embedding_dims = 50  # 詞嵌入維度
filters = 250  # 卷積核數量
kernel_size = 3  # 卷積核大小
hidden_dims = 250  # 全連接層神經元數量
batch_size = 32  # 批處理大小
epochs = 5  # 訓練迭代次數# 加載數據集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)# 對文本進行填充和截斷,使其具有相同的長度
x_train = sequence.pad_sequences(x_train, maxlen=max_length)
x_test = sequence.pad_sequences(x_test, maxlen=max_length)# 構建模型
model = Sequential()
model.add(Embedding(max_features, embedding_dims, input_length=max_length))
model.add(Dropout(0.2))
model.add(Conv1D(filters, kernel_size, padding='valid', activation='relu', strides=1))
model.add(GlobalMaxPooling1D())
model.add(Dense(hidden_dims, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))# 編譯模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])# 訓練模型
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))# 評估模型
scores = model.evaluate(x_test, y_test, verbose=0)
print("Test accuracy:", scores[1])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/715932.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/715932.shtml
英文地址,請注明出處:http://en.pswp.cn/news/715932.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【GPU驅動開發】-mesa簡介

前言 不必害怕未知,無需恐懼犯錯,做一個Creator! 一、mesa介紹 Mesa 是一個開源的3D圖形庫,它實現了多種圖形API,包括 OpenGL、Vulkan 和 OpenCL。Mesa 的目標是提供一個開源、跨平臺的圖形庫,使得開發者…

ABAP - SALV教程08 列設置熱點及綁定點擊事件

實現思路:將列設置成熱點,熱點列是可點擊的,再給SALV實例對象注冊點擊事件即可,一般作用于點擊單號跳轉到前臺等功能 "設置熱點方法METHODS:set_hotspot CHANGING co_alv TYPE REF TO cl_salv_table...."事件處理方法M…

SMART原則

在軟件研發領域,項目管理和目標設定尤為關鍵。一個成功的軟件項目不僅需要先進的技術支持,還需要一個清晰、明確且可實現的目標。SMART原則,作為一種高效的目標設定和管理方法,為軟件研發提供了有力的指導。SMART是五個英文單詞首…

合宙esp32-c3 進入深度睡眠無法喚醒解決一例

手賤,昨天收到了嘉立創最新的esp32 s3,想測試一下電流功耗,于是順便測試了一下以前的合宙esp32 c3 無串口芯片的版本 打算對比一下c3和s3的功耗相差多少,結果把自己玩死了: void setup() {esp_deep_sleep_start();// esp_light_s…

oppo手機備忘錄記錄怎么轉移到華為手機?

oppo手機備忘錄記錄怎么轉移到華為手機?使用oppo手機已經有三四年了,因為平時習慣,在手機系統的備忘錄中記錄了很多重要的筆記,比如工作會議的要點、讀書筆記、購物清單、朋友的生日提醒等。這些記錄對我來說非常重要,我可以通過…

STM32 HAL庫 串口使用問題記錄

文章目錄 STM32 HAL庫 串口使用問題記錄情況一:串口導致程序假死機情況二:其它程序正常運行,串口不再接收數據 STM32 HAL庫 串口使用問題記錄 情況一:串口導致程序假死機 多數應該出現在未開啟DMA模式使用中斷方式接收數據的情況…

鉀是人體內重要的電解質之一

鉀是人體內重要的電解質之一,是維持細胞生理活動的主要陽離子,在保持機體的正常滲透壓及酸堿平衡,維持內環境的穩定性,參與糖及蛋白質代謝,保證神經肌肉的正常功能,在興奮性等方面具有重要的作用。人體內的…

2000-2021年300+地級市進出口總額數據

2000-2021年300地級市進出口總額數據 1、時間:2000-2021年 2、指標:進出口總額 3、單位:萬美元 4、來源:城市年鑒、各省年鑒、城市公報、2021年為城市統計年鑒中進口額出口額加總之后換算成萬美元,已盡最大可能進行…

20240303

1.在優勢、劣勢、機會與威脅(SWOT)的分析期間,團隊發現另一個項目通過與該團隊合作可能從規模經濟中獲益。兩個項目的成本都可能大幅降低,并可能實現公司的利益,項目經理應該怎么做? A.在風險登記冊中記錄該發現 B.詢問項目發起人的意見 …

1.億級積分數據分庫分表:總體方案設計

項目背景 以一個積分系統為例,積分系統最核心的有積分賬戶表和積分明細表: 積分賬戶表:每個用戶在一個品牌下有一個積分賬戶記錄,記錄了用戶的積分余額,數據量在千萬級積分明細表:用戶每次積分發放、積分扣…

數據結構——Top-k問題

Top-k問題 方法一:堆排序(升序)(時間復雜度O(N*logN))向上調整建堆(時間復雜度:O(N * logN) )向下調整建堆(時間復雜度:O(N) )堆排序代碼 方法二&…

LeetCode---386周賽

題目列表 3046. 分割數組 3047. 求交集區域內的最大正方形面積 3048. 標記所有下標的最早秒數 I 3049. 標記所有下標的最早秒數 II 一、分割數組 這題簡單的思維題,要想將數組分為兩個數組,且分出的兩個數組中數字不會重復,很顯然一個數…

Redis 的哨兵模式配置

1.配置 vim sentinel.conf# mymaster 給主機起的名字 # 192.168.205.128 主機的ip地址 # 6379 端口號 # 2 當幾個哨兵發現主觀宕機,則判定為客觀宕機。 原則上是大于一半。比如三個哨兵,則設置為 2 sentinel monitor mymaster 192.168.205.128 63…

【動態規劃入門】01背包問題

每日一道算法題之01背包問題 一、題目描述二、思路三、C++代碼四、結語一、題目描述 題目來源:Acwing 有N件物品和一個容量是 V的背包。每件物品只能使用一次。第 i件物品的體積是 vi,價值是 wi。 求解將哪些物品裝入背包,可使這些物品的總體積不超過背包容量,且總價值最大…

LeetCode題練習與總結:合并K個升序鏈表

一、題目 給你一個鏈表數組,每個鏈表都已經按升序排列。 請你將所有鏈表合并到一個升序鏈表中,返回合并后的鏈表。 二、解題思路 創建一個最小堆(優先隊列)來存儲所有鏈表的頭節點。這樣我們可以始終取出當前所有鏈表中值最小…

人工智能指數報告2023

人工智能指數報告2023 主要要點第 1 章 研究與開發第 2 章 技術性能第 3 章 人工智能技術倫理第 4 章 經濟第 5 章 教育第 6 章 政策與治理第 7 章 多樣性第 8 章 輿論 人工智能指數是斯坦福大學以人為本的人工智能研究所(HAI)的一項獨立倡議&#xff0c…

Java 石頭剪刀布小游戲

一、任務 編寫一個剪刀石頭布游戲的程序。程序啟動后會隨機生成1~3的隨機數,分別代表剪刀、石頭和布,玩家通過鍵盤輸入剪刀、石頭和布與電腦進行5輪的游戲,贏的次數多的一方為贏家。若五局皆為平局,則最終結果判為平局。 二、實…

redis 為什么會阻塞

目錄 前言 客戶端交換時的阻塞 redis 磁盤交換的阻塞 主從節點交互的阻塞 切片集群交互時的阻塞 異步執行的演變 redis 異步執行如何實現的 前言 大家對redis 比較熟悉吧,只要做項目都會用到redis,提高系統的吞吐。小米商城搶購高峰18k的qps&…

KubeSphere平臺安裝系列之三【Linux多節點部署KubeSphere】(3/3)

**《KubeSphere平臺安裝系列》** 【Kubernetes上安裝KubeSphere(親測–實操完整版)】(1/3) 【Linux單節點部署KubeSphere】(2/3) 【Linux多節點部署KubeSphere】(3/3) **《KubeS…

一句話講清楚數據庫中事務的隔離級別(通俗易懂版)

為什么我只說通俗易懂版不說嚴謹版? 因為嚴謹版遍地都是, 但是他們卻有一個缺點就是讓人看得云里霧里, 所以這就是我寫通俗易懂版的初衷! 但是既然是通俗易懂版就必然有缺陷, 只為了各位在開發過程中頭腦更加清晰, 如有錯誤還望兄弟們不吝賜教! 在MySQL數據庫中,事務一共有4…