深度學習使用Pytorch訓練模型步驟

訓練模型是機器學習和深度學習中的核心過程,旨在通過大量數據學習模型參數,以便模型能夠對新的、未見過的數據做出準確的預測。

訓練模型通常包括以下幾個步驟:

1.數據準備:
收集和處理數據,包括清洗、標準化和歸一化。
將數據分為訓練集、驗證集和測試集。

2.定義模型:
選擇模型架構,例如決策樹、神經網絡等。
初始化模型參數(權重和偏置)。

3.選擇損失函數:
根據任務類型(如分類、回歸)選擇合適的損失函數。

4.選擇優化器:
選擇一個優化算法,如SGD、Adam等,來更新模型參數。

5.前向傳播:
在每次迭代中,將輸入數據通過模型傳遞,計算預測輸出。

6.計算損失:
使用損失函數評估預測輸出與真實標簽之間的差異。

7.反向傳播:
利用自動求導計算損失相對于模型參數的梯度。

8.參數更新:
根據計算出的梯度和優化器的策略更新模型參數。

9.迭代優化:
重復步驟5-8,直到模型在驗證集上的性能不再提升或達到預定的迭代次數。

10.評估和測試:
使用測試集評估模型的最終性能,確保模型沒有過擬合。

11.模型調優:
根據模型在測試集上的表現進行調參,如改變學習率、增加正則化等。

12.部署模型:
將訓練好的模型部署到生產環境中,用于實際的預測任務。


一、PyTorch 數據處理與加載

PyTorch 提供了Dataset 和 DataLoader,幫助管理數據集、批量加載和數據增強等任務。

PyTorch 數據處理與加載:
自定義 Dataset:通過繼承 torch.utils.data.Dataset 來加載自己的數據集。
DataLoader:使用DataLoader 按批次加載數據,支持多線程加載并進行數據打亂。(torch.utils.data.DataLoader

(一)自定義 Dataset

torch.utils.data.Dataset 是一個抽象類,允許你自己的數據源中創建數據集。

使用時需要繼承該類并實現以下兩個方法:
len(self):返回數據集中的樣本數量。
getitem(self, idx):通過索引返回一個樣本。

import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from models.utils import match_seq_lenDATASET_DIR = "D:\EMDKT\datasets\ASSIST2009"class ASSIST2009(Dataset):def __init__(self, seq_len, dataset_dir=DATASET_DIR) -> None:super().__init__()self.dataset_dir = dataset_dirself.dataset_path = os.path.join(self.dataset_dir, "skill_builder_data.csv")# 調用預處理self.q_seqs, self.r_seqs, self.q_list, self.u_list, self.q2idx, \self.u2idx = self.preprocess()self.num_u = self.u_list.shape[0]  # 用戶總數self.num_q = self.q_list.shape[0]  # 題目總數if seq_len:self.q_seqs, self.r_seqs = match_seq_len(self.q_seqs, self.r_seqs, seq_len)self.len = len(self.q_seqs)def __getitem__(self, index):return self.q_seqs[index], self.r_seqs[index]def __len__(self):return self.lendef preprocess(self):# 數據加載與清洗df = pd.read_csv(self.dataset_path, encoding="ISO-8859-15").dropna(subset=["skill_name"]).drop_duplicates(subset=["order_id", "skill_name"]).sort_values(by=["order_id"])u_list = np.unique(df["user_id"].values)q_list = np.unique(df["skill_name"].values)u2idx = {u: idx for idx, u in enumerate(u_list)}q2idx = {q: idx for idx, q in enumerate(q_list)}# 生成序列數據q_seqs = []r_seqs = []for u in u_list:df_u = df[df["user_id"] == u]q_seq = np.array([q2idx[q] for q in df_u["skill_name"]])r_seq = df_u["correct"].valuesq_seqs.append(q_seq)r_seqs.append(r_seq)# 返回結果return q_seqs, r_seqs, q_list, u_list, q2idx, u2idx

(二)使用 DataLoader 加載數據

DataLoader 是 PyTorch 提供的一個重要工具,用于從 Dataset 中按批次(batch)加載數據。
DataLoader 允許批量讀取數據并進行多線程加載,從而提高訓練效率。

from torch.utils.data import DataLoader, random_split
from models.utils import collate_fn# 加載數據集
dataset = ASSIST2009(seq_len)  # seq_len 是序列長度參數# 劃分數據集
train_size = int(len(dataset) * train_ratio)  # train_ratio 是訓練集比例
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size]
)# 創建數據加載器
train_loader = DataLoader(train_dataset,batch_size=batch_size,  # batch_size 是批處理大小shuffle=True,collate_fn=collate_fn  # 使用自定義的collate函數
)test_loader = DataLoader(test_dataset,batch_size=test_size,  # 測試集使用整個測試集作為一個批次shuffle=True,collate_fn=collate_fn  # 使用自定義的collate函數
)

