動手學深度學習(Pytorch版)代碼實踐 -計算機視覺-49風格遷移

49風格遷移

在這里插入圖片描述
在這里插入圖片描述

讀入內容圖像:

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 讀取內容圖像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
plt.imshow(content_img)
plt.show()

在這里插入圖片描述

讀取風格圖像:

# 讀取風格圖像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
plt.imshow(style_img)
plt.show()

在這里插入圖片描述

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 讀取內容圖像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
# plt.imshow(content_img)
# plt.show()# 讀取風格圖像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
# plt.imshow(style_img)
# plt.show()# 預處理和后處理
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])# 函數preprocess對輸入圖像在RGB三個通道分別做標準化,
# 并將結果變換成卷積神經網絡接受的輸入格式
def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0) # 增加一個通道# 后處理函數postprocess則將輸出圖像中的像素值還原回標準化之前的值。 
# 由于圖像打印函數要求每個像素的浮點數值在0~1之間,我們對小于0和大于1的值分別取0和1。
def postprocess(img):# img[0] 表示移除批次維度,從批次中提取出第一個圖像img = img[0].to(rgb_std.device) # 移除批次維度,并將圖像張量移動到與 rgb_std 相同的設備img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 反轉標準化過程return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))# ToPILImage() 期望的輸入是 [C, H, W] 形式,因此需要再次將張量的通道維度移動到第一個位置。# 抽取圖像特征
# 使用基于ImageNet數據集預訓練的VGG-19模型
# VGG19包含了19個隱藏層(16個卷積層和3個全連接層)
pretrained_net = torchvision.models.vgg19(pretrained=True)"""一般來說,越靠近輸入層,越容易抽取圖像的細節信息;反之,則越容易抽取圖像的全局信息。 為了避免合成圖像過多保留內容圖像的細節,我們選擇VGG較靠近輸出的層,即內容層,來輸出圖像的內容特征。 我們還從VGG中選擇不同層的輸出來匹配局部和全局的風格,這些圖層也稱為風格層。
"""
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
# net 模型包含了 VGG-19 從第 0 層到第 28 層的所有層
net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])# 由于我們還需要中間層的輸出,
# 因此這里我們逐層計算,并保留內容層和風格層的輸出
def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X)if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles# 對內容圖像抽取內容特征
def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Y# 對風格圖像抽取風格特征
def get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y# 定義損失函數
# 由內容損失、風格損失和全變分損失3部分組成# 內容損失
# 內容損失通過平方誤差函數衡量合成圖像與內容圖像在內容特征上的差異
# 平方誤差函數的兩個輸入均為extract_features函數計算所得到的內容層的輸出。
def content_loss(Y_hat, Y):# 我們從動態計算梯度的樹中分離目標:# 這是一個規定的值,而不是一個變量。return torch.square(Y_hat - Y.detach()).mean()# 風格損失
def gram(X): # 基于風格圖像的格拉姆矩陣num_channels, n = X.shape[1], X.numel() // X.shape[1]X = X.reshape((num_channels, n))return torch.matmul(X, X.T) / (num_channels * n)def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean()# 全變分損失
# 合成圖像里面有大量高頻噪點,即有特別亮或者特別暗的顆粒像素。 
# 一種常見的去噪方法是全變分去噪total variation denoising
def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())"""
風格轉移的損失函數是內容損失、風格損失和總變化損失的加權和。
通過調節這些權重超參數,我們可以權衡合成圖像在保留內容、遷移風格以及去噪三方面的相對重要性。
"""
content_weight, style_weight, tv_weight = 1, 1e3, 10def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):# 分別計算內容損失、風格損失和全變分損失contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weight# 對所有損失求和l = sum(10 * styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l# 初始化合成圖像
class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight# 函數創建了合成圖像的模型實例,并將其初始化為圖像X
def get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer# 訓練模型
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)  # 初始化合成圖像和優化器scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = lp.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()  # 梯度清零contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)  # 提取特征contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)  # 計算損失l.backward()  # 反向傳播計算梯度trainer.step()  # 更新模型參數scheduler.step()  # 更新學習率if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return Xdevice, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
plt.show()

運行結果:
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/38531.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/38531.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/38531.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用 Swift 遞歸搜索目錄中文件的內容,同時支持 Glob 模式和正則表達式

文章目錄 前言項目設置查找文件讀取CODEOWNERS文件解析規則搜索匹配的文件確定文件所有者輸出結果總結前言 如果你新加入一個團隊,想要快速的了解團隊的領域和團隊中擁有的代碼庫的詳細信息。 如果新團隊中的代碼庫在 GitHub / GitLab 中并且你不熟悉代碼所有權模型的概念或…

Unity開箱即用的UGUI面板的拖拽移動功能

文章目錄 👉一、背景👉二、效果圖👉三、原理👉四、核心代碼👉五,總結 👉一、背景 之前做PC項目時常常有面板拖拽移動的需求,今天總結封裝一下,做成一個隨時隨地可復用的…

More Effective C++ 35個改善編程與設計的有效方法筆記與心得 3

三. 異常 條款9: 利用destructors避免泄露資源 ????  在編程中,"資源"可以指任何系統級的有限資源,如內存、文件句柄、網絡套接字等。"泄露"則是指在應用程序中分配了資源,但在不再需要這些資源時沒有…

Linux 安裝 Redis 教程

優質博文:IT-BLOG-CN 一、準備工作 配置gcc:安裝Redis前需要配置gcc: yum install gcc如果配置gcc出現依賴包問題,在安裝時提示需要的依賴包版本和本地版本不一致,本地版本過高,出現如下問題&#xff1a…

Jupyter無法導入庫,但能在終端導入的問題

Jupyter無法導入庫,但能在終端導入 ?錯誤問題描述:conda activate LLMs激活某個Conda的環境后,盡管已經通過conda或者pip在這個環境中安裝了一些🐍Python的庫,但無法在Jupyter中導入,卻能在終端成功導入。…

京東商品詳情數據接口(JD.item_get)丨京東API實時接口指南

京東商品詳情API接口(JD.item_get)是京東開放平臺提供的一個數據接口,用于獲取京東平臺上單個商品的詳細信息。 通過這個接口,開發者可以獲取到包括商品名稱、品牌、產地、規格參數、價格信息、銷量、評價、圖片、描述等在內的詳…

Node.js開發實戰 視頻教程 下載

ode.js開發實戰 視頻教程 下載 下載地址 https://download.csdn.net/download/m0_67912929/89487510 01-課程介紹.mp4 02-內容綜述.mp4 03-Node.js是什么? .mp4 04-Node.js可以用來做什么?.mp4 05-課程實戰項目介紹.mp4 06-什么是技術預研? .mp4 07-Node.js開發環境…

Windows 11 安裝 安卓子系統 (WSA)

How to Install Windows Subsystem for Android (WSA) on Windows 11 新手教程:如何安裝Windows 11 安卓子系統 說明 Windows Subsystem for Android 或 WSA 是由 Hyper-V 提供支持的虛擬機,可在 Windows 11 操作系統上運行 Android 應用程序。雖然它需…

【JS】注意考點

1.聲明變量時所遵循的規則: (1)可以使用一個保留關鍵字var同時聲明多個變量 (2)可以在聲明變量的同時對其賦值, (3)如果只是聲明了變量,并未對其賦值,其值就默認為 Undefined。 (4)保留關鍵字var可以用作for語句和for…in語句…

python基礎_類

在Python中,類(Class)是面向對象編程(OOP)的核心概念之一。類提供了一種創建新對象的模板,這些對象通常被稱為類的實例或對象。以下是關于Python類的一些關鍵點和特性: 定義類 類通過class關鍵…

PostgreSQL的系統視圖pg_stat_wal

PostgreSQL的系統視圖pg_stat_wal 在 PostgreSQL 數據庫中,pg_stat_wal 視圖提供了與 WAL(Write-Ahead Logging)日志有關的統計信息。WAL 是 PostgreSQL 用于確保數據一致性和持久性的重要機制。因此,監控和分析 WAL 活動對于數據…

