【人工智能前沿弄潮】——生成式AI系列:Diffusers應用 (2) 訓練擴散模型(無條件圖像生成)

無條件圖像生成是擴散模型的一種流行應用,它生成的圖像看起來像用于訓練的數據集中的圖像。與文本或圖像到圖像模型不同,無條件圖像生成不依賴于任何文本或圖像。它只生成與其訓練數據分布相似的圖像。通常,通過在特定數據集上微調預訓練模型可以獲得最佳結果。

本教程主要來自huggingface官方教程,結合一些自己的修改,以支持訓練本地數據集。我們首先依據官方教程,利用史密森尼蝴蝶數據集的子集上從頭開始訓練UNet2DModel,以生我們自己的的🦋蝴蝶🦋。最后因為我是搞遙感方向的(測繪小卡拉米),所以利用遙感數據進行訓練嘗試,遙感影像使用的是煤礦區的無人機遙感影像,主要就是裸地和枯草,有的還有一些因為煤礦開采導致的地裂縫。

1、Train配置

為方便起見,創建一個包含訓練超參數的TrainingConfig類(請隨意調整它們):

from dataclasses import dataclass@dataclass
class TrainingConfig:image_size = 128  # the generated image resolutiontrain_batch_size = 16eval_batch_size = 16  # how many images to sample during evaluationnum_epochs = 50gradient_accumulation_steps = 1learning_rate = 1e-4lr_warmup_steps = 500save_image_epochs = 10save_model_epochs = 30mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precisionoutput_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hubpush_to_hub = True  # whether to upload the saved model to the HF Hubhub_private_repo = Falseoverwrite_output_dir = True  # overwrite the old model when re-running the notebookseed = 0config = TrainingConfig()

2、加載數據集

對于在hug 倉庫空開的數據集可以使用🤗 Datasets依賴庫輕松加載,比如本次的Smithsonian Butterflies:

from datasets import load_datasetconfig.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")

對于本地數據請用一下代碼進行加載(請根據自己情況進行修改):

from datasets import load_datasetdata_dir = "/home/diffusers/datasets/isprsdataset"
dataset = load_dataset('imagefolder', data_dir=data_dir, split='train')

🤗 Datasets使用圖像功能自動解碼圖像數據并將其加載為PIL. Image,我們可以將其可視化:

import matplotlib.pyplot as pltfig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):axs[i].imshow(image)axs[i].set_axis_off()
fig.show()

在這里插入圖片描述


3、圖像預處理

由于圖像大小不同,所以需要先對其進行預處理,也就是常規的圖像增強:

  • 調整大小將圖像大小更改為配置文件中定義的圖像大小—image_size
  • RandomHorizontalFlip通過隨機鏡像圖像來增強數據集。
  • Normalize對于將像素值重新縮放到[-1,1]范圍內很重要,這是模型所期望的。
from torchvision import transformspreprocess = transforms.Compose([transforms.Resize((config.image_size, config.image_size)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5835, 0.5820, 0.5841], [0.1149, 0.1111, 0.1064]), # isprs# transforms.Normalize([0.5], [0.5]),]
)

這里使用的是Pytorch自帶的數據增強接口,這里我推薦大家使用albumentations數據增強庫。

使用🤗Datasetsset_transform方法在訓練期間動態應用預處理函數:

def transform(examples):images = [preprocess(image.convert("RGB")) for image in examples["image"]]return {"images": images}dataset.set_transform(transform)

現在將數據集包裝在DataLoader中進行訓練:

import torch
python
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)   

4、創建UNet2DModel

🧨 Diffusers 中的預訓練模型可以使用您想要的參數從它們的模型類輕松創建。例如,要創建UNet2DModel

from diffusers import UNet2DModelmodel = UNet2DModel(sample_size=config.image_size,  # the target image resolutionin_channels=3,  # the number of input channels, 3 for RGB imagesout_channels=3,  # the number of output channelslayers_per_block=2,  # how many ResNet layers to use per UNet blockblock_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet blockdown_block_types=("DownBlock2D",  # a regular ResNet downsampling block"DownBlock2D","DownBlock2D","DownBlock2D","AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention"DownBlock2D",),up_block_types=("UpBlock2D",  # a regular ResNet upsampling block"AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention"UpBlock2D","UpBlock2D","UpBlock2D","UpBlock2D",),
)