注釋:
batch_size: 每次加載的樣本數量。
shuffle: 是否對數據進行洗牌,通常訓練時需要將數據打亂。

二、模型架構實現

通過繼承 nn.Module 來定義模型

class DKT(Module):def __init__(self, num_q, emb_size, hidden_size):super().__init__()self.num_q = num_qself.emb_size = emb_sizeself.hidden_size = hidden_sizeself.interaction_emb = Embedding(self.num_q * 2, self.emb_size)self.lstm_layer = LSTM(self.emb_size, self.hidden_size, batch_first=True)self.out_layer = Linear(self.hidden_size, self.num_q)self.dropout_layer = Dropout()def forward(self, q, r):'''q: [batch_size, n]r: [batch_size, n]'''x = q + self.num_q * rh, _ = self.lstm_layer(self.interaction_emb(x))y = self.out_layer(h)y = self.dropout_layer(y)y = torch.sigmoid(y)return y# 創建模型實例
model = DKT()

三、訓練配置

(一)初始化模型與設備

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = DKT(dataset.num_q, emb_size=100,hidden_size=100).to(device)

(二)定義損失函數與優化器

損失函數用于衡量預測值與真實值之間的差異。PyTorch 中提供了現成的損失函數。
將使用 SGD(隨機梯度下降) 或 Adam 優化器來最小化損失函數。

(1)損失函數

from torch.nn.functional import binary_cross_entropy
criterion = nn.binary_cross_entropy()

(2)優化器

from torch.optim import SGD, Adamif optimizer == "sgd":opt = SGD(model.parameters(), learning_rate, momentum=0.9)elif optimizer == "adam":opt = Adam(model.parameters(), learning_rate)

(三)訓練模型評估模型

在訓練過程中,將執行以下步驟:
使用輸入數據 X 進行前向傳播,得到預測值。
計算損失(預測值與實際值之間的差異)。
使用反向傳播計算梯度。
更新模型參數(權重和偏置)。

# 訓練模型
num_epochs = 1000  # 訓練 1000 輪
for epoch in range(num_epochs):model.train()  # 設置模型為訓練模式# 前向傳播predictions = model(X)  # 模型輸出預測值loss = criterion(predictions.squeeze(), Y)  # 計算損失# 反向傳播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 計算梯度optimizer.step()  # 更新模型參數# 打印損失if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item():.4f}')

注釋:
optimizer.zero_grad():每次反向傳播前需要清空之前的梯度。
loss.backward():計算梯度。
optimizer.step():更新權重和偏置。

(四)評估模型

訓練完成后,可以通過查看模型的權重和偏置來評估模型的效果

with torch.no_grad():  # 評估時不需要計算梯度predictions = model(X)

(五)訓練循環實現

