【MindSpore學習打卡】應用實踐-自然語言處理-基于RNN的情感分類:使用MindSpore實現IMDB影評分類

情感分類是自然語言處理(NLP)中的一個經典任務,廣泛應用于社交媒體分析、市場調研和客戶反饋等領域。本篇博客將帶領大家使用MindSpore框架,基于RNN(循環神經網絡)實現一個情感分類模型。我們將詳細介紹數據準備、模型構建、訓練與評估等步驟,并最終實現對自定義輸入的情感預測。

數據準備

下載和加載數據集

  • 為什么要使用requeststqdm庫? requests庫提供了簡潔的HTTP請求接口,方便我們從網絡上下載數據。tqdm庫則用于顯示下載進度條,幫助我們實時了解下載進度,提高用戶體驗。
  • 為什么要使用臨時文件和shutil庫? 使用臨時文件可以確保下載的數據在出現意外中斷時不會影響最終保存的文件。shutil庫提供了高效的文件操作方法,確保數據能正確地從臨時文件復制到最終保存路徑。

我們使用IMDB影評數據集,這是一個經典的情感分類數據集,包含積極(Positive)和消極(Negative)兩類影評。為了方便數據下載和處理,我們首先設計一個數據下載模塊。

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path# 指定保存路徑
cache_dir = Path.home() / '.mindspore_examples'def http_get(url: str, temp_file: IO):req = requests.get(url, stream=True)content_length = req.headers.get('Content-Length')total = int(content_length) if content_length is not None else Noneprogress = tqdm(unit='B', total=total)for chunk in req.iter_content(chunk_size=1024):if chunk:progress.update(len(chunk))temp_file.write(chunk)progress.close()def download(file_name: str, url: str):if not os.path.exists(cache_dir):os.makedirs(cache_dir)cache_path = os.path.join(cache_dir, file_name)cache_exist = os.path.exists(cache_path)if not cache_exist:with tempfile.NamedTemporaryFile() as temp_file:http_get(url, temp_file)temp_file.flush()temp_file.seek(0)with open(cache_path, 'wb') as cache_file:shutil.copyfileobj(temp_file, cache_file)return cache_path

下載IMDB數據集并進行解壓和加載:

import re
import six
import string
import tarfileclass IMDBData():label_map = {"pos": 1, "neg": 0}def __init__(self, path, mode="train"):self.mode = modeself.path = pathself.docs, self.labels = [], []self._load("pos")self._load("neg")def _load(self, label):pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))with tarfile.open(self.path) as tarf:tf = tarf.next()while tf is not None:if bool(pattern.match(tf.name)):self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(None, six.b(string.punctuation)).lower()).split())self.labels.append([self.label_map[label]])tf = tarf.next()def __getitem__(self, idx):return self.docs[idx], self.labels[idx]def __len__(self):return len(self.docs)

加載訓練數據集進行測試,輸出數據集數量:

imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_train = IMDBData(imdb_path, 'train')
print(len(imdb_train))

加載預訓練詞向量

為什么要使用預訓練詞向量? 預訓練詞向量(如Glove)是基于大規模語料庫訓練得到的,能夠捕捉到豐富的語義信息。使用預訓練詞向量可以提高模型的性能,尤其是在數據量較小的情況下。

  • 為什么要添加<unk><pad>標記符? <unk>標記符用于處理詞匯表中未出現的單詞,避免模型在遇到新詞時無法處理。<pad>標記符用于填充不同長度的文本序列,使其能夠被打包為一個batch進行并行計算,提高訓練效率。

我們使用Glove預訓練詞向量對文本進行編碼。以下是加載Glove詞向量的代碼:

import zipfile
import numpy as npdef load_glove(glove_path):glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')if not os.path.exists(glove_100d_path):glove_zip = zipfile.ZipFile(glove_path)glove_zip.extractall(cache_dir)embeddings = []tokens = []with open(glove_100d_path, encoding='utf-8') as gf:for glove in gf:word, embedding = glove.split(maxsplit=1)tokens.append(word)embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))embeddings.append(np.random.rand(100))embeddings.append(np.zeros((100,), np.float32))vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)embeddings = np.array(embeddings).astype(np.float32)return vocab, embeddingsglove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
print(len(vocab.vocab()))

