loss低但精確度低_低光照圖像增強網絡-RetinexNet(model.py解析【2】)

51b2c2d4-d92e-eb11-8da9-e4434bdf6706.png

53b2c2d4-d92e-eb11-8da9-e4434bdf6706.png

論文地址:https://arxiv.org/pdf/1808.04560.pdf

代碼地址:https://github.com/weichen582/RetinexNet

解析目錄:https://zhuanlan.zhihu.com/p/88761829


整個模型架構被實現為一個類:

class lowlight_enhance(object):

其構造函數實現了網絡結構的搭建、損失函數的定義、訓練的配置和參數的初始化,具體如下。

網絡結構的搭建(該部分包括低/正常光照圖像輸入的定義以及Decom-Net、Enhance-Net和重建這三部分的對接,注意這里并沒有對Rlow進行去噪的部分):

# build the model
self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')
self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)
[R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)I_delta = RelightNet(I_low, R_low)I_low_3 = concat([I_low, I_low, I_low])
I_high_3 = concat([I_high, I_high, I_high])
I_delta_3 = concat([I_delta, I_delta, I_delta])self.output_R_low = R_low
self.output_I_low = I_low_3
self.output_I_delta = I_delta_3
self.output_S = R_low * I_delta_3

損失函數的定義(該部分包括低/正常光照圖像的重建損失、反射分量一致性損失、光照分量平滑損失以及最后分別計算的Decom-Net和Enhance-Net的總損失):

# loss
self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low))
self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high))
self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low))
self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high))
self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high))self.Ismooth_loss_low = self.smooth(I_low, R_low)
self.Ismooth_loss_high = self.smooth(I_high, R_high)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss
self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta

訓練的配置(該部分包括學習率以及Decom-Net和Enhance-Net的優化器設置):

self.lr = tf.placeholder(tf.float32, name='learning_rate')
optimizer = tf.train.AdamOptimizer(self.lr, name='AdamOptimizer')self.var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
self.var_Relight = [var for var in tf.trainable_variables() if 'RelightNet' in var.name]self.train_op_Decom = optimizer.minimize(self.loss_Decom, var_list = self.var_Decom)
self.train_op_Relight = optimizer.minimize(self.loss_Relight, var_list = self.var_Relight)

訓練參數的初始化:

self.sess.run(tf.global_variables_initializer())self.saver_Decom = tf.train.Saver(var_list = self.var_Decom)
self.saver_Relight = tf.train.Saver(var_list = self.var_Relight)print("[*] Initialize model successfully...")

接下來是該類的一些成員函數。

def gradient(self, input_tensor, direction):self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])if direction == "x":kernel = self.smooth_kernel_xelif direction == "y":kernel = self.smooth_kernel_yreturn tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))

該函數實現的是通過與指定梯度算子進行卷積的方式求圖像的水平/垂直梯度圖。

def ave_gradient(self, input_tensor, direction):return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')

該函數實現的是通過平均池化的方式來對圖像的水平/垂直梯度圖進行平滑。

def smooth(self, input_I, input_R):input_R = tf.image.rgb_to_grayscale(input_R)return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))

該函數是對光照分量平滑損失的具體實現(可對應原論文中的公式來看)。

def evaluate(self, epoch_num, eval_low_data, sample_dir, train_phase):print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch_num))for idx in range(len(eval_low_data)):input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)if train_phase == "Decom":result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.input_low: input_low_eval})if train_phase == "Relight":result_1, result_2 = self.sess.run([self.output_S, self.output_I_delta], feed_dict={self.input_low: input_low_eval})save_images(os.path.join(sample_dir, 'eval_%s_%d_%d.png' % (train_phase, idx + 1, epoch_num)), result_1, result_2)

該函數是對訓練epoch_num次后的Decom-Net/Enhance-Net模型進行評估,并保存評估結果圖。

接下來是關于模型的訓練:

def train(self, train_low_data, train_high_data, eval_low_data, batch_size, patch_size, epoch, lr, sample_dir, ckpt_dir, eval_every_epoch, train_phase):

該函數中包含了預訓練模型的加載、數據的讀取與處理、模型的訓練、評估和保存這幾個部分。

