strict=False 但還是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur

strict=False 但還是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur

問題

我們知道通過

model.load_state_dict(state_dict, strict=False)

可以暫且忽略掉模型和參數文件中不匹配的參數,先將正常匹配的參數從文件中載入模型。

筆者在使用時遇到了這樣一個報錯:

RuntimeError: Error(s) in loading state_dict for ViT_Aes:size mismatch for mlp_head.1.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for mlp_head.1.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).

一開始筆者很奇怪,我已經寫明strict=False了,不匹配參數的不管就是了,為什么還要給我報錯。

原因及解決方案

經過筆者仔細打印模型的鍵和文件中的鍵進行比對,發現是這樣的:strict=False可以保證模型中的鍵與文件中的鍵不匹配時暫且跳過不管,但是一旦模型中的鍵和文件中的鍵匹配上了,PyTorch就會嘗試幫我們加載參數,就必須要求參數的尺寸相同,所以會有上述報錯。

比如在我們需要將某個預訓練的模型的最后的全連接層的輸出的類別數替換為我們自己的數據集的類別數,再進行微調,有時會遇到上述情況。這時,我們知道全連接層的參數形狀會是不匹配,比如我們加載 ImageNet 1K 1000分類的預訓練模型,它的最后一層全連接的輸出維度是1000,但如果我們自己的數據集是10分類,我們需要將最后一層全鏈接的輸出維度改為10。但是由于鍵名相同,所以PyTorch還是嘗試給我們加載,這時1000和10維度不匹配,就會導致報錯。

解決方案就是我們將 .pth 模型文件讀入后,將其中我們不需要的層(通常是最后的全連接層)的參數pop掉即可。

以 ViT 為例子,假設我們有一個 ViT 模型,并有一個參數文件 vit-in1k.pth,它里面存儲著 ViT 模型在 ImageNet-1K 1000分類數據集上訓練的參數,而我們要在自己的10分類數據集上微調這個模型。

model = ViT(num_classes=10)
ckpt = torch.load('vit-in1k.pth', map_location='cpu')
msg = model.load_state_dict(ckpt, strict=False)
print(msg)

直接這樣加載會出錯,就是上面的錯誤:

	size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).

我們將最后 pth 文件加載進來之后(即 ckpt) 中全連接層的參數直接pop掉,至于需要pop掉哪些鍵名,就是上面報錯信息中提到了的,在這里就是 head.weighthead.bias

ckpt.pop('head.weight')
ckpt.pop('head.bias')

之后在運行,會發現我們打印的 msg 顯示:

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])

即缺失了head.weighthead.bias 這兩個參數,這是正常的,因為在自己的數據集上微調時,我們本就不需要這兩個參數,并且已經將它們從模型文件字典 ckpt 中pop掉了。現在,模型全連接之前的層(通常即所謂的特征提取層)的參數已經正常加載了,接下來可以在自己的數據集上進行微調。

因為反正我們也不用這些參數,就直接把這個鍵值對從字典中pop掉,以免 PyTorch 在幫我們加載時試圖加載這些維度不匹配,我們也不需要的參數。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532802.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532802.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532802.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

linux中權限765啥意思,Linux中的文件權限

Linux系統中的每一個文件都與多種權限類型相關聯。在這些權限中,我們主要和三類權限打交道:用戶(user)、用戶組(group)和其他用戶(others)。用戶是文件的所有者;用戶組是指和文件所有者在同一組的其他多個用戶的集合;其他用戶是除…

CV中的色彩空間大全

HSI、HSV、RGB、CMY、CMYK、HSL、HSB、Ycc、XYZ、Lab、YUV顏色模型 HSV顏色空間 HSV(hue,saturation,value)顏色空間的模型對應于圓柱坐標系中的一個圓錐形子集,圓錐的頂面對應于V1. 它包含RGB模型中的R1,G1,B1 三個面,所代表的…

linux 系統調用時怎么知道當前上下文屬于那個進程,linux – 編寫系統調用來計算進程的上下文切換...

如果您的系統調用只應報告統計信息,則可以使用內核中已有的上下文切換計數代碼.struct rusage {...long ru_nvcsw; /* voluntary context switches */long ru_nivcsw; /* involuntary context switches */};您可以通過運行來嘗試:$/usr/bin/time -v /bin/ls -R....V…

linux串口緩沖區的大小,linux-----------串口設置緩沖器的大小

轉自:http://stackoverflow.com/questions/10815811/linux-serial-port-reading-can-i-change-size-of-input-bufferYou want to use the serial IOCTL TIOCSSERIAL which allows changing both receive buffer depth and send buffer depth (among other things). The maximum…

FLOPs、FLOPS、Params的含義及PyTorch中的計算方法

FLOPs、FLOPS、Params的含義及PyTorch中的計算方法 含義解釋 FLOPS:注意全大寫,是floating point operations per second的縮寫(這里的大S表示second秒),表示每秒浮點運算次數,理解為計算速度。是一個衡量…

設置中文linux輸入ubuntu,Linux_ubuntu怎么設置成中文?ubuntu中文設置圖文方法,  很多朋友安裝ubuntu后,發 - phpStudy...

ubuntu怎么設置成中文?ubuntu中文設置圖文方法很多朋友安裝ubuntu后,發現都是英文,看不懂要怎么辦?其實ubuntu是可以設置成中文的,下文小編就為大家帶來ubuntu中文的設置方法,一起去看下設置方法吧。ubuntu中文設置方…