數據集預處理

為什么要對文本進行分詞和去除標點? 分詞和去除標點是文本預處理的基礎步驟,有助于模型更好地理解文本的語義。將文本轉為小寫可以減少詞匯表的大小,提高模型的泛化能力。

對加載的數據集進行預處理,包括將文本轉為index id序列,并進行序列填充。

import mindspore as ms
import mindspore.dataset as dslookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
type_cast_op = ds.transforms.TypeCast(ms.float32)imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
imdb_valid = imdb_valid.batch(64, drop_remainder=True)

模型構建

接下來,我們構建用于情感分類的RNN模型。模型主要包括以下幾層:

  1. Embedding層:將單詞的index id轉為詞向量。
  2. RNN層:使用LSTM進行特征提取。
  3. 全連接層:將LSTM的輸出特征映射到分類結果。

為什么選擇LSTM而不是經典RNN? LSTM通過引入門控機制,有效地緩解了經典RNN中存在的梯度消失問題,能夠更好地捕捉長距離依賴信息,提高模型的效果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers, bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output

損失函數與優化器

我們使用二分類交叉熵損失函數nn.BCEWithLogitsLoss和Adam優化器。

為什么使用二分類交叉熵損失函數和Adam優化器? 二分類交叉熵損失函數適用于二分類任務,能夠衡量模型預測結果與真實標簽之間的差異。Adam優化器結合了動量和自適應學習率的優點,具有較快的收斂速度和較好的效果,廣泛應用于深度學習模型的訓練。

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)

訓練邏輯

設計訓練一個epoch的函數,用于訓練過程和loss的可視化。

def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total/step_total)t.update(1)

評估指標和邏輯

設計評估邏輯,計算模型在驗證集上的準確率。

def binary_accuracy(preds, y):rounded_preds = np.around(ops.sigmoid(preds).asnumpy())correct = (rounded_preds == y).astype(np.float32)acc = correct.sum() / len(correct)return accdef evaluate(model, test_dataset, criterion, epoch=0):total = test_dataset.get_dataset_size()epoch_loss = 0epoch_acc = 0step_total = 0model.set_train(False)with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in test_dataset.create_tuple_iterator():predictions = model(i[0])loss = criterion(predictions, i[1])epoch_loss += loss.asnumpy()acc = binary_accuracy(predictions, i[1])epoch_acc += accstep_total += 1t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)t.update(1)return epoch_loss / total

模型訓練與保存

進行模型訓練,并保存最優模型。

num_epochs = 2
best_valid_loss = float('inf')
ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')for epoch in range(num_epochs):train_one_epoch(model, imdb_train, epoch)valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)if valid_loss < best_valid_loss:best_valid_loss = valid_lossms.save_checkpoint(model, ckpt_file_name)

模型加載與測試

加載已保存的最優模型,并在測試集上進行評估。

param_dict = ms.load_checkpoint(ckpt_file_name)
ms.load_param_into_net(model, param_dict)
imdb_test = imdb_test.batch(64)
evaluate(model, imdb_test, loss_fn)

自定義輸入測試

設計一個預測函數,實現對自定義輸入的情感預測。

score_map = {1: "Positive",0: "Negative"
}def predict_sentiment(model, vocab, sentence):model.set_train(False)tokenized = sentence.lower().split()indexed = vocab.tokens_to_ids(tokenized)tensor = ms.Tensor(indexed, ms.int32)tensor = tensor.expand_dims(0)prediction = model(tensor)

通過本文的學習,我們成功地使用MindSpore框架實現了一個基于RNN的情感分類模型。我們從數據準備開始,詳細講解了如何加載和處理IMDB影評數據集,以及使用預訓練的Glove詞向量對文本進行編碼。然后,我們構建了一個包含Embedding層、LSTM層和全連接層的情感分類模型,并使用二分類交叉熵損失函數和Adam優化器進行訓練。最后,我們評估了模型在測試集上的性能,并實現了對自定義輸入的情感預測。希望這篇博客能幫助你更好地理解RNN在自然語言處理中的應用,并激發你在NLP領域的更多探索和實踐。

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/41027.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/41027.shtml
英文地址,請注明出處:http://en.pswp.cn/web/41027.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

