完整代碼在文末,可以一鍵運行。
1. 核心原理
Codebook是一種離散表征學習方法,其核心思想是將連續特征空間映射到離散的碼本空間。我們的實現方案包含三個關鍵組件:
1.1 ViT編碼器
class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")self.proj = nn.Linear(768, codebook_dim)def forward(self, x):outputs = self.vit(x).last_hidden_statepatch_embeddings = outputs[:, 1:, :] # 移除CLS tokenreturn self.proj(patch_embeddings)
- 使用預訓練的ViT-Base模型提取圖像特征
- 移除CLS token,保留196個圖像塊特征
- 線性投影調整特征維度適配Codebook
1.2 Codebook量化層
class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)def quantize(self, z):# 計算L2距離distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 最近鄰查找indices = torch.argmin(distances, dim=1)return indices, self.codebook(indices)
- 使用可學習的Embedding層存儲離散碼本
- 通過L2距離計算實現最近鄰查找
- 支持EMA更新(代碼中已注釋部分)
1.3 ViT解碼器
class ViTDecoder(nn.Module):def __init__(self):self.head = nn.Sequential(nn.ConvTranspose2d(768, 384, 4, 2, 1),nn.ReLU(),... # 更多上采樣層nn.Conv2d(48, 3, 1))
- 使用轉置卷積逐步上采樣
- 最終輸出224x224分辨率圖像
- 與編碼器形成對稱結構
2. 訓練策略
2.1 多目標損失函數
total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
- MSE Loss: 像素級重建誤差
- Perceptual Loss: VGG16特征匹配
- Codebook Loss: 碼本向量優化
- Commitment Loss: 編碼器輸出穩定性
2.2 優化技巧
opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
- 分層學習率設置
- EMA指數平滑更新
- 混合精度訓練支持
- 動態學習率調整
3. 完整訓練流程
3.1 數據準備
transform_train = transforms.Compose([transforms.Resize(224),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(...)
])
- CIFAR-10數據集
- 隨機裁剪+翻轉增強
- Batch Size=4適配顯存
3.2 訓練監控
# TensorBoard記錄
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)# 控制臺日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")
完整代碼
from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()# 加載預訓練ViT-Base模型self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")# 調整輸出維度匹配Codebookself.proj = nn.Linear(768, codebook_dim) # 網頁2/6中的線性嵌入策略def forward(self, x):outputs = self.vit(x).last_hidden_state # [batch, num_patches+1, 768]patch_embeddings = outputs[:, 1:, :] # 移除CLS tokenreturn self.proj(patch_embeddings) # [batch, 196, 512]class Codebook(nn.Module):def __init__(self, num_embeddings=16384, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight) # 網頁1的EMA更新可在此擴展def quantize(self, z):"""量化輸入特征向量參數:z: 輸入特征 [batch, num_patches, embedding_dim]返回:indices: 最近鄰碼本索引 [batch, num_patches]quantized: 量化后的特征 [batch, num_patches, embedding_dim]"""# 重塑輸入為二維矩陣 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim) # [batch*num_patches, dim]# 計算L2距離 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True) # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1) # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t()) # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近鄰indices = torch.argmin(distances, dim=1) # [batch*num_patches]indices = indices.reshape(batch, num_patches) # 恢復原始形狀quantized = self.codebook(indices) # [batch, num_patches, dim]return indices, quantized
class ViTDecoder(nn.Module):def __init__(self, in_dim=512):super().__init__()# 反向映射ViT的patch嵌入self.proj = nn.Linear(in_dim, 768)config = ViTConfig()config.is_decoder = True # 網頁7中的解碼器模式self.transformer = ViTModel(config).encoder self.head = nn.Sequential(# 14x14 -> 28x28nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 28x28 -> 56x56nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 56x56 -> 112x112 nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 112x112 -> 224x224nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 最終調整到3通道nn.Conv2d(48, 3, kernel_size=1))def forward(self, x):x = self.proj(x) # [batch, 196, 768]x = self.transformer(x).last_hidden_statex = x.permute(0, 2, 1).view(-1, 768, 14, 14) # 恢復空間布局return self.head(x) # 輸出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 數據增強和預處理
transform_train = transforms.Compose([transforms.Resize(224), # 調整圖像尺寸適配模型transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加載CIFAR-10數據集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)batch_size = 4 # 增大batch size加速訓練
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')# 改進的Codebook類(增加EMA更新)
class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)self.commitment_cost = commitment_costself.decay = decayself.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))nn.init.normal_(self.ema_w)def quantize(self, z):# 重塑輸入為二維矩陣 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim) # [batch*num_patches, dim]# 計算L2距離 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True) # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1) # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t()) # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近鄰indices = torch.argmin(distances, dim=1) # [batch*num_patches]indices = indices.reshape(batch, num_patches) # 恢復原始形狀quantized = self.codebook(indices) # [batch, num_patches, dim]# 新增EMA更新# if self.training:# with torch.no_grad():# encodings = F.one_hot(indices, self.codebook.num_embeddings).float()# self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)# n = torch.sum(self.ema_cluster_size)# self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)# dw = torch.matmul(encodings.t(), z_flat)# self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)# self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化組件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device) # 用于感知損失# 優化器分開設置
opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4} # 更小的學習率
], lr=3e-4)# 訓練循環
for epoch in range(100):avg_loss = 0start_time = time.time() # 記錄epoch開始時間for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):images = images.to(device)# 前向傳播z = encoder(images)indices, quantized = codebook.quantize(z)recon = decoder(quantized)# 多目標損失計算mse_loss = F.mse_loss(recon, images)# 感知損失(VGG特征匹配)with torch.no_grad():real_features = vgg(images)recon_features = vgg(recon)percep_loss = F.mse_loss(recon_features, real_features)# Codebook相關損失commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)codebook_loss = F.mse_loss(z, quantized.detach())# 總損失total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss# 反向傳播opt.zero_grad()total_loss.backward()opt.step()# 記錄數據avg_loss += total_loss.item()if batch_idx % 50 == 0:# 記錄TensorBoard數據writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)writer.add_scalars('Loss/components', {'mse': mse_loss.item(),'perceptual': percep_loss.item(),'codebook': codebook_loss.item(),'commitment': commitment_loss.item()}, epoch*len(trainloader)+batch_idx)# 保存重建樣本comparison = torch.cat([images[:4], recon[:4]])grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)# 打印epoch統計信息avg_loss /= len(trainloader)print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")# 保存模型檢查點if (epoch+1) % 10 == 0:torch.save({'encoder': encoder.state_dict(),'codebook': codebook.state_dict(),'decoder': decoder.state_dict(),'opt': opt.state_dict()}, f'checkpoint_epoch{epoch+1}.pth')writer.close()
通過本實踐,我們實現了從特征提取到離散表征學習的完整流程。Codebook技術可廣泛應用于圖像壓縮、生成模型等領域,期待讀者在此基礎上探索更多可能性。