【PYG】Cora數據集分類任務計算損失,cross_entropy為什么不能直接替換成mse_loss

  • cross_entropy計算誤差方式,輸入向量z為[1,2,3],預測y為[1],選擇數為2,計算出一大坨e的式子為3.405,再用-2+3.405計算得到1.405
  • MSE計算誤差方式,輸入z為[1,2,3],預測向量應該是[1,0,0],和輸入向量維度相同

在這里插入圖片描述
將cross_entropy直接替換成mse_loss報錯RuntimeError: The size of tensor a (7) must match the size of tensor b (140) at non-singleton dimension 1

cross_entropy 換成 mse_loss 會報錯的原因是,這兩個損失函數的輸入和輸出形狀要求不同。cross_entropy 是一個分類損失函數,它期望輸入是未歸一化的logits(形狀為 [batch_size, num_classes]),而標簽是整數類別(形狀為 [batch_size])。mse_loss 是一個回歸損失函數,它期望輸入和標簽的形狀相同。

如果你想使用 mse_loss 來替代 cross_entropy,你需要對標簽進行one-hot編碼,使它們與模型的輸出形狀匹配。下面是如何修改代碼以使用 mse_loss 的示例:

修改代碼以使用 mse_loss

  1. 加載必要的庫
    你需要一個工具來將標簽轉換為one-hot編碼。這里我們使用 torch.nn.functional.one_hot

  2. 修改訓練函數
    在訓練函數中,將標簽轉換為one-hot編碼,然后計算 mse_loss

核心測試代碼講解

out=model(data)模型輸出形狀為torch.Size([140, 7])
data.y中測試數據輸出形狀為torch.Size([140]),打印第一個數據為3,7個類別中的第4個類別
將3轉化為7位置獨熱碼計算MSE,對應train_labels_one_hot第一個數據[0., 0., 0., 1., 0., 0., 0.]為4
out形狀為torch.Size([140, 7]),train_labels_one_hot的形狀為[140, 7]

torch.Size([140, 7]) torch.Size([140])
tensor([-0.0166,  0.0191, -0.0036, -0.0053, -0.0160,  0.0071, -0.0042],device='cuda:0', grad_fn=<SelectBackward0>) tensor(3, device='cuda:0')
tensor([[0., 0., 0., 1., 0., 0., 0.],...[0., 1., 0., 0., 0., 0., 0.]], device='cuda:0')
train_labels_one_hot shape torch.Size([140, 7])
test out torch.Size([2708, 7])
train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()
print(out[data.train_mask].shape, data.y[data.train_mask].shape)
print(out[data.train_mask][0], data.y[data.train_mask][0])
print(train_labels_one_hot)
print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")
loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)

解釋

  1. 加載庫:我們使用 torch.nn.functional.one_hot 將標簽轉換為one-hot編碼。
  2. 修改訓練函數
    • 將標簽 train_labels 轉換為one-hot編碼,train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()
    • 使用 mse_loss 計算均方誤差損失 loss = F.mse_loss(train_out, train_labels_one_hot)
  3. 保持評估函數不變:評估函數仍然使用 argmax 提取預測類別,并計算準確性。

魔改完整代碼

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures# 加載Cora數據集
dataset = Planetoid(root='/tmp/Cora', name='Cora',  transform=NormalizeFeatures())
data = dataset[0]# 定義GCN模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16)self.conv2 = GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return x# return F.log_softmax(x, dim=1)# 初始化模型和優化器
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')# 打印歸一化后的特征
print(data.x[0])print(f"data.train_mask{data.train_mask}")# 訓練模型
def train():model.train()optimizer.zero_grad()out = model(data)# print(f"out[data.train_mask] {data.train_mask.shape} {out[data.train_mask].shape} {out[data.train_mask]}")# loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()print(out[data.train_mask].shape, data.y[data.train_mask].shape)print(out[data.train_mask][0], data.y[data.train_mask][0])print(train_labels_one_hot)print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)loss.backward()optimizer.step()return loss.item()# 評估模型
def test():model.eval()out = model(data)print(f"test out {out.shape}")print(f"test out[0] {out[0].shape} {out[0]}")print(f"test out[0:1,:] {out[0:1,:].shape} {out[0:1,:]}")print(f"test out[0:1,:].argmax(dim=1) {out[0:1,:].argmax(dim=1)}")pred = out.argmax(dim=1)print(f"test pred {pred[data.test_mask].shape} {pred[data.test_mask]}")print(f"data {data.y[data.test_mask].shape} {data.y[data.test_mask]}")correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accfor epoch in range(1):loss = train()acc = test()print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

