第36周———— RNN實現阿爾茨海默病診斷

目錄

前言

1.檢查GPU

2.查看數據

3.劃分數據集

4.創建模型與編譯訓練

????5.編譯及訓練模型?

6.結果可視化

7.模型預測?

8.總結:

前言

🍨 本文為🔗365天深度學習訓練營中的學習記錄博客
🍖 原作者:K同學啊

1.檢查GPU

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns#設置GPU訓練,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

2.查看數據

df = pd.read_csv("DATA/alzheimers_disease_data.csv")
# 刪除第一列和最后一列
df = df.iloc[:, 1:-1]
df

3.劃分數據集

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX = df.iloc[:,:-1]
y = df.iloc[:,-1]# 將每一列特征標準化為標準正太分布,注意,標準化是針對每一列而言的
sc = StandardScaler()
X  = sc.fit_transform(X)X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)X_train.shape, y_train.shapefrom torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64, shuffle=False)test_dl  = DataLoader(TensorDataset(X_test, y_test),batch_size=64, shuffle=False)

4.創建模型與編譯訓練

class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200, num_layers=1, batch_first=True)self.fc0   = nn.Linear(200, 50)self.fc1   = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x) out    = self.fc0(out) out    = self.fc1(out) return out   model = model_rnn().to(device)
model

????5.編譯及訓練模型?

# 訓練循環
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 訓練集的大小num_batches = len(dataloader) # 批次數目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0 # 初始化訓練損失和正確率for X, y in dataloader: # 獲取圖片及其標簽X, y = X.to(device), y.to(device)# 計算預測誤差pred = model(X) # 網絡輸出loss = loss_fn(pred, y) # 計算網絡輸出和真實值之間的差距,targets為真實值,計算二者差值即為損失# 反向傳播optimizer.zero_grad() # grad屬性歸零loss.backward() # 反向傳播optimizer.step() # 每一步自動更新# 記錄acc與losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 測試集的大小num_batches = len(dataloader) # 批次數目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 當不進行訓練時,停止梯度更新,節省計算內存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 計算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_lossloss_fn = nn.CrossEntropyLoss() # 創建損失函數
learn_rate = 1e-4 # 學習率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 獲取當前的學習率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))
print("="*20, 'Done', "="*20)

6.結果可視化

import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False # 用來正常顯示負號
plt.rcParams['figure.dpi'] = 100 #分辨率from datetime import datetime
current_time = datetime.now() # 獲取當前時間epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡請帶上時間戳,否則代碼截圖無效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay# 計算混淆矩陣
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")# 修改字體大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label", fontsize=10)
plt.ylabel("True Label", fontsize=10)# 顯示圖
plt.tight_layout()  # 調整布局防止重疊
plt.show()

??

?

7.模型預測?

test_X = X_test[0].reshape(1, -1) # X_test[0]即我們的輸入數據pred = model(test_X.to(device)).argmax(1).item()
print("模型預測結果為:",pred)
print("=="*20)
print("0:未患病")
print("1:已患病")

?

?

8.總結:

代碼展示了如何使用PyTorch框架進行阿爾茨海默病數據集的分類任務。以下是該代碼的主要步驟和功能總結:

檢查GPU:首先,代碼檢查是否有可用的GPU,并設置相應的設備(cuda或cpu)。

查看數據:通過Pandas庫加載數據集,并刪除第一列和最后一列,這可能是為了去除非特征信息(如ID)或冗余信息。

劃分數據集:對數據進行預處理,包括標準化以及將數據劃分為訓練集和測試集。接著,使用PyTorch的DataLoader創建數據加載器以便于后續模型訓練時的數據批次處理。

創建模型與編譯訓練:定義了一個基于RNN的神經網絡模型model_rnn,包含RNN層和兩個全連接層。模型被移動到之前設定的設備(GPU或CPU)上。