X86和ARM架構的服務器 的區別

X86和ARM架構的服務器各有其優缺點,并適用于不同的應用場景。 一、X86架構服務器的優缺點及應用場景: 優點: 1. 易于獲取和成本較低:X86服務器在市場上品牌和型號眾多,價格相對較低,適合中小型企業。 2. 處理能力強大:X86服務器通常具有強大的處理器性能,支持多核心…

CLIP-EBC:通過增強的逐塊分類,CLIP能夠準確計數

摘要 https://arxiv.org/pdf/2403.09281v1 CLIP&#xff08;Contrastive Language-Image Pretraining&#xff0c;對比語言-圖像預訓練&#xff09;模型在識別問題中表現出了卓越的性能&#xff0c;如零樣本圖像分類和對象檢測。然而&#xff0c;由于其固有的挑戰——即將計數…

Nettyの參數優化簡單RPC框架實現

本篇介紹Netty調優&#xff0c;在上篇聊天室的案例中進行改造&#xff0c;手寫一個簡單的RPC實現。 1、超時時間參數 CONNECT_TIMEOUT_MILLIS 是Netty的超時時間參數&#xff0c;屬于客戶端SocketChannel的參數&#xff0c;客戶端連接時如果一定時間沒有連接上&#xff0c;就會…

Spring Cloud 是什么?(Spring Cloud 組件介紹)

什么是 Spring Cloud&#xff1f; Spring Cloud 是微服務系統架構的一站式解決方案&#xff0c;是各個微服務架構落地技術的集合體&#xff0c;讓架構師、 開發者在使用微服務理念構建應用系統的時候&#xff0c; 面對各個環節的問題都可以找到相應的組件來處理&#xff0c;比…

二叉樹的遍歷算法:前序、中序與后序遍歷

在數據結構與算法中&#xff0c;二叉樹的遍歷是基礎且重要的操作之一&#xff0c;它允許我們按照某種順序訪問樹中的每個節點。常見的二叉樹遍歷方式有前序遍歷&#xff08;Preorder Traversal&#xff09;、中序遍歷&#xff08;Inorder Traversal&#xff09;和后序遍歷&…

React 19 競態問題解決

競態問題/競態條件 指的是&#xff0c;當我們在交互過程中&#xff0c;由于各種原因導致同一個接口短時間之內連續發送請求&#xff0c;后發送的請求有可能先得到請求結果&#xff0c;從而導致數據渲染出現預期之外的錯誤。 因為防止重復執行可以有效的解決競態問題&#xff0…

聊天廣場(Vue+WebSocket+SpringBoot)

由于心血來潮想要做個聊天室項目 &#xff0c;但是仔細找了一下相關教程&#xff0c;卻發現這么多的WebSocket教程里面&#xff0c;很多都沒有介紹詳細&#xff0c;代碼都有所殘缺&#xff0c;所以這次帶來一個比較完整得使用WebSocket的項目。 目錄 一、效果展示 二、準備工…

html+css+js圖片手動輪播

源代碼在界面圖片后面 輪播演示用的幾張圖片是Bing上的&#xff0c;直接用的幾張圖片的URL&#xff0c;誰加載可能需要等一下&#xff0c;現實中替換成自己的圖片即可 關注一下點個贊吧&#x1f604; 謝謝大佬 界面圖片 源代碼 <!DOCTYPE html> <html lang&quo…

內存對齊宏ALIGN的理解

內存對齊宏ALIGN的理解 在Android Camera HAL代碼中經常看到ALIGN這個宏&#xff0c;主要用來進行內存對齊&#xff0c;下面是v4l2_wrapper.cpp中ALIGN的一些定義 //v4l2_wrapper.cpp中內存分配進行對其的操作和定義#define ALIGN( num, to ) (((num) (to-1)) & (~(to-1)…

【Android】自定義換膚框架03之自定義LayoutInflaterFactory

AppCompatActivity是如何創建View的 Activity通過LayoutInflater解析出XmlLayout相關信息LayoutInflater內部維護了一個InflaterFactory對象InflaterFactory接口包含了一個onCreateView方法&#xff0c;用于創建View將解析出的Xml信息轉為AttributeSet&#xff0c;交給Inflate…

