49風格遷移
讀入內容圖像:
import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 讀取內容圖像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
plt.imshow(content_img)
plt.show()
讀取風格圖像:
# 讀取風格圖像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
plt.imshow(style_img)
plt.show()
import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 讀取內容圖像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
# plt.imshow(content_img)
# plt.show()# 讀取風格圖像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
# plt.imshow(style_img)
# plt.show()# 預處理和后處理
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])# 函數preprocess對輸入圖像在RGB三個通道分別做標準化,
# 并將結果變換成卷積神經網絡接受的輸入格式
def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0) # 增加一個通道# 后處理函數postprocess則將輸出圖像中的像素值還原回標準化之前的值。
# 由于圖像打印函數要求每個像素的浮點數值在0~1之間,我們對小于0和大于1的值分別取0和1。
def postprocess(img):# img[0] 表示移除批次維度,從批次中提取出第一個圖像img = img[0].to(rgb_std.device) # 移除批次維度,并將圖像張量移動到與 rgb_std 相同的設備img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 反轉標準化過程return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))# ToPILImage() 期望的輸入是 [C, H, W] 形式,因此需要再次將張量的通道維度移動到第一個位置。# 抽取圖像特征
# 使用基于ImageNet數據集預訓練的VGG-19模型
# VGG19包含了19個隱藏層(16個卷積層和3個全連接層)
pretrained_net = torchvision.models.vgg19(pretrained=True)"""一般來說,越靠近輸入層,越容易抽取圖像的細節信息;反之,則越容易抽取圖像的全局信息。 為了避免合成圖像過多保留內容圖像的細節,我們選擇VGG較靠近輸出的層,即內容層,來輸出圖像的內容特征。 我們還從VGG中選擇不同層的輸出來匹配局部和全局的風格,這些圖層也稱為風格層。
"""
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
# net 模型包含了 VGG-19 從第 0 層到第 28 層的所有層
net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])# 由于我們還需要中間層的輸出,
# 因此這里我們逐層計算,并保留內容層和風格層的輸出
def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X)if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles# 對內容圖像抽取內容特征
def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Y# 對風格圖像抽取風格特征
def get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y# 定義損失函數
# 由內容損失、風格損失和全變分損失3部分組成# 內容損失
# 內容損失通過平方誤差函數衡量合成圖像與內容圖像在內容特征上的差異
# 平方誤差函數的兩個輸入均為extract_features函數計算所得到的內容層的輸出。
def content_loss(Y_hat, Y):# 我們從動態計算梯度的樹中分離目標:# 這是一個規定的值,而不是一個變量。return torch.square(Y_hat - Y.detach()).mean()# 風格損失
def gram(X): # 基于風格圖像的格拉姆矩陣num_channels, n = X.shape[1], X.numel() // X.shape[1]X = X.reshape((num_channels, n))return torch.matmul(X, X.T) / (num_channels * n)def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean()# 全變分損失
# 合成圖像里面有大量高頻噪點,即有特別亮或者特別暗的顆粒像素。
# 一種常見的去噪方法是全變分去噪total variation denoising
def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())"""
風格轉移的損失函數是內容損失、風格損失和總變化損失的加權和。
通過調節這些權重超參數,我們可以權衡合成圖像在保留內容、遷移風格以及去噪三方面的相對重要性。
"""
content_weight, style_weight, tv_weight = 1, 1e3, 10def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):# 分別計算內容損失、風格損失和全變分損失contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weight# 對所有損失求和l = sum(10 * styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l# 初始化合成圖像
class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight# 函數創建了合成圖像的模型實例,并將其初始化為圖像X
def get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer# 訓練模型
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) # 初始化合成圖像和優化器scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = lp.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad() # 梯度清零contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers) # 提取特征contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) # 計算損失l.backward() # 反向傳播計算梯度trainer.step() # 更新模型參數scheduler.step() # 更新學習率if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return Xdevice, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
plt.show()
運行結果: