深度有趣 | 30 快速圖像風格遷移

簡介

使用TensorFlow實現快速圖像風格遷移(Fast Neural Style Transfer)

原理

在之前介紹的圖像風格遷移中,我們根據內容圖片和風格圖片優化輸入圖片,使得內容損失函數和風格損失函數盡可能小

和DeepDream一樣,屬于網絡參數不變,根據損失函數調整輸入數據,因此每生成一張圖片都相當于訓練一個模型,需要很長時間

訓練模型需要很長時間,而使用訓練好的模型進行推斷則很快

使用快速圖像風格遷移可大大縮短生成一張遷移圖片所需的時間,其模型結構如下,包括轉換網絡和損失網絡

風格圖片是固定的,而內容圖片是可變的輸入,因此以上模型用于將任意圖片快速轉換為指定風格的圖片

  • 轉換網絡:參數需要訓練,將內容圖片轉換成遷移圖片
  • 損失網絡:計算遷移圖片和風格圖片之間的風格損失,以及遷移圖片和原始內容圖片之間的內容損失

經過訓練后,轉換網絡所生成的遷移圖片,在內容上和輸入的內容圖片相似,在風格上和指定的風格圖片相似

進行推斷時,僅使用轉換網絡,輸入內容圖片,即可得到對應的遷移圖片

如果有多個風格圖片,對每個風格分別訓練一個模型即可

實現

基于以下兩個項目進行修改,github.com/lengstrom/f…、github.com/hzy46/fast-…

依然通過之前用過的imagenet-vgg-verydeep-19.mat計算內容損失函數和風格損失函數

需要一些圖片作為輸入的內容圖片,對圖片具體內容沒有任何要求,也不需要任何標注,這里選擇使用MSCOCO數據集的train2014部分,cocodataset.org/#download,共82612張圖片

加載庫

# -*- coding: utf-8 -*-import tensorflow as tf
import numpy as np
import cv2
from imageio import imread, imsave
import scipy.io
import os
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
復制代碼

查看風格圖片,共10張

style_images = glob.glob('styles/*.jpg')
print(style_images)
復制代碼

加載內容圖片,去掉黑白圖片,處理成指定大小,暫時不進行歸一化,像素值范圍為0至255之間

def resize_and_crop(image, image_size):h = image.shape[0]w = image.shape[1]if h > w:image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]else:image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]    image = cv2.resize(image, (image_size, image_size))return imageX_data = []
image_size = 256
paths = glob.glob('train2014/*.jpg')
for i in tqdm(range(len(paths))):path = paths[i]image = imread(path)if len(image.shape) < 3:continueX_data.append(resize_and_crop(image, image_size))
X_data = np.array(X_data)
print(X_data.shape)
復制代碼

加載vgg19模型,并定義一個函數,對于給定的輸入,返回vgg19各個層的輸出值,就像在GAN中那樣,通過variable_scope重用實現網絡的重用