安全測試之使用Docker搭建SQL注入安全測試平臺sqli-labs

1 搜索鏡像 docker search sqli-labs 2 拉取鏡像 docker pull acgpiano/sqli-labs 3 創建docker容器 docker run -d --name sqli-labs -p 10012:80 acgpiano/sqli-labs 4 訪問測試平臺網站 若直接使用虛擬機&#xff0c;則直接通過ip端口號訪問若通過配置域名&#xff0…

PyQt多線程詳解

PyQt多線程是在PyQt框架中利用多線程技術來提高應用程序的響應性和性能的一種方法。下面將詳細說明PyQt多線程的基本概念、應用場景以及實現方式。 一、PyQt多線程的基本概念 在PyQt中&#xff0c;多線程指的是在單個程序實例內同時運行多個線程。每個線程都可以執行不同的任…

第十五章 Nest Pipe(內置及自定義)

NestJS的Pipe是一個用于數據轉換和驗證的特殊裝飾器。Pipe可以應用于控制器&#xff08;Controller&#xff09;的處理方法&#xff08;Handler&#xff09;和中間件&#xff08;Middleware&#xff09;&#xff0c;用于處理傳入的數據。它可以用來轉換和驗證數據&#xff0c;確…

【Linux進階】文件系統5——ext2文件系統(inode)

1.再談inode (1) 理解inode&#xff0c;要從文件儲存說起。 文件儲存在硬盤上&#xff0c;硬盤的最小存儲單位叫做"扇區"&#xff08;Sector&#xff09;。每個扇區儲存512字節&#xff08;相當于0.5KB&#xff09;。操作系統讀取硬盤的時候&#xff0c;不會一個個…

記錄excel表生成一列按七天一個周期的方法

使用excel生成每七天一個周期的列。如下圖所示&#xff1a; 針對第一列的生成辦法&#xff0c;使用如下函數&#xff1a; TEXT(DATE(2024,1,1)(ROW()-2)*7,"yyyy/m/d")&" - "&TEXT(DATE(2024,1,1)(ROW()-1)*7-1,"yyyy/m/d") 特此記錄。…

charles使用教程

安裝與配置 下載鏈接&#xff1a;https://www.charlesproxy.com/download/ 進行移動端抓包&#xff1a; 電腦端配置&#xff1a; 關閉防火墻 Proxy–>勾選 macOS Proxy Proxy–>Proxy Setting–>填入代理端口8888–>勾選Enable transparent http proxying 安裝c…

俄羅斯方塊的python實現

俄羅斯方塊游戲是一種經典的拼圖游戲&#xff0c;玩家需要將不同形狀的方塊拼接在一起&#xff0c;使得每一行都被完全填滿&#xff0c;從而清除這一行并獲得積分。以下是該游戲的算法描述&#xff1a; 1. 初始化 初始化游戲界面&#xff0c;設置屏幕大小、方塊大小、網格大小…

昇思25天學習打卡營第1天|初識MindSpore

# 打卡 day1 目錄 # 打卡 day1 初識MindSpore 昇思 MindSpore 是什么&#xff1f; 昇思 MindSpore 優勢|特點 昇思 MindSpore 不足 官方生態學習地址 初識MindSpore 昇思 MindSpore 是什么&#xff1f; 昇思MindSpore 是全場景深度學習架構&#xff0c;為開發者提供了全…

女生學計算機好不好?感覺計算機分有點高……?

眾所周知&#xff0c;在國內的高校里&#xff0c;計算機專業的女生是非常少的&#xff0c;很多小班30人左右&#xff0c;但是每個班女生人數只有個位數。這就給很多人一個感覺&#xff0c;是不是女生天生就不適合學這個東西呢&#xff1f;女生是不是也應該放棄呢&#xff1f;當…

ubuntu 進入命令行

在Ubuntu中&#xff0c;有幾種方法可以進入命令行界面&#xff1a; 啟動時選擇命令行模式&#xff1a; 在計算機啟動時&#xff0c;如果安裝了GRUB引導加載器&#xff0c;可以通過GRUB菜單選擇進入命令行模式。這通常涉及到在啟動時按下Shift鍵或其他指定鍵來顯示GRUB菜單&…