原始代碼

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures# 加載Cora數據集,并應用NormalizeFeatures變換
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]# 計算訓練、驗證和測試集的大小
num_train = data.train_mask.sum().item()
num_val = data.val_mask.sum().item()
num_test = data.test_mask.sum().item()print(f'Number of training nodes: {num_train}')
print(f'Number of validation nodes: {num_val}')
print(f'Number of test nodes: {num_test}')# 定義GCN模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16)self.conv2 = GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return x  # 返回未歸一化的logits# 初始化模型和優化器
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')# 訓練模型
def train():model.train()optimizer.zero_grad()out = model(data)  # out 的形狀是 [num_nodes, num_classes]train_out = out[data.train_mask]  # 選擇訓練集節點的輸出train_labels = data.y[data.train_mask]  # 選擇訓練集節點的標簽# 將標簽轉換為one-hot編碼train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()# 計算均方誤差損失loss = F.mse_loss(train_out, train_labels_one_hot)loss.backward()optimizer.step()return loss.item()# 評估模型
def test():model.eval()out = model(data)pred = out.argmax(dim=1)  # 提取預測類別correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accfor epoch in range(200):loss = train()acc = test()print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

通過這些修改,你可以將交叉熵損失函數替換為均方誤差損失函數,并確保輸入和標簽的形狀匹配,從而避免報錯。

  • 簡單版本的的答案

Cross Entropy vs. MSE Loss

  1. Cross Entropy Loss:

    • 輸入:模型的logits,形狀為 ([N, C]),其中 (N) 是批次大小,(C) 是類別數量。
    • 目標:目標類別的索引,形狀為 ([N])。
  2. MSE Loss:

    • 輸入:模型的預測值,形狀為 ([N, C])。
    • 目標:實際值,形狀為 ([N, C])(通常是 one-hot 編碼)。

要將 cross_entropy 換成 mse_loss,需要確保輸入和目標的形狀匹配。具體來說,你需要將目標類別索引轉換為 one-hot 編碼。

示例代碼

假設你有一個分類任務,其中模型輸出的是 logits,目標是類別索引。我們將這個設置轉換為使用 MSE Loss。

import torch
import torch.nn.functional as F# 假設有一個批次的模型輸出和目標標簽
logits = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], requires_grad=True)  # 模型輸出
target = torch.tensor([0, 2])  # 目標類別# 使用 cross_entropy
cross_entropy_loss = F.cross_entropy(logits, target)
print("Cross-Entropy Loss:")
print(cross_entropy_loss)# 轉換目標類別為 one-hot 編碼
target_one_hot = F.one_hot(target, num_classes=logits.size(1)).float()
print("One-Hot Encoded Targets:")
print(target_one_hot)# 計算 MSE Loss
mse_loss = F.mse_loss(F.softmax(logits, dim=1), target_one_hot)
print("MSE Loss:")
print(mse_loss)

輸出

Cross-Entropy Loss:
tensor(1.4076, grad_fn=<NllLossBackward>)
One-Hot Encoded Targets:
tensor([[1., 0., 0.],[0., 0., 1.]])
MSE Loss:
tensor(0.2181, grad_fn=<MseLossBackward>)

解釋

  1. logits: 模型的原始輸出,形狀為 ([N, C])。
  2. target: 原始目標類別索引,形狀為 ([N])。
  3. target_one_hot: 將目標類別索引轉換為 one-hot 編碼,形狀為 ([N, C])。
  4. F.mse_loss: 使用 F.softmax(logits, dim=1) 計算模型的概率分布,然后與 target_one_hot 計算 MSE 損失。

通過將目標類別轉換為 one-hot 編碼并確保輸入和目標的形狀匹配,可以成功地將 cross_entropy 換成 mse_loss

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/39555.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/39555.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/39555.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Dify入門指南

