基于Pytorch的NLP入門任務思想及代碼實現:判斷文本中是否出現指定字

今天學了第一個基于Pytorch框架的NLP任務:

判斷文本中是否出現指定字

思路:(注意:這是基于字的算法)

任務:判斷文本中是否出現“xyz”,出現其中之一即可

訓練部分:

一,我們要先設計一個模型去訓練數據。
這個Pytorch的模型:
首先通過embedding層:將字符轉化為離散數值(矩陣)
通過線性層:設置網絡的連接層,進行映射
通過dropout層:將一部分輸入設為0(可去掉)
通過激活層:sigmoid激活
通過一個pooling層:降維,將矩陣->向量
通過另一個輸出線性層:使輸出是一維(1或0)
通過一個激活層*:sigmoid激活。

二,設置一個函數:這個函數能將設定的字符變成字符集,將每一個字符設定一個代號,比如說:“我愛你”-> 我:1,愛:2,你:3。當出現"你愛我"時,計算機接受的是:3,2,1。這樣方便計算機處理字符。

三,因為我們沒有訓練樣本和測試樣本,所以我們要自己生成一些隨機樣本。通過random.choice在字符集中隨機順序輸出字符作為輸入,并將輸入中含有"xyz"的樣本的輸出值為“1”,反之為“0”

四,設置一個函數,將隨機得到的樣本,放入數據集中(列表),便于運算。

五,設置測試函數:隨機建立一些樣本,根據樣本的輸出來設定有多少個正樣本,多少個負樣本,再將預測的樣本輸出來與樣本輸出對比,得到正確率。

六,最后的main函數:按照訓練輪數和訓練組數,通過BP反向傳播更新權重進行訓練。然后調取測試函數得到acc等數據。將loss和acc的數值繪制下來,保存模型和詞表。

預測部分

將保存的詞表和模型加載進來,將輸入的字符轉化為列表,然后進入模型forward函數進行預測,最后打印出結果。

代碼實現:

import torch
import torch.nn as nn
import numpy as np
import random
import json
import matplotlib.pyplot as plt"""
基于pytorch的網絡編寫
實現一個網絡完成一個簡單nlp任務
判斷文本中是否有某些特定字符出現
"""class TorchModel(nn.Module):def __init__(self, input_dim, sentence_length, vocab):super(TorchModel, self).__init__()self.embedding = nn.Embedding(len(vocab) + 1, input_dim)    #embedding層:將字符轉化為離散數值self.layer = nn.Linear(input_dim, input_dim)    #對輸入數據做線性變換self.classify = nn.Linear(input_dim, 1)     #映射到一維self.pool = nn.AvgPool1d(sentence_length)   #pooling層:降維self.activation = torch.sigmoid     #sigmoid做激活函數self.dropout = nn.Dropout(0.1)  #一部分輸入為0self.loss = nn.functional.mse_loss  #loss采用均方差損失#當輸入真實標簽,返回loss值;無真實標簽,返回預測值def forward(self, x, y=None):x = self.embedding(x)#輸入維度:(batch_size, sen_len)輸出維度:(batch_size, sen_len, input_dim)將文本->矩陣x = self.layer(x)#輸入維度:(batch_size, sen_len, input_dim)輸出維度:(batch_size, sen_len, input_dim)x = self.dropout(x)#輸入維度:(batch_size, sen_len, input_dim)輸出維度:(batch_size, sen_len, input_dim)x = self.activation(x)#輸入維度:(batch_size, sen_len, input_dim)輸出維度:(batch_size, sen_len, input_dim)x = self.pool(x.transpose(1,2)).squeeze()#輸入維度:(batch_size, sen_len, input_dim)輸出維度:(batch_size, input_dim)將矩陣->向量x = self.classify(x)#輸入維度:(batch_size, input_dim)輸出維度:(batch_size, 1)y_pred = self.activation(x)#輸入維度:(batch_size, 1)輸出維度:(batch_size, 1)if y is not None:return self.loss(y_pred, y)else:return y_pred#字符集隨便挑了一些漢字,實際上還可以擴充
#為每個漢字生成一個標號
#{"不":1, "東":2, "個":3...}
#不東個->[1,2,3]
def build_vocab():chars = "不東個么買五你兒幾發可同名呢方人上額旅法xyz"  #隨便設置一個字符集vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1   #每個字對應一個序號,+1是序號從1開始vocab['unk'] = len(vocab)+1   #不在表中的值設為前一個+1return vocab#隨機生成一個樣本
#從所有字中選取sentence_length個字
#如果vocab中的xyz出現在樣本中,則為正樣本
#反之為負樣本
def build_sample(vocab, sentence_length):#將vacab轉化為字表,隨機從字表選取sentence_length個字,可能重復x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]#指定哪些字必須在正樣本出現if set("xyz") & set(x):     #若xyz與x中的字符相匹配,則為1,為正樣本y = 1else:y = 0x = [vocab.get(word,vocab['unk']) for word in x]   #將字轉換成序號return x, y#建立數據集
#輸入需要的樣本數量。需要多少生成多少
def build_dataset(sample_length,vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append([y])return torch.LongTensor(dataset_x), torch.FloatTensor(dataset_y)#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model#測試代碼
#用來測試每輪模型的準確率
def evaluate(model, vocab, sample_length):model.eval()x, y = build_dataset(200, vocab, sample_length)#建立200個用于測試的樣本(因為測試樣本是隨機生成的,所以不存在過擬合)print("本次預測集中共有%d個正樣本,%d個負樣本"%(sum(y), 200 - sum(y)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)      #調用Pytorch模型預測for y_p, y_t in zip(y_pred, y):  #與真實標簽進行對比if float(y_p) < 0.5 and int(y_t) == 0:correct += 1   #負樣本判斷正確elif float(y_p) >= 0.5 and int(y_t) == 1:correct += 1   #正樣本判斷正確else:wrong += 1print("正確預測個數:%d, 正確率:%f"%(correct, correct/(correct+wrong)))return correct/(correct+wrong)def main():epoch_num = 10        #訓練輪數batch_size = 20       #每次訓練樣本個數train_sample = 1000   #每輪訓練總共訓練的樣本總數char_dim = 20         #每個字的維度sentence_length = 10   #樣本文本長度vocab = build_vocab()       #建立字表model = build_model(vocab, char_dim, sentence_length)    #建立模型optim = torch.optim.Adam(model.parameters(), lr=0.005)   #建立優化器log = []for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #每次訓練構建一組訓練樣本optim.zero_grad()    #梯度歸零loss = model(x, y)   #計算losswatch_loss.append(loss.item())  #將loss存下來,方便畫圖loss.backward()      #計算梯度optim.step()         #更新權重print("=========\n第%d輪平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length)   #測試本輪模型結果log.append([acc, np.mean(watch_loss)])plt.plot(range(len(log)), [l[0] for l in log])  #畫acc曲線:藍色的plt.plot(range(len(log)), [l[1] for l in log])  #畫loss曲線:黃色的plt.show()#保存模型torch.save(model.state_dict(), "model.pth")writer = open("vocab.json", "w", encoding="utf8")#保存詞表writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return#最終預測
def predict(model_path, vocab_path, input_strings):char_dim = 20  # 每個字的維度sentence_length = 10  # 樣本文本長度vocab = json.load(open(vocab_path, "r", encoding="utf8"))model = build_model(vocab, char_dim, sentence_length)    #建立模型model.load_state_dict(torch.load(model_path))   #將模型文件加載進來x = []for input_string in input_strings:  #轉化輸入x.append([vocab[char] for char in input_string])model.eval()    #在torch中預測注意這個:停止dropoutwith torch.no_grad():   #在torch中預測注意這個:停止梯度result = model.forward(torch.LongTensor(x)) #根據自己設計的函數定義,只輸入x就會輸出預測值for i, input_string in enumerate(input_strings):print(round(float(result[i])), input_string, result[i])#round(float(result))是將預測結果四舍五入得到0或1的預測值if __name__ == "__main__":main()#如果是進行預測,將下面兩行解除注釋,將main()注釋掉,即可調用最終預測函數進行預測# test_strings = ["個么買不東五你兒x發", "不東東么兒幾買五你發", "不東個么買五你個么買", "不z個五么買你兒幾發"]# predict("model.pth", "vocab.json", test_strings)

運行結果展示:

訓練部分:

在這里插入圖片描述
(藍色是acc曲線,黃色是loss曲線)

預測部分:

在這里插入圖片描述

一些補充:

一:model.eval()或者model.train()的作用

如果模型中有BN層(Batch Normalization)和Dropout,需要在訓練時添加model.train(),在測試時添加model.eval()。其中model.train()是保證BN層用每一批數據的均值和方差,而model.eval()是保證BN用全部訓練數據的均值和方差;而對于Dropout,model.train()是隨機取一部分網絡連接來訓練更新參數,而model.eval()是利用到了所有網絡連接。

二:Pytorch模型中用了兩次激活函數

在每一個網絡層后使用一個激活層是一種比較常見的模型搭建方式,但不是必要的。這個只是舉例,去掉也是可行的。在具體任務中,帶著好還是不帶好也跟數據和任務本身有關,沒有確定答案(如果在代碼 中把第一個激活層注釋掉 反而性能更好)

三:對x = self.pool(x.transpose(1,2)).squeeze()代碼的解讀

通過shape方法我們能知道,在pool前,x的維度輸出是[20,10,20],代表20個10×20的矩陣,代表著[這一批的個數,樣本文本長度,輸入維度],transpose(1,2)是將x中行和列調換(轉置),然后通過pooling層將[20,20,10]->[20,20,1],最后通過squeeze()進行降維變成[20,20]。
(池化層的作用及理解)

四:embedding層的理解

embedding層并不是單純的單詞映射,而是將單詞表中每個單詞的數值與權重相乘。在第一次時有默認權重,然后在接下來的訓練中,embedding層的權重與分類權重一起經過訓練。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389571.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389571.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389571.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

erlang下lists模塊sort(排序)方法源碼解析(二)

上接erlang下lists模塊sort&#xff08;排序&#xff09;方法源碼解析(一)&#xff0c;到目前為止&#xff0c;list列表已經被分割成N個列表&#xff0c;而且每個列表的元素是有序的&#xff08;從大到小&#xff09; 下面我們重點來看看mergel和rmergel模塊&#xff0c;因為我…

洛谷P4841 城市規劃(多項式求逆)

傳送門 這題太珂怕了……如果是我的話完全想不出來…… 題解 1 //minamoto2 #include<iostream>3 #include<cstdio>4 #include<algorithm>5 #define ll long long6 #define swap(x,y) (x^y,y^x,x^y)7 #define mul(x,y) (1ll*(x)*(y)%P)8 #define add(x,y) (x…

支撐阻力指標_使用k表示聚類以創建支撐和阻力

支撐阻力指標Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seek…

高版本(3.9版本)python在anaconda安裝opencv庫及skimage庫(scikit_image庫)諸多問題解決辦法

今天開始CV方向的學習&#xff0c;然而剛拿到基礎代碼的時候發現 from skimage.color import rgb2gray 和 import cv2標紅&#xff08;這里是因為我已經配置成功了&#xff0c;所以沒有紅標&#xff09;&#xff0c;我以為是單純兩個庫沒有下載&#xff0c;去pycharm中下載ski…

python 實現斐波那契數列

# coding:utf8 __author__ blueslidef fun(arg1,arg2,stop):if arg10:print(arg1,arg2)arg3 arg1arg2print(arg3)if arg3<stop:arg3 fun(arg2,arg3,stop)fun(0,1,100)轉載于:https://www.cnblogs.com/bluesl/p/9079705.html

單機安裝ZooKeeper

2019獨角獸企業重金招聘Python工程師標準>>> zookeeper下載、安裝以及配置環境變量 本節介紹單機的zookeeper安裝&#xff0c;官方下載地址如下&#xff1a; https://archive.apache.org/dist/zookeeper/ 我這里使用的是3.4.11版本&#xff0c;所以找到相應的版本點…

均線交易策略的回測 r_使用r創建交易策略并進行回測

均線交易策略的回測 rR Programming language is an open-source software developed by statisticians and it is widely used among Data Miners for developing Data Analysis. R can be best programmed and developed in RStudio which is an IDE (Integrated Development…

opencv入門課程:彩色圖像灰度化和二值化(采用skimage庫和opencv庫兩種方法)

用最簡單的辦法實現彩色圖像灰度化和二值化&#xff1a; 首先采用skimage庫&#xff08;skimage庫現在在scikit_image庫中&#xff09;實現&#xff1a; from skimage.color import rgb2gray import numpy as np import matplotlib.pyplot as plt""" skimage庫…

SVN中Revert changes from this revision 跟Revert to this revision

譬如有個文件&#xff0c;有十個版本&#xff0c;假定版本號是1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;5&#xff0c;6&#xff0c;7&#xff0c;8&#xff0c;9&#xff0c;10。Revert to this revision&#xff1a; 如果是在版本6這里點擊“Revert to this rev…

歸 [拾葉集]

歸 心歸故鄉 想象行走在 鄉間恬靜小路上 讓那些疲憊的夢 都隨風飛散吧&#xff01; 不去想那些世俗 人來人往 熙熙攘攘 秋日午后 陽光下 細數落葉 來日方長 世上的路 有詩人、浪子 歌詠吟唱 世上的人 在欲望、信仰中 彷徨 彷徨又迷茫 親愛的人兒 快結束那 無休止的獨自流浪 莫要…

instagram分析以預測與安的限量版運動鞋轉售價格

Being a sneakerhead is a culture on its own and has its own industry. Every month Biggest brands introduce few select Limited Edition Sneakers which are sold in the markets according to Lottery System called ‘Raffle’. Which have created a new market of i…

opencv:用最鄰近插值和雙線性插值法實現上采樣(放大圖像)與下采樣(縮小圖像)

上采樣與下采樣 概念&#xff1a; 上采樣&#xff1a; 放大圖像&#xff08;或稱為上采樣&#xff08;upsampling&#xff09;或圖像插值&#xff08;interpolating&#xff09;&#xff09;的主要目的 是放大原圖像,從而可以顯示在更高分辨率的顯示設備上。 下采樣&#xff…

CSS魔法堂:那個被我們忽略的outline

前言 在CSS魔法堂&#xff1a;改變單選框顏色就這么吹毛求疵&#xff01;中我們要模擬原生單選框通過Tab鍵獲得焦點的效果&#xff0c;這里涉及到一個常常被忽略的屬性——outline&#xff0c;由于之前對其印象確實有些模糊&#xff0c;于是本文打算對其進行稍微深入的研究^_^ …

初創公司怎么做銷售數據分析_初創公司與Faang公司的數據科學

初創公司怎么做銷售數據分析介紹 (Introduction) In an increasingly technological world, data scientist and analyst roles have emerged, with responsibilities ranging from optimizing Yelp ratings to filtering Amazon recommendations and designing Facebook featu…

opencv:灰色和彩色圖像的像素直方圖及直方圖均值化的實現與展示

直方圖及直方圖均值化的理論&#xff0c;實現及展示 直方圖&#xff1a; 首先&#xff0c;我們來看看什么是直方圖&#xff1a; 理論概念&#xff1a; 在圖像處理中&#xff0c;經常用到直方圖&#xff0c;如顏色直方圖、灰度直方圖等。 圖像的灰度直方圖就描述了圖像中灰度分…

mysql.sock問題

Cant connect to local MySQL server through socket /tmp/mysql.sock 上述提示可能在啟動mysql時遇到&#xff0c;即在/tmp/mysql.sock位置找不到所需要的mysql.sock文件&#xff0c;主要是由于my.cnf文件里對mysql.sock的位置設定導致。 mysql.sock默認的是在/var/lib/mysql,…

交換機的基本原理配置(一)

1、配置主機名 在全局模式下輸入hostname 名字 然后回車即可立馬生效&#xff08;在生產環境交換機必須有自己唯一的名字&#xff09; Switch(config)#hostname jsh-sw1jsh-sw1(config)#2、顯示系統OS名稱及版本信息 特權模式下&#xff0c;輸入命令 show version Switch#show …

opencv:卷積涉及的基礎概念,Sobel邊緣檢測代碼實現及Same(相同)填充與Vaild(有效)填充

濾波 線性濾波可以說是圖像處理最基本的方法&#xff0c;它可以允許我們對圖像進行處理&#xff0c;產生很多不同的效果。 卷積 卷積的概念&#xff1a; 卷積的原理與濾波類似。但是卷積卻有著細小的差別。 卷積操作也是卷積核與圖像對應位置的乘積和。但是卷積操作在做乘…

機器學習股票_使用概率機器學習來改善您的股票交易

機器學習股票Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seek…

BZOJ 2818 Gcd

傳送門 題解&#xff1a;設p為素數 &#xff0c;則gcd(x/p,y/p)1也就是說求 x&#xff0f;p以及 y&#xff0f;p的歐拉函數。歐拉篩前綴和就可以解決 #include <iostream> #include <cstdio> #include <cmath> #include <algorithm> #include <map&…