vgg = scipy.io.loadmat('imagenet-vgg-verydeep-19.mat')
vgg_layers = vgg['layers']def vgg_endpoints(inputs, reuse=None):with tf.variable_scope('endpoints', reuse=reuse):def _weights(layer, expected_layer_name):W = vgg_layers[0][layer][0][0][2][0][0]b = vgg_layers[0][layer][0][0][2][0][1]layer_name = vgg_layers[0][layer][0][0][0][0]assert layer_name == expected_layer_namereturn W, bdef _conv2d_relu(prev_layer, layer, layer_name):W, b = _weights(layer, layer_name)W = tf.constant(W)b = tf.constant(np.reshape(b, (b.size)))return tf.nn.relu(tf.nn.conv2d(prev_layer, filter=W, strides=[1, 1, 1, 1], padding='SAME') + b)def _avgpool(prev_layer):return tf.nn.avg_pool(prev_layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')graph = {}graph['conv1_1']  = _conv2d_relu(inputs, 0, 'conv1_1')graph['conv1_2']  = _conv2d_relu(graph['conv1_1'], 2, 'conv1_2')graph['avgpool1'] = _avgpool(graph['conv1_2'])graph['conv2_1']  = _conv2d_relu(graph['avgpool1'], 5, 'conv2_1')graph['conv2_2']  = _conv2d_relu(graph['conv2_1'], 7, 'conv2_2')graph['avgpool2'] = _avgpool(graph['conv2_2'])graph['conv3_1']  = _conv2d_relu(graph['avgpool2'], 10, 'conv3_1')graph['conv3_2']  = _conv2d_relu(graph['conv3_1'], 12, 'conv3_2')graph['conv3_3']  = _conv2d_relu(graph['conv3_2'], 14, 'conv3_3')graph['conv3_4']  = _conv2d_relu(graph['conv3_3'], 16, 'conv3_4')graph['avgpool3'] = _avgpool(graph['conv3_4'])graph['conv4_1']  = _conv2d_relu(graph['avgpool3'], 19, 'conv4_1')graph['conv4_2']  = _conv2d_relu(graph['conv4_1'], 21, 'conv4_2')graph['conv4_3']  = _conv2d_relu(graph['conv4_2'], 23, 'conv4_3')graph['conv4_4']  = _conv2d_relu(graph['conv4_3'], 25, 'conv4_4')graph['avgpool4'] = _avgpool(graph['conv4_4'])graph['conv5_1']  = _conv2d_relu(graph['avgpool4'], 28, 'conv5_1')graph['conv5_2']  = _conv2d_relu(graph['conv5_1'], 30, 'conv5_2')graph['conv5_3']  = _conv2d_relu(graph['conv5_2'], 32, 'conv5_3')graph['conv5_4']  = _conv2d_relu(graph['conv5_3'], 34, 'conv5_4')graph['avgpool5'] = _avgpool(graph['conv5_4'])return graph
復制代碼

選擇一張風格圖,減去通道顏色均值后,得到風格圖片在vgg19各個層的輸出值,計算四個風格層對應的Gram矩陣

style_index = 1
X_style_data = resize_and_crop(imread(style_images[style_index]), image_size)
X_style_data = np.expand_dims(X_style_data, 0)
print(X_style_data.shape)MEAN_VALUES = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3))X_style = tf.placeholder(dtype=tf.float32, shape=X_style_data.shape, name='X_style')
style_endpoints = vgg_endpoints(X_style - MEAN_VALUES)
STYLE_LAYERS = ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3']
style_features = {}sess = tf.Session()
for layer_name in STYLE_LAYERS:features = sess.run(style_endpoints[layer_name], feed_dict={X_style: X_style_data})features = np.reshape(features, (-1, features.shape[3]))gram = np.matmul(features.T, features) / features.sizestyle_features[layer_name] = gram
復制代碼

定義轉換網絡,典型的卷積、殘差、逆卷積結構,內容圖片輸入之前也需要減去通道顏色均值

batch_size = 4
X = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3], name='X')
k_initializer = tf.truncated_normal_initializer(0, 0.1)def relu(x):return tf.nn.relu(x)def conv2d(inputs, filters, kernel_size, strides):p = int(kernel_size / 2)h0 = tf.pad(inputs, [[0, 0], [p, p], [p, p], [0, 0]], mode='reflect')return tf.layers.conv2d(inputs=h0, filters=filters, kernel_size=kernel_size, strides=strides, padding='valid', kernel_initializer=k_initializer)def deconv2d(inputs, filters, kernel_size, strides):shape = tf.shape(inputs)height, width = shape[1], shape[2]h0 = tf.image.resize_images(inputs, [height * strides * 2, width * strides * 2], tf.image.ResizeMethod.NEAREST_NEIGHBOR)return conv2d(h0, filters, kernel_size, strides)def instance_norm(inputs):return tf.contrib.layers.instance_norm(inputs)def residual(inputs, filters, kernel_size):h0 = relu(conv2d(inputs, filters, kernel_size, 1))h0 = conv2d(h0, filters, kernel_size, 1)return tf.add(inputs, h0)with tf.variable_scope('transformer', reuse=None):h0 = tf.pad(X - MEAN_VALUES, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='reflect')h0 = relu(instance_norm(conv2d(h0, 32, 9, 1)))h0 = relu(instance_norm(conv2d(h0, 64, 3, 2)))h0 = relu(instance_norm(conv2d(h0, 128, 3, 2)))for i in range(5):h0 = residual(h0, 128, 3)h0 = relu(instance_norm(deconv2d(h0, 64, 3, 2)))h0 = relu(instance_norm(deconv2d(h0, 32, 3, 2)))h0 = tf.nn.tanh(instance_norm(conv2d(h0, 3, 9, 1)))h0 = (h0 + 1) / 2 * 255.shape = tf.shape(h0)g = tf.slice(h0, [0, 10, 10, 0], [-1, shape[1] - 20, shape[2] - 20, -1], name='g')
復制代碼

將轉換網絡的輸出即遷移圖片,以及原始內容圖片都輸入到vgg19,得到各自對應層的輸出,計算內容損失函數

CONTENT_LAYER = 'conv3_3'
content_endpoints = vgg_endpoints(X - MEAN_VALUES, True)
g_endpoints = vgg_endpoints(g - MEAN_VALUES, True)def get_content_loss(endpoints_x, endpoints_y, layer_name):x = endpoints_x[layer_name]y = endpoints_y[layer_name]return 2 * tf.nn.l2_loss(x - y) / tf.to_float(tf.size(x))content_loss = get_content_loss(content_endpoints, g_endpoints, CONTENT_LAYER)
復制代碼

根據遷移圖片和風格圖片在指定風格層的輸出,計算風格損失函數

style_loss = []
for layer_name in STYLE_LAYERS:layer = g_endpoints[layer_name]shape = tf.shape(layer)bs, height, width, channel = shape[0], shape[1], shape[2], shape[3]features = tf.reshape(layer, (bs, height * width, channel))gram = tf.matmul(tf.transpose(features, (0, 2, 1)), features) / tf.to_float(height * width * channel)style_gram = style_features[layer_name]style_loss.append(2 * tf.nn.l2_loss(gram - style_gram) / tf.to_float(tf.size(layer)))style_loss = tf.reduce_sum(style_loss)
復制代碼

計算全變差正則,得到總的損失函數

def get_total_variation_loss(inputs):h = inputs[:, :-1, :, :] - inputs[:, 1:, :, :]w = inputs[:, :, :-1, :] - inputs[:, :, 1:, :]return tf.nn.l2_loss(h) / tf.to_float(tf.size(h)) + tf.nn.l2_loss(w) / tf.to_float(tf.size(w)) total_variation_loss = get_total_variation_loss(g)content_weight = 1
style_weight = 250
total_variation_weight = 0.01loss = content_weight * content_loss + style_weight * style_loss + total_variation_weight * total_variation_loss
復制代碼

定義優化器,通過調整轉換網絡中的參數降低總損失

vars_t = [var for var in tf.trainable_variables() if var.name.startswith('transformer')]
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss, var_list=vars_t)
復制代碼

訓練模型,每輪訓練結束后,用一張測試圖片進行測試,并且將一些tensor的值寫入events文件,便于使用tensorboard查看

style_name = style_images[style_index]
style_name = style_name[style_name.find('/') + 1:].rstrip('.jpg')
OUTPUT_DIR = 'samples_%s' % style_name
if not os.path.exists(OUTPUT_DIR):os.mkdir(OUTPUT_DIR)tf.summary.scalar('losses/content_loss', content_loss)
tf.summary.scalar('losses/style_loss', style_loss)
tf.summary.scalar('losses/total_variation_loss', total_variation_loss)
tf.summary.scalar('losses/loss', loss)
tf.summary.scalar('weighted_losses/weighted_content_loss', content_weight * content_loss)
tf.summary.scalar('weighted_losses/weighted_style_loss', style_weight * style_loss)
tf.summary.scalar('weighted_losses/weighted_total_variation_loss', total_variation_weight * total_variation_loss)
tf.summary.image('transformed', g)
tf.summary.image('origin', X)
summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(OUTPUT_DIR)sess.run(tf.global_variables_initializer())
losses = []
epochs = 2X_sample = imread('sjtu.jpg')
h_sample = X_sample.shape[0]
w_sample = X_sample.shape[1]for e in range(epochs):data_index = np.arange(X_data.shape[0])np.random.shuffle(data_index)X_data = X_data[data_index]for i in tqdm(range(X_data.shape[0] // batch_size)):X_batch = X_data[i * batch_size: i * batch_size + batch_size]ls_, _ = sess.run([loss, optimizer], feed_dict={X: X_batch})losses.append(ls_)if i > 0 and i % 20 == 0:writer.add_summary(sess.run(summary, feed_dict={X: X_batch}), e * X_data.shape[0] // batch_size + i)writer.flush()print('Epoch %d Loss %f' % (e, np.mean(losses)))losses = []gen_img = sess.run(g, feed_dict={X: [X_sample]})[0]gen_img = np.clip(gen_img, 0, 255)result = np.zeros((h_sample, w_sample * 2, 3))result[:, :w_sample, :] = X_sample / 255.result[:, w_sample:, :] = gen_img[:h_sample, :w_sample, :] / 255.plt.axis('off')plt.imshow(result)plt.show()imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % e), result)
復制代碼

保存模型

saver = tf.train.Saver()
saver.save(sess, os.path.join(OUTPUT_DIR, 'fast_style_transfer'))
復制代碼

測試圖片依舊是之前用過的交大廟門

風格遷移結果

訓練過程中可以使用tensorboard查看訓練過程

tensorboard --logdir=samples_starry
復制代碼

在單機上使用以下代碼即可快速完成風格遷移,在CPU上也只需要10秒左右

# -*- coding: utf-8 -*-import tensorflow as tf
import numpy as np
from imageio import imread, imsave
import os
import timedef the_current_time():print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))))style = 'wave'
model = 'samples_%s' % style
content_image = 'sjtu.jpg'
result_image = 'sjtu_%s.jpg' % style
X_image = imread(content_image)sess = tf.Session()
sess.run(tf.global_variables_initializer())saver = tf.train.import_meta_graph(os.path.join(model, 'fast_style_transfer.meta'))
saver.restore(sess, tf.train.latest_checkpoint(model))graph = tf.get_default_graph()
X = graph.get_tensor_by_name('X:0')
g = graph.get_tensor_by_name('transformer/g:0')the_current_time()gen_img = sess.run(g, feed_dict={X: [X_image]})[0]
gen_img = np.clip(gen_img, 0, 255) / 255.
imsave(result_image, gen_img)the_current_time()
復制代碼

