13.2 微調
為了防止在訓練集上過擬合,有兩種辦法,第一種是擴大訓練集數量,但是需要大量的成本;第二種就是應用遷移學習,將源數據學習到的知識遷移到目標數據集,即在把在源數據訓練好的參數和模型(除去輸出層)直接復制到目標數據集訓練。
# IPython魔法函數,可以不用執行plt .show()
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
13.2.1 獲取數據集
#@save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i-1][0] for i in range(8)]
# 展示2行8列矩陣的圖片,共16張
d2l.show_images(hotdogs+not_hotdogs,2,8,scale=1.5)
# 使用RGB通道的均值和標準差,以標準化每個通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# 圖像增廣
train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])
test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])
?13.2.2 初始化模型
# 自動下載網上的訓練模型
finetune_net = torchvision.models.resnet18(pretrained=True)
# 輸入張量的形狀還是源輸入張量大小,輸入張量大小改為2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);
13.2.3?微調模型
# 如果param_group=True,輸出層中的模型參數將使用十倍的學習率
# 如果param_group=False,輸出層中模型參數為隨機值
# 訓練模型
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]# params_1x的參數使用learning_rate學習率, net.fc.parameters()的參數使用0.001的學習率trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
train_fine_tuning(finetune_net, 5e-5)
13.3 目標檢測和邊界框
有時候不僅要識別圖像的類別,還需要識別圖像的位置。在計算機視覺中叫做目標識別或者目標檢測。這小節是介紹目標檢測的深度學習方法。
%matplotlib inline
import torch
from d2l import torch as d2l
#@save
def box_corner_to_center(boxes):"""從(左上,右下)轉換到(中間,寬度,高度)"""x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]# cx,xy,w,h的維度是ncx = (x1 + x2) / 2cy = (y1 + y2) / 2w = x2 - x1h = y2 - y1# torch.stack()沿著新維度對張量進行鏈接。boxes最開始維度是(n,4),axis=-1表示倒數第一個維度# torch.stack()將(cx, cy, w, h)的維度n將其沿著倒數第一個維度拼接在一起,又是(n,4)boxes = torch.stack((cx, cy, w, h), axis=-1)return boxes#@save
def box_center_to_corner(boxes):"""從(中間,寬度,高度)轉換到(左上,右下)"""cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]x1 = cx - 0.5 * wy1 = cy - 0.5 * hx2 = cx + 0.5 * wy2 = cy + 0.5 * hboxes = torch.stack((x1, y1, x2, y2), axis=-1)return boxes