FFaceNeRF模塊
論文《FFaceNeRF: Few-shot Face Editing in Neural Radiance Fields》
深度交流Q裙:1051849847
全網同名 【大嘴帶你水論文
】 B站定時發布詳細講解視頻
視頻地址,點擊查看論文詳細講解,每日更新:
https://b23.tv/zdapaC3
詳細代碼見文章最后
1、作用
FFaceNeRF旨在解決現有基于NeRF的3D人臉編輯方法嚴重依賴固定布局的預訓練分割蒙版、導致用戶控制能力有限的問題。它使用戶能夠根據特定的編輯需求(如虛擬試妝、醫療整形預覽等)自由定義和使用新的蒙版布局,而無需收集和標注大規模數據集,極大地提升了3D人臉編輯的靈活性和實用性。
2、機制
- 幾何適配器 (Geometry Adapter) : 在預訓練的幾何解碼器(用于生成固定布局的分割圖)之后,添加一個輕量級的MLP網絡作為幾何適配器。該適配器負責將固定布局的輸出調整為用戶期望的、任意布局的分割蒙版。
- 特征注入 (Feature Injection) : 為了在適應新蒙版時保留豐富的幾何細節,模型將預訓練模型中的三平面特征(tri-plane feature)和視角方向(view direction)直接注入到幾何適配器中,彌補了預訓練解碼器可能丟失的信息。
- 三平面增強的潛在混合 (LMTA) : 這是一種為小樣本學習設計的數據增強策略。通過在生成器的潛在空間中混合不同層的潛在編碼,可以在保持核心語義信息(如人臉結構)不變的同時,生成多樣化的訓練樣本(如改變色調、飽和度),有效避免了在僅有10個樣本的情況下發生的過擬合。
- 基于重疊的優化 (Overlap-based Optimization) : 在訓練和推理過程中,除了使用傳統的交叉熵損失外,還引入了基于DICE系數的重疊損失(overlap loss)。這種損失函數對小區域的變化更敏感,確保了即使在編輯精細區域(如瞳孔、鼻翼)時也能實現精確對齊和穩定生成。
3、獨特優勢
- 小樣本高效學習 : 最顯著的優勢是其小樣本(Few-shot)能力,僅需約10個帶有自定義蒙版的樣本即可完成訓練,快速適應新的編輯任務,極大降低了數據和時間成本。
- 高度的靈活性和控制力 : 用戶不再受限于固定的分割類別,可以為任何感興趣的面部區域創建蒙版并進行編輯,實現了前所未有的控制自由度。
- 精細區域的精準編輯 : 通過基于重疊的優化策略,模型能夠精確地處理和編輯微小面部特征,解決了傳統方法在小區域編輯上容易失敗的痛點。
- 保持身份和非編輯區域 : 在編輯特定區域時,能夠很好地保持人臉的身份特征以及未編輯區域的圖像內容,生成結果自然且保真度高。
4、代碼實現
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MockTriPlaneGenerator(nn.Module):""" NeRF三平面生成器,用于生成基礎的分割圖。 """def __init__(self, num_classes=10):super().__init__()self.num_classes = num_classes# 一個簡單的卷積層,模擬從三平面特征到分割圖的解碼過程self.decoder = nn.Conv2d(32, num_classes, kernel_size=1)def forward(self, triplane_features, view_direction):# 解碼過程,返回一個隨機的分割圖batch_size = triplane_features.shape[0]# 返回一個隨機的、固定布局的分割圖return torch.randn(batch_size, self.num_classes, 64, 64)class MockDataset(torch.utils.data.Dataset):def __init__(self, num_samples=10, num_classes=12):self.num_samples = num_samplesself.num_classes = num_classesdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 生成模擬數據triplane_features = torch.randn(1, 32, 256, 256) # 模擬三平面特征view_direction = torch.randn(1, 3) # 模擬視角方向# 模擬用戶自定義的目標蒙版target_mask = torch.randint(0, self.num_classes, (1, 64, 64), dtype=torch.long)return triplane_features, view_direction, target_mask# ----------------------------------------------------------------------------
# 核心代碼實現 (Core Implementation)
# ----------------------------------------------------------------------------class GeometryAdapter(nn.Module):""" 幾何適配器,將固定布局的分割圖調整為用戶自定義的布局。 """def __init__(self, input_channels, output_channels):super().__init__()# 一個輕量級的MLP,用于適配幾何特征self.adapter = nn.Sequential(nn.Conv2d(input_channels, 64, kernel_size=1),nn.ReLU(),nn.Conv2d(64, output_channels, kernel_size=1))def forward(self, x):return self.adapter(x)class FFaceNeRF(nn.Module):""" FFaceNeRF 模塊,集成了預訓練生成器和幾何適配器。 """def __init__(self, pretrained_generator, num_custom_classes):super().__init__()self.pretrained_generator = pretrained_generator# 幾何適配器的輸入通道數等于預訓練生成器的輸出類別數self.adapter = GeometryAdapter(pretrained_generator.num_classes, num_custom_classes)def forward(self, triplane_features, view_direction):# 首先,使用預訓練生成器獲取固定布局的分割圖fixed_layout_seg = self.pretrained_generator(triplane_features, view_direction)# 然后,通過幾何適配器將其調整為自定義布局custom_layout_seg = self.adapter(fixed_layout_seg.detach()) # detach以凍結預訓練部分return custom_layout_segdef dice_loss(pred, target, smooth=1e-5):""" 計算DICE損失,對小區域優化更有效。 """pred = F.softmax(pred, dim=1)# 將target轉換為one-hot編碼target_one_hot = F.one_hot(target.squeeze(1), num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()intersection = torch.sum(pred * target_one_hot, dim=(2, 3))union = torch.sum(pred, dim=(2, 3)) + torch.sum(target_one_hot, dim=(2, 3))dice = (2. * intersection + smooth) / (union + smooth)return 1 - dice.mean()# ----------------------------------------------------------------------------
# 運行示例 (Runnable Example)
# ----------------------------------------------------------------------------if __name__ == '__main__':# --- 1. 初始化參數和模型 ---num_fixed_classes = 10 # 預訓練模型支持的固定類別數num_custom_classes = 12 # 用戶自定義的類別數(例如,眉毛、上唇、下唇等)num_samples = 10 # 小樣本數量epochs = 50 # 訓練輪次# 初始化模擬的預訓練生成器和FFaceNeRF模型pretrained_generator = MockTriPlaneGenerator(num_classes=num_fixed_classes)fface_nerf = FFaceNeRF(pretrained_generator, num_custom_classes)optimizer = torch.optim.Adam(fface_nerf.adapter.parameters(), lr=0.001)# --- 2. 創建模擬數據集 ---dataset = MockDataset(num_samples=num_samples, num_classes=num_custom_classes)dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)# --- 3. 訓練循環 ---for epoch in range(epochs):total_loss = 0for triplane_features, view_direction, target_mask in dataloader:optimizer.zero_grad()# 獲取模型預測的自定義分割圖predicted_seg = fface_nerf(triplane_features, view_direction)# 計算損失(交叉熵 + DICE損失)loss_ce = F.cross_entropy(predicted_seg, target_mask.squeeze(1))loss_dice = dice_loss(predicted_seg, target_mask)loss = loss_ce + 0.5 * loss_dice # 組合損失loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch+1}/{epochs}], 平均損失: {total_loss / len(dataloader):.4f}")# --- 4. 模擬推理 ---print("\n進行一次模擬推理...")with torch.no_grad():# 取一個樣本進行測試test_features, test_view_dir, ground_truth_mask = next(iter(dataloader))predicted_mask_logits = fface_nerf(test_features, test_view_dir)predicted_mask = torch.argmax(predicted_mask_logits, dim=1)print(f"輸入特征尺寸: {test_features.shape}")print(f"預測蒙版尺寸: {predicted_mask.shape}")print(f"預測蒙版中的類別: {torch.unique(predicted_mask)}")
詳細代碼 gitcode地址:https://gitcode.com/2301_80107842/research