對于其他風格圖片,用相同方法訓練對應模型即可

參考

  • Perceptual Losses for Real-Time Style Transfer and Super-Resolution:arxiv.org/abs/1603.08…
  • Fast Style Transfer in TensorFlow:github.com/lengstrom/f…
  • A Tensorflow Implementation for Fast Neural Style:github.com/hzy46/fast-…

視頻講解課程

深度有趣(一)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/450622.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/450622.shtml
英文地址,請注明出處:http://en.pswp.cn/news/450622.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

轉型從思維習慣的轉變開始

摘要&#xff1a;首先建議大家不要輕易轉向管理崗位&#xff0c;要認清自己是否適合做管理。轉型過程中應把握好幾點&#xff1a;良好的技術基礎&#xff0c;它是贏得團隊信任的前提&#xff0c;是把握團隊整體方向的關鍵&#xff1b;培養大局觀&#xff0c;只有站得高才能看得…

數據庫小知識點(一直更新)

一、mysql查詢是否含有某字段&#xff1a; mysql數據庫查詢帶有某個字段的所有表名 SELECT * FROM information_schema.columns WHERE column_namecolumn_name; oracle數據庫查詢帶有某個字段的所有表名 select column_name,table_name,from user_tab_columns where column_n…

其他運算符

原文地址&#xff1a;https://wangdoc.com/javascript/ void運算符 void運算符的作用是執行一個表達式&#xff0c;然后不返回任何值&#xff0c;或者說返回undefined。 void 0 // undefined void(0) // undefined 上面是void運算符的兩種寫法&#xff0c;都正確。建議采用后一…

git pull --rebase 做了什么? 以及 Cannot rebase: You have unstaged changes 解決辦法

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 最近剛學 git rebase&#xff0c;覺得很牛逼的樣子&#xff0c; 結果今天就被打臉了。 git pull --rebase 1 報錯&#xff1a; Cann…

vue如何實現單頁緩存方案分析

實現全站的頁面緩存&#xff0c;前進刷新&#xff0c;返回走緩存&#xff0c;并且能記住上一頁的滾動位置&#xff0c;參考了很多技術實現&#xff0c;github上的導航組件實現的原理要么使用的keep-alive&#xff0c;要么參考了keep-alive的源碼&#xff0c;但是只用keep-alive…

C語言常用函數簡介

一、字符測試函數 isupper()測試字符是否為大寫英文字ispunct()測試字符是否為標點符號或特殊符號isspace()測試字符是否為空格字符isprint()測試字符是否為可打印字符islower()測試字符是否為小寫字母isgraphis()測試字符是否為可打印字符isdigit()測試字符是否為阿拉伯數字i…

thinkphp如何增加session的過期時間

原理&#xff1a;我們都知道session是建立在cookie的基礎上的&#xff0c;如果瀏覽器cookie清楚了&#xff0c;則tp就會重新建立一個session。 操作&#xff1a;直接增加瀏覽器的cookie的到期時間&#xff0c;就可以使tp的session增加。

需求心得

