寫在前面:有些地方和視頻里不一樣的是因為官方文檔更新了,一些參數用法不一樣也很正常,包括我現在的也是我這個時間節點最新的,誰知道過段時間會不會更新呢= =建議大家不要一味看視頻/博客,多看看官方文檔才是正道(
加載現有的網絡模型
加載有兩種方式加載,一種是直接加載固有的網絡結構,這種比較簡單,還有一種是將原有的網絡訓練好的參數也下載下來,這種加載的時候如果原來沒有的話會自動下載,如下:
對應的用法如下:
#只加載網絡結構
vgg16_false = torchvision.models.vgg16(weights=None)
print(vgg16_false)#加載網絡結構and參數
vgg16_true = torchvision.models.vgg16(weights='DEFAULT')
print(vgg16_true)
VGG16原有結構(圖太長了,開頭沒截全,重點關注最后的就ok)
在最后加入新層(以修改為10分類為例)
#在最后加入新層
vgg16_true.add_module('my_add_linear1',nn.Linear(1000,10))
print(vgg16_true)
在原有區域塊中加入新層
#在原有區域塊中加入新層
vgg16_true.classifier.add_module('my_add_linear2',nn.Linear(1000,10))
print(vgg16_true)
對原有層進行修改
#對原有層進行修改
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)