編譯及訓練模型:定義了訓練和測試函數,分別用于執行模型的訓練過程和評估過程。采用交叉熵損失作為損失函數,Adam優化器作為優化算法。經過30個epoch的訓練后,記錄并打印出每個epoch的訓練和測試準確率及損失值。

結果可視化:使用Matplotlib繪制訓練和測試的準確率與損失的變化曲線圖,直觀地展示模型的學習效果。同時,還生成了混淆矩陣以進一步分析模型性能。

模型預測:最后,選取了一條測試數據進行模型預測,輸出預測結果,并解釋了預測結果的意義(是否患病)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/91773.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/91773.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/91773.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

equals和hashcode方法重寫

在 Java 中,當你需要基于對象的內容而非引用地址來判斷兩個對象是否相等時,就需要重寫equals和hashCode方法。以下是具體場景和實現原則:一、為什么需要同時重寫這兩個方法?equals方法:默認比較對象的內存地址&#xf…

Excel批量生成SQL語句 Excel批量生成SQL腳本 Excel拼接sql

Excel批量生成SQL語句 Excel批量生成SQL腳本 Excel拼接sql一、情境描述在Excel中有標準的格式化數據,如何快速導入到數據庫中呢?有些工具支持Excel導入的,則可以快速導入數據---例如Navicat;如果不支持呢,如果將Excel表…

金和OA C6 DelTemp.aspx 存在XML實體注入漏洞(CVE-2025-7523)

免責聲明 本文檔所述漏洞詳情及復現方法僅限用于合法授權的安全研究和學術教育用途。任何個人或組織不得利用本文內容從事未經許可的滲透測試、網絡攻擊或其他違法行為。 前言:我們建立了一個更多,更全的知識庫。每日追蹤最新的安全漏洞,追中25HW情報。 更多詳情: http…

Android性能優化之啟動優化

一、啟動性能瓶頸深度分析 1. 冷啟動階段耗時分布階段耗時占比關鍵阻塞點進程創建15%fork進程 加載ZygoteApplication初始化40%ContentProvider/庫初始化Activity創建30%布局inflate 視圖渲染首幀繪制15%VSync信號等待 GPU渲染2. 高頻性能問題 初始化風暴:多個庫…

中國優秀開源軟件及企業調研報告

中國優秀開源軟件及企業調研報告 引言 當前中國開源生態呈現蓬勃發展態勢,技術創新領域尤為活躍,其中人工智能大模型成為開源動作的核心聚焦方向。2025年上半年,國內AI領域開源生態迎來密集爆發,頭部科技企業相繼推出重要開源舉…

C++語法 匿名對象 與 命名對象 的詳細區分

目錄一、匿名對象的本質定義二、匿名對象的調用邏輯:即生即用的設計三、與命名對象的核心差異四、匿名對象的典型應用場景五、匿名對象的潛在風險與規避六、總結:匿名對象的價值定位在 C 類與對象的知識體系中,匿名對象是一種容易被咱們忽略&…

【Fedora 42】Linux內核升級后,鼠標滾輪失靈,libinput的鍋?

解決: 最近在玩Fedora 42,升級了一次給俺鼠標滾輪干失靈了。原因可能是 libinput 升級后與Fedora升級后的某些配置有沖突?(搞不懂) sudo dnf downgrade libinput降級 libinput (1.28.901-1.fc42 -> 1.28.0-1.fc42) …

虛擬機centos服務器安裝

創建虛擬機選擇鏡像啟動 移除舊的repo文件: sudo rm -f /etc/yum.repos.d/CentOS-Base.repo下載阿里云的repo文件: 對于CentOS 7: sudo wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repo清除緩存并生…

【js(1)一文解決】var let const

var let const!在 ES6 之前,JavaScript 只有兩種作用域: 全局變量 與 函數內的局部變量一、var1. 函數級作用域,有變量提升二、let(ES6新增)1. 塊級作用域,不會影響外部作用域2.let 關鍵字在不同…

論螺旋矩陣

