昇思訓練營打卡第二十五天(RNN實現情感分類)

RNN,即循環神經網絡(Recurrent Neural Network),是一種深度學習模型,特別適用于處理序列數據。以下是對RNN的簡要介紹:

RNN的特點:

  1. 記憶性:與傳統的前饋神經網絡不同,RNN具有內部狀態(記憶),可以捕獲到目前為止觀察到的序列信息。
  2. 參數共享:在處理序列的不同時間步時,RNN使用相同的權重,這意味著模型的參數數量不會隨著輸入序列長度的增加而增加。
  3. 靈活性:RNN能夠處理任意長度的輸入序列。

RNN的結構:

  • 輸入層:接收序列中的單個元素。
  • 隱藏層:包含循環單元,這些單元具有記憶功能,能夠存儲之前的信息。
  • 輸出層:根據當前輸入和隱藏層的狀態輸出結果。

RNN的類型:

  1. 簡單RNN:基礎模型,但容易受到梯度消失和梯度爆炸問題的影響。
  2. LSTM(長短期記憶網絡):通過引入門控機制,解決了簡單RNN的長期依賴問題。
  3. GRU(門控循環單元):LSTM的變體,結構更簡單,但性能相似。

應用場景:

  • 自然語言處理:如語言模型、機器翻譯、文本生成等。
  • 語音識別:將語音信號轉換為文本。
  • 時間序列預測:如股票價格預測、天氣預報等。

數據下載模塊

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path# 指定保存路徑為 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'def http_get(url: str, temp_file: IO):"""使用requests庫下載數據,并使用tqdm庫進行流程可視化"""req = requests.get(url, stream=True)content_length = req.headers.get('Content-Length')total = int(content_length) if content_length is not None else Noneprogress = tqdm(unit='B', total=total)for chunk in req.iter_content(chunk_size=1024):if chunk:progress.update(len(chunk))temp_file.write(chunk)progress.close()def download(file_name: str, url: str):"""下載數據并存為指定名稱"""if not os.path.exists(cache_dir):os.makedirs(cache_dir)cache_path = os.path.join(cache_dir, file_name)cache_exist = os.path.exists(cache_path)if not cache_exist:with tempfile.NamedTemporaryFile() as temp_file:http_get(url, temp_file)temp_file.flush()temp_file.seek(0)with open(cache_path, 'wb') as cache_file:shutil.copyfileobj(temp_file, cache_file)return cache_pathimdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path

加載IMDB數據集

import re
import six
import string
import tarfileclass IMDBData():"""IMDB數據集加載器加載IMDB數據集并處理為一個Python迭代對象。"""label_map = {"pos": 1,"neg": 0}def __init__(self, path, mode="train"):self.mode = modeself.path = pathself.docs, self.labels = [], []self._load("pos")self._load("neg")def _load(self, label):pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))# 將數據加載至內存with tarfile.open(self.path) as tarf:tf = tarf.next()while tf is not None:if bool(pattern.match(tf.name)):# 對文本進行分詞、去除標點和特殊字符、小寫處理self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(None, six.b(string.punctuation)).lower()).split())self.labels.append([self.label_map[label]])tf = tarf.next()def __getitem__(self, idx):return self.docs[idx], self.labels[idx]def __len__(self):return len(self.docs)
imdb_train = IMDBData(imdb_path, 'train')
len(imdb_train)
import mindspore.dataset as dsdef load_imdb(imdb_path):imdb_train = ds.GeneratorDataset(IMDBData(imdb_path, "train"), column_names=["text", "label"], shuffle=True, num_samples=10000)imdb_test = ds.GeneratorDataset(IMDBData(imdb_path, "test"), column_names=["text", "label"], shuffle=False)return imdb_train, imdb_test
imdb_train, imdb_test = load_imdb(imdb_path)
imdb_train

加載預訓練詞向量

預訓練詞向量是對輸入單詞的數值化表示,通過nn.Embedding層,采用查表的方式,輸入單詞對應詞表中的index,獲得對應的表達向量。 因此進行模型構造前,需要將Embedding層所需的詞向量和詞表進行構造。

import zipfile
import numpy as npdef load_glove(glove_path):glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')if not os.path.exists(glove_100d_path):glove_zip = zipfile.ZipFile(glove_path)glove_zip.extractall(cache_dir)embeddings = []tokens = []with open(glove_100d_path, encoding='utf-8') as gf:for glove in gf:word, embedding = glove.split(maxsplit=1)tokens.append(word)embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))# 添加 <unk>, <pad> 兩個特殊占位符對應的embeddingembeddings.append(np.random.rand(100))embeddings.append(np.zeros((100,), np.float32))vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)embeddings = np.array(embeddings).astype(np.float32)return vocab, embeddings
glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
len(vocab.vocab())
idx = vocab.tokens_to_ids('the')
embedding = embeddings[idx]
idx, embedding

