PyTorch RNN 名字分類器

PyTorch RNN 名字分類器詳解

使用PyTorch實現的字符級RNN(循環神經網絡)項目,用于根據人名預測其所屬的語言/國家。該模型通過學習不同語言名字的字符模式,夠識別名字的語言起源。

環境設置

import torch
import string
import unicodedata
import glob
import os
import time
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

1. 數據預處理

1.1 字符編碼處理

# 定義允許的字符集(ASCII字母 + 標點符號 + 占位符)
allowed_characters = string.ascii_letters + " .,;'" + "_"
n_letters = len(allowed_characters)  # 58個字符def unicodeToAscii(s):"""將Unicode字符串轉換為ASCII"""return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn' and c in allowed_characters)

關鍵點:

  • 使用One-hot編碼表示每個字符
  • 將非ASCII字符規范化(如 ‘?lusàrski’ → ‘Slusarski’)
  • 未知字符用 “_” 表示

1.2 張量轉換

def letterToIndex(letter):"""將字母轉換為索引"""if letter not in allowed_characters:return allowed_characters.find("_")return allowed_characters.find(letter)def lineToTensor(line):"""將名字轉換為張量 <line_length x 1 x n_letters>"""tensor = torch.zeros(len(line), 1, n_letters)for li, letter in enumerate(line):tensor[li][0][letterToIndex(letter)] = 1return tensor

張量維度說明:

  • 每個名字表示為3D張量:[序列長度, 批次大小=1, 字符數=58]
  • 使用One-hot編碼:每個字符位置只有一個1,其余為0

2. 數據集構建

2.1 自定義Dataset類

class NamesDataset(Dataset):def __init__(self, data_dir):self.data = []           # 原始名字self.data_tensors = []   # 名字的張量表示self.labels = []         # 語言標簽self.labels_tensors = [] # 標簽的張量表示# 讀取所有.txt文件(每個文件代表一種語言)text_files = glob.glob(os.path.join(data_dir, '*.txt'))for filename in text_files:label = os.path.splitext(os.path.basename(filename))[0]lines = open(filename, encoding='utf-8').read().strip().split('\n')for name in lines:self.data.append(name)self.data_tensors.append(lineToTensor(name))self.labels.append(label)

2.2 數據集劃分

# 85/15 訓練/測試集劃分
train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024)
)

3. RNN模型架構

3.1 模型定義

class CharRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(CharRNN, self).__init__()# RNN層:輸入大小 → 隱藏層大小self.rnn = nn.RNN(input_size, hidden_size)# 輸出層:隱藏層 → 輸出類別self.h2o = nn.Linear(hidden_size, output_size)# LogSoftmax用于分類self.softmax = nn.LogSoftmax(dim=1)def forward(self, line_tensor):rnn_out, hidden = self.rnn(line_tensor)output = self.h2o(hidden[0])output = self.softmax(output)return output

模型參數:

  • 輸入大小:58(字符數)
  • 隱藏層大小:128
  • 輸出大小:18(語言類別數)

4. 訓練過程

4.1 訓練函數

def train(rnn, training_data, n_epoch=10, n_batch_size=64, learning_rate=0.2, criterion=nn.NLLLoss()):rnn.train()optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)for iter in range(1, n_epoch + 1):# 創建小批量batches = list(range(len(training_data)))random.shuffle(batches)batches = np.array_split(batches, len(batches)//n_batch_size)for batch in batches:batch_loss = 0for i in batch:label_tensor, text_tensor, label, text = training_data[i]output = rnn.forward(text_tensor)loss = criterion(output, label_tensor)batch_loss += loss# 反向傳播和優化batch_loss.backward()nn.utils.clip_grad_norm_(rnn.parameters(), 3)  # 梯度裁剪optimizer.step()optimizer.zero_grad()

訓練技巧:

  • 使用SGD優化器,學習率0.15
  • 梯度裁剪防止梯度爆炸
  • 批量大小:64

5. 模型評估

5.1 混淆矩陣可視化

def evaluate(rnn, testing_data, classes):confusion = torch.zeros(len(classes), len(classes))rnn.eval()with torch.no_grad():for i in range(len(testing_data)):label_tensor, text_tensor, label, text = testing_data[i]output = rnn(text_tensor)guess, guess_i = label_from_output(output, classes)label_i = classes.index(label)confusion[label_i][guess_i] += 1# 歸一化并可視化# ...