檢查樣本圖像形狀與模型輸出形狀是否匹配:

sample_image = dataset[0]["images"].unsqueeze(0)
print("Input shape:", sample_image.shape)print("Output shape:", model(sample_image, timestep=0).sample.shape)

接下來創建一個scheduler為圖像添加一些噪點。


5、創建scheduler

根據您是使用模型進行訓練還是推理,scheduler的行為會有所不同。在推理期間,scheduler從噪聲中生成圖像。在訓練期間,scheduler從擴散過程中的特定點獲取模型輸出或樣本,并根據噪聲時間表和更新規則(比如我們本系列第一張所說的step)將噪聲應用于圖像。(我們可以看到,遙感影像生成的結果還行,已經能明顯的看清楚地表和枯草,甚至能夠出現可看清的地裂縫!)

讓我們看看DDPMScheduler并使用add_noise方法向之前的sample_image添加一些隨機噪聲:

import torch
from PIL import Image
from diffusers import DDPMSchedulernoise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])


在這里插入圖片描述

模型的訓練目標是預測添加到圖像中的噪聲。該步驟的損失可以通過以下方式計算,這里官方教程使用的是mse損失函數

import torch.nn.functional as Fnoise_pred = model(noisy_image, timesteps).sample
loss = F.mse_loss(noise_pred, noise)

6、訓練模型

到目前為止,已經有了開始訓練模型的大部分部分,剩下的就是把所有東西放在一起。 首先,您需要一個優化器和一個學習率調度器:

from diffusers.optimization import get_cosine_schedule_with_warmupoptimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=config.lr_warmup_steps,num_training_steps=(len(train_dataloader) * config.num_epochs),
)

然后,您需要一種評估模型的方法。對于評估,您可以使用DDPMPipeline生成一批示例圖像并將其保存為網格格式(官方輸出為格網,大家也可自行修改為單張保存):

from diffusers import DDPMPipeline
import math
import osdef make_grid(images, rows, cols):w, h = images[0].sizegrid = Image.new("RGB", size=(cols * w, rows * h))for i, image in enumerate(images):grid.paste(image, box=(i % cols * w, i // cols * h))return griddef evaluate(config, epoch, pipeline):# Sample some images from random noise (this is the backward diffusion process).# The default pipeline output type is `List[PIL.Image]`images = pipeline(batch_size=config.eval_batch_size,generator=torch.manual_seed(config.seed),).images# Make a grid out of the imagesimage_grid = make_grid(images, rows=2, cols=3)# Save the imagestest_dir = os.path.join(config.output_dir, "samples")os.makedirs(test_dir, exist_ok=True)image_grid.save(f"{test_dir}/{epoch + 1:04d}.png")

現在,可以使用🤗Accelerate將所有這些組件包裝在一個訓練循環中,以便于TensorBoard日志記錄、梯度累積混合精度訓練

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):# Initialize accelerator and tensorboard loggingaccelerator = Accelerator(mixed_precision=config.mixed_precision,gradient_accumulation_steps=config.gradient_accumulation_steps,log_with="tensorboard",project_dir=os.path.join(config.output_dir, "logs"),)if accelerator.is_main_process:if config.push_to_hub:repo_name = get_full_repo_name(Path(config.output_dir).name)repo = Repository(config.output_dir, clone_from=repo_name)elif config.output_dir is not None:os.makedirs(config.output_dir, exist_ok=True)accelerator.init_trackers("train_example")# Prepare everything# There is no specific order to remember, you just need to unpack the# objects in the same order you gave them to the prepare method.model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_dataloader, lr_scheduler)global_step = 0# Now you train the modelfor epoch in range(config.num_epochs):progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)progress_bar.set_description(f"Epoch {epoch + 1}")for step, batch in enumerate(train_dataloader):clean_images = batch["images"]# Sample noise to add to the imagesnoise = torch.randn(clean_images.shape).to(clean_images.device)bs = clean_images.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()# Add noise to the clean images according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)with accelerator.accumulate(model):# Predict the noise residualnoise_pred = model(noisy_images, timesteps, return_dict=False)[0]loss = F.mse_loss(noise_pred, noise)accelerator.backward(loss)accelerator.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}progress_bar.set_postfix(**logs)accelerator.log(logs, step=global_step)global_step += 1# After each epoch you optionally sample some demo images with evaluate() and save the modelif accelerator.is_main_process:pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:print(f'----------------------------------------------------- Evaluate Iter [{(epoch + 1) // config.save_image_epochs}] ------------------------------------------------------------------')evaluate(config, epoch, pipeline)if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:pipeline.save_pretrained(config.output_dir)

接下來使用🤗Acceleratenotebook_launcher函數啟動訓練了。將訓練循環、所有訓練參數和進程數(可以將此值更改為可用于訓練的GPU數)傳遞給該函數:

from accelerate import notebook_launcherargs = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)notebook_launcher(train_loop, args, num_processes=1)