ctfshow-web入門-命令執行(web71-web74)

目錄 1、web71 2、web72 3、web73 4、web74 1、web71 像上一題那樣掃描但是輸出全是問號 查看提示:我們可以結合 exit() 函數執行php代碼讓后面的匹配緩沖區不執行直接退出。 payload: cvar_export(scandir(/));exit(); 同理讀取 flag.txt cinclud…

文華財經博易大師盤立方多空波段止損畫線指標公式

TT:PERIOD7; EMA120:EMA(C,120); RSV:(CLOSE-LLV(LOW,9))/(HHV(HIGH,9)-LLV(LOW,9))*100; K:SMA(RSV,3,1); D:SMA(K,3,1); J:3*K-2*D; DRAWTEXT(TT&&J<0,L,多),VALIGN0; DRAWTEXT(TT&&J>100,H,空),VALIGN2; IF(TT,EMA(C,60),NULL),RGB(255,255,2…

JavaScript數組對象 , 正則對象 , String對象以及自定義對象介紹

1. Array數組對象 數組對象是使用單獨的變量名來存儲一系列的值。 1.1創建一個數組 創建一個數組&#xff0c;有三種方法。 【1】常規方式: let 數組名 new Array();【2】簡潔方式: 推薦使用 let 數組名 new Array(數值1,數值2,...);【3】字面:在js中創建數組使用中括號…

【ubuntu 】使用samba配置共享用戶home目錄和其他具體路徑

目錄 1 安裝samba 2 修改Samba配置文件 3 增加Rose用戶的samba帳號 4 重啟samba 5 測試 1 安裝samba 使用如下命令安裝samba&#xff1a; sudo apt-get updatesudo apt-get install samba openssh-server 2 修改Samba配置文件 sudo cp /etc/samba/smb.conf /etc/samba…

試用筆記之-收錢吧安卓版演示源代碼,收錢吧手機版感受

首先下載&#xff1a; https://download.csdn.net/download/tjsoft/89499105 安卓手機安裝 如果有收錢吧帳號輸入收錢吧帳號和密碼。 如果沒有收錢吧帳號點我的注冊 登錄收錢吧帳號后就可以把手機當成收錢吧POS機用了&#xff0c;還可以掃客服的付款碼哦 源代碼技術交流QQ:42…

Docker安裝MySQL5

Docker安裝MySQL5 前言 MySQL 是一個開源的關系型數據庫管理系統&#xff0c;廣泛用于各種 Web 應用程序的開發和生產環境中。MySQL 5 是 MySQL 數據庫的一個較早版本&#xff0c;雖然不再是最新版本&#xff0c;但仍然被一些項目所使用和支持。 在 Docker 中安裝 MySQL 5 可…

Docker 手冊

幫助命令 docker 命令 --help鏡像命令 docker images (-a所有 &#xff5c; -q只顯示容器的ID) docker search 鏡像名 docker pull 鏡像名&#xff1a;版本號 docker rmi -f ID&#xff5c;鏡像名&#xff1a;版本號 // 刪除本地一個或多個鏡像 docker rmi -f $(docker …

U盤數據恢復實戰指南:原因、方案與預防措施

一、引言&#xff1a;U盤數據恢復概述 在數字化時代&#xff0c;U盤作為一種便攜式存儲設備&#xff0c;廣泛應用于個人和企業中。然而&#xff0c;由于各種原因&#xff0c;U盤數據丟失的問題時有發生。U盤數據恢復技術便是在這種情況下應運而生&#xff0c;它幫助用戶在數據…

TPS61085非同步650kHz,1.2MHz, 18.5V升壓DCDC芯片

1 特點 TPS61085外觀和絲印PMKI 2.3 V 至 6 V 輸入電壓范圍 具有 2.0A 開關電流的 18.5V 升壓轉換器 650kHz/1.2MHz 可選開關頻率 可調軟啟動 熱關斷 欠壓閉鎖 8引腳VSSOP封裝 8引腳TSSOP封裝 2 應用 手持設備 GPS接收器 數碼相機 便攜式應用 DSL調制解調器 PCMCIA卡 TFT LCD…