assert len(train_low_data) == len(train_high_data)
numBatch = len(train_low_data) // int(batch_size)

檢查所有需要參與訓練的低/正常光照樣本數量是否一致,若一致則計算訓練集含有的batch數量。

# load pretrained model
if train_phase == "Decom":train_op = self.train_op_Decomtrain_loss = self.loss_Decomsaver = self.saver_Decom
elif train_phase == "Relight":train_op = self.train_op_Relighttrain_loss = self.loss_Relightsaver = self.saver_Relightload_model_status, global_step = self.load(saver, ckpt_dir)
if load_model_status:iter_num = global_stepstart_epoch = global_step // numBatchstart_step = global_step % numBatchprint("[*] Model restore success!")
else:iter_num = 0start_epoch = 0start_step = 0
print("[*] Not find pretrained model!")

若存在Decom-Net/Enhance-Net對應的預訓練模型,則進行加載;否則從頭開始訓練。

# generate data for a batch
batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
for patch_id in range(batch_size):h, w, _ = train_low_data[image_id].shapex = random.randint(0, h - patch_size)y = random.randint(0, w - patch_size)rand_mode = random.randint(0, 7)batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)image_id = (image_id + 1) % len(train_low_data)if image_id == 0:tmp = list(zip(train_low_data, train_high_data))random.shuffle(list(tmp))train_low_data, train_high_data = zip(*tmp)

順序讀取訓練圖像,在每次讀取的低/正常光照圖像對上隨機取patch,并進行數據擴增(具體見 中對函數data_augmentation的描述)。這里,應當注意的是,訓練數據每滿一個batch時將會重新打亂整個訓練集。

# train
_, loss = self.sess.run([train_op, train_loss], feed_dict={self.input_low: batch_input_low, self.input_high: batch_input_high, self.lr: lr[epoch]})print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
iter_num += 1

訓練一個iter并打印相關信息。

# evalutate the model and save a checkpoint file for it
if (epoch + 1) % eval_every_epoch == 0:self.evaluate(epoch + 1, eval_low_data, sample_dir=sample_dir, train_phase=train_phase)self.save(saver, iter_num, ckpt_dir, "RetinexNet-%s" % train_phase)

每訓練eval_every_epoch次評估并保存一次模型。

保存指定iter的模型:

def save(self, saver, iter_num, ckpt_dir, model_name):if not os.path.exists(ckpt_dir):os.makedirs(ckpt_dir)print("[*] Saving model %s" % model_name)saver.save(self.sess, os.path.join(ckpt_dir, model_name), global_step=iter_num)

加載最新的模型:

def load(self, saver, ckpt_dir):ckpt = tf.train.get_checkpoint_state(ckpt_dir)if ckpt and ckpt.model_checkpoint_path:full_path = tf.train.latest_checkpoint(ckpt_dir)try:global_step = int(full_path.split('/')[-1].split('-')[-1])except ValueError:global_step = Nonesaver.restore(self.sess, full_path)return True, global_stepelse:print("[*] Failed to load model from %s" % ckpt_dir)return False, 0

最后是關于模型的測試(其中test_high_data并沒有用到):

def test(self, test_low_data, test_high_data, test_low_data_names, save_dir, decom_flag):

該函數中包含了模型的加載、模型的測試和結果圖的保存這幾個部分。

tf.global_variables_initializer().run()print("[*] Reading checkpoint...")
load_model_status_Decom, _ = self.load(self.saver_Decom, './model/Decom')
load_model_status_Relight, _ = self.load(self.saver_Relight, './model/Relight')
if load_model_status_Decom and load_model_status_Relight:print("[*] Load weights successfully...")

初始化所有參數并加載最新的Decom-Net和Enhance-Net模型。

print("[*] Testing...")
for idx in range(len(test_low_data)):print(test_low_data_names[idx])[_, name] = os.path.split(test_low_data_names[idx])suffix = name[name.find('.') + 1:]name = name[:name.find('.')]input_low_test = np.expand_dims(test_low_data[idx], axis=0)[R_low, I_low, I_delta, S] = self.sess.run([self.output_R_low, self.output_I_low, self.output_I_delta, self.output_S], feed_dict = {self.input_low: input_low_test})if decom_flag == 1:save_images(os.path.join(save_dir, name + "_R_low." + suffix), R_low)save_images(os.path.join(save_dir, name + "_I_low." + suffix), I_low)save_images(os.path.join(save_dir, name + "_I_delta." + suffix), I_delta)save_images(os.path.join(save_dir, name + "_S." + suffix), S)