科普 | 單精度、雙精度、多精度和混合精度計算的區別是什么?

科普 | 單精度、雙精度、多精度和混合精度計算的區別是什么? 轉自:https://zhuanlan.zhihu.com/p/93812784 我們提到圓周率 π 的時候,它有很多種表達方式,既可以用數學常數3.14159表示,也可以用一長串1和0的二進制長串表示。 …

linux 磁盤分配 簡書,linux 磁盤分區

1物理磁盤的構成: 盤面:由一圈一圈的磁道組成機械手臂:讀取數據主軸馬達:幫助機械手臂轉動2 扇區:磁盤上存取數據的最小單位512字節按照扇區分配大小,如果數據只有一字節也會占用512字節簇:用若…

條件控制與條件傳送詳解

條件控制與條件傳送詳解 提要 CSAPP3e中文譯本 3.6.5 用條件控制來實現條件分支 3.6.6 用條件傳送來實現條件分支 CSAPP3e第三章前面主要是介紹了機器級代碼的二進制形式和匯編形式、反匯編、x86匯編的基礎指令、條件碼及其訪問方式等。 在介紹到匯編語言的條件分支時分了兩…

聯合體(union)的使用方法及其本質

聯合體(union)的使用方法及其本質 轉自:https://blog.csdn.net/huqinwei987/article/details/23597091 有些基礎知識快淡忘了,所以有必要復習一遍,在不借助課本死知識的前提下做些推理判斷,溫故知新。 1…

linux設備驅動之串口移植,Linux設備驅動之UART驅動結構

一、對于串口驅動Linux系統中UART驅動屬于終端設備驅動,應該說是實現串口驅動和終端驅動來實現串口終端設備的驅動。要了解串口終端的驅動在Linux系統的結構就先要了解終端設備驅動在Linux系統中的結構體系,一方面自己了解的不夠,另一發面關于…

linux python復制安裝,復制一個Python全部環境到另一個環境,python另一個,導出此環境下安裝的包...

復制一個Python全部環境到另一個環境,python另一個,導出此環境下安裝的包導出此環境下安裝的包的版本信息清單pipfreeze>requirements.txt聯網,下載清單中的包到all-packet文件夾[[email protected] ~]# pip download -d ./all-packet -r requirement…

NVIDIA英偉達的Multi-GPU多卡通信框架NCCL

NVIDIA英偉達的Multi-GPU多卡通信框架NCCL 筆者注:NCCL 開源項目地址:https://github.com/NVIDIA/nccl 轉自:https://www.zhihu.com/question/63219175/answer/206697974 NCCL是Nvidia Collective multi-GPU Communication Library的簡稱&…

C語言n個坐標點間的最大距離,c語言已知兩點坐標,求另一點到穿過這兩點的直線最短距離。...

c語言已知兩點坐標,求另一點到穿過這兩點的直線最短距離。以下文字資料是由(歷史新知網www.lishixinzhi.com)小編為大家搜集整理后發布的內容,讓我們趕快一起來看一下吧!c語言已知兩點坐標,求另一點到穿過這兩點的直線最短距離。#…

[分布式訓練] 單機多卡的正確打開方式:理論基礎

[分布式訓練] 單機多卡的正確打開方式:理論基礎 轉自:https://fyubang.com/2019/07/08/distributed-training/ 瓦礫由于最近bert-large用的比較多,踩了很多分布式訓練的坑,加上在TensorFlow和PyTorch之間更換,算是熟…

s3c2416開發板 linux,S3C2416移植內核Linux3.1的wm9713聲卡過程

移植內核的聲卡驅動。原因沒有聲卡驅動,WM9713聲卡驅動移植(原來的內核有UDA1341聲卡驅動,我們再次基礎上直接修改)1、直接復制內核得到三個文件:s3c2416_wm9713.c , wm9713.c , s3c2416_ac97.c.linux-3.1\sound\soc\codecs\Wm9713.c---->wm9713.c;li…

Linux查看文件內容命令:cat, tail, head, more, less

Linux查看文件內容命令:cat, tail, head, more, less cat 直接顯示整個文件。 cat直接顯示全部文件內容,沒有換頁等交互。 cat filenamemore more命令,功能類似 cat ,cat命令是整個文件的內容從上到下顯示在屏幕上。 more會…

linux查看隊列 msg,linux第10天 msg消息隊列

cat /proc/sys/kernel/msgmax最大消息長度限制cat /proc/sys/kernel/msgmnb消息隊列總的字節數cat /proc/sys/kernel/msgmni消息條目數消息隊列綜合案例//server#include #include #include #include #include #include #include #include #define ERR_EXIT(m)do{perror(m);}wh…

Linux中 C++ main函數參數argc和argv含義及用法

Linux中 C main函數參數argc和argv含義及用法 簡介 argc 是 argument count的縮寫,表示傳入main函數的參數個數; argv 是 argument vector的縮寫,表示傳入main函數的參數序列或指針,并且第一個參數argv[0]一定是程序的名稱&…

c語言六位搶答器課程設計,51單片機八路搶答器課程設計

;說明:本人的這個設計改進后解決了前一個版本中1號搶答優先的問題,并增加了錦囊的設置,當參賽選手在回答問題時要求使用錦囊,則主持人按下搶答開始鍵,計時重新開始。;八路搶答器電路請看下圖是用ps仿真的,已…