深入學習pytorch筆記

兩個重要的函數

  • dir(): 一個內置函數,用于列出對象的所有屬性和方法
    在這里插入圖片描述

  • help():一個內置函數,用于獲取關于Python對象、模塊、函數、類等的詳細信息
    在這里插入圖片描述

Dateset類

  • Dataset:pytorch中的一個類,開發者在訓練和測試時,用一個子類去繼承Dataset類,繼承和重寫Dataset類中方法和屬性,以加載數據集。
class Dataset(object):"""An abstract class representing a Dataset.All other datasets should subclass it. All subclasses should override``__len__``, that provides the size of the dataset, and ``__getitem__``,supporting integer indexing in range from 0 to len(self) exclusive."""def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])
  • def getitem(self, index):必須重寫,用于以加載數據集。
  • def len(self):可不重寫,用于計算數據集中樣本個數。
    在這里插入圖片描述

TensorBoard

  • TensorBoard 是pytorch中一組用于數據可視化的工具,包含在TensorFlow庫。
  • SummaryWriter類:用于在給定目錄中創建事件文件,在訓練時,將數據添加到文件中,用于顯示。使用SummaryWriter類創建對象時,若沒有給出事件文件名,則默認的事件文件名為run。

損失函數

  • torch.nn.loss():PyTorch 中的一個類,用于計算L1 損失函數,即計算了預測值與實際值之間的L1范數(即絕對差值)。
  • 在創建torch.nn.L1Loss(reduction)對象時,可以傳入一個可選的參數reduction,它決定了如何從每個樣本的損失中聚合得到最終的損失。
    1. reduction=‘mean’:計算所有樣本損失的平均值作為最終損失。默認情況下,reduction參數的值為’mean’,即計算所有樣本損失的平均值作為最終損失。
    2. reduction=‘none’:不進行任何聚合操作,直接返回每個樣本的損失。
    3. reduction=‘sum’:計算所有樣本損失的總和作為最終損失。
    4. reduction= ‘mean_none’: 計算所有樣本損失的平均值,但是不除以樣本數,即不進行歸一化。
    5. reduction=‘sum_none’:計算所有樣本損失的總和,但是不乘以樣本數,即不進行歸一化。
  • 在調用torch.nn.L1Loss()對象時,要傳入預測值和實際值。
    在這里插入圖片描述
  • torch.nn.MSELoss():PyTorch庫中的一個類,用于計算均方誤差。MSE損失函數的計算方式是:對于每個樣本,計算預測值與真實值之間的平方差,然后取這些平方差的平均值。具體公式為:loss = 1/n Σ (y_pred - y_true)^2,其中n是樣本數量。
    在這里插入圖片描述
  • torch.nn.CrossEntropyLoss:是PyTorch庫中的一個類,用于計算交叉熵損失。
  • 在創建對象時,torch.nn.CrossEntropyLoss()參數:
    1. weight: 類別權重。這是一個一維的tensor,用于為每個類別指定不同的權重。默認值是None,這時所有的類別權重都相等。如果指定了類別權重,那么在計算損失時,每個類別的損失將會根據其對應的權重進行加權平均。
    2. reduction: 損失的歸約方式。這個參數決定了如何將交叉熵損失的值從樣本級別降低到批次級別。可能的值有:‘none’(不進行歸約,返回每個樣本的交叉熵損失),‘mean’(對所有樣本的交叉熵損失取平均),‘sum’(將所有樣本的交叉熵損失相加)。默認值是’mean’。
    3. ignore_index: 被忽略的類別索引。如果設置了該參數,那么在計算交叉熵損失時,該類別對應的損失將被忽略。這個參數主要用于處理數據集中的無效類別或不需要分類的類別。默認值是-100。
  • 在調用torch.nn.CrossEntropyLoss的對象時,需要傳入兩個參數:
    1. input:這是一個一維或二維張量,表示模型的輸出。對于每個輸入樣本,輸出應該是一個長度為類別數量的向量,每個元素表示該類別與輸入樣本的相似度。
    2. target:這是一個一維張量,表示每個輸入樣本的正確類別標簽。
      在這里插入圖片描述

