Pytorch-07 如何快速把已經有的視覺模型權重扒拉過來為己所用

下載,保存,加載,使用模型權重

在這一節里面我們會過一遍對模型權重的常用操作,比如:

  • 如何下載常用模型的預訓練權重
  • 如何下載常用模型的無訓練權重(只下載網絡結構)
  • 如何加載模型權重
  • 如何保存權重
  • 加載模型權重后進行推理的注意事項

寫在開頭:權重是以什么形式存在的?

光用,是肯定夠的,但是如果你能稍微懂一點原理,那么你很有可能在某一日突然融會貫通,做出非常牛逼的優化。

Pytorch模型會將學習到的參數存儲在稱為state_dict的內部狀態字典中。為了深入探究,我們可以創建一個單線性層簡單模型,然后看看它的狀態字典長啥樣:
在這里插入圖片描述

import torch
import torch.nn as nn
import jsonclass SimpleLinearModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(5, 2) # 五個輸入,兩個輸出def forward(self, x):return self.linear(x)model = SimpleLinearModel()
model_state_dict = model.state_dict()
print(model_state_dict)

以上代碼塊的輸出如下:
在這里插入圖片描述
雖然它是個字典,但是由于張量的存在讓他不是一個友好的鍵值對結構,而是元組結構,但是我們還是可以把他轉換JSON來直觀感受一下:

{"linear.weight": [[0.1941000074148178,0.22420001029968262,-0.3236999809741974,-0.1558000087738037,0.2337000072002411],[0.15130000114440918,0.11470000594854355,0.3953999876976013,-0.33970001339912415,-0.20650000870227814]],"linear.bias": [0.046799998730421066,-0.03530000150203705]
}

這下直觀多了,可以看到,state_dict存儲了兩個學習的參數,其中包括了一個W2×5W_{2\times5}W2×5?的全連接矩陣和一個長度為2的偏置向量bbb

不過需要注意的是這是一個有序字典,這樣才能保證數據在流經權重文件的時候才能一層一層的被處理。

下載并保存常用預訓練模型權重

torchvision.models包內置了很多的不同任務模型權重,包括但不限于圖像分類,語義分割,實例分割,關鍵點檢測,視頻分類,光流等等,你可以逛逛這個權重菜市場,這里我就不放圖介紹了。

一般情況下,你可能要使用這些預訓練的模型來進行 遷移學習, pytorch加載這些權重相當簡單,一般的調用公式為:

model = model.模型名(weights='用什么數據集預訓練的')

舉個例子,我們加載一下vgg的在IMAGENET1K_V1上的權重,看看它結構如何

model = models.vgg16(weights='IMAGENET1K_V1')
print(model)

在這里插入圖片描述
可以看到這個包含預訓練權重的網絡模型已經被我們保存到state_dict中了,但是它目前還只是在內存里面,沒有寫入外存(硬盤),如果你想把它保存到本地,你可以這樣做:

torch.save(model.state_dict(), 'vgg1-_model_weights.pth')

這樣就會把你當前的權重保存為一個.pth文件。

加載模型權重-僅加載權重

OK, 現在你已經有了一個vgg模型權重,那你要怎么把它加載到對應的網絡上呢?

正常的步驟是這樣的:

  1. 創建同一模型的實例 (不指定數據集的時候說明只要結構不要預訓練權重)
  2. 使用load_state_dict()方法加載參數
model = models.vgg16() # 這里沒有指定數據集,說明只要結構
model.load_state_dict(torch.load('model_weights.pth', weights_only=True)) # 這里加載權重,該.pth文件只包含權重,不包含結構。

注意,如果你使用torch.save(model.state_dict(), 'path'),只有權重會被保存!如果你想在保存權重的同時也保存模型結構,你可以這么做:

torch.save(model, 'model.pth')

這個做法的優點是可以在加載被這樣保存的權重的時候無需初始化對應的網絡:

model = torch.load('model.pth', weights_only=False) # 說明這個model.pth并不是只保存了權重,還有模型架構,所以不需要先實例化再加載權重

但是,模型和權重一起加載并不是pytorch官方推薦的最佳實踐! pytorch官方推薦的方式還是只保存模型權重,要加載的時候先實例化網絡再加載權重(后一段就是講這個的)。這是因為.pth文件的解析基于pickle協議實現,而pickle文件不僅僅是數據存儲,它還可以包含可執行代碼。當 torch.load() 反序列化一個 pickle 文件時,它會執行文件中的字節碼來重新創建對象。

這代表:如果一個 .pth 文件是惡意創建的,它可能包含惡意代碼。當你在不知情的情況下加載這個文件時,這些惡意代碼會被執行,從而導致你的系統受到攻擊,例如被植入病毒、竊取數據等。weights_only=True 的作用就是切斷這個風險鏈。它告訴 PyTorch:“我只信任文件中的張量數據。不要執行任何其他的 Python 對象代碼,即使它們存在于文件中。”這樣就有效地防止了潛在的惡意代碼被執行。

