深度學習分詞器char-level實戰詳解

一、三種分詞器基本介紹

word-level:將文本按照空格或者標點分割成單詞,但是詞典大小太大

subword-level:詞根分詞(主流)

char-level:將文本按照字母級別分割成token

二、charlevel代碼

導包:

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as Fprint(sys.version_info)
for module in mpl, np, pd, sklearn, torch:print(module.__name__, module.__version__)device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

數據準備(需下載):

# https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
#文件已經下載好了
with open("./shakespeare.txt", "r", encoding="utf8") as file:text = file.read()print("length", len(text))
print(text[0:100])

?構造字典:

# 1. generate vocab
# 2. build mapping char->id
# 3. data -> id_data  把數據都轉為id
# 4. a b c d [EOS] -> [BOS] b c d  預測下一個字符生成的模型,也就是輸入是a,輸出就是b#去重,留下獨立字符,并排序(排序是為了好看)
vocab = sorted(set(text)) # 利用set去重,sorted排序
print(len(vocab))
print(vocab)#每個字符都編好號,enumerate對每一個位置編號,生成的是列表中是元組,下面字典生成式
char2idx = {char:idx for idx, char in enumerate(vocab)}
print(char2idx)# 把vocab從列表變為ndarray
idx2char = np.array(vocab)
print(idx2char)#把字符都轉換為id
text_as_int = np.array([char2idx[c] for c in text])
print(text_as_int.shape)
print(len(text_as_int))
print(text_as_int[0:10])
print(text[0:10])
  • enumerate()?是Python內置函數,用于給可迭代對象添加序號
  • 語法:enumerate(iterable, start=0)
  • 作用:將列表/字符串等轉換為(索引, 元素)元組的序列

一共1115394個字符,這里分為11043個batch,每個樣本101個字符,原因如下:

比如有Jeep四個字符,那么那前三個字母輸入J就預測到e,再輸入e預測到e再預測到p,相當于錯開預測。前100和最后一個錯開,就是上圖的效果。

把text分為樣本:

rom torch.utils.data import Dataset, DataLoaderclass CharDataset(Dataset):#text_as_int是字符的id列表,seq_length是每個樣本的長度def __init__(self, text_as_int, seq_length):self.sub_len = seq_length + 1 #一個樣本的長度self.text_as_int = text_as_intself.num_seq = len(text_as_int) // self.sub_len #樣本的個數def __getitem__(self, index):#index是樣本的索引,返回的是一個樣本,比如第一個,就是0-100的字符,總計101個字符return self.text_as_int[index * self.sub_len: (index + 1) * self.sub_len]def __len__(self): #返回樣本的個數return self.num_seq#batch是一個列表,列表中的每一個元素是一個樣本,有101個字符,前100個是輸入,后100個是輸出
def collat_fct(batch):src_list = [] #輸入trg_list = [] #輸出for part in batch:src_list.append(part[:-1]) #輸入trg_list.append(part[1:]) #輸出src_list = np.array(src_list) #把列表轉換為ndarraytrg_list = np.array(trg_list) #把列表轉換為ndarrayreturn torch.Tensor(src_list).to(dtype=torch.int64), torch.Tensor(trg_list).to(dtype=torch.int64) #返回的是一個元組,元組中的每一個元素是一個torch.Tensor#每個樣本的長度是101,也就是100個字符+1個結束符
train_ds = CharDataset(text_as_int, 100)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collat_fct)
#%%
  • seq_length:模型輸入的序列長度(例如100)

  • sub_len:實際存儲長度 = 輸入長度 + 目標長度(每個樣本多存1個字符用于構造目標)

假設原始文本數字編碼為:[1,2,3,4,5,6,7,8,9,10],當seq_length=3時:樣本1: [1,2,3,4] → 輸入[1,2,3],目標[2,3,4] 樣本2: [5,6,7,8] → 輸入[5,6,7],目標[6,7,8] 剩余字符[9,10]被舍棄

定義模型:

class CharRNN(nn.Module):def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024):super(CharRNN, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)#batch_first=True,輸入的數據格式是(batch_size, seq_len, embedding_dim)self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, hidden=None):x = self.embedding(x) #(batch_size, seq_len) -> (batch_size, seq_len, embedding_dim) (64, 100, 256)#這里和02的差異是沒有只拿最后一個輸出,而是把所有的輸出都拿出來了#(batch_size, seq_len, embedding_dim)->(batch_size, seq_len, hidden_dim)(64, 100, 1024)output, hidden = self.rnn(x, hidden)x = self.fc(output) #[bs, seq_len, hidden_dim]--->[bs, seq_len, vocab_size] (64, 100,65)return x, hidden #x的shape是(batch_size, seq_len, vocab_size)vocab_size = len(vocab)print("{:=^80}".format(" 一層單向 RNN "))       
for key, value in CharRNN(vocab_size).named_parameters():print(f"{key:^40}paramerters num: {np.prod(value.shape)}")