螺旋矩陣題型總結。我刷了幾道螺旋矩陣相關的題目,這里我們介紹一下一些常見的解法。 螺旋矩陣 方形矩陣 當我們遇到n*n的方形矩陣時,可以用一種特殊的解法來遍歷實現,以下面這道題為例: 59. 螺旋矩陣 II 我們可以定義幾個變…

數學金融與金融工程:學科差異與選擇指南

在金融領域的學習中,數學金融與金融工程常被混淆。兩者雖同屬 “金融 量化” 交叉方向,但在研究側重、培養路徑上有顯著區別。結合學科特點與行業實踐,幫大家理清兩者的核心差異,以便更精準地選擇方向。一、核心差異:…

包管理工具npm cnpm yarn的使用

包管理工具 1. 什么是包管理工具? 包管理工具是用于管理和安裝 Node.js 項目依賴的工具。它們提供了一種結構化的方式來管理項目的依賴關系,使得項目的依賴管理變得更加便捷和可靠。 2. 常見的包管理工具有哪些? npm(Node Package Manager):是 Node.js 的默認包管理工…

網絡基礎13--鏈路聚合技術

一、鏈路聚合概述定義將多條物理鏈路捆綁為一條邏輯鏈路,提升帶寬與可靠性。2. 應用場景交換機/路由器/服務器之間的互聯,支持二層(數據鏈路層)和三層(網絡層)聚合。二、核心作用增加帶寬聚合鏈路的總帶寬 …

一文講清楚React性能優化

文章目錄一文講清楚React性能優化1. React性能優化概述2. React性能優化2.1 render優化2.2 較少使用內聯函數2.3 使用React Fragments避免額外標記2.4 使用Immutable上代碼2.5 組件懶加載2.6 服務端渲染2.7 其他優化手段一文講清楚React性能優化 1. React性能優化概述 React通…

3.0 - 指針-序列化

一、關于Serialize的使用 可以使用該指令臨時將用戶程序的多個結構化數據項保存到緩沖區中(最好位于全局數據塊中)。用于保存轉換后數據的存儲區的數據類型必需為 ARRAY of BYTE 或 ARRAY of CHAR 相當于把一個struct或其他自定義類型變成一個字節數組。 比如我有好幾個結構體…

【論文精讀】基于共識的分布式量子分解算法用于考慮最優傳輸線切換的安全約束機組組合

本次分析的論文《Consensus‐Based Distributed Quantum Decomposition Algorithm for Security‐Constrained Unit Commitment Considering Optimal Transmission Switching》于2025年6月25日在《Advanced Quantum Technologies》期刊上公開發表。本文提出了一個新的基于共識的…

MyBatis-Flex代碼生成

引入依賴 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId> </dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok<…

知網論文批量下載pdf格式論文,油猴腳本

任務描述 今天收到一個任務&#xff0c;在知網上&#xff0c;把一位專家所有的論文全都下載下來&#xff0c;要保存為PDF格式。 知網不支持批量導出PDF格式論文。一個一個下載PDF&#xff0c;太繁瑣了。 解決方案&#xff1a;找到一個油猴腳本&#xff0c;這個腳本可以從知網…

低代碼平臺:驅動項目管理敏捷開發新范式

隨著企業數字化轉型加速&#xff0c;項目管理系統已從單一任務跟蹤工具到集成流程自動化、資源調度、跨團隊協作與風險監控的綜合平臺&#xff0c;項目管理系統的功能復雜度持續提升。然而&#xff0c;根據Gartner 2024年研究報告顯示&#xff0c;約60%的項目管理系統因未能有效…

圖機器學習(11)——鏈接預測

圖機器學習&#xff08;11&#xff09;——鏈接預測0. 鏈接預測1. 基于相似性的方法1.1 基于指標的方法1.2 基于社區的方法2. 基于嵌入的方法0. 鏈接預測 鏈接預測 (link prediction)&#xff0c;也稱為圖補全&#xff0c;是處理圖時常見的問題。具體而言&#xff0c;給定一個…