圖像修復:深度學習實現老照片劃痕修復+老照片上色

第一步:介紹

1)GLCIC-PyTorch是一個基于PyTorch的開源項目,它實現了“全局和局部一致性圖像修復”方法。該方法由Iizuka等人提出,主要用于圖像修復任務,能夠有效地恢復圖像中被遮擋或損壞的部分。項目使用Python編程語言編寫,并依賴于PyTorch深度學習框架。

2)? DDColor 是最新的 SOTA 圖像上色算法,能夠對輸入的黑白圖像生成自然生動的彩色結果,使用 UNet 結構的骨干網絡和圖像解碼器分別實現圖像特征提取和特征圖上采樣,并利用 Transformer 結構的顏色解碼器完成基于視覺語義的顏色查詢,最終聚合輸出彩色通道預測結果。

核心思想:先GLCIC修復劃痕,再DDColor進行上色

第二步:網絡結構

1) GLCIC項目的核心功能是圖像修復,它通過訓練一個生成網絡(Completion Network)和一個判別網絡(Context Discriminator)來實現。生成網絡負責完成圖像修復任務,而判別網絡則用于提高修復質量,確保修復后的圖像在全局和局部上都與原始圖像保持一致性。主要特點如下:

????????圖像修復:利用生成網絡對圖像中缺失的部分進行修復。
????????全局與局部一致性:確保修復后的圖像既在全局上與原圖一致,又在局部細節上保持連貫。
????????判別網絡輔助:通過判別網絡對生成圖像進行評估,以提升修復質量。

2)DDColor算法整體流程如下圖,使用?UNet?結構的骨干網絡和圖像解碼器分別實現圖像特征提取和特征圖上采樣,并利用 Transformer 結構的顏色解碼器完成基于視覺語義的顏色查詢,最終聚合輸出彩色通道預測結果。

第三步:模型代碼展示

