【Andrej Karpathy 神經網絡從Zero到Hero】--2.語言模型的兩種實現方式 (Bigram 和 神經網絡)

目錄

  • 統計 Bigram 語言模型
    • 質量評價方法
  • 神經網絡語言模型

【系列筆記】
【Andrej Karpathy 神經網絡從Zero到Hero】–1. 自動微分autograd實踐要點

本文主要參考 大神Andrej Karpathy 大模型講座 | 構建makemore 系列之一:講解語言建模的明確入門,演示

  1. 如何利用統計數值構建一個簡單的 Bigram 語言模型
  2. 如何用一個神經網絡來復現前面 Bigram 語言模型的結果,以此來展示神經網絡相對于傳統 n-gram 模型的拓展性。

統計 Bigram 語言模型

首先給定一批數據,每個數據是一個英文名字,例如:

['emma','olivia','ava','isabella','sophia','charlotte','mia','amelia','harper','evelyn']

Bigram語言模型的做法很簡單,首先將數據中的英文名字都做成一個個bigram的數據

其中每個格子中是對應的二元組,eg: “rh” ,在所有數據中出現的次數。那么一個自然的想法是對于給定的字母,取其對應的行,將次數歸一化轉成概率值,然后根據概率分布抽取下一個可能的字母:

g = torch.Generator().manual_seed(2147483647)
P = N.float() # N 即為上述 counts 矩陣
P = P / P.sum(1, keepdims=True) # P是每行歸一化后的概率值for i in range(5):out = []ix = 0  ## start符和end符都用 id=0 表示,這里是startwhile True:p = P[ix] # 當前字符為 ix 時,預測下一個字符的概率分布,實質是一個多項分布(即可能抽到的值有多個,eg: 擲色子是六項分布)ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()out.append(itos[ix])if ix == 0: ## 當運行到end符,停止生成breakprint(''.join(out))

輸出類似于:

mor.
axx.
minaymoryles.
kondlaisah.
anchshizarie.

質量評價方法

我們還需要方法來評估語言模型的質量,一個直觀的想法是:
P ( s 1 s 2 . . . s n ) = P ( s 1 ) P ( s 2 ∣ s 1 ) ? P ( s n ∣ s n ? 1 ) P(s_1s_2...s_n) = P(s_1)P(s_2|s_1)\cdots P(s_n|s_{n-1}) P(s1?s2?...sn?)=P(s1?)P(s2?s1?)?P(sn?sn?1?)
但上述計算方式有一個問題,概率值都是小于1的,當序列的長度比較長時,上述數值會趨于0,計算時容易下溢。因此實踐中往往使用 l o g ( P ) log(P) log(P)來代替,為了可以對比不同長度的序列的預測效果,再進一步使用 l o g ( P ) / n log(P)/n log(P)/n 表示一個序列平均的質量

上述統計 Bigram 模型在訓練數據上的平均質量為:

log_likelihood = 0.0
n = 0for w in words: # 所有word里的二元組概率疊加chs = ['.'] + list(w) + ['.']for ch1, ch2 in zip(chs, chs[1:]):ix1 = stoi[ch1]ix2 = stoi[ch2]prob = P[ix1, ix2]logprob = torch.log(prob)log_likelihood += logprobn += 1 # 所有word里的二元組數量之和nll = -log_likelihood
print(f'{nll/n}') ## 值為 2.4764,表示前面做的bigram模型,對現有訓練數據的置信度## 這個值越低表示當前模型越認可訓練數據的質量,而由于訓練數據是我們認為“好”的數據,因此反過來就說明這個模型好

但這里有一個問題是,例如:

log_likelihood = 0.0
n = 0#for w in words:
for w in ["andrejz"]:chs = ['.'] + list(w) + ['.']for ch1, ch2 in zip(chs, chs[1:]):ix1 = stoi[ch1]ix2 = stoi[ch2]prob = P[ix1, ix2]logprob = torch.log(prob)log_likelihood += logprobn += 1print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')

輸出是

.a: 0.1377 -1.9829
an: 0.1605 -1.8296
nd: 0.0384 -3.2594
dr: 0.0771 -2.5620
re: 0.1336 -2.0127
ej: 0.0027 -5.9171
jz: 0.0000 -inf
z.: 0.0667 -2.7072
log_likelihood=tensor(-inf)
nll=tensor(inf)
inf