因為字典太小,所以embedding_dim要放大。輸入形狀(bs,seq)→輸出形狀(bs,seq,emb_dim)。

這樣的話才能把里面的信息分的更清楚,其他情況都是縮小。

生成的時候不能只取最后一個時間步了,全都要。

前向傳播流程:x→Embedding→RNN→Linear

訓練:

class SaveCheckpointsCallback:def __init__(self, save_dir, save_step=5000, save_best_only=True):"""Save checkpoints each save_epoch epoch. We save checkpoint by epoch in this implementation.Usually, training scripts with pytorch evaluating model and save checkpoint by step.Args:save_dir (str): dir to save checkpointsave_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.save_best_only (bool, optional): If True, only save the best model or save each model at every epoch."""self.save_dir = save_dirself.save_step = save_stepself.save_best_only = save_best_onlyself.best_metrics = -1# mkdirif not os.path.exists(self.save_dir):os.mkdir(self.save_dir)def __call__(self, step, state_dict, metric=None):if step % self.save_step > 0:returnif self.save_best_only:assert metric is not Noneif metric >= self.best_metrics:# save checkpointstorch.save(state_dict, os.path.join(self.save_dir, "best.ckpt"))# update best metricsself.best_metrics = metricelse:torch.save(state_dict, os.path.join(self.save_dir, f"{step}.ckpt"))#%%
# 訓練
def training(model, train_loader, epoch, loss_fct, optimizer, save_ckpt_callback=None,stateful=False      # 想用stateful,batch里的數據就必須連續,不能打亂):record_dict = {"train": [],}global_step = 0model.train()hidden = Nonewith tqdm(total=epoch * len(train_loader)) as pbar:for epoch_id in range(epoch):# trainingfor datas, labels in train_loader:datas = datas.to(device)labels = labels.to(device)# 梯度清空optimizer.zero_grad()# 模型前向計算,如果數據集打亂了,stateful=False,hidden就要清空# 如果數據集沒有打亂,stateful=True,hidden就不需要清空logits, hidden = model(datas, hidden=hidden if stateful else None)# 計算損失,交叉熵損失第一個參數要是二階張量,第二個參數要是一階張量,所以要reshapeloss = loss_fct(logits.reshape(-1, vocab_size), labels.reshape(-1))# 梯度回傳loss.backward()# 調整優化器,包括學習率的變動等optimizer.step()loss = loss.cpu().item()# recordrecord_dict["train"].append({"loss": loss, "step": global_step})# 保存模型權重 save model checkpointif save_ckpt_callback is not None:save_ckpt_callback(global_step, model.state_dict(), metric=-loss)# udate stepglobal_step += 1pbar.update(1)pbar.set_postfix({"epoch": epoch_id})return record_dictepoch = 100model = CharRNN(vocab_size=vocab_size)# 1. 定義損失函數 采用交叉熵損失 
loss_fct = nn.CrossEntropyLoss()
# 2. 定義優化器 采用 adam
# Optimizers specified in the torch.optim package
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# save best
if not os.path.exists("checkpoints"):os.makedirs("checkpoints")
save_ckpt_callback = SaveCheckpointsCallback("checkpoints/text_generation", save_step=1000, save_best_only=True)model = model.to(device)#%%
record = training(model,train_dl,epoch,loss_fct,optimizer,save_ckpt_callback=save_ckpt_callback,)
#%%
plt.plot([i["step"] for i in record["train"][::50]], [i["loss"] for i in record["train"][::50]], label="train")
plt.grid()
plt.show()
#%% md
## 推理
#%%#下面的例子是為了說明temperature
logits = torch.tensor([400.0,600.0]) #這里是logitsprobs1 = F.softmax(logits, dim=-1)
print(probs1)
#%%
logits = torch.tensor([0.04,0.06])  #現在 temperature是2probs1 = F.softmax(logits, dim=-1)
print(probs1)
#%%
import torch# 創建一個概率分布,表示每個類別被選中的概率
# 這里我們有一個簡單的四個類別的概率分布
prob_dist = torch.tensor([0.1, 0.45, 0.35, 0.1])# 使用 multinomial 進行抽樣
# num_samples 表示要抽取的樣本數量
num_samples = 5# 抽取樣本,隨機抽樣,概率越高,抽到的概率就越高,1代表只抽取一個樣本,replacement=True表示可以重復抽樣
samples_index = torch.multinomial(prob_dist, 1, replacement=True)print("概率分布:", prob_dist)
print("抽取的樣本索引:", samples_index)# 顯示每個樣本對應的概率
print("每個樣本對應的概率:", prob_dist[samples_index])
#%%
def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True):input_eval = torch.Tensor([char2idx[char] for char in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1) #bacth_size=1, seq_len長度是多少都可以 (1,5)hidden = Nonetext_generated = [] #用來保存生成的文本model.eval()pbar = tqdm(range(max_len)) # 進度條print(start_string, end="")# no_grad是一個上下文管理器,用于指定在其中的代碼塊中不需要計算梯度。在這個區域內,不會記錄梯度信息,用于在生成文本時不影響模型權重。with torch.no_grad():for i in pbar:#控制進度條logits, hidden = model(input_eval, hidden=hidden)# 溫度采樣,較高的溫度會增加預測結果的多樣性,較低的溫度則更加保守。#取-1的目的是只要最后,拼到原有的輸入上logits = logits[0, -1, :] / temperature #logits變為1維的# using multinomial to samplingprobs = F.softmax(logits, dim=-1) #算為概率分布idx = torch.multinomial(probs, 1).item() #從概率分布中抽取一個樣本,取概率較大的那些input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1) #把idx轉為tensortext_generated.append(idx)if stream:print(idx2char[idx], end="", flush=True)return "".join([idx2char[i] for i in text_generated])# load checkpoints
model.load_state_dict(torch.load("checkpoints/text_generation/best.ckpt", weights_only=True,map_location="cpu"))
start_string = "All: " #這里就是開頭,什么都可以
res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/897464.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/897464.shtml
英文地址,請注明出處:http://en.pswp.cn/news/897464.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于SpringBoot實現旅游酒店平臺功能六

一、前言介紹: 1.1 項目摘要 隨著社會的快速發展和人民生活水平的不斷提高,旅游已經成為人們休閑娛樂的重要方式之一。人們越來越注重生活的品質和精神文化的追求,旅游需求呈現出爆發式增長。這種增長不僅體現在旅游人數的增加上&#xff0…

git規范提交之commitizen conventional-changelog-cli 安裝

一、引言 使用規范的提交信息可以讓項目更加模塊化、易于維護和理解,同時也便于自動化工具(如發布工具或 Changelog 生成器)解析和處理提交記錄。 通過編寫符合規范的提交消息,可以讓團隊和協作者更好地理解項目的變更歷史和版本…

前端實現版本更新自動檢測?

🤖 作者簡介:水煮白菜王,一位資深前端勸退師 👻 👀 文章專欄: 前端專欄 ,記錄一下平時在博客寫作中,總結出的一些開發技巧和知識歸納總結?。 感謝支持💕💕&a…

硬件基礎(4):(5)設置ADC電壓采集中MCU的參考電壓

Vref 引腳通常是 MCU (特別是帶有 ADC 的微控制器) 上用來提供或接收基準電壓的引腳,ADC 會以該基準電壓作為量程參考對輸入模擬信號進行數字化轉換。具體來說: 命名方式 在不同廠家的 MCU 中,Vref 引腳可能會被標記為 VREF / VREF- / VREF_…

postman接口請求中的 Raw是什么

前言 在現代的網絡開發中,API 的使用已經成為數據交換的核心方式之一。然而,在與 API 打交道時,關于如何發送請求體(body)內容類型的問題常常困擾著開發者們,尤其是“raw”和“json”這兩個術語之間的區別…

為什么要使用前綴索引,以及建立前綴索引:sql示例

背景: 你想啊,數據庫里有些字段,它老長了,就像那種 varchar(255) 的字段,這玩意兒要是整個字段都拿來建索引,那可太占地方了。打個比方,這就好比你要在一個超級大的筆記本上記東西,每…

【語料數據爬蟲】Python爬蟲|批量采集會議紀要數據(1)

前言 本文是該專欄的第2篇,后面會持續分享Python爬蟲采集各種語料數據的的干貨知識,值得關注。 在本文中,筆者將主要來介紹基于Python,來實現批量采集“會議紀要”數據。同時,本文也是采集“會議紀要”數據系列的第1篇。 采集相關數據的具體細節部分以及詳細思路邏輯,筆…

Android 線程池實戰指南:高效管理多線程任務

在 Android 開發中,線程池的使用非常重要,尤其是在需要處理大量異步任務時。線程池可以有效地管理線程資源,避免頻繁創建和銷毀線程帶來的性能開銷。以下是線程池的使用方法和最佳實踐。 1. 線程池的基本使用 (1)創建線…

SQL29 計算用戶的平均次日留存率

SQL29 計算用戶的平均次日留存率 計算用戶的平均次日留存率_牛客題霸_牛客網 題目:現在運營想要查看用戶在某天刷題后第二天還會再來刷題的留存率。 示例:question_practice_detail -- 輸入: DROP TABLE IF EXISTS question_practice_detai…

深度學習分類回歸(衣帽數據集)

一、步驟 1 加載數據集fashion_minst 2 搭建class NeuralNetwork模型 3 設置損失函數,優化器 4 編寫評估函數 5 編寫訓練函數 6 開始訓練 7 繪制損失,準確率曲線 二、代碼 導包,打印版本號: import matplotlib as mpl im…

【leetcode hot 100 19】刪除鏈表的第N個節點

解法一:將ListNode放入ArrayList中,要刪除的元素為num list.size()-n。如果num 0則將頭節點刪除;否則利用num-1個元素的next刪除第num個元素。 /*** Definition for singly-linked list.* public class ListNode {* int val;* Lis…

【iOS逆向與安全】sms短信轉發插件與上傳服務器開發

一、目標 一步步分析并編寫一個短信自動轉發的deb插件 二、工具 mac系統已越獄iOS設備:脫殼及frida調試IDA Pro:靜態分析測試設備:iphone6s-ios14.1.1三、步驟 1、守護進程 ? 守護進程(daemon)是一類在后臺運行的特殊進程,用于執行特定的系統任務。例如:推送服務、人…

Midjourney繪圖參數詳解:從基礎到高級的全面指南

引言 Midjourney作為當前最受歡迎的AI繪圖工具之一,其強大的參數系統為用戶提供了豐富的創作可能性。本文將深入解析Midjourney的各項參數,幫助開發者更好地掌握這一工具,提升創作效率和質量。 一、基本參數配置 1. 圖像比例調整 使用--ar…

音頻進階學習十九——逆系統(簡單進行回聲消除)

文章目錄 前言一、可逆系統1.定義2.解卷積3.逆系統恢復原始信號過程4.逆系統與原系統的零極點關系 二、使用逆系統去除回聲獲取原信號的頻譜原系統和逆系統幅頻響應和相頻響應使用逆系統恢復原始信號整體代碼如下 總結 前言 在上一篇音頻進階學習十八——幅頻響應相同系統、全…

vue3 使用sass變量

1. 在<style>中使用scss定義的變量和css變量 1. 在/style/variables.scss文件中定義scss變量 // scss變量 $menuText: #bfcbd9; $menuActiveText: #409eff; $menuBg: #304156; // css變量 :root {--el-menu-active-color: $menuActiveText; // 活動菜單項的文本顏色--el…

gbase8s rss集群通信流程

什么是rss RSS是一種將數據從主服務器復制到備服務器的方法 實例級別的復制 (所有啟用日志記錄功能的數據庫) 基于邏輯日志的復制技術&#xff0c;需要傳輸大量的邏輯日志,數據庫需啟用日志模式 通過網絡持續將數據復制到備節點 如果主服務器發生故障&#xff0c;那么備用服務…

熵與交叉熵詳解

前言 本文隸屬于專欄《機器學習數學通關指南》&#xff0c;該專欄為筆者原創&#xff0c;引用請注明來源&#xff0c;不足和錯誤之處請在評論區幫忙指出&#xff0c;謝謝&#xff01; 本專欄目錄結構和參考文獻請見《機器學習數學通關指南》 ima 知識庫 知識庫廣場搜索&#…

程序化廣告行業(3/89):深度剖析行業知識與數據處理實踐

程序化廣告行業&#xff08;3/89&#xff09;&#xff1a;深度剖析行業知識與數據處理實踐 大家好&#xff01;一直以來&#xff0c;我都希望能和各位技術愛好者一起在學習的道路上共同進步&#xff0c;分享知識、交流經驗。今天&#xff0c;咱們聚焦在程序化廣告這個充滿挑戰…

探索在生成擴散模型中基于RAG增強生成的實現與未來

概述 像 Stable Diffusion、Flux 這樣的生成擴散模型&#xff0c;以及 Hunyuan 等視頻模型&#xff0c;都依賴于在單一、資源密集型的訓練過程中通過固定數據集獲取的知識。任何在訓練之后引入的概念——被稱為 知識截止——除非通過 微調 或外部適應技術&#xff08;如 低秩適…

DeepSeek 助力 Vue3 開發:打造絲滑的表格(Table)之添加列寬調整功能,示例Table14基礎固定表頭示例

前言&#xff1a;哈嘍&#xff0c;大家好&#xff0c;今天給大家分享一篇文章&#xff01;并提供具體代碼幫助大家深入理解&#xff0c;徹底掌握&#xff01;創作不易&#xff0c;如果能幫助到大家或者給大家一些靈感和啟發&#xff0c;歡迎收藏關注哦 &#x1f495; 目錄 Deep…