使用預訓練權重進行推理

這里要說的不多,用預訓練權重加載好模型之后記得打開.eval評估模式再開始推理:

model.eval()

.eval() 方法的作用是將模型切換到評估(evaluation)模式。這個模式會關閉一些在訓練時才需要的特殊層,以確保模型在推理時能夠產生一致且可預測的結果。

具體來說,他會關閉這兩個層的以下作用:

.eval() 是 PyTorch 模型推理時一個非常重要的步驟,你提到的這一點非常關鍵。

為什么需要在推理前調用 model.eval()

.eval() 方法的作用是將模型切換到評估(evaluation)模式。這個模式會關閉一些在訓練時才需要的特殊層,以確保模型在推理時能夠產生一致且可預測的結果。

具體來說,.eval() 主要影響以下兩種類型的層:

  1. Dropout

    • 訓練模式 (model.train())Dropout 層會以一定的概率隨機“丟棄”一些神經元的輸出,以防止模型過擬合。這意味著每次前向傳播(forward pass)時,網絡結構都是不一樣的。
    • 評估模式 (model.eval())Dropout 層會被關閉。所有神經元都參與計算,不再隨機丟棄。這確保了在推理時,每次對同一輸入進行預測,都會得到完全相同的結果。
  2. BatchNorm(批量歸一化)層

    • 訓練模式 (model.train())BatchNorm 層會根據當前批次(batch)的輸入數據來計算均值和方差,并用這些統計量進行歸一化。
    • 評估模式 (model.eval())BatchNorm 層會停止更新均值和方差。它會使用在訓練階段已經學到的、全局的、固定的均值和方差來進行歸一化。這同樣是為了確保推理結果的穩定性,因為在推理時,我們通常只處理單個樣本或小批量的樣本,它們的統計量沒有代表性。

如果沒有調用 model.eval(),模型將保持在訓練模式。這將導致:

  1. 結果不穩定:因為 Dropout 層會隨機丟棄神經元,即使輸入相同,每次推理的結果也可能不同。
  2. 結果不準確BatchNorm 層會使用不穩定的批次統計量進行歸一化,而不是使用訓練時學到的穩定統計量,這會導致推理結果的準確性下降。

所以,為了得到穩定、準確且可復現的推理結果,在使用預訓練模型進行預測時,必須在推理循環之前調用 model.eval()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/92167.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/92167.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/92167.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C語言零基礎第9講:指針基礎

目錄 1.內存和地址 2.指針變量和地址 2.1 取地址操作符(&) 2.2 指針變量 2.3 解引用操作符(*) 2.4 指針變量的大小 3.指針變量類型的意義 3.1 指針的解引用 3.2 指針 - 整數 3.3 void*指針 4.指針運算 4.1 指針…

013 HTTP篇

3.1 HTTP常見面試題 1、HTTP基本概念: 超文本傳輸協議:在計算機世界里專門在「兩點」之間「傳輸」文字、圖片、音頻、視頻等「超文本」數據的「約定和規范」HTTP常見的狀態碼 [[Pasted image 20250705140705.png]]HTTP常見字段 Host 字段:客戶…

每日面試題20:spring和spring boot的區別

我曾經寫過一道面試題,題目是為什么springboot項目可以直接打包給別人運行?其實這涉及到的就是springboot的特點。今天來簡單了解一下springboot和spring的區別, Spring 與 Spring Boot:從“全能框架”到“開箱即用”的進化之路 …

ClickHouse數據遷移

ClickHouse實例是阿里云上的云實例,想同步數據到本地,本地部署有ClickHouse實例,下面為單庫單表 源實例:阿里云cc-gs5xxxxxxx.public.clickhouse.ads.aliyuncs.com:8123 目標實例:本地172.16.22.10:8123 1、目標實例建…

sqli-labs-master/Less-41~Less-50

Less-41這一關還是用堆疊注入,這關數字型不需要閉合了。用堆疊的話,我們就不爆信息了。我們直接用堆疊,往進去寫一條數據?id-1 union select 1,2,3;insert into users (id,username,password) values(666,zk,180)--看一下插進去了沒?id-1 u…

Tiger任務管理系統-10

十是個很好美好的數字,十全十美,確實沒讓人失望,收獲還是很大的。 溫習了前端知識,鞏固了jQuery,thymeleaf等被忽視的框架,意外將之前的所學所用的知識都連起來了,感覺有點像打通了任督二脈一樣…

ora-01658 無法為表空間 users中的段創建initial區