訓練完成后,看看擴散模型生成的最終🦋圖像(🦋我隔10個epoch生成一次,在下面給大家瞅瞅)和遙感影像(因為我電腦的原因,遙感影像跑了一半停了,不過也保存了一些,感慨一下,擴散模型太吃顯存了,比之前跑分割檢測啥的更加依賴,可能是我圖像整的太大了,之后裁小一點試一試,感覺生成模型用于遙感領域,又困難,也有無限可能!這只是一個簡單的擴散生成示例模型,還得再深入研究研究,以后再和大家分享其他更新又有意思的生成模型。

import globsample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])

在這里插入圖片描述

請添加圖片描述
請添加圖片描述
請添加圖片描述
請添加圖片描述
請添加圖片描述
請添加圖片描述
請添加圖片描述
請添加圖片描述
請添加圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/36676.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/36676.shtml
英文地址,請注明出處:http://en.pswp.cn/news/36676.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

OpenLayers入門,OpenLayers入門文檔,OpenLayers入門手冊,OpenLayers相比其他地圖引擎有哪些優點?

專欄目錄: OpenLayers入門教程匯總目錄 前言 在學習OpenLayers之前,總是需要了解OpenLayers,知道OpenLayers是什么,OpenLayers能夠做什么,OpenLayers有哪些用途和特性,然后OpenLayers相比其他地圖引擎又有…

數學運算1

正確答案:F 你的答案:E 參考答案:最大排列為100 1 99 2 98 3…51 49 50 所以和為999897…1(100-50)因為是一個圈所以,100和50相接,所以等于5000 知識點:數學運算

MySQL 慢查詢探究分析

目錄 背景: mysql 整體結構: SQL查詢語句執行過程是怎樣的: 知道了mysql的整體架構,那么一條查詢語句是怎么被執行的呢: 什么是索引: 建立索引越多越好嗎:   如何發現慢查詢&#xff1…

樹結構--介紹--二叉樹遍歷的遞歸實現

目錄 樹 樹的學術名詞 樹的種類 二叉樹的遍歷 算法實現 遍歷命名 二叉樹的中序遍歷 二叉樹的后序遍歷 二叉樹的后序遍歷迭代算法 二叉樹的前序遍歷 二叉樹的前序遍歷迭代算法 樹 樹是一種非線性的數據結構,它是由n(n≥0)個有限節點組成一個具有層次關系…

Docker安裝 elasticsearch-head

目錄 前言安裝elasticsearch-head步驟1:準備1. 安裝docker2. 搜索可以使用的鏡像。3. 也可從docker hub上搜索鏡像。4. 選擇合適的redis鏡像。 步驟2:拉取elasticsearch-head鏡像拉取鏡像查看已拉取的鏡像 步驟3:創建容器創建容器方式1&#…

SpringBoot復習:(28)【前后端不分離】自定義View

一、自定義View package cn.edu.tju.view;import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; import org.springframework.stereotype.Comp…

C# --- Case Study

C# --- Case Study C# — Mongo數據庫事務的應用 C# — 如何解析Json文件并注入MongoDB C# — MongoDB如何安全的替換Collection

百度翻譯API整合SpringBoot