數據集預處理

通過加載器加載的IMDB數據集進行了分詞處理,但不滿足構造訓練數據的需要,因此要對其進行額外的預處理。其中包含的預處理如下:

  • 通過Vocab將所有的Token處理為index id。
  • 將文本序列統一長度,不足的使用<pad>補齊,超出的進行截斷。
  • import mindspore as mslookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
    pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
    type_cast_op = ds.transforms.TypeCast(ms.float32)
    imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])
    imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
    imdb_train = imdb_train.batch(64, drop_remainder=True)
    imdb_valid = imdb_valid.batch(64, drop_remainder=True)

    Embedding

    Embedding層又可稱為EmbeddingLookup層,其作用是使用index id對權重矩陣對應id的向量進行查找,當輸入為一個由index id組成的序列時,則查找并返回一個相同長度的矩陣

  • RNN(循環神經網絡)

    循環神經網絡(Recurrent Neural Network, RNN)是一類以序列(sequence)數據為輸入,在序列的演進方向進行遞歸(recursion)且所有節點(循環單元)按鏈式連接的神經網絡。

Dense

在經過LSTM編碼獲取句子特征后,將其送入一個全連接層,即nn.Dense,將特征維度變換為二分類所需的維度1,經過Dense層后的輸出即為模型預測結果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output
hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total/step_total)t.update(1)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/45596.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/45596.shtml
英文地址,請注明出處:http://en.pswp.cn/web/45596.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

老板新招的牛人,竟然用1天搭建了一套完整的倉庫管理系統!

倉儲管理系統是什么&#xff1f; 倉儲管理系統&#xff08;WMS&#xff09;是一個全面的軟件解決方案&#xff0c;旨在幫助企業優化倉庫管理流程、管理和控制日常倉庫運營。通過數學模型和信息手段&#xff0c;對倉庫管理的各個環節進行優化和調控&#xff0c;涵蓋了從貨物入庫…

使用網關和Spring Security進行認證和授權

個人名片 &#x1f393;作者簡介&#xff1a;java領域優質創作者 &#x1f310;個人主頁&#xff1a;碼農阿豪 &#x1f4de;工作室&#xff1a;新空間代碼工作室&#xff08;提供各種軟件服務&#xff09; &#x1f48c;個人郵箱&#xff1a;[2435024119qq.com] &#x1f4f1…

jquery發送jsonp請求

使用 jQuery 發送 JSONP 請求相對來說比較簡單&#xff0c;以下是示例代碼&#xff1a; $.ajax({url: "http://example.com/data",dataType: "jsonp",jsonp: "callback",jsonpCallback: "myCallback" }).done(function(response) {//…

Linux命令更新-sort 和 uniq 命令

簡介 sort 和 uniq 都是 Linux 系統中常用的文本處理命令。 sort 命令用于對文件內容進行排序。 uniq 命令用于去除文件中重復出現的行。 1. sort 命令 命令格式 sort [選項] [文件]選項&#xff1a; -n: 按照數字進行排序 -r: 反向排序 -c: 統計每個元素出現的次數 -…

怎么錄制視頻?電腦錄制,試試這3種方法

在數字化快速發展的時代&#xff0c;視頻已經成為我們傳遞信息、分享生活、表達情感的重要載體。每一個人都希望自己能夠掌握視頻錄制技巧&#xff0c;輕松駕馭影像的力量&#xff0c;創造出屬于自己的視覺盛宴。 那么&#xff0c;怎么錄制視頻呢&#xff1f;首先選擇一款好用…

vue腳手架配置代理請求

在 Vue 腳手架中&#xff0c;可以通過配置vue.config.js文件來設置代理請求&#xff0c;以解決跨域問題或實現其他代理需求。以下是兩種常見的配置方式&#xff1a; 方法一&#xff1a; 在vue.config.js中添加如下配置&#xff1a; module.exports {devServer: {proxy: http…

《信息與電腦(理論版)》是什么級別的期刊?是正規期刊嗎?能評職稱嗎?

問題解答 問&#xff1a;《信息與電腦(理論版)》是不是核心期刊&#xff1f; 答&#xff1a;不是&#xff0c;是知網收錄的正規學術期刊。 問&#xff1a;《信息與電腦(理論版)》級別&#xff1f; 答&#xff1a;省級。主管單位&#xff1a;北京電子控股有限責任公司 主辦…

AI安全入門-人工智能數據與模型安全

參考 人工智能數據與模型安全 from 復旦大學視覺與學習實驗室 文章目錄 0. 計算機安全學術知名公眾號1. 概述數據安全模型安全 3. 人工智能安全基礎3.1 基本概念攻擊者攻擊方法受害者受害數據受害模型防御者防御方法威脅模型目標數據替代數據替代模型 3.2 威脅模型3.2.1 白盒威…

實踐致知第16享:設置Word中某一頁橫著的效果及操作