ora-01658 無法為表空間 users中的段創建initial區 參考1 參考2 參考3 參考4 給用戶新增表空間 alter tablespace system add datafile D:\APP\ADMINISTRATOR\ORADATA\ORCL\SYSTEM03.DBF size 5G autoextend on next 10M;設置表空間文件自動擴展 ALTER DATABASE DATAFILE /…

lodash的替代品es-toolkit詳解

一、es-toolkit簡介 es-toolkit 是一款先進的高性能 JavaScript 實用程序庫,體積小巧,并支持強類型注釋,典型特征包括: 提供各種日常實用函數并采用現代實現,例如: debounce、delay、chunk、sum 和 pick 等 設計充分考慮了性能,在現代 JavaScript 環境中實現了 2-3 倍…

【原創】基于gemini-2.5-flash-preview-05-20多模態模型實現短視頻的自動化二創

畫面和解說保持一致,這個模型就是NB[16:57:37] [*] 正在從視頻中提取幀和時長 (頻率: 1.0 幀/秒)... [16:57:55] [] 提取完成。視頻時長: 83.40秒, 提取了 84 幀。 [16:57:55] [*] 使用AI供應商: gemini [16:57:55] [*] 正在進行視覺分析... [16:57:55] L-> 正…

數倉架構 數據表建模

數倉架構 主要用來描述 數據加工的實時鏈路 和 離線鏈路之間的關系,即 流批 關系; lamda 架構, 是兩條路, 實時計算式的, 維護數據的實時性。然后每天經過批計算后, 覆蓋實時的計算結果。 保證數據準確性。 kappa架構, 即流批一體了 數據建模 星型模型是數據倉庫中最…

vscode調試python腳本時無法進入函數內部的解決方法

只需在launch.json配置文件中添加“justMyCode”:false.

Python day37

浙大疏錦行 python day37. 內容: 保存模型只需要保存模型的參數即可,使用的時候直接構建模型再導入參數即可 # 保存模型參數 torch.save(model.state_dict(), "model_weights.pth")# 加載參數(需先定義模型結構) mod…

ORACLE進階操作

1 事務 事務的任務便是使數據庫從一種狀態變換成為另一種狀態,這不同于文件系統,它是數據庫所特用的。 所有的數據庫中,事務只針對DML(增刪改),不針對select select只能查看其他事務提交或回滾的數據,不能查…

Modbus 的一些理解

疑問:(使用的是Modbustcp)我在 Modbus slave 上面設置了slave地址為1,位置為40001的位置的值為1,40001這個位置上面的值是怎么存儲的,存儲在哪里的?他們是怎么進行交互的?在Modbus協…

【運動控制框架】WPF運動控制框架源碼,可用于激光切割機,雕刻機,分板機,點膠機,插件機等設備,開箱即用

WPF運動控制框架源碼,可用于激光切割機,雕刻機,分板機,點膠機,插件機等設備,考慮到各運動控制硬件不同,視覺應用功能(應用視覺軟件)也不同,所以只開發各路徑編…

RabbitMQ-日常運維命令

作者介紹:簡歷上沒有一個精通的運維工程師。請點擊上方的藍色《運維小路》關注我,下面的思維導圖也是預計更新的內容和當前進度(不定時更新)。中間件,我給它的定義就是為了實現某系業務功能依賴的軟件,包括如下部分:Web服務器代理…

【Linux基礎知識系列】第九十篇 - 使用awk進行文本處理

在Linux系統中,文本處理是一個常見的任務,尤其是在處理日志文件、配置文件和數據文件時。awk是一個功能強大的文本處理工具,廣泛用于數據提取、分析和格式化。它不僅可以處理簡單的文本文件,還可以處理復雜的結構化數據&#xff0…

第二十七天(數據結構:圖)

圖:是一種非線性結構形式化的描述: G{V,R}V:圖中各個頂點元素(如果這個圖代表的是地圖,這個頂點就是各個點的地址)R:關系集合,圖中頂點與頂點之間的關系(如果是地圖,這個關系集合可能就代表的是各個地點之間的距離)在頂點與頂點…

數據賦能(386)——數據挖掘——迭代過程

概述重要性如下:提升挖掘效果:迭代過程能不斷優化數據挖掘模型,提高挖掘結果的準確性和有效性,從而更好地滿足業務需求。適應復雜數據:數據往往具有復雜性和多樣性,通過迭代可以逐步探索和適應數據的特點&a…

什么是鍵值緩存?讓 LLM 閃電般快速

一、為什么 LLMs 需要 KV 緩存?大語言模型(LLMs)的文本生成遵循 “自回歸” 模式 —— 每次僅輸出一個 token(如詞語、字符或子詞),再將該 token 與歷史序列拼接,作為下一輪輸入,直到…