import os
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
import numpy as npfrom basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.img_util import tensor_lab2rgb
from basicsr.utils.dist_util import master_only
from basicsr.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
from basicsr.metrics.custom_fid import INCEPTION_V3_FID, get_activations, calculate_activation_statistics, calculate_frechet_distance
from basicsr.utils.color_enhance import color_enhacne_blend@MODEL_REGISTRY.register()
class ColorModel(BaseModel):"""Colorization model for single image colorization."""def __init__(self, opt):super(ColorModel, self).__init__(opt)# define network net_gself.net_g = build_network(opt['network_g'])self.net_g = self.model_to_device(self.net_g)self.print_network(self.net_g)# load pretrained model for net_gload_path = self.opt['path'].get('pretrain_network_g', None)if load_path is not None:param_key = self.opt['path'].get('param_key_g', 'params')self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)if self.is_train:self.init_training_settings()def init_training_settings(self):train_opt = self.opt['train']self.ema_decay = train_opt.get('ema_decay', 0)if self.ema_decay > 0:logger = get_root_logger()logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')# define network net_g with Exponential Moving Average (EMA)# net_g_ema is used only for testing on one GPU and saving# There is no need to wrap with DistributedDataParallelself.net_g_ema = build_network(self.opt['network_g']).to(self.device)# load pretrained modelload_path = self.opt['path'].get('pretrain_network_g', None)if load_path is not None:self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')else:self.model_ema(0)  # copy net_g weightself.net_g_ema.eval()# define network net_dself.net_d = build_network(self.opt['network_d'])self.net_d = self.model_to_device(self.net_d)self.print_network(self.net_d)# load pretrained model for net_dload_path = self.opt['path'].get('pretrain_network_d', None)if load_path is not None:param_key = self.opt['path'].get('param_key_d', 'params')self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)self.net_g.train()self.net_d.train()# define lossesif train_opt.get('pixel_opt'):self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)else:self.cri_pix = Noneif train_opt.get('perceptual_opt'):self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)else:self.cri_perceptual = Noneif train_opt.get('gan_opt'):self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)else:self.cri_gan = Noneif self.cri_pix is None and self.cri_perceptual is None:raise ValueError('Both pixel and perceptual losses are None.')if train_opt.get('colorfulness_opt'):self.cri_colorfulness = build_loss(train_opt['colorfulness_opt']).to(self.device)else:self.cri_colorfulness = None# set up optimizers and schedulersself.setup_optimizers()self.setup_schedulers()# set real dataset cache for fid metric computingself.real_mu, self.real_sigma = None, Noneif self.opt['val'].get('metrics') is not None and self.opt['val']['metrics'].get('fid') is not None:self._prepare_inception_model_fid()def setup_optimizers(self):train_opt = self.opt['train']# optim_params_g = []# for k, v in self.net_g.named_parameters():#     if v.requires_grad:#         optim_params_g.append(v)#     else:#         logger = get_root_logger()#         logger.warning(f'Params {k} will not be optimized.')optim_params_g = self.net_g.parameters()# optimizer goptim_type = train_opt['optim_g'].pop('type')self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])self.optimizers.append(self.optimizer_g)# optimizer doptim_type = train_opt['optim_d'].pop('type')self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])self.optimizers.append(self.optimizer_d)def feed_data(self, data):self.lq = data['lq'].to(self.device)self.lq_rgb = tensor_lab2rgb(torch.cat([self.lq, torch.zeros_like(self.lq), torch.zeros_like(self.lq)], dim=1))if 'gt' in data:self.gt = data['gt'].to(self.device)self.gt_lab = torch.cat([self.lq, self.gt], dim=1)self.gt_rgb = tensor_lab2rgb(self.gt_lab)if self.opt['train'].get('color_enhance', False):for i in range(self.gt_rgb.shape[0]):self.gt_rgb[i] = color_enhacne_blend(self.gt_rgb[i], factor=self.opt['train'].get('color_enhance_factor'))def optimize_parameters(self, current_iter):# optimize net_gfor p in self.net_d.parameters():p.requires_grad = Falseself.optimizer_g.zero_grad()self.output_ab = self.net_g(self.lq_rgb)self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)self.output_rgb = tensor_lab2rgb(self.output_lab)l_g_total = 0loss_dict = OrderedDict()# pixel lossif self.cri_pix:l_g_pix = self.cri_pix(self.output_ab, self.gt)l_g_total += l_g_pixloss_dict['l_g_pix'] = l_g_pix# perceptual lossif self.cri_perceptual:l_g_percep, l_g_style = self.cri_perceptual(self.output_rgb, self.gt_rgb)if l_g_percep is not None:l_g_total += l_g_perceploss_dict['l_g_percep'] = l_g_percepif l_g_style is not None:l_g_total += l_g_styleloss_dict['l_g_style'] = l_g_style# gan lossif self.cri_gan:fake_g_pred = self.net_d(self.output_rgb)l_g_gan = self.cri_gan(fake_g_pred, target_is_real=True, is_disc=False)l_g_total += l_g_ganloss_dict['l_g_gan'] = l_g_gan# colorfulness lossif self.cri_colorfulness:l_g_color = self.cri_colorfulness(self.output_rgb)l_g_total += l_g_colorloss_dict['l_g_color'] = l_g_colorl_g_total.backward()self.optimizer_g.step()# optimize net_dfor p in self.net_d.parameters():p.requires_grad = Trueself.optimizer_d.zero_grad()real_d_pred = self.net_d(self.gt_rgb)fake_d_pred = self.net_d(self.output_rgb.detach())l_d = self.cri_gan(real_d_pred, target_is_real=True, is_disc=True) + self.cri_gan(fake_d_pred, target_is_real=False, is_disc=True)loss_dict['l_d'] = l_dloss_dict['real_score'] = real_d_pred.detach().mean()loss_dict['fake_score'] = fake_d_pred.detach().mean()l_d.backward()self.optimizer_d.step()self.log_dict = self.reduce_loss_dict(loss_dict)if self.ema_decay > 0:self.model_ema(decay=self.ema_decay)def get_current_visuals(self):out_dict = OrderedDict()out_dict['lq'] = self.lq_rgb.detach().cpu()out_dict['result'] = self.output_rgb.detach().cpu()if self.opt['logger'].get('save_snapshot_verbose', False):  # only for verboseself.output_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.output_ab], dim=1)self.output_rgb_chroma = tensor_lab2rgb(self.output_lab_chroma)out_dict['result_chroma'] = self.output_rgb_chroma.detach().cpu()if hasattr(self, 'gt'):out_dict['gt'] = self.gt_rgb.detach().cpu()if self.opt['logger'].get('save_snapshot_verbose', False):  # only for verboseself.gt_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.gt], dim=1)self.gt_rgb_chroma = tensor_lab2rgb(self.gt_lab_chroma)out_dict['gt_chroma'] = self.gt_rgb_chroma.detach().cpu()return out_dictdef test(self):if hasattr(self, 'net_g_ema'):self.net_g_ema.eval()with torch.no_grad():self.output_ab = self.net_g_ema(self.lq_rgb)self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)self.output_rgb = tensor_lab2rgb(self.output_lab)else:self.net_g.eval()with torch.no_grad():self.output_ab = self.net_g(self.lq_rgb)self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)self.output_rgb = tensor_lab2rgb(self.output_lab)self.net_g.train()def dist_validation(self, dataloader, current_iter, tb_logger, save_img):if self.opt['rank'] == 0:self.nondist_validation(dataloader, current_iter, tb_logger, save_img)def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):dataset_name = dataloader.dataset.opt['name']with_metrics = self.opt['val'].get('metrics') is not Noneuse_pbar = self.opt['val'].get('pbar', False)if with_metrics and not hasattr(self, 'metric_results'):  # only execute in the first runself.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}# initialize the best metric results for each dataset_name (supporting multiple validation datasets)if with_metrics:self._initialize_best_metric_results(dataset_name)# zero self.metric_resultsif with_metrics:self.metric_results = {metric: 0 for metric in self.metric_results}metric_data = dict()if use_pbar:pbar = tqdm(total=len(dataloader), unit='image')if self.opt['val']['metrics'].get('fid') is not None:fake_acts_set, acts_set = [], []for idx, val_data in enumerate(dataloader):# if idx == 100:#     breakimg_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]if hasattr(self, 'gt'):del self.gtself.feed_data(val_data)self.test()visuals = self.get_current_visuals()sr_img = tensor2img([visuals['result']])metric_data['img'] = sr_imgif 'gt' in visuals:gt_img = tensor2img([visuals['gt']])metric_data['img2'] = gt_imgtorch.cuda.empty_cache()if save_img:if self.opt['is_train']:save_dir = osp.join(self.opt['path']['visualization'], img_name)for key in visuals:save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))img = tensor2img(visuals[key])imwrite(img, save_path)else:if self.opt['val']['suffix']:save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,f'{img_name}_{self.opt["val"]["suffix"]}.png')else:save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,f'{img_name}_{self.opt["name"]}.png')imwrite(sr_img, save_img_path)if with_metrics:# calculate metricsfor name, opt_ in self.opt['val']['metrics'].items():if name == 'fid':pred, gt = visuals['result'].cuda(), visuals['gt'].cuda()fake_act = get_activations(pred, self.inception_model_fid, 1)fake_acts_set.append(fake_act)if self.real_mu is None:real_act = get_activations(gt, self.inception_model_fid, 1)acts_set.append(real_act)else:self.metric_results[name] += calculate_metric(metric_data, opt_)if use_pbar:pbar.update(1)pbar.set_description(f'Test {img_name}')if use_pbar:pbar.close()if with_metrics:if self.opt['val']['metrics'].get('fid') is not None:if self.real_mu is None:acts_set = np.concatenate(acts_set, 0)self.real_mu, self.real_sigma = calculate_activation_statistics(acts_set)fake_acts_set = np.concatenate(fake_acts_set, 0)fake_mu, fake_sigma = calculate_activation_statistics(fake_acts_set)fid_score = calculate_frechet_distance(self.real_mu, self.real_sigma, fake_mu, fake_sigma)self.metric_results['fid'] = fid_scorefor metric in self.metric_results.keys():if metric != 'fid':self.metric_results[metric] /= (idx + 1)# update the best metric resultself._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)self._log_validation_metric_values(current_iter, dataset_name, tb_logger)def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):log_str = f'Validation {dataset_name}\n'for metric, value in self.metric_results.items():log_str += f'\t # {metric}: {value:.4f}'if hasattr(self, 'best_metric_results'):log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ 'f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')log_str += '\n'logger = get_root_logger()logger.info(log_str)if tb_logger:for metric, value in self.metric_results.items():tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)def _prepare_inception_model_fid(self, path='pretrain/inception_v3_google-1a9a5a14.pth'):incep_state_dict = torch.load(path, map_location='cpu')block_idx = INCEPTION_V3_FID.BLOCK_INDEX_BY_DIM[2048]self.inception_model_fid = INCEPTION_V3_FID(incep_state_dict, [block_idx])self.inception_model_fid.cuda()self.inception_model_fid.eval()@master_onlydef save_training_images(self, current_iter):visuals = self.get_current_visuals()save_dir = osp.join(self.opt['root_path'], 'experiments', self.opt['name'], 'training_images_snapshot')os.makedirs(save_dir, exist_ok=True)for key in visuals:save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))img = tensor2img(visuals[key])imwrite(img, save_path)def save(self, epoch, current_iter):if hasattr(self, 'net_g_ema'):self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])else:self.save_network(self.net_g, 'net_g', current_iter)self.save_network(self.net_d, 'net_d', current_iter)self.save_training_state(epoch, current_iter)