遍歷測試樣本進行測試,并保存最終結果圖(可自行指定是否保存Decom-Net的分解結果)。

歡迎關注公眾號:huangxiaobai880

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/397330.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/397330.shtml
英文地址,請注明出處:http://en.pswp.cn/news/397330.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

計算機應用發表論文,計算機應用論文發表.docx

計算機應用論文發表1在工程項目管理中應用計算機技術存在的問題計算機軟件是計算機運行的重要保障,一個好的計算機軟件直接決定計算機技術在工程項目管理的高效應用。但由于市場上計算機軟件種類繁多,質量好壞不一,質量好的價格高&#xff0c…

添加dubbo xsd的支持

使用dubbo時遇到問題: org.xml.sax.SAXParseException: schema_reference.4: Failed to read schema document http://code.alibabatech.com/schema/dubbo/dubbo.xsd, because 1) could not find the document; 2) the document could not be read; 3) the root ele…

byte數組穿換成pcm格式_形象地介紹DSD的編解碼原理及和PCM的區別

一直有人不清楚DSD到底是啥原理,和MP3, FLAC, APE, WAV等基于PCM編碼技術的音頻格式又有啥區別。特意做了兩張圖說明一下。圖一是是由很多黑點構成的蒙娜麗莎頭像,點擊看大圖就知道是沒有灰階只有黑白兩色。但是人眼是可以看到有豐富的灰階的。這和DSD一…

最大熵對應的概率分布