電路圖是人們為研究、工程規劃的需要。我們組項目需要設計實現一個矢量圖編輯器。在通過對變電站的電路圖進行矢量繪圖后&#xff0c;就可以通過矢量圖的縮放詳細信息。在分析需求后&#xff0c;寫下心得&#xff01; 分析需求主要有一下幾個步驟&#xff1a; 1. 獲取和引導需求…

IT部門不應該是一個后勤部門

管理上最大的問題在于不重視預算與核算的管理。從管理層到員工&#xff0c;很少有經營的念頭&#xff0c;只是一味地埋頭做事。西方企業總結了當今幾百年的經營理念&#xff0c;最終把企業一切活動的評價都歸結到唯一的、可度量的標準上&#xff1a;錢來度量。 by——華為 作為…

you need to resolve your current index first 解決辦法

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 從一個分支A切換到另一個分支B后&#xff0c;對切換后的B分支進行pull操作&#xff0c;因為pull操作實際上包含了fetchmerge操作&#x…

C語言,一種如此美麗的語言

人們說足球是一種優美的體育運動&#xff0c;而當我們在綠茵場上看到羅納爾多那行云流水的帶球動作時&#xff0c;我們不能不承認這種說法。然而&#xff0c;對于我來說&#xff0c;這種運動之所以如此的賞心悅目&#xff0c;跟那些乖張的天才球星們關系并不是那么大&#xff0…

基于websocket的聊天實現邏輯(springboot)

websocket的知識點&#xff1a;當用戶建立socket連接請求之后&#xff0c;服務器會給客戶段建一個session&#xff08;非httpsession&#xff09;,這是是對客戶端的唯一識別碼&#xff0c;用于消息通信 第二上流程圖&#xff0c;流程圖解釋&#xff1a;用戶1要給用戶2發送消息…

Elasticsearch就這么簡單

Elasticsearch就這么簡單 Lucene就這么簡單轉載于:https://www.cnblogs.com/gaogaoyanjiu/p/9908520.html

大學生學編程系列」第五篇:自學編程需要多久才能找到工作?

很多編程初學者都會有這種疑問&#xff0c;自學學到什么程度或者學多久能夠找到工作&#xff0c;這種問題沒有統一答案&#xff0c;因為每個人的出發時候的基礎以及在學習過程中掌握的程度不盡相同&#xff0c;也會導致結果不一樣&#xff0c;只能說要看個人的造化了&#xff0…

chrome 谷歌瀏覽器怎么添加Axure擴展

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 工具/原料 谷歌瀏覽器Axure RP Extension for Chrome方法/步驟 百度搜索Axure RP&#xff0c;下載Axure RP&#xff0c;并進行安裝 安裝后…

配置nginx-rtmp流媒體服務器(寶塔面板配置教程)

參考文檔&#xff1a;https://www.kancloud.cn/jiangguowu/kfjsdkfjskd/1209896 1.在寶塔面板中安裝帶nginx的服務器 2.在寶塔面板中卸載nginx&#xff08;因為nginx-rtmp和nginx的配置不同&#xff0c;并且寶塔面板中不支持安裝nginx-rtmp&#xff09; 3.開始預下載nginx &a…

C語言的應用范圍和發展前途簡介

C一般用來底層開發&#xff0c;如操作系統&#xff0c;嵌入式開發&#xff0c;或者要求效率&#xff0c;高可移植性的地方。C對人要求很高&#xff0c;程序員要考慮的地方太多。他的特點就是每一個字節都可以精確控制&#xff0c;不象C&#xff0c;編譯器為你自動加的東西太多&…

css控制div等比高度

在移動端開發中&#xff0c;在banner輪播圖未加載出來之前&#xff0c;banner層是不占文檔流高度的&#xff0c;當從服務器獲取完banner數據&#xff0c;展示的時候&#xff0c;banner層因為有了內容 所以會撐開&#xff0c;導致banner層下面的內容也隨之移動&#xff0c;為解決…

2018杭州云棲大會,梁勝博士的演講PPT來啦!

2019獨角獸企業重金招聘Python工程師標準>>> 2018杭州云棲大會已經結束&#xff0c;Rancher作為阿里云的緊密合作伙伴&#xff0c;Rancher Labs聯合創始人兼CEO梁勝博士&#xff0c;在9月21日上午受邀出席大會并作題為**“如何能讓每個人都用Kubernetes和Service Me…

No Identifier specified for entity的解決辦法

見&#xff1a;http://blog.csdn.net/u011617875/article/details/18550305 前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 No Identifier specified for entity的錯誤IdGeneratedVal…