優化器(參數更新)

  • torch.optim.SGD:PyTorch 中的一個類,它實現了隨機梯度下降(Stochastic Gradient Descent)算法。
  • 創建類對象時,torch.optim.SGD(params,lr,momentum,dampening,weight_decay,nesterov)的參數:
    1. params:要優化的參數,通常是模型中的參數。
    2. lr:學習率。控制參數更新的步長。默認值是0.01。
    3. momentum:動量。這個參數會考慮之前梯度的方向,使得優化器具有一定的"慣性",有助于加速訓練。默認值是0。
    4. dampening:阻尼。這個參數可以防止動量過大導致震蕩。默認值是0。
    5. weight_decay:權重衰減。可以防止過擬合,通過對參數本身進行懲罰來控制模型的復雜度。默認值是0,表示不進行權重衰減。
    6. nesterov:是否使用 Nesterov 動量。如果為 True,會使用 Nesterov 動量,否則使用標準 momentum。默認值是False
  • 創建優化器后,我們可以通過調用 optimizer.zero_grad() 清除之前的梯度,然后通過反向傳播計算新的梯度,最后使用 optimizer.step() 更新模型的參數。

import torch
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten
from torch.nn import Linear
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=64)class MY_Dodule(nn.Module):def __init__(self):super(MY_Dodule,self).__init__()self.model = Sequential(Conv2d(3, 32, kernel_size=5, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,input):output = self.model(input)return outputmy_module = MY_Dodule()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(my_module.parameters(),lr=0.1)
for epoch in range(20):running_loss = 0.0for data in dataloader:images,targets = datainput = imagesoutput = my_module(input)  # 前向轉播result_loss = loss(output,targets)  # 計算損失optim.zero_grad()  # 清除之前的梯度result_loss.backward() # 反向轉播optim.step() #梯度更新running_loss += result_losspassprint(running_loss)pass

網絡模型的使用和修改

  • torchvision.models.vgg16(pretrained,progress):PyTorch 中的一個類,是用來加載預訓練的 VGG-16 模型的函數。

    1. pretrained:布爾型,決定是否從 PyTorch 的預訓練模型庫中加載訓練好的權重。如果設為 True,則返回的模型會包含在大規模圖像分類任務上訓練得到的權重。如果設為 False,則模型不包含預訓練的權重,你需要自己訓練模型。默認為False。
    2. progress:布爾型,決定是否顯示下載預訓練模型過程的進度條。如果設為 True,則在下載預訓練模型時會顯示進度條。默認為True。
  • 在 VGG-16 模型中添加層:model是torchvision.models.vgg16()示例化對象,model.classifier.add_module(str,nn.Module)這個函數接受兩個參數。

    1. 模塊名稱(str):這是你想要添加的模塊的名稱。你可以自己定義一個有意義的名稱,以便在后續的代碼中引用這個模塊。
    2. 模塊對象(nn.Module):這是你想要添加的模塊本身。這個模塊可以是任何PyTorch定義的神經網絡層或者你自己定義的層。
  • 在 VGG-16 模型中修改層:model是torchvision.models.vgg16()示例化對象,model.classifier[n] = nn.Module

    1. n:VGG-16 模型中修改層的層號
    2. nn.Module:修改后的模塊本身。這個模塊可以是任何PyTorch定義的神經網絡層或者你自己定義的層。
      在這里插入圖片描述

網絡模型的保存與讀取

  • torch.save(model, ‘model.pth’):PyTorch 中的一個函數,模型model的權重和參數,保存在指定文件model.pth中。
  • model = torch.load(‘model.pth’):PyTorch 中的一個函數,根據model.pth文件,加載保存的模型并返回給變量 model
  • torch.save(model.state_dict(), ‘model.pth’): 將模型model參數(權重和偏置等,不包括模型的結構),以字典的形式保存到指定的文件 ‘model.pth’ 中。
  • model.load_state_dict(torch.load(‘model.pth’)):torch.load()函數讀取文件中模型的參數信息,加載到model模型中。請注意,這種方式要求你在加載模型時已經知道模型model的結構。

模型訓練流程(以CIFAR10為例)

  • 第一步:準備數據集,包括訓練集和測試集
import torchvision# 準備訓練集
train_data = torchvision.datasets.CIFAR10("dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)# 準備測試集
test_data = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
  • 第二步:計算數據長度
# 計算數據集長度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("訓練數據集的長度:{}".format(train_data_size))
print("測試數據集的長度:{}".format(test_data_size))
  • 第三步:用dataloader()加載數據集,將數據集劃分為批量子集
# dataloader()加載數據集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
  • 第四步:搭建神經網絡,一般用一個單獨python文件保存
import torch
from torch import nnclass My_Module(nn.Module):def __init__(self):super(My_Module,self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32 ,32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32,64,5,1,2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4,64),nn.Linear(64,10),)def forward(self,input):output = self.model(input)return outputif __name__ == '__main__':my_module = My_Module()input = torch.ones((64, 3, 32, 32))output = my_module(input)print(output.shape)
  • 第五步:創建網絡模型
# 創建網絡模型
my_module = My_Module()
  • 第六步:定義損失函數
loss_f = nn.CrossEntropyLoss()
  • 第七步:定義優化器,進行梯度下降
# 定義優化器,進行梯度下降
learning_rate = 0.01  # 學習效率
optimizer = torch.optim.SGD(my_module, lr=learning_rate)
  • 第八步:設置訓練網絡模型的一些參數
# 設置訓練網絡模型的一些參數
total_train_step = 0  # 記錄訓練次數
total_test_step = 0  # 記錄測試次數
epoch = 10 # 訓練的輪次
writer = SummaryWriter("P27")  # 添加tensorboard
  • 第九步:訓練網絡模型
# 訓練網絡模型
for i in range(epoch):print("------第{}輪訓練開始------".format(i + 1))for data in train_dataloader:images ,targets = datainput = imagesoutput = my_module(input)  # 前向傳播loss = loss_f(output, targets)  # 計算損失loss.backward()  # 反向轉播optimizer.zero_grad()  #optimizer.step() # 梯度下降total_train_step = total_train_step + 1print("訓練次數:{},loss:{}".format(total_train_step, loss.item()))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/166476.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/166476.shtml
英文地址,請注明出處:http://en.pswp.cn/news/166476.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

抖音電商品牌力不足咋辦?如何升級或強開旗艦店、官方旗艦店?我們有妙招!

隨著抖音電商的發展,越來越多的商家蜂擁而至,入駐經營抖音小店... 然而我們在開店的時候,選擇開通官方旗艦店、旗艦店、專營店或專賣店,卻被系統提示為你的商標品牌力不足,無法開通官方旗艦店、旗艦店、專營店、專賣店…

Android手電筒、閃光燈、torch、flash

1. 僅開啟手電筒 單純的開啟手電筒我們可以使用CameraManager的.setTorchMode()方法。 cameraCharacteristics.get(CameraCharacteristics.FLASH_INFO_AVAILABLE)獲取該相機特征是否可獲取閃光燈。 CameraManager cameraManager (CameraManager) getSystemService(CAMERA_SE…

在 vscode 中的json文件寫注釋,不報錯的解決辦法

打開 vscode 的「設置」,搜索:files: associations,然后添加 *.json jsonc最后

Nginx 配置錯誤導致的漏洞

目錄 1. CRLF注入漏洞 Bottle HTTP頭注入漏洞 2.目錄穿越漏洞 3. http add_header被覆蓋 本篇要復現的漏洞實驗有一個網站直接為我們提供了Docker的環境,我們只需要下載下來就可以使用: Docker環境的安裝可以參考:Docker安裝 漏洞環境的…

Docker rm 命令

docker rm:刪除一個或多個容器。 語法: docker rm [OPTIONS] CONTAINER [CONTAINER...]OPTIONS說明: -f:通過SIGKILL信號強制刪除一個運行中的容器。 -l :移除容器間的網絡連接,而非容器本身。 -v &…

2023亞太杯數學建模A題思路代碼分析

已經完成A題完整思路代碼,文末名片查看獲取 A題就是我們機器學習中的一個圖像識別,他是水果圖像識別,就是蘋果識別的一個問題,我們用到的方法基本是使用深度學習中的卷積神經網絡來進行識別和分類 問題一:基于附件1中…

展現天津援疆工作成果 “團結村里看振興”媒體采風團走進和田

央廣網天津11月19日消息(記者周思楊)11月18日,由媒體記者、書法和攝影家、旅行社企業代表等40余人組成的“團結村里看振興”媒體采風團走進新疆和田。在接下來的一周時間里,采風團將走訪天津援疆和田地區策勒縣、于田縣、民豐縣鄉村振興示范村&#xff0…

HTML CSS登錄網頁設計

一、效果圖: 二、HTML代碼: <!DOCTYPE html> <!-- 定義HTML5文檔 --> <html lang="en"> …

在全球碳市場中嶄露頭角的中碳CCNG

在全球氣候治理的大背景下&#xff0c;中國碳中和發展集團有限公司&#xff08;簡稱中國碳中和&#xff09;正在成為全球碳交易市場的一個重要參與者。隨著國際社會對碳排放的日益關注&#xff0c;中國碳中和憑借其在碳資產開發、咨詢與管理等領域的深厚積累&#xff0c;正成為…

acedInitGet 函數

acedInitGet 函數是 AutoCAD 的 C++ API(ObjectARX)中用于初始化下一次用戶輸入操作選項的函數。以下是該函數簽名及其組成部分的中文翻譯和解釋: extern "C" int acedInitGet(int val,const ACHAR * kwl );cpp 復制 extern “C”:指定函數使用 C 語言鏈接(lin…

LeetCode93. Restore IP Addresses

文章目錄 一、題目二、題解 一、題目 A valid IP address consists of exactly four integers separated by single dots. Each integer is between 0 and 255 (inclusive) and cannot have leading zeros. For example, “0.1.2.201” and “192.168.1.1” are valid IP add…

視頻剪輯新招:批量隨機分割,分享精彩瞬間

隨著社交媒體的普及&#xff0c;短視頻已經成為分享生活、交流信息的重要方式。為制作出吸引的短視頻&#xff0c;許多創作者都投入了大量的時間和精力進行剪輯。然而&#xff0c;對于一些沒有剪輯經驗的新手來說&#xff0c;這個過程可能會非常繁瑣。現在一起來看云炫AI智剪批…

楊傳輝:從一體化架構,到一體化產品,為關鍵業務負載打造一體化數據庫

在剛剛結束的年度發布會上&#xff0c;OceanBase正式推出一體化數據庫的首個長期支持版本 4.2.1 LTS&#xff0c;這是面向 OLTP 核心場景的全功能里程碑版本&#xff0c;相比上一個 3.2.4 LTS 版本&#xff0c;新版本能力全面提升&#xff0c;適應場景更加豐富&#xff0c;有更…

web前端之若依框架圖標對照表、node獲取文件夾中的文件名,并通過數組返回文件名、在html文件中引入.svg文件、require、icon

MENU 前言效果圖htmlJavaScripstylenode獲取文件夾中的文件名 前言 需要把若依原有的icon的svg文件拿到哦&#xff01; 注意看生成svg的路徑。 效果圖 html <div id"idSvg" class"svg_box"></div>JavaScrip let listSvg [404, bug, build, …

02 如何快速讀懂一個C++程序

系列文章目錄 02 如何快速讀懂一個C程序 目錄 系列文章目錄 文章目錄 前言 一、C 的基本語法 二、如何看懂一個c程序&#xff1f; 1.了解程序結構 2.C 中的分號 & 語句塊 3.C 注釋 總結 前言 C 是一種高級編程語言&#xff0c;它具有豐富的特性&#xff0c;用于…

CentOS7安裝Docker運行環境

1 引言 Docker 是一個用于開發&#xff0c;交付和運行應用程序的開放平臺。Docker 使您能夠將應用程序與基礎架構分開&#xff0c;從而可以快速交付軟件。借助 Docker&#xff0c;您可以與管理應用程序相同的方式來管理基礎架構。通過利用 Docker 的方法來快速交付&#xff0c;…

11.前綴和、異或前綴和、差分數組練習題

前綴和 前綴和可以用來求滿足條件的子數組的和、個數、長度 更多前綴和題目&#xff1a; 560. 和為 K 的子數組 974. 和可被 K 整除的子數組 1590. 使數組和能被 P 整除 523. 連續的子數組和 525. 連續數組 560. 和為 K 的子數組 中等 給你一個整數數組 nums 和一個整數…

在新疆烏魯木齊的汽車托運

在新疆烏魯木齊要托運的寶! 看過來了 找汽車托運公司了 連夜吐血給你們整理了攻略!! ??以下&#xff1a; 1 網上搜索 可以在搜索引擎或專業的貨運平臺上搜索相關的汽車托運公司信息。在網站上可以了解到公司的服務范圍、托運價格、運輸時效等信息&#xff0c;也可以參考其他車…

2024年的云趨勢:云計算的前景如何?

本文討論了2024年云計算的發展趨勢。 適應復雜的生態系統、提供實時功能、優先考慮安全性和確保可持續性的需求正在引領云計算之船。多樣化的工作負載允許探索通用的公共云基礎設施范例之外的選項。由于需要降低成本、提高靈活性和降低風險&#xff0c;混合云和多云系統越來越受…

RabbitMQ 消息隊列編程

安裝與配置 安裝 RabbitMQ 讀者可以在 RabbitMQ 官方文檔中找到完整的安裝教程&#xff1a;Downloading and Installing RabbitMQ — RabbitMQ 本文使用 Docker 的方式部署。 RabbitMQ 社區鏡像列表&#xff1a;https://hub.docker.com/_/rabbitmq 創建目錄用于映射存儲卷…