假設你的類別只有10個,而torchvision.models中Vgg16的輸出類別為1000,這時應該如何調整呢?
方法一,直接修改模型中類別的輸出。
from torch.nn import Linear
import torchvision
import torchVgg16=torchvision.models.vgg16(pretrained=True)
Vgg16.classifier[6]=Linear(in_features=4096,out_features=10)
if torch.cuda.is_available():T=Vgg16.cuda()
方法二,再模型的最后增加全連接層,改變輸出類別。
from torch.nn import Linear
import torchvision
import torchres=torchvision.models.resnet101(pretrained=True,progress=True)
res.fc.add_module('linelayer',Linear(in_features=1000,out_features=10))
if torch.cuda.is_available():T=res.cuda()
?