一.Dify介紹 生成式 AI 應用創新引擎&#xff0c;開源的 LLM 應用開發平臺。提供從 Agent 構建到 AI workflow 編排、RAG 檢索、模型管理等能力&#xff0c;輕松構建和運營生成式 AI 原生應用&#xff0c;比 LangChain 更易用。一個平臺&#xff0c;接入全球大型語言模型。不同…

CesiumJS【Basic】- #050 繪制掃描線(Primitive方式)

文章目錄 繪制掃描線(Primitive方式)- 需要自定義著色器1 目標2 代碼2.1 main.ts繪制掃描線(Primitive方式)- 需要自定義著色器 1 目標 使用Primitive方式繪制掃描線 2 代碼 2.1 main.ts import * as Cesium from cesium;const viewer = new Cesium.Viewer(cesiumConta…

自我反思與暑假及大三上學期規劃

又要放暑假了&#xff0c;依稀記得上個暑假一邊練車&#xff0c;一邊試圖拿捏C語言&#xff0c;第一次感覺暑假也可以如此忙碌。但是開學以后&#xff0c;我并沒有把重心放在期望自己應該做的事情上&#xff0c;更多的時間花費在了處理學院的相關事務。現在看來&#xff0c;大二…

《昇思 25 天學習打卡營第 9 天 | FCN 圖像語義分割 》

活動地址&#xff1a;https://xihe.mindspore.cn/events/mindspore-training-camp 簽名&#xff1a;Sam9029 這一章節 出現了一個 深度學習 中經常出現的概念 全卷積網絡&#xff08;Fully Convolutional Networks&#xff09; : 官話&#xff1a;FCN 主要用于圖像分割領域&…

德璞資本:橋水公司如何利用AI實現投資決策的精準提升?

摘要&#xff1a; 在金融科技的浪潮中&#xff0c;橋水公司推出了一只依靠機器學習決策的創新基金&#xff0c;吸引了大量投資者的關注。本文將深入探討該基金的背景、AI技術的應用、對橋水公司轉型的影響&#xff0c;以及未來發展的前景。 新基金背景&#xff1a;橋水公司的創…

2024年7月2日 (周二) 葉子游戲新聞

老板鍵工具來喚去: 它可以為常用程序自定義快捷鍵&#xff0c;實現一鍵喚起、一鍵隱藏的 Windows 工具&#xff0c;并且支持窗口動態綁定快捷鍵&#xff08;無需設置自動實現&#xff09;。 卸載工具 HiBitUninstaller: Windows上的軟件卸載工具 經典名作30周年新篇《恐怖驚魂夜…

MyBatis入門案例

實施前的準備工作&#xff1a; 1.準備數據庫表2.創建一個新的springboot工程&#xff0c;選擇引入對應的起步依賴&#xff08;mybatis、mysql驅動、lombok&#xff09;3.在application.properties文件中引入數據庫連接信息4.創建對應的實體類Emp&#xff08;實體類屬性采用駝峰…

throw 和return的區別,A函數里面執行B函數 B函數異常后 不再執行A函數

function aFun() {try {bFun();console.log(22222222222);} catch (e) {// 如果bFun中拋出異常&#xff0c;中止aFun的執行console.log(e.message);} }function bFun() {let a 1, b 1;if (a b) {throw new Error(Stopped by bFun); // 拋出異常&#xff0c;停止aFun}// bFun…

python3遞歸目錄刪除N天前的文件(帶有日志記錄)

本來想用linux find去處理,為了裝逼,寫了py玩玩,刪除2w個文件總共用了2毫秒。因為這個腳本有記錄刪除時間,你可以看到開始時間和最后刪除的時間。由于只用了2毫秒,把我嚇了一跳以為刪錯文件了!! #!/usr/bin/env python3 # -*- encoding: utf-8 -*-@File : del_N…

補瀏覽器環境

一&#xff0c;導言 // global是node中的關鍵字&#xff08;全局變量&#xff09;&#xff0c;在node中調用其中的元素時&#xff0c;可以直接引用&#xff0c;不用加global前綴&#xff0c;和瀏覽器中的window類似&#xff1b;在瀏覽器中可能會使用window前綴&#xff1a;win…

校園水質信息化監管系統——水質監管物聯網系統

隨著物聯網技術的發展越來越成熟&#xff0c;它不斷地與人們的日常生活和工作深入融合&#xff0c;推動著社會的進步。其中物聯網系統集成在高校實踐課程中可以應用到許多項目&#xff0c;如環境氣象檢測、花卉種植信息化監管、水質信息化監管、校園設施物聯網信息化改造、停車…

C++編程(八)多態

文章目錄 一、多態&#xff08;一&#xff09;概念1. 多態2. 函數重寫3. 虛函數 &#xff08;二&#xff09;實現多態的條件1. 繼承關系2. 父類中寫虛函數3. 在子類中重寫父類的虛函數4.父類的指針或引用指向子類的對象5. 使用示例 &#xff08;三&#xff09;虛析構函數&#…

springboot項目jar包修改數據庫配置運行時異常

一、背景 我將軟件成功打好jar包了&#xff0c;到部署的時候發現jar包中數據庫配置寫的有問題&#xff0c;不想再重新打包了&#xff0c;打算直接修改配置文件&#xff0c;結果修改配置后&#xff0c;再通過java -jar運行時就報錯了。 二、問題描述 本地項目是springBoot項目…

【計算機圖形學 | 基于MFC三維圖形開發】期末考試知識點匯總(上)

文章目錄 視頻教程第一章 計算機圖形學概述計算機圖形學的定義計算機圖形學的應用計算機圖形學 vs 圖像處理 vs模式識別圖形顯示器的發展及工作原理理解三維渲染管線 第二章 基本圖元的掃描轉換掃描轉換直線的掃描轉換DDA算法Bresenham算法中點畫線算法圓的掃描轉換中點畫圓算法…

Java中的持續集成與持續部署

Java中的持續集成與持續部署 大家好&#xff0c;我是免費搭建查券返利機器人省錢賺傭金就用微賺淘客系統3.0的小編&#xff0c;也是冬天不穿秋褲&#xff0c;天冷也要風度的程序猿&#xff01;今天我們將深入探討Java中的持續集成&#xff08;Continuous Integration&#xff…

熟練掌握Docker及linux常用命令排查線上問題。熟悉Git, Maven等項目管理及構建工具,熟悉微服務中基于Jenkins的CI/CD

掌握Docker、Linux命令、項目管理及構建工具&#xff0c;以及CI/CD流程是現代軟件開發和運維的關鍵技能。以下是對這些技能的概述和一些實踐建議&#xff1a; ### Docker - **概述**&#xff1a;Docker是一個開源的容器化平臺&#xff0c;允許開發者打包應用及其依賴到一個可移…

【Godot4.2】Godot中的貝塞爾曲線

概述 通過指定平面上的多個點&#xff0c;然后順次連接&#xff0c;我們可以得到折線段&#xff0c;如果閉合圖形&#xff0c;就可以獲得多邊形。通過向量旋轉我們可以獲得圓等特殊圖形。 但是對于任意曲線&#xff0c;我們無法使用簡單的方式來獲取其頂點&#xff0c;好在計…

mac上使用finder時候,顯示隱藏的文件或者文件夾

默認在finder中是不顯示隱藏的文件和文件夾的&#xff0c;但是想創建.gitignore文件&#xff0c;并向里面寫入內容&#xff0c;即便是打開xcode也是不顯示這幾個隱藏文件的&#xff0c;那有什么辦法呢&#xff1f; 使用快捷鍵&#xff1a; 使用finder打開包含隱藏文件的文件夾…

Linux如何安裝openjdk1.8

文章目錄 Centosyum安裝jdk和JRE配置全局環境變量驗證ubuntu使用APT(適用于Ubuntu 16.04及以上版本)使用PPA(可選,適用于需要特定版本或舊版Ubuntu)Centos yum安裝jdk和JRE yum install java-1.8.0-openjdk-devel.x86_64 安裝后的目錄 配置全局環境變量 vim /etc/pr…

ISP IC/FPGA設計-第一部分-SC130GS攝像頭分析-IIC通信(1)

1.攝像頭模組 SC130GS通過一個引腳&#xff08;SPI_I2C_MODE&#xff09;選擇使用IIC或SPI配置接口&#xff0c;通過查看攝像頭模組的原理圖&#xff0c;可知是使用IIC接口&#xff1b; 通過手冊可知IIC設備地址通過一個引腳控制&#xff0c;查看攝像頭模組的原理圖&#xff…