6. 訓練結果

  • 訓練樣本數:17,063
  • 測試樣本數:3,011
  • 訓練輪數:27
  • 最終損失:約0.43

損失曲線顯示模型收斂良好,從初始的0.88降至0.43。

在這里插入圖片描述
在這里插入圖片描述
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/917941.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/917941.shtml
英文地址,請注明出處:http://en.pswp.cn/news/917941.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

面向對象之類方法,成員變量和局部變量

1.類的方法必須包含幾個部分&#xff1f;2.成員變量和局部變量類的方法必須包含哪幾個部分&#xff1f;.方法名&#xff1a;用于標識方法的名稱&#xff0c;遵循標識符命名規則&#xff0c;通常采用駝峰命名法。返回值類型&#xff1a;指定方法返回的數據類型。如果方法不返回任…

古法筆記 | 通過查表進行ASCII字符編碼轉換

ASCII字符集是比較早期的一種字符編碼&#xff0c;只能表示英文字符&#xff0c;最多能表示128個字符。 字符集規定了每個字符和二進制數之間的對應關系&#xff0c;可以通過查表完成二進制數到字符的轉換ASCII字符占用的存儲空間是定長的1字節 ASCII字符的官方碼點表見下圖&…

Linux C實現單生產者多消費者環形緩沖區

使用C11里的原子變量實現&#xff0c;沒有用互斥鎖&#xff0c;效率更高。ring_buffer.h:/*** file ring_buffer.h* author tl* brief 單生產者多消費者環形緩沖區&#xff0c;每條數據被所有消費者讀后才釋放。讀線程安全&#xff0c;寫僅單線程。* version* date 2025-08-06*…

復雜場景識別率↑31%!陌訊多模態融合算法在智慧環衛的實戰解析

摘要&#xff1a;針對邊緣計算優化的垃圾堆放識別場景&#xff0c;本文解析了基于動態決策機制的視覺算法如何提升復雜環境的魯棒性。實測數據顯示在遮擋/光照干擾下&#xff0c;mAP0.5較基線提升28.3%&#xff0c;誤報率降低至行業1/5水平。一、行業痛點&#xff1a;智慧環衛的…

MyBatis-Plus Service 接口:如何在 MyBatis-Plus 中實現業務邏輯層??

全文目錄&#xff1a;開篇語前言1. MyBatis-Plus 的 IService 接口1.1 基本使用示例&#xff1a;創建實體類 User 和 UserService1.2 創建 IService 接口1.3 創建 ServiceImpl 類1.4 典型的數據庫操作方法1.4.1 save()&#xff1a;保存數據1.4.2 remove()&#xff1a;刪除數據1…

[激光原理與應用-168]:光源 - 常見光源的分類、特性及應用場景的詳細解析,涵蓋技術原理、優缺點及典型應用領域

一、半導體光源1. LED光源&#xff08;發光二極管&#xff09;原理&#xff1a;通過半導體PN結的電子-空穴復合發光&#xff0c;波長由材料帶隙決定&#xff08;如GaN發藍光、AlGaInP發紅光&#xff09;。特性&#xff1a;優點&#xff1a;壽命長&#xff08;>5萬小時&#…

Metronic v.7.1.7企業級Web應用前端框架全攻略

本文還有配套的精品資源&#xff0c;點擊獲取 簡介&#xff1a;Metronic是一款專注于構建響應式、高性能企業級Web應用的前端開發框架。最新版本v.7.1.7引入了多種功能和優化&#xff0c;以增強開發效率和用戶體驗。詳細介紹了其核心特性&#xff0c;包括響應式設計、多種模…

鴻蒙開發--Notification Kit(用戶通知服務)

通知是手機系統中很重要的信息展示方式&#xff0c;通知不僅可以展示文字&#xff0c;也可以展示圖片&#xff0c;甚至可以將組件加到通知中&#xff0c;只要用戶不清空&#xff0c;通知的信息可以永久保留在狀態欄上通知的介紹 通知 Notification通知&#xff0c;即在一個應用…

鴻蒙 - 分享功能

文章目錄一、背景二、app發起分享1. 通過分享面板進行分享2. 使用其他應用打開二、處理分享的內容1. module.json5 配置可接收分享2. 解析分享的數據一、背景 在App開發中&#xff0c;分享是常用功能&#xff0c;這里介紹鴻蒙開發中&#xff0c;其他應用分享到自己的app中&…

