從零開始訓練Codebook:基于ViT的圖像重建實踐

完整代碼在文末,可以一鍵運行。

在這里插入圖片描述

1. 核心原理

Codebook是一種離散表征學習方法,其核心思想是將連續特征空間映射到離散的碼本空間。我們的實現方案包含三個關鍵組件:

1.1 ViT編碼器

class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")self.proj = nn.Linear(768, codebook_dim)def forward(self, x):outputs = self.vit(x).last_hidden_statepatch_embeddings = outputs[:, 1:, :]  # 移除CLS tokenreturn self.proj(patch_embeddings)
  • 使用預訓練的ViT-Base模型提取圖像特征
  • 移除CLS token,保留196個圖像塊特征
  • 線性投影調整特征維度適配Codebook

1.2 Codebook量化層

class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)def quantize(self, z):# 計算L2距離distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 最近鄰查找indices = torch.argmin(distances, dim=1)return indices, self.codebook(indices)
  • 使用可學習的Embedding層存儲離散碼本
  • 通過L2距離計算實現最近鄰查找
  • 支持EMA更新(代碼中已注釋部分)

1.3 ViT解碼器

class ViTDecoder(nn.Module):def __init__(self):self.head = nn.Sequential(nn.ConvTranspose2d(768, 384, 4, 2, 1),nn.ReLU(),... # 更多上采樣層nn.Conv2d(48, 3, 1))
  • 使用轉置卷積逐步上采樣
  • 最終輸出224x224分辨率圖像
  • 與編碼器形成對稱結構

2. 訓練策略

2.1 多目標損失函數

total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
  • MSE Loss: 像素級重建誤差
  • Perceptual Loss: VGG16特征匹配
  • Codebook Loss: 碼本向量優化
  • Commitment Loss: 編碼器輸出穩定性

2.2 優化技巧

opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
  • 分層學習率設置
  • EMA指數平滑更新
  • 混合精度訓練支持
  • 動態學習率調整

3. 完整訓練流程

3.1 數據準備

transform_train = transforms.Compose([transforms.Resize(224),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(...)
])
  • CIFAR-10數據集
  • 隨機裁剪+翻轉增強
  • Batch Size=4適配顯存

3.2 訓練監控

# TensorBoard記錄
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)# 控制臺日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")

完整代碼

from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()# 加載預訓練ViT-Base模型self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")# 調整輸出維度匹配Codebookself.proj = nn.Linear(768, codebook_dim)  # 網頁2/6中的線性嵌入策略def forward(self, x):outputs = self.vit(x).last_hidden_state  # [batch, num_patches+1, 768]patch_embeddings = outputs[:, 1:, :]     # 移除CLS tokenreturn self.proj(patch_embeddings)       # [batch, 196, 512]class Codebook(nn.Module):def __init__(self, num_embeddings=16384, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)  # 網頁1的EMA更新可在此擴展def quantize(self, z):"""量化輸入特征向量參數:z: 輸入特征 [batch, num_patches, embedding_dim]返回:indices: 最近鄰碼本索引 [batch, num_patches]quantized: 量化后的特征 [batch, num_patches, embedding_dim]"""# 重塑輸入為二維矩陣 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 計算L2距離 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近鄰indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢復原始形狀quantized = self.codebook(indices)  # [batch, num_patches, dim]return indices, quantized
class ViTDecoder(nn.Module):def __init__(self, in_dim=512):super().__init__()# 反向映射ViT的patch嵌入self.proj = nn.Linear(in_dim, 768)config = ViTConfig()config.is_decoder = True  # 網頁7中的解碼器模式self.transformer = ViTModel(config).encoder  self.head = nn.Sequential(# 14x14 -> 28x28nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 28x28 -> 56x56nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 56x56 -> 112x112 nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 112x112 -> 224x224nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 最終調整到3通道nn.Conv2d(48, 3, kernel_size=1))def forward(self, x):x = self.proj(x)  # [batch, 196, 768]x = self.transformer(x).last_hidden_statex = x.permute(0, 2, 1).view(-1, 768, 14, 14)  # 恢復空間布局return self.head(x)  # 輸出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 數據增強和預處理
transform_train = transforms.Compose([transforms.Resize(224),  # 調整圖像尺寸適配模型transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加載CIFAR-10數據集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)batch_size = 4  # 增大batch size加速訓練
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')# 改進的Codebook類(增加EMA更新)
class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)self.commitment_cost = commitment_costself.decay = decayself.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))nn.init.normal_(self.ema_w)def quantize(self, z):# 重塑輸入為二維矩陣 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 計算L2距離 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近鄰indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢復原始形狀quantized = self.codebook(indices)  # [batch, num_patches, dim]# 新增EMA更新# if self.training:#     with torch.no_grad():#         encodings = F.one_hot(indices, self.codebook.num_embeddings).float()#         self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)#         n = torch.sum(self.ema_cluster_size)#         self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)#         dw = torch.matmul(encodings.t(), z_flat)#         self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)#         self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化組件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device)  # 用于感知損失# 優化器分開設置
opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}  # 更小的學習率
], lr=3e-4)# 訓練循環
for epoch in range(100):avg_loss = 0start_time = time.time()  # 記錄epoch開始時間for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):images = images.to(device)# 前向傳播z = encoder(images)indices, quantized = codebook.quantize(z)recon = decoder(quantized)# 多目標損失計算mse_loss = F.mse_loss(recon, images)# 感知損失(VGG特征匹配)with torch.no_grad():real_features = vgg(images)recon_features = vgg(recon)percep_loss = F.mse_loss(recon_features, real_features)# Codebook相關損失commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)codebook_loss = F.mse_loss(z, quantized.detach())# 總損失total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss# 反向傳播opt.zero_grad()total_loss.backward()opt.step()# 記錄數據avg_loss += total_loss.item()if batch_idx % 50 == 0:# 記錄TensorBoard數據writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)writer.add_scalars('Loss/components', {'mse': mse_loss.item(),'perceptual': percep_loss.item(),'codebook': codebook_loss.item(),'commitment': commitment_loss.item()}, epoch*len(trainloader)+batch_idx)# 保存重建樣本comparison = torch.cat([images[:4], recon[:4]])grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)# 打印epoch統計信息avg_loss /= len(trainloader)print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")# 保存模型檢查點if (epoch+1) % 10 == 0:torch.save({'encoder': encoder.state_dict(),'codebook': codebook.state_dict(),'decoder': decoder.state_dict(),'opt': opt.state_dict()}, f'checkpoint_epoch{epoch+1}.pth')writer.close()