第四步:運行

第五步:整個工程的內容

??項目完整文件下載請見演示與介紹視頻的簡介處給出:???

圖像修復:深度學習實現老照片劃痕修復+老照片上色_嗶哩嗶哩_bilibili

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/88929.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/88929.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/88929.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

css 邊框顏色漸變

border-image: linear-gradient(90deg, rgba(207, 194, 195, 1), rgba(189, 189, 189, 0.2),rgba(207, 194, 195, 1)) 1;

本地 LLM API Python 項目分步指南

分步過程 需要Python 3.9 或更高版本。 安裝 Ollama 并在本地下載 LLM 根據您的操作系統,您可以從其網站下載一個或另一個版本的 Ollama 。下載并啟動后,打開終端并輸入以下命令: ollama run llama3此命令將在本地拉取(下載&…

日本的所得稅計算方式

? 【1】所得稅的計算步驟(概要) 日本的所得稅大致按照以下順序來計算: 1?? 統計收入(銷售額、工資等) 2?? 扣除必要經費等,得到「所得金額」 3?? 扣除各類「所得控除」(所得扣除&#xf…

【langchain4j篇01】:5分鐘上手langchain4j 1.1.0(SpringBoot整合使用)

目錄 一、環境準備 二、創建項目、導入依賴 三、配置 application.yml 四、注入Bean,開箱即用 五、日志觀察 一、環境準備 首先和快速上手 Spring AI 框架一樣的前置條件:先申請一個 apikey ,此部分步驟參考:【SpringAI篇01…

js運算符

運算符 jarringslee*賦值運算符 - / 對變量進行賦值的運算符,用于簡化代碼。左邊是容器,右邊是值一元運算符正號 符號- 賦予數據正值、負值自增 自減– 前置和后置:i和i:一般情況下習慣使用后置i,兩者在單獨…

next.js 登錄認證:使用 github 賬號授權登錄。

1. 起因, 目的: 一直是這個報錯。2. 最終效果, 解決問題,能成功登錄、體驗地址:https://next-js-gist-app.vercel.app/代碼地址: https://github.com/buxuele/next-js-gist-app3. 過程: 根本原因: github 的設置&…

深入理解設計模式:原型模式(Prototype Pattern)

在軟件開發中,對象的創建是一個永恒的話題。當我們需要創建大量相似對象,或者對象創建成本較高時,傳統的new操作符可能不是最佳選擇。原型模式(Prototype Pattern)為我們提供了一種優雅的解決方案——通過克隆現有對象…

Rocky Linux 9 源碼包安裝php8

Rocky Linux 9 源碼包安裝php8大家好,我是星哥!今天咱們不聊yum一鍵安裝的“快餐式”部署,來點兒硬核的——源碼編譯安裝PHP 8.3。為什么要折騰源碼?因為它能讓你深度定制PHP功能、啟用最新特性,還能避開系統默認源的版…

Django母嬰商城項目實踐(四)

4、路由規劃與設計 1、概述 介紹 路由稱為 URL(Uniform Resource Locator,統一資源定位符),也稱為 URLconf,對互聯網上得到的資源位置和訪問方式的一種簡潔表示,是互聯網上標準梓源的地址。互聯網上的每個文件都有一個唯一的路由,用于指出網站文件的路由位置,也可以理…

論文閱讀:arxiv 2025 A Survey of Large Language Model Agents for Question Answering

https://arxiv.org/pdf/2503.19213 https://www.doubao.com/chat/12038636966213122 A Survey of Large Language Model Agents for Question Answering 文章目錄速覽論文翻譯面向問答的大型語言模型代理綜述摘要一、引言速覽 這篇文檔主要是對基于大型語言模型(…

ONNX 是什么

ONNX 是什么? ONNX,全稱 Open Neural Network Exchange,是微軟和 Facebook(現在的 Meta)聯合發起的一個開放的神經網絡模型交換格式。簡單理解:ONNX 是一個通用的「AI 模型存檔格式」。用 PyTorch、TensorF…

【Python3】掌握DRF核心裝飾器:提升API開發效率

在 Django REST Framework (DRF) 中,裝飾器(Decorators)通常用于視圖函數或類視圖,以控制訪問權限、請求方法、認證等行為。以下是 DRF 中常用的裝飾器及其功能說明: 1. api_view 用途: 用于基于函數的視圖&#xff0c…

Datawhale AI 夏令營第一期(機器學習方向)Task2 筆記:用戶新增預測挑戰賽 —— 從業務理解到技術實現

Datawhale AI夏令營第一期(機器學習方向)Task2筆記:用戶新增預測挑戰賽——從業務理解到技術實現 一、任務核心:業務與技術的“翻譯” 本次Task聚焦“用戶新增預測挑戰賽”的核心邏輯,核心目標是鍛煉“將業務問題轉化為…

【人工智能】華為昇騰NPU-MindIE鏡像制作

本文通過不使用官方鏡像,自己在910b 進行華為mindie的鏡像制作,可離線安裝部署。 硬件:cann 8.0 1. 部署參考文檔: 安裝依賴-安裝開發環境-MindIE安裝指南-MindIE1.0.0開發文檔-昇騰社區 2. 參數說明文檔:https://www.hiascend.com/document/detail/zh/mindie/100/min…

關于我用AI編寫了一個聊天機器人……(番外1)

極大地精簡了1.3.6版本的邏輯。 不會作為正式版發布。 未填充數據。核心結構代碼包含兩個主要部分&#xff1a;數據結構&#xff1a;使用map<string, string>存儲問答對&#xff0c;其中鍵是問題&#xff0c;值是答案主程序流程&#xff1a;初始化預定義的問答對進入無限…

全球鈉離子電池市場研究,市場占有率及市場規模

鈉離子電池是一種新興的儲能技術&#xff0c;利用鈉離子&#xff08;Na?&#xff09;代替鋰離子作為電荷載體&#xff0c;為鋰離子電池提供了一種經濟高效且可持續的替代品。它們的工作原理類似&#xff0c;在充電和放電循環過程中&#xff0c;鈉離子在陽極和陰極之間移動。關…

SwiftUI 全面介紹與使用指南

目錄一、SwiftUI 核心優勢二、基礎組件與布局2.1、基本視圖組件2.2、布局系統2.3、列表與導航三、狀態管理與數據流3.1、狀態管理基礎3.2、數據綁定與共享四、高級功能與技巧4.1、動畫效果4.2、繪圖與自定義形狀4.3、網絡請求與異步數據五、SwiftUI 最佳實踐六、SwiftUI 開發環…

ADC采集、緩存

FPGA學習筆記_李敏兒oc的博客-CSDN博客 TLV5618.v&#xff1a;實現DAC數模轉換&#xff0c;產生模擬信號&#xff0c;輸出指定電壓值 時序圖 FPGA學習筆記&#xff1a;數據采集傳輸系統設計&#xff08;二&#xff09;&#xff1a;TLV5618型DAC驅動-CSDN博客 ADC128S052.v&…

(C++)STL:stack、queue簡單使用解析

stack 棧 簡介 stack 棧——容器適配器 container adapter 與前面學的容器vector、list的底層實現不同&#xff0c;stack功能的實現是要借助其他容器的功能的&#xff0c;所以看stack的第二個模板參數是容器。 最大特點&#xff1a;LIFO&#xff1a;Last In, First Out&#xf…

在Adobe Substance 3D Painter中,已經有基礎圖層,如何新建一個圖層A,clone基礎圖層的紋理和內容到A圖層

在Adobe Substance 3D Painter中&#xff0c;已經有基礎圖層&#xff0c;如何新建一個圖層A&#xff0c;clone基礎圖層的紋理和內容到A圖層 在 Substance 3D Painter 中克隆底層紋理到新圖層的最快做法 操作步驟 添加空白 Paint Layer 在 Layer Stack 頂部點擊 → Paint La…