【Agent 系統設計】基于大語言模型的智能Agent系統

一篇阿里博文引發的思考和探索。基于大語言模型的智能Agent系統 1. 系統核心思想 核心思想是構建一個以大語言模型&#xff08;LLM&#xff09;為“大腦”的智能代理&#xff08;Agent&#xff09;&#xff0c;旨在解決將人類的自然語言指令高效、準確地轉化為機器可執行的自動…

企業級Web框架性能對決:Spring Boot、Django、Node.js與ASP.NET深度測評

企業級Web應用的開發效率與運行性能直接關系到業務的成敗。本文通過構建標準化的待辦事項&#xff08;Todo&#xff09;應用&#xff0c;對四大主流框架——Spring Boot、Django、Node.js和ASP.NET展開全面的性能較量。我們將從底層架構特性出發&#xff0c;結合實測數據與數據…

為什么 `source ~/.bashrc` 在 systemd 或 crontab 中不生效

摘要&#xff1a;你是否遇到過這樣的問題&#xff1a;在終端里運行腳本能正常工作&#xff0c;但用 systemd 或 crontab 自動啟動時卻報錯“命令找不到”、“模塊導入失敗”&#xff1f; 本文將揭示一個深藏在 ~/.bashrc 中的“陷阱”&#xff1a;非交互式 shell 會直接退出&am…

Linux 磁盤中的文件

1.磁盤結構 Linux中的文件加載到內存上之前是放到哪的&#xff1f; 放在磁盤上的文件——>訪問文件&#xff0c;打開它——>找到這個文件——>路徑 但文件是怎樣存儲在磁盤上的 1.1物理結構磁盤可以理解為上百億個小磁鐵&#xff08;如N為1&#xff0c;S為0&#xff0…

【方法】Git本地倉庫的文件夾不顯示紅色感嘆號、綠色對號等圖標

文章目錄前言開始操作winr&#xff0c;輸入regedit&#xff0c;打開注冊表重啟資源管理器前言 這個綠色對號圖標表示本地倉庫和遠程的GitHub倉庫內容保持一致&#xff0c;紅色則是相反咯&#xff0c;給你們瞅一下。 首先這兩個東西你一定要安裝配置好了&#xff0c;安裝順序不…

量化交易與主觀交易:哪種方式更勝一籌?

文章概要 在投資的世界里&#xff0c;量化交易和主觀交易如同冰與火&#xff0c;各自擁有獨特的優勢與挑戰。作為一名投資者&#xff0c;了解這兩種交易方式的差異和各自的優缺點至關重要。本文將從決策依據、執行方式、風險管理等方面深入探討量化交易的精確性與主觀交易的靈活…

【JS】扁平樹數據轉為樹結構

扁平數據轉為最終效果[{"label":"疼遜有限公司","code":"1212","disabled":false,"parentId":"none","children":[{"label":"財務部","code":"34343&quo…

數據結構4-棧、隊列

摘要&#xff1a;本文系統介紹了棧和隊列兩種基礎數據結構。棧采用"先進后出"原則&#xff0c;分為順序棧和鏈式棧&#xff0c;詳細說明了壓棧、出棧等基本操作及其實現方法。隊列遵循"先進先出"規則&#xff0c;同樣分為順序隊列和鏈式隊列&#xff0c;重…

大數據spark、hasdoop 深度學習、機器學習算法的音樂平臺用戶情感分析系統設計與實現

大數據spark、hasdoop 深度學習、機器學習算法的音樂平臺用戶情感分析系統設計與實現

視頻匯聚系統EasyCVR調用設備錄像保活時視頻流不連貫問題解決方案

在使用EasyCVR過程中&#xff0c;有用戶反饋調用設備錄像保活功能時&#xff0c;出現視頻流不連貫的情況。針對這一問題&#xff0c;我們經過排查與測試&#xff0c;整理出如下解決步驟&#xff0c;供開發者參考&#xff1a;具體解決步驟1&#xff09;先調用登錄接口完成鑒權確…

【保姆級喂飯教程】python基于mysql-connector-python的數據庫操作通用封裝類(連接池版)

目錄項目環境一、db_config.py二、mysql_executor.py三、test/main.py在使用mysql-connector-python連接MySQL數據庫的時候&#xff0c;如同Java中的jdbc一般&#xff0c;每條sql需要創建和刪除連接&#xff0c;很自然就想到寫一個抽象方法&#xff0c;但是找了找沒有官方標準的…