案例背景,按照官方給的Demo,實在是太啰嗦了, 大致步驟 封裝數據>簽名>發送請求, 仔細一看劈里啪啦一大堆,最后還要手動關流關連接,難道整合到SpringBoot項目里面我還得為內存管理考慮 所以就有了如下需求 使用 RestTemplate的對象進行發送請求數據,RestTemplate由s…

Redis緩存刪除略和內存淘汰策略及LRU

1、Redis內存若在配置文件中未設置,內存會無限制增長,直到超出物理內存,拋出out of memory內存耗盡異常 解決方法,調整maxmemory參數,一般設置為物理內存的3/4,并且添加緩存刪除策略 2、Redis對于設置了過…

項目經理的會議之道:全參與還是精選參與?

引言 在項目管理中,會議是一個常見的工具,用于溝通信息、解決問題、做出決策等。然而,項目經理是否需要參加所有的會議呢?這是一個值得深思的問題。作為項目經理,我們需要權衡會議的重要性和我們的時間管理。我們不能…

【第一階段】kotlin的函數

函數頭 fun main() {getMethod("zhangsan",22) }//kotlin語言默認是public,kotlin更規范,先有輸入( getMethod(name:String,age:Int))再有輸出(Int[返回值]) private fun getMethod(name:String,age:Int): Int{println("我叫…

Elasticsearch集群shard過多后導致的性能問題分析

1.問題現象 上午上班以后發現ES日志集群狀態不正確,集群頻繁地重新發起選主操作。對外不能正常提供數據查詢服務,相關日志數據入庫也產生較大延時 2.問題原因 相關日志 查看ES集群日志如下: 00:00:51開始集群各個節點與當時的master節點…

Playwright快速上手-1

前言 隨著近年來對UI自動化測試的要求越來越高,,功能強大的測試框架也不斷的涌現。本系列主講的Playwright作為一款新興的端到端測試框架,憑借其獨特優勢,正在逐漸成為測試工程師的熱門選擇。 本系列文章將著重通過示例講解 Playwright python開發環境的搭建 …

Linux Day07

一、僵死進程 1.1僵死進程產生的原因 子進程先于父進程結束, 而父進程沒有獲取子進程退出碼,釋放子進程占用的資源,此時子進程將成為一個僵死進程。 在第一個框這里時父進程子進程都沒有結束,顯示其pid 父進程是2349,子進程是235…

【Nginx】Nginx網站服務

國外主流還是使用apache;國內現在主流是nginx(并發能力強,相對穩定) nginx:高性能、輕量級的web服務軟件 特點: 1.穩定性高(沒apache穩); 2.系統資源消耗比較低&#xf…

Failed to set locale, defaulting to C.UTF-8 或者中文系統語言轉英文系統語言

CentOS 8中執行命令,出現報錯:Failed to set locale, defaulting to C.UTF-8報錯原因: 1、沒有安裝相應的語言包。2、沒有設置正確的語言環境。 解決方法1:安裝語言包 設置語言環境需使用命令 localelocale -a 命令,查…

代碼隨想錄day02

977.有序數組的平方 ● 力扣題目鏈接 ● 給你一個按 非遞減順序 排序的整數數組 nums,返回 每個數字的平方 組成的新數組,要求也按 非遞減順序 排序。 思路 ● 暴力排序,時間復雜度O(n nlogn) ● 使用雙指針,時間復雜度O(n) …

Vue中使用v-bind:class動態綁定多個類名

Vue.js是一個流行的前端框架,它可以幫助開發者構建動態交互的UI界面。在Vue.js開發中,經常需要動態綁定HTML元素的class(類名)屬性,以改變元素的外觀和行為。本文將介紹采用v-bind:class指令在Vue中如何動態綁定多個類…

【大數據】-- 本地部署 Flink kubernetes operator

目錄 1.說明 1.1 版本 1.2 kubernetes 環境 1.3 參考 2.安裝步驟 2.1 安裝本地 kubernetes 環境

判斷鏈表有環的證明

目錄 1.問題 2.證明 3.代碼實現 1.問題 給你一個鏈表的頭節點 head ,判斷鏈表中是否有環。 如果鏈表中有某個節點,可以通過連續跟蹤 next 指針再次到達,則鏈表中存在環。 為了表示給定鏈表中的環,評測系統內部使用…