最大熵對應的概率分布 最大熵定理 設 \(X \sim p(x)\) 是一個連續型隨機變量,其微分熵定義為\[ h(X) - \int p(x)\log p(x) dx \]其中,\(\log\) 一般取自然對數 \(\ln\), 單位為 奈特(nats)。 考慮如下優化問題:\[ \b…

UBUNTU : Destination Host Unreachable

介紹我的系統的搭建的方式: WIN7 64 VMWARE STATION,方式是進行橋接的方式。最近突然出現了問題,Ubuntu ping 外網或者 PING WIN 7 的時候,出現 Destination Host Unreachable 的錯誤;想著去修改網卡的鏈接形式: 編輯…

焦作師范高等專科學校對口計算機分數線,焦作師范高等專科學校錄取分數線2018...

焦作師范高等專科學校錄取分數線20182018年 電子信息工程技術 理科 332 3602018年 物聯網應用技術 文科 391 4082018年 物聯網應用技術 理科 328 3692018年 學前教育 文科 388 4022018年 學前教育 理科 324 3512018年 移動應用開發 文科 02018年 移動應用開發 理科 305 3322018…

在Spring boot 配置過濾器(filter)

在spring boot 配置servlet filter 邏輯上與配置spring 是一樣的。 不過相比spring 更加簡化配置的難度。 這里只需要兩步1 創建一個自定義顧慮器并繼承spring filter 例如OncePerRequestFilterpublic class AuthenticationFilter extends OncePerRequestFilter{private final …

Flink之狀態之狀態存儲 state backends

流計算中可能有各種方式來保存狀態: 窗口操作使用 了KV操作的函數繼承了CheckpointedFunction的函數當開始做checkpointing的時候,狀態會被持久化到checkpoints里來規避數據丟失和狀態恢復。選擇的狀態存儲策略不同,會導致狀態持久化如何和ch…

怎么把分開的pdf放在一起_糖和鹽混在一起了要怎么分開?| 趣問萬物

趣 問 萬 物來源:把科學帶回家撰文:Mirror如何分離糖和鹽?圖源:Pixabay小手一抖,不小心把糖(蔗糖)和鹽(氯化鈉)混在一塊兒了該怎么辦?趁著光棍節,就讓我們吃飽了撐著研究研究把糖和鹽拆散的N種方…

《JavaScript DOM編程藝術》筆記

1. 把<script>標簽放到HTML文檔的最后&#xff0c;<body>標簽之前能使瀏覽器更快地加載頁面。 2. nodeType的常見取值 元素節點(1) 屬性節點(2) 文本節點(3) 3. <a href"http://www.baidu.com" οnclick"popUp(this.href);return false;"&g…

maven POM.xml內的標簽大全詳解

<project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0http://maven.apache.org/maven-v4_0_0.xsd"><!--父項目的坐標。如果…

常熟理工學院計算機考研,2018江蘇專轉本考生必看-常熟理工學院介紹

原標題&#xff1a;2018江蘇專轉本考生必看-常熟理工學院介紹這次輪到默默學介紹常熟理工學院啦&#xff01;今年常熟理工學院有個專轉本的學生&#xff0c;也是默默學專轉本視頻課程考上常熟理工的一個學生&#xff0c;叫黃群超&#xff0c;當年專轉本計算機也考了八九十分吧&…

.net中調用esb_大型ESB服務總線平臺服務運行分析和監控預警實踐

今天準備談下ESB總線平臺建設項目中的服務運行統計分析&#xff0c;服務心跳監測&#xff0c;服務監控預警方面的設計和實現。可以看到&#xff0c;在一個ESB服務總線平臺上線后&#xff0c;SOA治理管控就變得相當重要&#xff0c;而這些運行監控分析本身也是提升ESB總線平臺高…

使用Maven創建Web項目后,jsp引入靜態文件提示報錯。JSP 報錯:javax.servlet.ServletException cannot be resolved to a type...

用maven創建多模塊的web工程后&#xff0c;不同于直接創建普通的web工程。 1、在普通的web工程創建后&#xff0c;在項目中會有tomcat等服務器的jar包&#xff0c;這時創建JSP文件肯定是沒有錯的&#xff1b; 2、即使是使用maven創建的單模塊的web工程&#xff0c;也會自動的在…

ES6之路第十三篇:Iterator和for...of循環

Iterator(遍歷器)的概念 JavaScript 原有的表示“集合”的數據結構&#xff0c;主要是數組&#xff08;Array&#xff09;和對象&#xff08;Object&#xff09;&#xff0c;ES6 又添加了Map和Set。這樣就有了四種數據集合&#xff0c;用戶還可以組合使用它們&#xff0c;定義自…

MyBatis 特殊字符處理

http://blog.csdn.net/zheng0518/article/details/10449549

計算機操作系統實驗銀行家算法,實驗六 銀行家算法(下)

實驗六 銀行家算法(下)一、實驗說明實驗說明&#xff1a;本次實驗主要是對銀行家算法進行進一步的實踐學習&#xff0c;掌握銀行家算法的整體流程&#xff0c;理解程序測試時每一步的當前狀態&#xff0c;能對當前的資源分配進行預判斷。二、實驗要求1、獲取源代碼2、看懂大致框…

什么原因導致芯片短路_華為為什么突然大量用起了聯發科芯片,或是這三個產品策略原因...

經常關注數碼圈的都知道&#xff0c;近幾年來&#xff0c;隨著華為自研能力的提升&#xff0c;華為幾乎很少采購第三方芯片&#xff0c;近幾年來的絕大多數華為手機&#xff0c;幾乎都是用的自研芯片麒麟系列。并沒有像其它國產品牌那樣用聯發科或者高通的芯片。不過今年卻大不…

如何運行vue項目(維護他人的項目)

假如你是個小白&#xff0c;在公司接手他人的項目&#xff0c;這個時候&#xff0c;該怎么將這個項目跑通&#xff1f; 前提&#xff1a; 首先&#xff0c;這個教程主要針對vue小白&#xff0c;并且不知道安裝node.js環境的。言歸正傳&#xff0c;下面開始教程&#xff1a;在維…

進程操作

2019獨角獸企業重金招聘Python工程師標準>>> 一、創建一個進程 進程是系統中最基本的執行單位。Linux系統允許任何一個用戶進程創建一個子進程&#xff0c;創建之后&#xff0c;子進程存在于系統之中并獨立于父進程。 關于父進程與子進程這兩個概念&#xff0c;除了…