通過本實踐,我們實現了從特征提取到離散表征學習的完整流程。Codebook技術可廣泛應用于圖像壓縮、生成模型等領域,期待讀者在此基礎上探索更多可能性。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/75513.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/75513.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/75513.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

大數據筆試題_第一階段配套筆試題02

已知一個字符類型的日期&#xff1a;2022-01-20&#xff0c;請用SQL顯示出此日期對應的下個月的月份&#xff0c;結果要求為Number類型&#xff08;202201&#xff09;。 參考答案 sql SELECT to_date(2022-01-20, yyyy-mm-dd) a1,add_months(to_date(2022-01-20, yyyy-mm-d…

C++實現對象單例模式

在 C 中實現單例模式有多種方法&#xff0c;以下是線程安全的現代 C 實現方式&#xff08;推薦 C11 及以上版本&#xff09;&#xff1a; 1. Meyers’ Singleton&#xff08;推薦&#xff09; class Singleton { public:// 刪除拷貝構造和賦值運算符Singleton(const Singleto…

企業常用Linux服務搭建

1.需要兩臺centos 7服務器&#xff0c;一臺部署DNS服務器&#xff0c;另一臺部署ftp和Samba服務器。 2. 部署DNS 服務器? #!/bin/bash# 更新系統 echo "更新系統..." sudo yum update -y# 安裝 BIND 和相關工具 echo "安裝 BIND 和相關工具..." sudo y…

UE5Actor模塊源碼深度剖析:從核心架構到實踐應用

UE5 Actor模塊源碼深度剖析:從核心架構到實踐應用 a. UE5 Actor模塊架構概述 在UE5引擎中,Actor扮演著至關重要的角色,它是整個游戲世界中各類可交互對象的基礎抽象。從本質上來說,所有能夠被放置到關卡中的對象都屬于Actor的范疇,像攝像機、靜態網格體以及玩家起始位置…

DreamDiffusion代碼學習及復現

論文解讀在這里 File path | Description /pretrains ┣ &#x1f4c2; models ┃ ┗ &#x1f4dc; config.yaml ┃ ┗ &#x1f4dc; v1-5-pruned.ckpt┣ &#x1f4c2; generation ┃ ┗ &#x1f4dc; checkpoint_best.pth ┣ &#x1f4c2; eeg_pretain ┃ ┗ …

用Python實現TCP代理

依舊是Python黑帽子這本書 先附上代碼&#xff0c;我在原書代碼上加了注釋&#xff0c;更好理解 import sys import socket import threading#生成可打印字符映射 HEX_FILTER.join([(len(repr(chr(i)))3) and chr(i) or . for i in range(256)])#接收bytes或string類型的輸入…

Pyinstaller 打包flask_socketio為exe程序后出現:ValueError: Invalid async_mode specified

Pyinstaller 打包flask_socketio為exe程序后出現&#xff1a;ValueError: Invalid async_mode specified 一、詳細描述問題描述 Traceback (most recent call last): File "app_3.py", line 22, in <module> File "flask_socketio\__init__.py"…

django REST framework(DRF)教程

Django DRF API Django 基本使用Django DRF序列化器Django DRF視圖Django DRF常用功能Django 基本使用 前后端分離開發模式認識RestFulAPI回顧Django開發模式Django REST Framework初探前后端分離開發模式 前后端分離前:前端頁面看到的效果都是由后端控制,即后端渲染HTML頁面…