可以發現由于,jz 在計數矩陣 N 中為0,即數據中沒有出現過,導致 log(loss) 變成了負無窮,這里為了避免這樣的情況,需要做 平滑處理,即 P = N.float() 改成 P = (N+1).float(),這樣上述代碼輸出變成:

.a: 0.1376 -1.9835
an: 0.1604 -1.8302
nd: 0.0384 -3.2594
dr: 0.0770 -2.5646
re: 0.1334 -2.0143
ej: 0.0027 -5.9004
jz: 0.0003 -7.9817
z.: 0.0664 -2.7122
log_likelihood=tensor(-28.2463)
nll=tensor(28.2463)
3.5307815074920654

避免了出現 inf 這種數據溢出問題。


神經網絡語言模型

接下來嘗試用神經網絡的方式構建上述bigram語言模型:

# 構建訓練數據
xs, ys = [], [] # 分別是前一個字符和要預測的下一個字符的id
for w in words[:5]:chs = ['.'] + list(w) + ['.']for ch1, ch2 in zip(chs, chs[1:]):ix1 = stoi[ch1]ix2 = stoi[ch2]print(ch1, ch2)xs.append(ix1)ys.append(ix2)    xs = torch.tensor(xs)
ys = torch.tensor(ys)
# 輸出示例:. e
#          e m
#          m m
#          m a
#          a .
#       xs: tensor([ 0,  5, 13, 13,  1])
#       ys: tensor([ 5, 13, 13,  1,  0])# 隨機初始化一個 27*27 的參數矩陣
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True) # 基于正態分布隨機初始化
# 前向傳播
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=27).float() # 將輸入數據xs做成one-hot embedding
logits = xenc @ W # 用于模擬統計模型中的統計數值矩陣,由于 W 是基于正態分布采樣,logits 并非直接是計數值,可以認為是 log(counts)
## tensor([[-0.5288, -0.5967, -0.7431,  ...,  0.5990, -1.5881,  1.1731],
##        [-0.3065, -0.1569, -0.8672,  ...,  0.0821,  0.0672, -0.3943],
##        [ 0.4942,  1.5439, -0.2300,  ..., -2.0636, -0.8923, -1.6962],
##        ...,
##        [-0.1936, -0.2342,  0.5450,  ..., -0.0578,  0.7762,  1.9665],
##        [-0.4965, -1.5579,  2.6435,  ...,  0.9274,  0.3591, -0.3198],
##        [ 1.5803, -1.1465, -1.2724,  ...,  0.8207,  0.0131,  0.4530]])
counts = logits.exp() # 將 log(counts) 還原成可以看作是 counts 的矩陣
## tensor([[ 0.5893,  0.5507,  0.4756,  ...,  1.8203,  0.2043,  3.2321],
##        [ 0.7360,  0.8548,  0.4201,  ...,  1.0856,  1.0695,  0.6741],
##        [ 1.6391,  4.6828,  0.7945,  ...,  0.1270,  0.4097,  0.1834],
##        ...,
##        [ 0.8240,  0.7912,  1.7245,  ...,  0.9438,  2.1732,  7.1459],
##        [ 0.6086,  0.2106, 14.0621,  ...,  2.5279,  1.4320,  0.7263],
##        [ 4.8566,  0.3177,  0.2802,  ...,  2.2722,  1.0132,  1.5730]])
probs = counts / counts.sum(1, keepdims=True) # 用于模擬統計模型中的概率矩陣,這其實即是 softmax 的實現
loss = -probs[torch.arange(5), ys].log().mean() # loss = log(P)/n, 這其實即是 cross-entropy 的實現

接下來可以通過loss.backward()來更新參數 W:

for k in range(100):# forward passxenc = F.one_hot(xs, num_classes=27).float() logits = xenc @ W # predict log-countscounts = logits.exp()probs = counts / counts.sum(1, keepdims=True) loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() ## 這里加上了L2正則,防止過擬合print(loss.item())# backward passW.grad = None # 每次反向傳播前置為Noneloss.backward()# updateW.data += -50 * W.grad  

注意這里 logits = xenc @ W 由于 xenc 是 one-hot 向量,因此這里 logits 相當于是抽出了 W 中的某一行,而結合 bigram 模型中,loss 實際上是在計算實際的 log(P[x_i, y_i]),那么可以認為這里 W 其實是在擬合 bigram 中的計數矩陣 N(不過實際是 logW 在擬合 N)