一、背景需求 小姑電話說&#xff1a;現在有個word文檔,里面有個表格太長&#xff08;如下圖所示&#xff09;&#xff0c;希望這一個設置成橫的&#xff0c;其余頁還是保持豎的&#xff01; 二、解決方案 1、將鼠標放置在該頁的最前面閃爍&#xff0c;然后選擇“頁面”》“↘…

Python面經

文章目錄 Python基本概念1. Python是**解釋型**語言還是**編譯型**語言2. Python是**面向對象**語言還是面向過程語言3. Python基本數據類型4.append和 extend區別5.del、pop和remove區別6. sort和sorted區別介紹一下Python 中的字符串編碼is 和 的區別*arg 和**kwarg作用淺拷…

Electron 進程間通信

文章目錄 渲染進程到主進程&#xff08;單向&#xff09;渲染進程到主進程&#xff08;雙向&#xff09;主進程到渲染進程 &#xff08;單向&#xff0c;可模擬雙向&#xff09; 渲染進程到主進程&#xff08;單向&#xff09; send &#xff08;render 發送&#xff09;on &a…

【Stable Diffusion】(基礎篇三)—— 圖生圖基礎

圖生圖基礎 本系列筆記主要參考B站nenly同學的視頻教程&#xff0c;傳送門&#xff1a;B站第一套系統的AI繪畫課&#xff01;零基礎學會Stable Diffusion&#xff0c;這絕對是你看過的最容易上手的AI繪畫教程 | SD WebUI 保姆級攻略_嗶哩嗶哩_bilibili 本文主要講解如何使用S…

客戶端與服務端之間的通信連接

目錄 那什么是Socket? 什么是ServerSocket? 代碼展示&#xff1a; 代碼解析&#xff1a; 補充&#xff1a; 輸入流&#xff08;InputStream&#xff09;&#xff1a; 輸出流&#xff08;OutputStream&#xff09;&#xff1a; BufferedReader 是如何提高讀取效率的&a…

K8s集群初始化遇到的問題

kubectl describe pod coredns-545d6fc579-s9g5s -n kube-system 找到原因1&#xff1a;CoreDNS Pod 處于 Pending 狀態的原因是集群中的節點都帶有 node.kubernetes.io/not-ready 污點 journalctl -u kubelet -f 14:57:59.178592 3553 remote_image.go:114] "PullIma…

《簡歷寶典》12 - 簡歷中“項目經歷”,內功學習 - 下篇

這一小節呢&#xff0c;我們繼續說簡歷中 “項目經歷” 的一些內功心法。因為項目經歷比較核心&#xff0c;所以說完了&#xff0c;內功呢&#xff0c;我們會著重說一下 實戰部分。 目錄 1 所用技術的考慮 2 自我成長的突出 3 綜合使用STAR法則 4 小節 1 所用技術的考慮 …

如何評估AI模型:評估指標的分類、方法及案例解析

如何評估AI模型&#xff1a;評估指標的分類、方法及案例解析 引言第一部分&#xff1a;評估指標的分類第二部分&#xff1a;評估指標的數學基礎第三部分&#xff1a;評估指標的選擇與應用第四部分&#xff1a;評估指標的局限性第五部分&#xff1a;案例研究第六部分&#xff1a…

pear-admin-fast項目修改為集成PostgreSQL啟動

全局搜索代碼中的sysdate()&#xff0c;修改為now() 【前者是mysql特有的&#xff0c;后者是postgre特有的】修改application-dev.yml中的數據庫url使用DBeaver把mysql中的數據庫表導出csv&#xff0c;再從postgre中導入csv腳本轉換后出現了bpchar(xx)類型&#xff0c;那么一定…

用友U8存貨分類按層級取數SQL語句

SELECT cInvCCode 分類編碼, cInvCName 分類名稱, iInvCGrade 分類層級, ss.bInvCEnd 是否是末級, aa.* FROM InventoryClass ss LEFT JOIN ( SELECT * FROM ( SELECT cInvCCode AS 一級分類編碼, …

python數據可視化(6)——繪制散點圖

課程學習來源&#xff1a;b站up&#xff1a;【螞蟻學python】 【課程鏈接&#xff1a;【【數據可視化】Python數據圖表可視化入門到實戰】】 【課程資料鏈接&#xff1a;【鏈接】】 Python繪制散點圖查看BMI與保險費的關系 散點圖: 用兩組數據構成多個坐標點&#xff0c;考察…

如何降低老年人患帕金森病的風險?

降低老年人患帕金森病風險的方法 避免接觸有害物質&#xff1a;長期接觸某些化學物質、農藥或其他有害物質可能會增加患帕金森病的風險。應減少這些物質的暴露&#xff0c;例如在工作或生活中采取防護措施。 健康飲食&#xff1a;均衡飲食&#xff0c;多吃富含抗氧化劑的食物&a…