import os
import numpy as np
import torch
from torch.nn.functional import one_hot, binary_cross_entropy
from sklearn import metricsdef train_dkt_model(model, train_loader, test_loader, num_epochs, optimizer, ckpt_path):"""訓練DKT模型的獨立函數參數:model: 要訓練的DKT模型實例train_loader: 訓練數據加載器test_loader: 測試數據加載器num_epochs: 訓練輪數optimizer: 優化器實例ckpt_path: 模型檢查點保存路徑"""aucs = []       # 存儲每輪測試AUCloss_means = [] # 存儲每輪平均訓練損失max_auc = 0     # 記錄最佳AUC# 開始訓練循環for epoch in range(1, num_epochs + 1):epoch_losses = []  # 存儲當前epoch的訓練損失# 訓練階段model.train()for data in train_loader:# 解包數據q, r, qshft, rshft, mask = data# 前向傳播y_pred = model(q.long(), r.long())y_pred = (y_pred * one_hot(qshft.long(), model.num_q)).sum(-1)# 應用掩碼選擇有效預測valid_pred = torch.masked_select(y_pred, mask)valid_target = torch.masked_select(rshft, mask)# 計算損失loss = binary_cross_entropy(valid_pred, valid_target)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()# 記錄損失epoch_losses.append(loss.detach().cpu().item())# 計算本輪平均訓練損失epoch_loss_mean = np.mean(epoch_losses)loss_means.append(epoch_loss_mean)# 驗證階段model.eval()all_preds = []all_targets = []with torch.no_grad():for data in test_loader:q, r, qshft, rshft, mask = data# 預測并選擇有效結果y_pred = model(q.long(), r.long())y_pred = (y_pred * one_hot(qshft.long(), model.num_q)).sum(-1)valid_pred = torch.masked_select(y_pred, mask).cpu().numpy()valid_target = torch.masked_select(rshft, mask).cpu().numpy()all_preds.extend(valid_pred)all_targets.extend(valid_target)# 計算整體AUCauc = metrics.roc_auc_score(all_targets, all_preds)aucs.append(auc)# 打印訓練信息print(f"Epoch: {epoch}, AUC: {auc:.4f}, Loss Mean: {epoch_loss_mean:.4f}")# 保存最佳模型if auc > max_auc:torch.save(model.state_dict(), os.path.join(ckpt_path, "model.ckpt"))max_auc = aucprint(f"保存最佳模型,AUC = {auc:.4f}")return aucs, loss_means

理解 y = self(q.long(), r.long())

這行代碼是知識追蹤模型的核心,表示將輸入數據傳入模型進行前向傳播。

1. 代碼結構解析

y = self(q.long(), r.long())

self: 指當前模型實例(EM_DKT)

q.long(): 將問題ID序列轉換為長整型(整數類型)

r.long(): 將響應序列(0/1)轉換為長整型

y: 模型輸出(預測概率)

2. 數據流分析

輸入數據
變量 含義 維度 示例
q 問題ID序列 (batch_size, seq_len) [[101, 102, 0], [201, 0, 0]]
r 響應序列 (batch_size, seq_len) [[1, 0, 0], [0, 0, 0]]

3. 模型內部處理(在EM_DKT.forward()中)

步驟1: 交互編碼
x = q + self.num_q * r

目的: 創建唯一的交互ID
邏輯:
-正確響應: ID = q + num_q * 1
-錯誤響應: ID = q + num_q * 0 = q

示例:
問題101正確: 101 + 100 * 1 = 201
問題101錯誤: 101 + 100 * 0= 101

步驟2: 嵌入層
emb = self.interaction_emb(x)

輸入: 交互ID (batch_size, seq_len)
輸出: 嵌入向量 (batch_size, seq_len, emb_size)
作用: 將離散ID映射為連續向量表示

步驟3: XLSTM處理
for t in range(seq_len):x_t = emb[:, t, :]  # 當前時間步h_t, states = self.xlstm(x_t, states)y_t = self.out_layer(h_t)

XLSTM結構:
-7層MLSTM: 處理長期知識狀態
-1層ELSTM: 處理近期動態

狀態傳遞: 每個時間步更新內部狀態

輸出: 每個時間步的隱藏表示 (hidden_size)

步驟4: 輸出層
y = torch.stack(outputs, dim=1)
y = torch.sigmoid(y)

維度變化: (batch_size, seq_len, num_q)
sigmoid激活: 將輸出轉換為概率[0,1]

4. 輸出

輸出 y 的結構[batch_size,seq_len,num_q]
輸出示例 y[0, 2, 101] = 0.85
表示:批次0中,第2個時間步后,學生答對問題101的概率是85%

5. 實際應用場景

訓練時
# 預測下一個問題的正確概率
y_next = (y * one_hot(qshft)).sum(-1)
預測時
# 獲取學生當前知識狀態
current_state = self.xlstm.states# 預測下一個問題
next_q = 105
next_input = create_input(next_q)
next_pred = self(next_input, current_state)

6. 數學表示

模型本質上學習了一個函數:

P ( r t + 1 = 1 ∣ q 1 : t , r 1 : t ) = f ( q 1 : t , r 1 : t ) P(r_{t+1}=1 | q_{1:t}, r_{1:t}) = f(q_{1:t}, r_{1:t}) P(rt+1?=1∣q1:t?,r1:t?)=f(q1:t?,r1:t?)
其中:

q 1 : t q_{1:t} q1:t?: 到時間t為止的問題序列

r 1 : t r_{1:t} r1:t?: 到時間t為止的響應序列

f f f: 由EM_DKT模型參數化的復雜非線性函數

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/86575.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/86575.shtml
英文地址,請注明出處:http://en.pswp.cn/web/86575.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Unity_導航操作(鼠標控制人物移動)_運動動畫

文章目錄 前言一、Navigation 智能導航地圖烘焙1.創建Plan和NavMesh Surface2.智能導航地圖烘焙 二、MouseManager 鼠標控制人物移動1.給場景添加人物,并給人物添加導航組件2.編寫腳本管理鼠標控制3.給人物編寫腳本,訂閱事件(添加方法給Mouse…

6. 接口分布式測試pytest-xdist

pytest-xdist實戰指南:解鎖分布式測試的高效之道 隨著測試規模擴大,執行時間成為瓶頸。本文將帶你深入掌握pytest-xdist插件,利用分布式測試將執行速度提升300%。 一、核心命令解析 加速安裝(國內鏡像) pip install …

預訓練語言模型

預訓練語言模型 1.1Encoder-only PLM ? Transformer結構主要由Encoder、Decoder組成,根據特點引入了ELMo的預訓練思路。 ELMo(Embeddings from Language Models)是一種深度上下文化詞表示方法, 該模型由一個**前向語言模型&…

Altera PCI IP target設計分享

最近調試也有關于使用Altera 家的PCI IP,然后分享一下代碼: 主要實現:主控作為主設備,FPGA作為從設備,主控對FPGA IO讀寫的功能 后續會分享FPGA作為主設備, 從 FPGA通過 memory寫到主控內存,會…

基于機器學習的智能文本分類技術研究與應用

在當今數字化時代,文本數據的爆炸式增長給信息管理和知識發現帶來了巨大的挑戰。從新聞文章、社交媒體帖子到企業文檔和學術論文,海量的文本數據需要高效地分類和管理,以便用戶能夠快速找到所需信息。傳統的文本分類方法主要依賴于人工規則和…

前端項目3-01:登錄頁面

一、效果圖 二、全部代碼 <!DOCTYPE html> <html><head><meta charset"utf-8"><title>碼農魔盒</title><style>.bg{position: fixed;top: 0;left:0;object-fit: cover;width: 100vw;height: 100vh;}.box{width: 950px;he…

Nexus CLI:簡化你的分布式計算貢獻之旅

探索分布式證明網絡的力量&#xff1a;Nexus CLI 項目深入解析 在今天的數字時代&#xff0c;分布式計算和去中心化技術正成為互聯網發展的前沿。Nexus CLI 是一個為 Nexus 網絡提供證明的高性能命令行界面&#xff0c;它不僅在概念上先進&#xff0c;更是在具體實現中為開發者…

IBW 2025: CertiK首席商務官出席,探討AI與Web3融合帶來的安全挑戰

6月26日至27日&#xff0c;全球最大的Web3安全公司CertiK亮相伊斯坦布爾區塊鏈周&#xff08;IBW 2025&#xff09;&#xff0c;首席商務官Jason Jiang出席兩場圓桌論壇&#xff0c;分享了CertiK在AI與Web3融合領域的前沿觀察與安全見解。他與普華永道土耳其網絡安全服務主管Nu…

Vivado 五種仿真類型的區別

Vivado 五種仿真類型的區別 我們還是用“建房子”的例子來類比。您已經有了“建筑藍圖”&#xff08;HLS 生成的 RTL 代碼&#xff09;&#xff0c;現在要把它建成真正的房子&#xff08;FPGA 電路&#xff09;。這五種仿真就是在這個過程中不同階段的“質量檢查”。 1. 行為…

小程序快速獲取url link方法,短信里面快速打開鏈接

獲取小程序鏈接方法 uni.request({url:https://api.weixin.qq.com/cgi-bin/token?grant_typeclient_credential&appidwxxxxxxxxxxxx&secret111111111111111111111111111111111,method:GET,success(res) {console.log(res.data)let d {"path": "/xxx/…

Spring 框架(1-4)

第一章&#xff1a;Spring 框架概述 1.1 Spring 框架的定義與背景 Spring 是一個開源的輕量級 Java 開發框架&#xff0c;于 2003 年由 Rod Johnson 創立&#xff0c;旨在解決企業級應用開發的復雜性。其核心設計思想是面向接口編程和松耦合架構&#xff0c;通過分層設計&…

RabitQ 量化:既省內存又提性能

突破高維向量內存瓶頸:Mlivus Cloud RaBitQ量化技術的工程實踐與調優指南 作為大禹智庫高級研究員,擁有三十余年向量數據庫與AI系統架構經驗的我發現,在當今多模態AI落地的核心場景中,高維向量引發的內存資源消耗問題已成為制約系統規模化部署的“卡脖子”因素。特別是在大…

創客匠人:創始人 IP 打造的得力助手

在當今競爭激烈的商業環境中&#xff0c;創始人 IP 的打造對于企業的發展愈發重要。一個鮮明且具有影響力的創始人 IP&#xff0c;能夠為企業帶來獨特的競爭優勢&#xff0c;提升品牌知名度與美譽度。創客匠人在創始人 IP 打造過程中扮演著不可或缺的角色&#xff0c;為創始人提…

如何為虛擬機上的 Manjaro Linux啟用 VMware 拖放功能

如果你的Manjaro 發行版本是安裝在 VMware Workstation Player 上使用的 &#xff0c;而且希望可以通過拖放功能將文件或文件夾從宿主機復制到客戶端的Manjaro 里面&#xff0c;那么可以按照以下的步驟進行操作&#xff0c;開啟拖放功能。 在 VMware 虛擬機上安裝 Manjaro 后&…

【C/C++】單元測試實戰:Stub與Mock框架解析

C 單元測試中的 Stub/Mock 框架詳解 在單元測試中&#xff0c;Stub&#xff08;打樁&#xff09;和Mock都是替代真實依賴以簡化測試的技術。通常&#xff0c;Stub&#xff08;或 Fake&#xff09;提供了一個簡化實現&#xff0c;用于替代生產代碼中的真實對象&#xff08;例如…

工廠 + 策略設計模式(實戰教程)

在軟件開發中&#xff0c;設計模式是解決特定問題的通用方案&#xff0c;而工廠模式與策略模式的結合使用&#xff0c;能在特定業務場景下發揮強大的威力。本文將基于新增題目&#xff08;題目類型有單選、多選、判斷、解答&#xff09;這一業務場景&#xff0c;詳細闡述如何運…

Nuxt3中使用 Ant-Design-Vue 的BackTop 組件實現自動返回頁面頂部

在現代 Web 應用中&#xff0c;提供一個方便用戶返回頁面頂部的功能是非常重要的。Ant Design Vue 提供了 BackTop 組件&#xff0c;可以輕松實現這一功能。本文將詳細介紹如何在 Nuxt 3 項目中使用 <a-back-top/> 組件&#xff0c;并通過按需引入的方式加載組件及其樣式…

在統信UOS(Linux)中構建SQLite3桌面應用筆記

目錄 1 下載lazarus 2 下載sqlite3源碼編譯生成庫文件 3 新建項目 4 設置并編譯 一次極簡單的測試&#xff0c;記錄一下。 操作系統&#xff1a;統信UOS&#xff0c; 內核&#xff1a;4.19.0-arm64-desktop 處理器&#xff1a;D3000 整個流程難點是生成so庫文件并正確加…

Host ‘db01‘ is not allowed to connect to this MariaDB server 怎么解決?

出現錯誤 ERROR 1130 (HY000): Host db01 is not allowed to connect to this MariaDB server&#xff0c;表示當前用戶 test 沒有足夠的權限從主機 db01 連接到 MariaDB 服務器。以下是逐步解決方案&#xff1a; 1. 檢查用戶權限 登錄 MariaDB 服務器&#xff08;需本地或通過…

打造高可用的大模型推理服務:基于 DeepSeek 的企業級部署實戰

&#x1f4dd;個人主頁&#x1f339;&#xff1a;一ge科研小菜雞-CSDN博客 &#x1f339;&#x1f339;期待您的關注 &#x1f339;&#x1f339; 一、引言&#xff1a;從“能部署”到“可用、好用、能擴展” 近年來&#xff0c;隨著 DeepSeek、Qwen、Yi 等開源大模型的持續發…