strict=False 但還是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur
問題
我們知道通過
model.load_state_dict(state_dict, strict=False)
可以暫且忽略掉模型和參數文件中不匹配的參數,先將正常匹配的參數從文件中載入模型。
筆者在使用時遇到了這樣一個報錯:
RuntimeError: Error(s) in loading state_dict for ViT_Aes:size mismatch for mlp_head.1.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for mlp_head.1.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).
一開始筆者很奇怪,我已經寫明strict=False
了,不匹配參數的不管就是了,為什么還要給我報錯。
原因及解決方案
經過筆者仔細打印模型的鍵和文件中的鍵進行比對,發現是這樣的:strict=False
可以保證模型中的鍵與文件中的鍵不匹配時暫且跳過不管,但是一旦模型中的鍵和文件中的鍵匹配上了,PyTorch就會嘗試幫我們加載參數,就必須要求參數的尺寸相同,所以會有上述報錯。
比如在我們需要將某個預訓練的模型的最后的全連接層的輸出的類別數替換為我們自己的數據集的類別數,再進行微調,有時會遇到上述情況。這時,我們知道全連接層的參數形狀會是不匹配,比如我們加載 ImageNet 1K 1000分類的預訓練模型,它的最后一層全連接的輸出維度是1000,但如果我們自己的數據集是10分類,我們需要將最后一層全鏈接的輸出維度改為10。但是由于鍵名相同,所以PyTorch還是嘗試給我們加載,這時1000和10維度不匹配,就會導致報錯。
解決方案就是我們將 .pth 模型文件讀入后,將其中我們不需要的層(通常是最后的全連接層)的參數pop掉即可。
以 ViT 為例子,假設我們有一個 ViT 模型,并有一個參數文件 vit-in1k.pth
,它里面存儲著 ViT 模型在 ImageNet-1K 1000分類數據集上訓練的參數,而我們要在自己的10分類數據集上微調這個模型。
model = ViT(num_classes=10)
ckpt = torch.load('vit-in1k.pth', map_location='cpu')
msg = model.load_state_dict(ckpt, strict=False)
print(msg)
直接這樣加載會出錯,就是上面的錯誤:
size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).
我們將最后 pth 文件加載進來之后(即 ckpt
) 中全連接層的參數直接pop掉,至于需要pop掉哪些鍵名,就是上面報錯信息中提到了的,在這里就是 head.weight
和 head.bias
ckpt.pop('head.weight')
ckpt.pop('head.bias')
之后在運行,會發現我們打印的 msg
顯示:
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
即缺失了head.weight
和 head.bias
這兩個參數,這是正常的,因為在自己的數據集上微調時,我們本就不需要這兩個參數,并且已經將它們從模型文件字典 ckpt
中pop掉了。現在,模型全連接之前的層(通常即所謂的特征提取層)的參數已經正常加載了,接下來可以在自己的數據集上進行微調。
因為反正我們也不用這些參數,就直接把這個鍵值對從字典中pop掉,以免 PyTorch 在幫我們加載時試圖加載這些維度不匹配,我們也不需要的參數。