另外上述神經網絡的 loss 最終也是達到差不多 2.47 的最低 loss。這是合理的,因為從上面的分析可知,這個神經網絡是完全在擬合 bigram 計數矩陣的,沒有使用更復雜的特征提取方法,因此效果最終也會差不多。

這里 loss 中還加了一個 L2 正則,主要目的是壓縮 W,使得它向全 0 靠近,這里的效果非常類似于 bigram 中的平滑手段,想象給一個極大的平滑:P = (N+10000).float()`,那么 P 會趨于一個均勻分布,而 W 全為 0 會導致 counts = logits.exp() 全為 1,即也在擬合一個均勻分布。這里前面的參數 0.01 即是用來調整平滑強度的,如果這個給的太大,那么平滑太大了,就會學成一個均勻分布(當然實際不會希望這樣,所以不會給很大)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/897345.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/897345.shtml
英文地址,請注明出處:http://en.pswp.cn/news/897345.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

(二 十 二)趣學設計模式 之 備忘錄模式!

目錄 一、 啥是備忘錄模式?二、 為什么要用備忘錄模式?三、 備忘錄模式的實現方式四、 備忘錄模式的優缺點五、 備忘錄模式的應用場景六、 總結 🌟我的其他文章也講解的比較有趣😁,如果喜歡博主的講解方式,…

安裝SPSS后啟動顯示應用程序無法啟動,因為應用程序的并行配置不正確的解決方案

軟件安裝報錯問題有需要遠程文章末尾獲取聯系方式,可以幫你遠程處理各類安裝報錯。 一、安裝SPSS后啟動顯示應用程序無法啟動,因為應用程序的并行配置不正確報錯 在成功安裝 SPSS 軟件后,嘗試啟動應用程序時,系統彈出錯誤提示窗…

IP,MAC,ARP 筆記

1.什么是IP地址 IP 地址是一串由句點分隔的數字。IP 地址表示為一組四個數字,比如 192.158.1.38 就是一個例子。該組合中的每個數字都可以在 0 到 255 的范圍內。因此,完整的 IP 尋址范圍從 0.0.0.0 到 255.255.255.255。 IP 地址不是隨機的。它們由互…

C++11中的Condition_variable

C11中的condition_variable 在C11中,條件變量(std::condition_variable)是線程同步機制之一,用于在多線程環境中實現線程間的通信和協調。它允許一個或多個線程在某個條件尚未滿足時等待,直到其他線程通知條件已經滿足…

IO多路復用實現并發服務器

一.select函數 select 的調用注意事項 在使用 select 函數時,需要注意以下幾個關鍵點: 1. 參數的修改與拷貝 readfds 等參數是結果參數 : select 函數會直接修改傳入的 fd_set(如 readfds、writefds 和 exceptfds&#xf…

_二級繼電器程控放大倍數自動設置

簡介 在開發項目中,有時會遇到需要使用程控放大的情況,如果沒有opa那種可編程放大器,那么就需要通過繼電器來控制放大倍數。而在繼電器程控中,常用的是二級程控,三級程控相較于二級就復雜了許多。 在二級程控中&#x…

電腦總顯示串口正在被占用處理方法

1.現象 在嵌入式開發過程中,有很多情況下要使用串口調試,其中485/422/232轉usb串口是非常常見的做法。 根據協議,接口芯片不同,需要安裝對應的驅動程序,比如ch340,cp2102,CDM212364等驅動。可…

優雅拼接字符串:StringJoiner 的完整指南

在Java開發中,字符串拼接是高頻操作。無論是日志格式化、構建CSV數據,還是生成動態SQL,開發者常需處理分隔符、前綴和后綴的組合。傳統的StringBuilder雖然靈活,但代碼冗余且易出錯。Java 8推出的StringJoiner類,以簡潔…

LabVIEW閉環控制系統硬件選型與實時性能

在LabVIEW閉環控制系統的開發中,硬件選型直接影響系統的實時性、精度與穩定性。需綜合考慮數據采集速度(采樣率、接口帶寬)、計算延遲(算法復雜度、處理器性能)、輸出響應時間(執行器延遲、控制周期&#x…

Hive的架構

1. 概念 Hive 是建立在 Hadoop 上的數據倉庫工具,旨在簡化大規模數據集的查詢與管理。它通過類 SQL 語言(HiveQL)將結構化數據映射為 Hadoop 的 MapReduce,適合離線批處理,尤其適用于數據倉庫場景。 2. 數據模型 表&a…

深入解析:Linux中KVM虛擬化技術

這篇文章將深入分析Linux中虛擬化技術的實現----KVM技術,從KVM技術的簡介、技術架構、以及虛擬機和宿主機交互的重要處理邏輯出發,深入探究KVM技術的實現。 一、KVM簡介: 首先,我們先查看一下KVM架構,看看它的整體工…

golang學習筆記——go語言安裝及系統環境變量設置

文章目錄 go語言安裝go envgo getgoproxy測試安裝 Go 插件安裝 Go 插件依賴工具參考資料用戶環境變量和系統環境變量用戶環境變量系統環境變量示例設置環境變量的步驟設置用戶環境變量設置系統環境變量 驗證環境變量總結 2024年最火的5大Go框架1. Gin:高并發接口的“…

3.6c語言

#define _CRT_SECURE_NO_WARNINGS #include <math.h> #include <stdio.h> int main() {int sum 0,i,j;for (j 1; j < 1000; j){sum 0;for (i 1; i < j; i){if (j % i 0){sum i;} }if (sum j){printf("%d是完數\n", j);}}return 0; }#de…

【TI】如何更改 CCS20.1.0 的 WORKSPACE 默認路徑

參考鏈接&#xff1a; 如何更改 CCS Theia 中工作區的默認位置&#xff1f;- Code Composer Studio 論壇 - Code Composer Studio?? - TI E2E 支持論壇 --- How to change the default location for the workspace in CCS Theia? - Code Composer Studio forum - Code Comp…

Vue3中動態Ref的魔法:綁定與妙用

前言 在Vue 3的開發過程中,動態綁定Ref是一項非常實用的技術,特別是在處理復雜組件結構和動態數據時。通過動態綁定Ref,我們可以更靈活地訪問和操作DOM元素或組件實例,實現更高效的交互和狀態管理。本文將詳細介紹如何在Vue 3中實現動態Ref的綁定,并通過實例展示其妙用。…

CarPlanner:用于自動駕駛大規模強化學習的一致性自回歸軌跡規劃

25年2月來自浙大和菜鳥網絡的論文“CarPlanner: Consistent Auto-regressive Trajectory Planning for Large-scale Reinforcement Learning in Autonomous Driving”。 軌跡規劃對于自動駕駛至關重要&#xff0c;可確保在復雜環境中安全高效地導航。雖然最近基于學習的方法&a…

VS Code連接服務器教程

VS Code是什么 VS Code&#xff08;全稱 Visual Studio Code&#xff09;是一款由微軟推出的免費、開源、跨平臺的代碼編輯神器。VS Code 支持 所有主流操作系統&#xff0c;擁有強大的功能和靈活的擴展性。 官網&#xff1a;https://code.visualstudio.com/插件市場&#xff1…

【JavaWeb】Web基礎概念

文章目錄 1、服務器與客戶端2、服務器端應用程序3、請求和響應4、項目的邏輯構成5、架構5.1 概念5.2 發展演變歷程單一架構分布式架構 5.3 單一架構技術體系 6、本階段技術體系 1、服務器與客戶端 ①線下的服務器與客戶端 ②線上的服務器與客戶端 2、服務器端應用程序 我…

安徽省考計算機專業科目2025(持續更新)

目錄 第一部分 計算機科學技術基礎 第一章 計算機及其應用基礎知識 1.1 計算機的特點、分類及其應用 1.2 信息編碼與數據表示&#xff1b;數制及其轉換方法&#xff1b;算術運算和邏輯運算的過程 第一部分 計算機科學技術基礎 第一章 計算機及其應用基礎知識 1.1 計算機…

前端知識點---路由模式-實例模式和單例模式(ts)

在 ArkTS&#xff08;Ark UI 框架&#xff09;中&#xff0c;路由實例模式&#xff08;Standard Instance Mode&#xff09;主要用于管理頁面跳轉。當創建一個新頁面時&#xff0c;可以選擇標準實例模式&#xff08;Standard Mode&#xff09;或單實例模式&#xff08;Single M…