【Linux】Orin NX + Ubuntu22.04配置國內源

1、獲取源 清華源 arm 系統的源,可以在如下地址獲取到 https://mirror.tuna.tsinghua.edu.cn/help/ubuntu-ports/ 選擇HTTPS,否則可能報錯: 明文簽署文件不可用,結果為‘NOSPLIT’(您的網絡需要認證嗎?)查看Orin NX系統版本 選擇jammy的源 2、更新源 1)備份原配…

【含文檔+PPT+源碼】基于微信小程序的社交攝影約拍平臺的設計與實現

項目介紹 本課程演示的是一款基于微信小程序的社交攝影約拍平臺的設計與實現&#xff0c;主要針對計算機相關專業的正在做畢設的學生與需要項目實戰練習的 Java 學習者。 1.包含&#xff1a;項目源碼、項目文檔、數據庫腳本、軟件工具等所有資料 2.帶你從零開始部署運行本套系…

JDBC常用的接口

一、什么是JDBC JDBC是Java語言連接數據庫的接口規范。 二、JDBC的體系 1、Java官方提供一個操作數據庫的抽象接口 抽象接口有很多的接口和抽象類。 例如&#xff1a;Driver、Connection、Statement。 2、各個數據庫廠商提供各自的Java實現類 需要各自實現具體的細節。 例如&am…

容器適配器-stack棧

C標準庫不只是包含了順序容器&#xff0c;還包含一些為滿足特殊需求而設計的容器&#xff0c;它們提供簡單的接口。 這些容器可被歸類為容器適配器(container adapter)&#xff0c;它們是改造別的標準順序容器&#xff0c;使之滿足特殊需求的新容器。 適配器:也稱配置器,把一…

[250403] HuggingFace 新增檢查模型與電腦兼容性的功能 | Firefox 發布137.0 支持標簽組

目錄 Hugging Face 讓尋找兼容的 AI 模型變得更容易Firefox 137 版本更新摘要 Hugging Face 讓尋找兼容的 AI 模型變得更容易 Hugging Face 是一個流行的在線平臺&#xff0c;用于訪問開源人工智能 (AI) 工具和模型。該平臺推出了一項有用的新功能&#xff0c;允許個人輕松檢查…

.NET 創建MCP使用大模型對話二:調用遠程MCP服務

在上一篇文章.NET 創建MCP使用大模型對話-CSDN博客中&#xff0c;我們簡述了如何使用mcp client使用StdIo模式調用本地mcp server。本次實例將會展示如何使用mcp client模式調用遠程mcp server。 一&#xff1a;創建mcp server 我們創建一個天氣服務。 新建WebApi項目&#x…

Redis 中 Set(例如標簽) 和 ZSet(例如排行榜) 的詳細對比,涵蓋定義、特性、命令、適用場景及總結表格

以下是 Redis 中 Set 和 ZSet 的詳細對比&#xff0c;涵蓋定義、特性、命令、適用場景及總結表格&#xff1a; 1. 核心定義 數據類型SetZSet&#xff08;Sorted Set&#xff09;定義無序的、唯一的字符串集合&#xff0c;元素不重復。有序的、唯一的字符串集合&#xff0c;每個…

解決Spring參數解析異常:Name for argument of type XXX not specified

前言 在開發 Spring Boot 應用時&#xff0c;我們常遇到類似 java.lang.IllegalArgumentException: Name for argument not specified 的報錯。這類問題通常與方法參數名稱的解析機制相關&#xff0c;尤其在使用 RequestParam、PathVariable 等注解時更為常見。 一、問題現象與…

剛剛,OpenAI開源PaperBench,重塑頂級AI Agent評測

今天凌晨1點&#xff0c;OpenAI開源了一個全新的AI Agent評測基準——PaperBench。 這個基準主要考核智能體的搜索、整合、執行等能力&#xff0c;需要對2024年國際機器學習大會上頂尖論文的復現&#xff0c;包括對論文內容的理解、代碼編寫以及實驗執行等方面的能力。 根據O…

Golang封裝Consul 服務發現庫

以下是一個經過生產驗證的 Consul 服務發現封裝庫,支持注冊/注銷、健康檢查、智能發現等核心功能,可直接集成到項目中: package consulimport ("context""fmt""log""math/rand""net""os""sync"&quo…

自適應信號處理任務(過濾,預測,重建,分類)

自適應濾波 # signals creation: u, v, d N = 5000 n = 10 u = np.sin(np.arange(0, N/10., N/50000

PyTorch深度學習框架 的基礎知識

目錄 1.pyTorch檢查是否安裝成功 2.PyTorch的張量tensor 基礎創建方式&#xff08;三種&#xff09; 2.2用列表創建tensor 2.2使用元組創建 tensor 2.3使用ndarray創建創建 tensor 2.4 快速創建tensor的常用方法 3.pyTorch中的張量tensor的常用屬性 4. tensor中的基礎數據…