RNN循環神經網絡python實現

import collections
import math
import re
import random
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2ldef read_txt():# 讀取文本數據with open('./A Study in Drowning.txt', 'r', encoding='utf-8') as f:# 讀取每一行lines = f.readlines()# 將不是英文字符的轉換為空格,全部變為小寫字符返回return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):# 將文本以空格進行分割成詞if token == 'word':return [line.split() for line in lines]# 將文本分割成字符elif token == 'char':return [list(line) for line in lines]else:print('錯誤:未知令牌類型:' + token)def count_corpus(tokens):# 如果tokens長度為0或tokens[0]是listif len(tokens) == 0 or isinstance(tokens[0], list):# 將[[,,,],[,,,]]多層結構變為一層結構[,,,,]tokens = [token for line in tokens for token in line]# 統計可迭代對象中元素出現的次數,并返回一個字典(key-value)key 表示元素,value 表示各元素 key 出現的次數return collections.Counter(tokens)# idx_to_token  是一個list  由token作為元素構成['<unk>', ' ', 'e', 't', 'a', 'o', 'h', 'n', 'i', 's', 'r', 'd', 'l', 'u', 'f', 'w', 'g', 'm', 'y', 'c', 'p', 'b', 'k', 'v', 'j', 'x', 'z', 'q']
# token_freqs   是一個list  由token和該token出現的次數構成的元組作為元素構成[(' ', 94824), ('e', 54804), ('t', 38742), ('a', 33172), ('o', 30656), ('h', 29047), ('n', 28667), ('i', 28093), ('s', 27922), ('r', 26121), ('d', 20394), ('l', 17755), ('u', 12267), ('f', 11033), ('w', 10033), ('g', 9837), ('m', 9258), ('y', 9251), ('c', 8872), ('p', 6998), ('b', 6620), ('k', 4817), ('v', 3574), ('j', 500), ('x', 372), ('z', 308), ('q', 285)]
# token_to_idx  是一個dict  由token作為key token在idx_to_token的索引作為value構成{' ': 1, '<unk>': 0, 'a': 4, 'b': 21, 'c': 19, 'd': 11, 'e': 2, 'f': 14, 'g': 16, 'h': 6, 'i': 8, 'j': 24, 'k': 22, 'l': 12, 'm': 17, 'n': 7, 'o': 5, 'p': 20, 'q': 27, 'r': 10, 's': 9, 't': 3, 'u': 13, 'v': 23, 'w': 15, 'x': 25, 'y': 18, 'z': 26}
class Vocab:def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):# 處理特殊情況if tokens is None:tokens = []# 處理特殊情況if reserved_tokens is None:reserved_tokens = []# counter為一個字典(key-value)key 表示元素,value 表示各元素 key 出現的次數counter = count_corpus(tokens)# 排序# iterable:待排序的序列counter.items()# key:排序規則lambda x: x[1]從小到大# reverse:指定排序的方式,默認值False,即升序排列,這是True也就是降序self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化self.unk, uniq_tokens = 0, ['<unk>'] + reserved_tokens# 初始化  token_freqs中 key不在 uniq_tokens中  且  value大于min_freq 返回token放入uniq_tokensuniq_tokens += [token for token, freq in self.token_freqsif freq >= min_freq and token not in uniq_tokens]# 初始化self.idx_to_token, self.token_to_idx = [], dict()# 賦值for token in uniq_tokens:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self.__getitem__(token) for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]def load_corpus_time_machine(max_tokens=-1):# 將文本處理成行lines = read_txt()# print(lines)# 將行tokens化tokens = tokenize(lines, 'char')# print(tokens)# 構建字典表vocab = Vocab(tokens)# vocab的格式為{list:524222}[5, 7, 2, 5, 7, 2, 8, 3, ......, 1, 18, 5, 13]#print(vocab)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus[:max_tokens]return corpus, vocab# 隨機地生成一個小批量數據的特征和標簽以供讀取。 在隨機采樣中,每個樣本都是在原始的長序列上任意捕獲的子序列
def seq_data_iter_random(corpus, batch_size, num_steps):"""使用隨機抽樣生成一個小批量子序列。"""corpus = corpus[random.randint(0, num_steps - 1):]num_subseqs = (len(corpus) - 1) // num_stepsinitial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)def data(pos):return corpus[pos:pos + num_steps]num_batches = num_subseqs // batch_sizefor i in range(0, batch_size * num_batches, batch_size):initial_indices_per_batch = initial_indices[i:i + batch_size]X = [data(j) for j in initial_indices_per_batch]Y = [data(j + 1) for j in initial_indices_per_batch]yield torch.tensor(X), torch.tensor(Y)# 保證兩個相鄰的小批量中的子序列在原始序列上也是相鄰的
def seq_data_iter_sequential(corpus, batch_size, num_steps):"""使用順序分區生成一個小批量子序列。"""offset = random.randint(0, num_steps)num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_sizeXs = torch.tensor(corpus[offset:offset + num_tokens])Ys = torch.tensor(corpus[offset + 1:offset + 1 + num_tokens])Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)num_batches = Xs.shape[1] // num_stepsfor i in range(0, num_steps * num_batches, num_steps):X = Xs[:, i:i + num_steps]Y = Ys[:, i:i + num_steps]yield X, Yclass SeqDataLoader:"""加載序列數據的迭代器。"""def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):if use_random_iter:self.data_iter_fn = seq_data_iter_randomelse:self.data_iter_fn = seq_data_iter_sequentialself.corpus, self.vocab = load_corpus_time_machine(max_tokens)self.batch_size, self.num_steps = batch_size, num_stepsdef __iter__(self):return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)def load_data_time_machine(batch_size, num_steps,use_random_iter=False, max_tokens=10000):"""返回時光機器數據集的迭代器和詞匯表。"""data_iter = SeqDataLoader(batch_size, num_steps, use_random_iter,max_tokens)return data_iter, data_iter.vocab# 初始化模型參數
def get_params(vocab_size, num_hiddens, device):# 輸入等于輸出等于字典大小num_inputs = num_outputs = vocab_size# 均值為0方差為1的隨機張量*0.01def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 輸入到隱藏層邊緣的WW_xh = normal((num_inputs, num_hiddens))# 隱藏層的WW_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 隱藏層到輸出的WW_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params# 初始化隱藏狀態
def init_rnn_state(batch_size, num_hiddens, device):# 批量大小,隱藏層大小的全0張量return (torch.zeros((batch_size, num_hiddens), device=device),)# 計算輸出
def rnn(inputs, state, params):W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:# 激活函數是tanh H為初始化隱藏狀態H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)# H為當前隱藏狀態return torch.cat(outputs, dim=0), (H,)class RNNModelScratch:"""從零開始實現的循環神經網絡模型"""def __init__(self, vocab_size, num_hiddens, device, get_params,init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):X = F.one_hot(X.T, self.vocab_size).type(torch.float32)return self.forward_fn(X, state, self.params)def begin_state(self, batch_size, device):return self.init_state(batch_size, self.num_hiddens, device)# 推理測試
def predict_ch8(prefix, num_preds, net, vocab, device):"""在`prefix`后面生成新字符。"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))for y in prefix[1:]:_, state = net(get_input(), state)outputs.append(vocab[y])for _ in range(num_preds):y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])# 梯度剪裁
def grad_clipping(net, theta):"""裁剪梯度。"""if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm# 訓練函數
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""訓練模型一個迭代周期(定義見第8章)。"""state = Nonemetric = d2l.Accumulator(2)for X, Y in train_iter:if state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(net, nn.Module) and not isinstance(state, tuple):state.detach_()else:for s in state:s.detach_()y = Y.T.reshape(-1)X, y = X.to(device), y.to(device)y_hat, state = net(X, state)l = loss(y_hat, y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()else:l.backward()grad_clipping(net, 1)updater(batch_size=1)metric.add(l * y.numel(), y.numel())return math.exp(metric[0] / metric[1])def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""訓練模型(定義見第8章)。"""loss = nn.CrossEntropyLoss()if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(), lr)else:updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)for epoch in range(num_epochs):ppl = train_epoch_ch8(net, train_iter, loss, updater, device,use_random_iter)if (epoch + 1) % 10 == 0:print(predict('But'))print(f'困惑度 {ppl:.1f},  {str(device)}')print(predict('But'))# 批量大小為32  時序序列的長度為35 隱藏層大小512
batch_size, num_steps, num_hiddens = 32, 35, 512
# 獲取迭代數據和字典
train_iter, vocab = load_data_time_machine(batch_size, num_steps)
# 定義網絡
net = RNNModelScratch(len(vocab), num_hiddens, torch.device('cpu'), get_params,init_rnn_state, rnn)
# 訓練500輪 學習率為1
num_epochs, lr = 50, 1
# 訓練
train_ch8(net, train_iter, vocab, lr, num_epochs, torch.device('cpu'),use_random_iter=True)

訓練結果

<unk>ut the the the the the the the the the the the the t
<unk>uthe the the the the the the the the the the the the
<unk>uthe sher and the sher and the sher and the sher and
<unk>uthe sher and the sher and the sher and the sher and
<unk>uthe sher and he her sher and her sher and her sher 
困惑度 8.8,  cpu
<unk>uthe sher and he her sher and her sher and her sher Process finished with exit code 0

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/214194.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/214194.shtml
英文地址,請注明出處:http://en.pswp.cn/news/214194.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

軟件測試之缺陷管理

一、軟件缺陷的基本概念 1、軟件缺陷的基本概念主要分為&#xff1a;缺陷、故障、失效這三種。 &#xff08;1&#xff09;缺陷&#xff08;defect&#xff09;&#xff1a;存在于軟件之中的偏差&#xff0c;可被激活&#xff0c;以靜態的形式存在于軟件內部&#xff0c;相當…

【隱馬爾可夫模型】隱馬爾可夫模型的觀測序列概率計算算法及例題詳解

【隱馬爾可夫模型】用前向算法計算觀測序列概率P&#xff08;O&#xff5c;λ&#xff09;??????? 【隱馬爾可夫模型】用后向算法計算觀測序列概率P&#xff08;O&#xff5c;λ&#xff09; 隱馬爾可夫模型是關于時序的概率模型&#xff0c;描述由一個隱藏的馬爾可夫鏈…

Elbie勒索病毒:最新變種.elbie襲擊了您的計算機?

引言&#xff1a; 在數字時代&#xff0c;.Elbie勒索病毒的威脅越發突出&#xff0c;對個人和組織的數據安全構成了巨大挑戰。本文將深入介紹.Elbie勒索病毒的特征&#xff0c;有效的數據恢復方法&#xff0c;以及一系列預防措施&#xff0c;幫助您更好地保護數字資產。當面對…

線性規劃-單純形法推導

這里寫目錄標題 線性規劃例子啤酒廠問題圖解法 單純形法數學推導將問題標準化并轉為矩陣形式開始推導 實例圖解法單純形法 線性規劃例子 啤酒廠問題 每日銷售上限&#xff1a;100箱啤酒營業時間&#xff1a;14小時生產1箱生啤需1小時生產1箱黑啤需2小時生啤售價&#xff1a;2…

從零開發短視頻電商 AWS OpenSearch Service開發環境申請以及Java客戶端介紹

文章目錄 創建域1.創建域2.輸入配置部署選項數據節點網絡精細訪問控制訪問策略 獲取域端點數據如何插入到OpenSearch ServiceJava連接OpenSearch Servicespring-data-opensearchelasticsearch-rest-high-level-clientopensearch-rest-clientopensearch-java 因為是開發測試使用…

[Linux] nginx的location和rewrite

一、Nginx常用的正則表達式 符號作用^匹配輸入字符串的起始位置$ 匹配輸入字符串的結束位置*匹配前面的字符零次或多次。如“ol*”能匹配“o”及“ol”、“oll” 匹配前面的字符一次或多次。如“ol”能匹配“ol”及“oll”、“olll”&#xff0c;但不能匹配“o”?匹配前面的字…

Vue3 setup 頁面跳轉監聽路由變化調整頁面訪問位置

頁面跳轉后頁面還是停留在上一個頁面的位置&#xff0c;沒有回到頂部 解決 1、router中路由守衛中統一添加 router.beforeEach(async (to, from, next) > {window.scrollTo(0, 0);next(); }); 2、頁面中監聽頁面變化 <script setup> import { ref, onMounted, wat…

@Autowired 找不到Bean的問題

排查思路 檢查包掃描&#xff1a;查詢的Bean是否被spring掃描裝配到檢查該Bean上是否配上注解&#xff08;Service/Component/Repository…&#xff09;如果使用第三方&#xff0c;檢查相關依賴是否已經安裝到當前項目 Autowired和Resource的區別 Autowired 是spring提供的注…

圖像清晰度 和像素、分辨率、鏡頭的關系

關于圖像清晰度的幾個知識點分享。 知識點 清晰度 清晰度指影像上各細部影紋及其邊界的清晰程度。清晰度&#xff0c;一般是從錄像機角度出發&#xff0c;通過看重放圖像的清晰程度來比較圖像質量&#xff0c;所以常用清晰度一詞。 而攝像機一般使用分解力一詞來衡量它“分解被…

linux通過命令切換用戶

在Linux中&#xff0c;你可以使用su&#xff08;substitute user或switch user&#xff09;命令來切換用戶。這個命令允許你臨時或永久地以另一個用戶的身份運行命令。以下是基本的用法&#xff1a; 基本切換到另一個用戶&#xff08;需要密碼&#xff09;&#xff1a;su [用戶…

APIFox:打造高效便捷的API管理工具

隨著互聯網技術的不斷發展&#xff0c;API&#xff08;應用程序接口&#xff09;已經成為了企業間數據交互的重要方式。然而&#xff0c;API的管理和維護卻成為了開發者們面臨的一大挑戰。為了解決這一問題&#xff0c;APIFox應運而生&#xff0c;它是一款專為API管理而生的工具…

【力扣100】189.輪轉數組

添加鏈接描述 class Solution:def rotate(self, nums: List[int], k: int) -> None:"""Do not return anything, modify nums in-place instead."""# 思路&#xff1a;三次數組翻轉nlen(nums)kk%nnums[:] nums[-k:] nums[:-k]思路就是&…

數據科學實踐:探索數據驅動的決策

寫在前面 你是否曾經困擾于如何從海量的數據中提取有價值的信息?你是否想過如何利用數據來指導你的決策,讓你的決策更加科學和精確?如果你有這樣的困擾和疑問,那么你來對了地方。這篇文章將引導你走進數據科學的世界,探索數據驅動的決策。 1.數據科學的基本原則 在我們…

第四屆傳智杯初賽(蓮子的機械動力學)

題目描述 題目背景的問題可以轉化為如下描述&#xff1a; 給定兩個長度分別為 n,m 的整數 a,b&#xff0c;計算它們的和。 但是要注意的是&#xff0c;這里的 a,b 采用了某種特殊的進制表示法。最終的結果也會采用該種表示法。具體而言&#xff0c;從低位往高位數起&#xf…

【linux】yum安裝時: Couldn‘t resolve host name for XXXXX

yum 安裝 sysstat 報錯了&#xff1a; Kylin Linux Advanced Server 10 - Os 0.0 B/s | 0 B 00:00 Errors during downloading metadata for repository ks10-adv-os:- Curl error (6): Couldnt resolve host nam…

在非Spring環境下Main方法中,怎么使用spring的ThreadPoolTaskScheduler啟動Scheduler?

作為Java開發人員&#xff0c;在使用spring框架的時候&#xff0c;如果想要獲取到線程池對象&#xff0c;可以直接使用spring框架提供的ThreadPoolxxx來獲取。那么在非spring環境下&#xff0c;main函數怎么使用ThreadPoolTaskScheduler呢&#xff1f;下面凱哥(凱哥Java:kaigej…

10.vue3項目(十):spu管理頁面的sku的新增和修改

目錄 一、sku靜態頁面的搭建 1.思路分析 2.代碼實現 3.效果展示

微信小程序 長按錄音+錄制視頻

<view class"bigCircle" bindtouchstart"start" bindtouchend"stop"><view class"smallCircle {{startVedio?onVedio:}}"><text>{{startVedio?正在錄音:長按錄音}}</text></view> </view> <…

排序算法:【選擇排序]

一、選擇排序——時間復雜度 定義&#xff1a;第一趟排序&#xff0c;從整個序列中找到最小的數&#xff0c;把它放到序列的第一個位置上&#xff0c;第二趟排序&#xff0c;再從無序區找到最小的數&#xff0c;把它放到序列的第二個位置上&#xff0c;以此類推。 也就是說&am…

軟件項目管理---胡亂復習版

范圍控制的一個重點是避免需求的不合理擴張。(√)一個任務原計劃2個人全職工作2周完成,而實際上只有一個人參與這個任務,到第二周末這個人完成了任務的75%,那么:BCWS = 4人周,ACWP = 2人周,BCWP = 3人周。CV = 1,SV = -1。 【在項目管理中,BCWS、ACWP和BCWP是用來衡量…