2.1 視覺的“大模型”時代:ViT的誕生與革新
在計算機視覺領域,卷積神經網絡(CNN)曾是當之無愧的霸主。從LeNet到ResNet,CNN在圖像分類、目標檢測等任務上取得了巨大成功。然而,隨著Transformer模型在自然語言處理(NLP)領域的崛起(如BERT、GPT),研究者們開始思考:Transformer能否也像處理文本序列一樣處理圖像呢?
答案是肯定的,這正是Vision Transformer (ViT) 的核心思想。ViT的出現,打破了CNN在視覺領域的主導地位,開啟了視覺領域“大模型”的新紀元。
核心思想: ViT將圖像視為一系列圖像塊(patches),然后將這些圖像塊序列輸入到標準的Transformer編碼器中進行處理。就像Transformer處理文本中的單詞序列一樣,它處理圖像中的圖像塊序列。
原理詳解:
- 圖像分塊(Patch Embedding): 將輸入圖像分割成固定大小的、不重疊的圖像塊(例如,一個224x224的圖像被分成16x16的圖像塊,會得到14x14=196個圖像塊)。每個圖像塊被展平(flatten)并線性投影到一個嵌入維度,形成“圖像塊嵌入”。
- 類別Token (CLS Token): 在圖像塊嵌入序列的開頭添加一個特殊的
[CLS]
token(可學習的嵌入),它的最終輸出作為整個圖像的表示,用于分類等下游任務。 - 位置編碼(Positional Embedding): 為了保留圖像塊的空間信息,為每個圖像塊嵌入添加可學習的“位置編碼”。這樣,Transformer就能知道每個圖像塊在原始圖像中的相對位置。
- Transformer編碼器: 將帶有位置編碼的圖像塊序列輸入到標準的Transformer編碼器中。編碼器由多個自注意力層和前饋網絡層堆疊而成。自注意力機制允許模型在處理每個圖像塊時,能夠考慮到圖像中所有其他圖像塊的信息,從而捕捉全局依賴。
- 分類頭: 編碼器輸出的
[CLS]
token的特征向量被送入一個簡單的多層感知機(MLP)分類頭,用于最終的圖像分類任務。
ViT的革新之處:
- 全局感受野: 相較于CNN通過層層堆疊逐步擴大感受野,Transformer的自注意力機制天生就具備全局感受野,可以一步到位地捕捉圖像中的長距離依賴關系。
- 可擴展性: 事實證明,Transformer在擁有足夠大的數據集進行預訓練時,其性能會隨著模型規模的增大而顯著提升,展現出強大的可擴展性。
- 統一架構: 將NLP和CV任務統一到Transformer架構下,為多模態學習提供了新的可能。
Python示例:簡化版ViT實現
我們將實現一個非常簡化的ViT模型,用于圖像分類。為了可運行和理解,這里會用線性層代替多頭自注意力,并使用簡單的前饋網絡。
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange # 方便圖像分塊操作
import torchvision.transforms as T
from PIL import Image# 輔助函數:生成一個虛擬圖像
def create_dummy_image(size=(224, 224)):return Image.new('RGB', size, color='red')# 1. Patch Embedding 和 Positional Embedding
class PatchEmbedding(nn.Module):def __init__(self, img_size, patch_size, in_channels, embed_dim):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):# x: (B, C, H, W)x = self.proj(x) # (B, embed_dim, H_new, W_new)x = x.flatten(2) # (B, embed_dim, num_patches)x = x.transpose(1, 2) # (B, num_patches, embed_dim)return x# 2. 簡化的Transformer Encoder Block
class SimplifiedTransformerBlock(nn.Module):def __init__(self, embed_dim, num_heads=8, mlp_ratio=4.):super().__init__()self.norm1 = nn.LayerNorm(embed_dim)# 簡化注意力:用一個線性層來模擬多頭自注意力的一部分功能# 實際的MultiheadAttention會更復雜self.attn = nn.Linear(embed_dim, embed_dim) self.norm2 = nn.LayerNorm(embed_dim)# MLP Blockhidden_features = int(embed_dim * mlp_ratio)self.mlp = nn.Sequential(nn.Linear(embed_dim, hidden_features),nn.GELU(),nn.Linear(hidden_features, embed_dim))def forward(self, x):# 簡化版自注意力attn_output = self.attn(self.norm1(x))x = x + attn_output # Residual connection# MLPx = x + self.mlp(self.norm2(x)) # Residual connectionreturn x# 3. 簡化版Vision Transformer
class SimpleViT(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, num_layers=12, num_heads=12, mlp_ratio=4.):super().__init__()self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # +1 for CLS tokenself.blocks = nn.Sequential(*[SimplifiedTransformerBlock(embed_dim, num_heads, mlp_ratio)for _ in range(num_layers)])self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes) # 分類頭# 初始化位置編碼(這里只是簡單初始化,實際會更復雜)nn.init.trunc_normal_(self.pos_embed, std=.02)nn.init.trunc_normal_(self.cls_token, std=.02)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):B = x.shape[0]x = self.patch_embed(x) # (B, num_patches, embed_dim)cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, embed_dim)x = torch.cat((cls_tokens, x), dim=1) # (B, num_patches + 1, embed_dim)x = x + self.pos_embed # Add positional embeddingfor blk in self.blocks:x = blk(x)x = self.norm(x)# 取出 CLS token 的特征進行分類return self.head(x[:, 0])# --- 運行示例 ---
# 圖像預處理
transform = T.Compose([T.Resize((224, 224)),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 創建一個虛擬圖像
dummy_img = create_dummy_image()
input_tensor = transform(dummy_img).unsqueeze(0) # Add batch dimension (1, 3, 224, 224)# 初始化模型
# 注意:這里為了可運行,將 num_layers 和 num_heads 設小
# 實際ViT模型非常大
model = SimpleViT(img_size=224, patch_size=16, in_channels=3, num_classes=10, embed_dim=128, num_layers=2, num_heads=4) print("--- 簡化版ViT示例 ---")
print(f"Input tensor shape: {input_tensor.shape}")output = model(input_tensor)
print(f"Output logits shape: {output.shape}") # (Batch_size, num_classes)# 模擬訓練
dummy_labels = torch.tensor([0]) # 假設標簽是0
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)for epoch in range(5):optimizer.zero_grad()outputs = model(input_tensor)loss = criterion(outputs, dummy_labels)loss.backward()optimizer.step()print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
代碼說明:
PatchEmbedding
: 將輸入圖像分割成不重疊的patch_size
xpatch_size
大小的圖像塊,然后通過卷積層將其投影到embed_dim
維度,并展平、轉置,形成Transformer所需的序列。SimplifiedTransformerBlock
: 這是一個高度簡化的Transformer編碼器層。真正的Transformer Block包含Multi-Head Self-Attention,這里為了演示將注意力部分簡化為一個線性層。核心思想是殘差連接和LayerNorm。SimpleViT
:cls_token
:一個可學習的特殊嵌入,用于聚合整個圖像的信息。pos_embed
:可學習的位置編碼,為每個圖像塊(包括CLS token)提供位置信息。blocks
:堆疊的簡化Transformer編碼器塊。head
:最終的分類頭,用于從CLS token的輸出中預測類別。
- 運行示例: 創建一個虛擬圖像,經過預處理后輸入模型,并模擬一個簡單的訓練過程。
2.2 自監督視覺學習的先鋒:MoCo與DINO
ViT雖然強大,但它通常需要海量的標注數據進行預訓練。為了解決標注數據的依賴,自監督學習在視覺領域也大放異彩。MoCo (Momentum Contrast) 和 DINO (Self-Distillation with No Labels) 是其中的佼佼者。
2.2.1 MoCo (Momentum Contrast)
核心思想: MoCo通過維護一個動態的“負樣本隊列”和“動量編碼器”,將對比學習推向了大規模預訓練。它解決了對比學習中負樣本數量的限制,使得模型能學到更好的特征表示。
原理詳解:?
MoCo的優勢:
- 高效利用負樣本: 解決了單批次負樣本數量不足的問題。
- 穩定訓練: 動量更新使得鍵編碼器輸出的特征更加穩定,避免了因為負樣本更新過快而導致的訓練不穩定。
- 高性能: 預訓練后的模型在多種下游任務上表現出色。
2.2.2 DINO (Self-Distillation with No Labels)
核心思想: DINO通過自蒸餾(Self-Distillation)的方式進行自監督學習,而無需負樣本。它訓練一個“學生”網絡去匹配一個“教師”網絡的輸出,而教師網絡的參數則是學生網絡參數的動量平均。
原理詳解: DINO的核心是“學生-教師”范式:
- 學生網絡 (Student Network): 接受一個圖像增強后的視圖作為輸入,并進行常規的梯度更新。
- 教師網絡 (Teacher Network): 接受同一圖像的另一個(通常是更強的)增強視圖作為輸入。教師網絡的參數是學生網絡參數的指數移動平均,與MoCo的動量編碼器類似。教師網絡不進行梯度更新。
- 目標: 學生網絡的目標是預測教師網絡的輸出。損失函數通常是交叉熵(或者更精確地說是Kullback-Leibler散度),它鼓勵學生網絡輸出與教師網絡輸出的概率分布相匹配。
- 中心化 (Centering) 和銳化 (Sharpening): DINO還引入了中心化(防止模型崩潰為常數輸出)和銳化(增加輸出分布的區分度)的操作,進一步提升訓練效果。
DINO的優勢:
- 無需負樣本: 簡化了訓練流程,避免了負樣本選擇的復雜性。
- 性能優異: 在許多基準測試中取得了與MoCo相當甚至更好的性能。
- 強大的特征可視化能力: 預訓練的DINO模型學習到的特征具有出色的語義聚類和可解釋性,即使不經過微調也能在圖像分割等任務上表現出強大的零樣本能力。
Python示例:非常簡化版的DINO
我們使用一個極簡的DINO框架來演示其核心概念,省略了大部分細節。
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image# 輔助函數:生成一個虛擬圖像
def create_dummy_image(size=(32, 32)):return Image.new('RGB', size, color='green')# 1. 定義簡單的編碼器(作為學生和教師的骨干)
class SimpleDINOEncoder(nn.Module):def __init__(self, in_channels=3, out_dim=128):super().__init__()self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.flatten = nn.Flatten()self.fc1 = nn.Linear(64 * 8 * 8, 256) # Assuming 32x32 input -> 8x8 after 2 poolsself.fc2 = nn.Linear(256, out_dim)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = self.flatten(x)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 2. 定義DINO的Projection Head
class DINOHead(nn.Module):def __init__(self, in_dim, out_dim, hidden_dim=256, bottleneck_dim=64):super().__init__()self.mlp = nn.Sequential(nn.Linear(in_dim, hidden_dim),nn.GELU(),nn.Linear(hidden_dim, bottleneck_dim),nn.GELU(),nn.Linear(bottleneck_dim, out_dim) # 輸出的維度是教師/學生網絡的輸出維度)self.temp = nn.Parameter(torch.ones(1) * 0.07) # Temperature parameter for sharpeningdef forward(self, x):x = self.mlp(x)# 注意:這里我們省略了DINO中的centering和sharpening的復雜邏輯# 實際DINO會有一個中心化(moving average of teacher output)和溫度調整return x / self.temp.exp() # 簡單模擬溫度縮放# 3. 更新教師網絡參數的動量機制
@torch.no_grad()
def update_teacher_params(student_model, teacher_model, m):"""m: momentum coefficient"""for param_s, param_t in zip(student_model.parameters(), teacher_model.parameters()):param_t.data = param_t.data * m + param_s.data * (1. - m)# 圖像增強
transform_dino = T.Compose([T.RandomResizedCrop(32, scale=(0.4, 1.0)), # DINO常用不同尺度的裁剪T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 初始化模型
out_dim = 128 # 輸出特征維度
student_encoder = SimpleDINOEncoder(out_dim=out_dim)
teacher_encoder = SimpleDINOEncoder(out_dim=out_dim)
student_head = DINOHead(out_dim, out_dim) # 頭的輸出維度與教師網絡輸出維度一致
teacher_head = DINOHead(out_dim, out_dim)# 初始化教師網絡參數與學生網絡一致
teacher_encoder.load_state_dict(student_encoder.state_dict())
teacher_head.load_state_dict(student_head.state_dict())# 教師網絡不計算梯度
for p in teacher_encoder.parameters():p.requires_grad = False
for p in teacher_head.parameters():p.requires_grad = Falseoptimizer = torch.optim.Adam(list(student_encoder.parameters()) + list(student_head.parameters()), lr=1e-3)# 模擬數據
dummy_img_dino = create_dummy_image()print("\n--- 簡化版DINO示例 ---")
for epoch in range(10): # 簡化訓練10個epochoptimizer.zero_grad()# 生成兩個增強視圖view_1 = transform_dino(dummy_img_dino).unsqueeze(0)view_2 = transform_dino(dummy_img_dino).unsqueeze(0)# 學生網絡輸出student_output_v1 = student_head(student_encoder(view_1)) # (B, out_dim)student_output_v2 = student_head(student_encoder(view_2))# 教師網絡輸出 (在無梯度模式下)with torch.no_grad():teacher_output_v1 = teacher_head(teacher_encoder(view_1))teacher_output_v2 = teacher_head(teacher_encoder(view_2))# DINO損失:學生網絡預測教師網絡輸出的交叉熵# F.log_softmax 后面接 F.softmax 實際上是KL散度# DINO實際會使用更復雜的交叉熵,這里簡化loss = 0.5 * (F.cross_entropy(student_output_v1, teacher_output_v2.detach().softmax(dim=-1)) +F.cross_entropy(student_output_v2, teacher_output_v1.detach().softmax(dim=-1)))loss.backward()optimizer.step()# 更新教師網絡參數update_teacher_params(student_encoder, teacher_encoder, m=0.996)update_teacher_params(student_head, teacher_head, m=0.996)print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
代碼說明:
SimpleDINOEncoder
: 簡單的CNN骨干,作為學生和教師的編碼器。DINOHead
: 預測頭,包含MLP和溫度參數。實際DINO中會有中心化(centering)操作,防止模型崩潰,這里為簡化省略。update_teacher_params
: 實現教師網絡參數的動量更新。- DINO損失: 學生網絡對教師網絡輸出的交叉熵損失。注意教師網絡的輸出
detach().softmax()
,表示教師網絡不參與梯度計算,且其輸出經過softmax后作為學生網絡的目標分布。
2.3 掩碼自編碼器:MAE引領視覺預訓練新范式
在上一篇中我們提到了掩碼重建。MAE (Masked Autoencoders) 是將掩碼重建思想成功引入視覺領域的代表性工作。
核心思想: MAE通過對圖像進行高比例的隨機掩碼(例如75%),然后訓練編碼器僅處理可見的圖像塊,解碼器則負責根據編碼器的輸出和掩碼信息重建原始圖像的像素。
MAE的優勢:
- 高效訓練: 編碼器只處理少部分可見的圖像塊,大大降低了計算成本,使得在大型數據集上進行預訓練變得經濟高效。
- 強大的表示學習: 高比例的掩碼迫使模型學習到更全局、更高層次的語義特征,因為要重建大量缺失信息需要模型對圖像內容有深刻的理解。
- 通用性: 預訓練后的MAE模型可以輕松遷移到各種下游任務,并且表現出色。
Python示例:見專欄第一篇1.3節,其“簡化版掩碼重建(MAE啟發)”已經涵蓋了MAE的核心原理與實現。
2.4 普適的視覺分割利器:Segment Anything Model (SAM)
核心思想: SAM (Segment Anything Model) 是一個提示驅動(promptable) 的圖像分割基礎模型。它能夠對圖像中的任何物體進行分割,并且其核心能力在于能夠對各種交互提示(如點擊點、邊框、文本提示) 做出響應,生成高質量的分割掩碼。
原理詳解: SAM的設計理念旨在解決圖像分割的通用性問題,即訓練一個模型能夠分割任何圖像中的任何物體。其核心架構包括:
- 圖像編碼器(Image Encoder): 預訓練的Vision Transformer(如MAE)用于將圖像編碼成高質量的特征嵌入。這個編碼器在推理時是固定的,只計算一次。
- 提示編碼器(Prompt Encoder): 負責將各種提示(點、框、文本、掩碼)編碼成嵌入向量。
- 稀疏提示(Sparse Prompts): 如點(X, Y坐標)、框(邊界框),通過位置編碼和學習到的嵌入進行編碼。
- 密集提示(Dense Prompts): 如粗糙的掩碼,通過卷積層進行編碼。
- 掩碼解碼器(Mask Decoder): 這是一個輕量級的Transformer解碼器,它接收圖像編碼器輸出的圖像嵌入、提示編碼器輸出的提示嵌入,然后預測出高質量的物體掩碼。解碼器設計為高效且可以多次運行,以生成多個可能的分割結果。
SAM的創新之處:
- 提示工程: 將傳統的“分割一切”任務轉化為“給定提示,分割提示所指的一切”的范式,極大地提升了模型的靈活性和用戶交互性。
- 龐大數據集: SAM伴隨一個巨大的、高質量的SA-1B數據集發布,該數據集包含了11億個掩碼,是模型泛化能力的關鍵。
- 零樣本遷移: 憑借其強大的預訓練能力和提示機制,SAM在許多新的分割任務上表現出卓越的零樣本和少樣本能力。
Python示例:SAM的推理流程(基于預訓練模型)
由于SAM模型龐大且復雜,從頭訓練不切實際。這里我們將演示如何使用Meta官方提供的預訓練SAM模型進行推理,展示其核心功能。
步驟:
- 安裝必要的庫。
- 下載預訓練的SAM模型權重。
- 加載圖像并轉換為模型輸入。
- 使用
SamAutomaticMaskGenerator
進行全圖分割(無需提示)。 - 使用
SamPredictor
進行提示驅動分割(點或框提示)。
首先,確保你安裝了必要的庫:
Bash
pip install opencv-python pycocotools matplotlib
pip install segment-anything
然后,你需要下載SAM的預訓練權重文件。你可以從Meta的GitHub頁面下載sam_vit_h_4b8939.pth
(或sam_vit_b
等較小版本)并放在你的代碼目錄下。
Python
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from PIL import Image# 輔助函數:顯示圖片和掩碼
def show_mask(mask, ax, random_color=False):if random_color:color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30/255, 144/255, 255/255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):pos_points = coords[labels==1]neg_points = coords[labels==0]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax):x0, y0, x1, y1 = boxax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, edgecolor='green', facecolor=(0,0,0,0), lw=2)) # --- SAM 運行示例 ---
print("\n--- SAM模型推理示例 ---")# 1. 加載模型
sam_checkpoint = "sam_vit_b.pth" # 確保你已下載此文件
model_type = "vit_b" # 對應下載的模型類型device = "cuda" if torch.cuda.is_available() else "cpu"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)# 2. 加載圖像 (這里使用一張簡單的圖像)
# 你可以用自己的圖片路徑替換
image_path = "path/to/your/image.jpg" # 替換為你的圖片路徑
# 如果沒有圖片,可以生成一張簡單的圖像
try:image = cv2.imread(image_path)if image is None: # 如果圖片路徑無效,使用虛擬圖片print(f"Warning: Image not found at {image_path}. Using a dummy image.")dummy_img_sam = Image.new('RGB', (500, 500), color = 'yellow')# 在虛擬圖片上畫個圈或方塊,方便分割draw = ImageDraw.Draw(dummy_img_sam)draw.ellipse((100, 100, 400, 400), fill='blue', outline='black')draw.rectangle((50, 50, 150, 150), fill='red', outline='black')image = np.array(dummy_img_sam)image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert to BGR for opencvexcept FileNotFoundError:print(f"Warning: Image not found at {image_path}. Using a dummy image.")from PIL import ImageDrawdummy_img_sam = Image.new('RGB', (500, 500), color = 'yellow')draw = ImageDraw.Draw(dummy_img_sam)draw.ellipse((100, 100, 400, 400), fill='blue', outline='black')draw.rectangle((50, 50, 150, 150), fill='red', outline='black')image = np.array(dummy_img_sam)image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert to BGR for opencvimage_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # SAM期望RGB圖像# 3. 自動掩碼生成器 (Segment Anything Anywhere)
print("\n--- 自動分割所有物體 ---")
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image_rgb)print(f"Found {len(masks)} masks.")# 可視化所有自動生成的掩碼
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image_rgb)
ax[0].set_title("Original Image")
ax[0].axis('off')ax[1].imshow(image_rgb)
for mask_data in masks:show_mask(mask_data["segmentation"], ax[1], random_color=True)
ax[1].set_title("Automatic Segmentation")
ax[1].axis('off')
plt.show()# 4. 提示驅動的分割 (Segment Anything with Prompts)
print("\n--- 提示驅動分割 (點提示) ---")
predictor = SamPredictor(sam)
predictor.set_image(image_rgb)# 示例:通過點提示進行分割
# 假設我們想分割一個在圖像中心附近的物體
input_point = np.array([[250, 250]]) # 坐標 (x, y)
input_label = np.array([1]) # 1表示前景點masks_prompt, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=True, # 可以生成多個可能的掩碼
)# 可視化點提示分割結果
for i, (mask, score) in enumerate(zip(masks_prompt, scores)):plt.figure(figsize=(6,6))plt.imshow(image_rgb)show_mask(mask, plt.gca())show_points(input_point, input_label, plt.gca())plt.title(f"Mask {i+1}, Score: {score:.3f}")plt.axis('off')plt.show()print("\n--- 提示驅動分割 (框提示) ---")
# 示例:通過框提示進行分割
# 假設我們想分割一個框住的物體 (x_min, y_min, x_max, y_max)
input_box = np.array([100, 100, 400, 400]) masks_box, scores_box, logits_box = predictor.predict(point_coords=None,point_labels=None,box=input_box[None, :], # 注意需要加batch維度multimask_output=True,
)# 可視化框提示分割結果
for i, (mask, score) in enumerate(zip(masks_box, scores_box)):plt.figure(figsize=(6,6))plt.imshow(image_rgb)show_mask(mask, plt.gca())show_box(input_box, plt.gca())plt.title(f"Mask {i+1}, Score: {score:.3f}")plt.axis('off')plt.show()
代碼說明:
- 模型加載: 通過
sam_model_registry
加載預訓練的SAM模型。確保sam_checkpoint
路徑正確且模型權重文件已下載。 - 圖像處理: SAM期望RGB格式的圖像。
SamAutomaticMaskGenerator
: 這是SAM的一個高級接口,可以自動在整張圖像中檢測并分割出所有潛在的物體,無需任何提示。它會生成一系列字典,每個字典包含一個分割掩碼、其置信度等信息。SamPredictor
: 這是SAM的核心推理接口,允許你根據提供的提示(點、框)來預測分割掩碼。predictor.set_image()
:在進行預測前,需要先將圖像輸入編碼器,生成圖像嵌入。predictor.predict()
:根據輸入的點坐標point_coords
、點標簽point_labels
(1表示前景,0表示背景)、邊界框box
來生成掩碼。multimask_output=True
表示可以返回多個可能的掩碼。
- 可視化函數:
show_mask
、show_points
、show_box
是輔助函數,用于將分割結果、點和框可視化在圖像上。
總結
本篇專欄深入探討了視覺基礎模型,從革命性的ViT開始,到自監督學習領域的MoCo和DINO,再到普適性的分割模型SAM。我們看到:
- ViT 成功將Transformer引入視覺領域,開啟了圖像處理的“大模型”時代。
- MoCo和DINO 作為自監督學習的典范,解決了視覺大模型對海量標注數據的依賴,使得模型能夠從無標簽數據中學習到強大的視覺表示。
- MAE 進一步提升了自監督學習的效率和效果,尤其在高掩碼比例下表現出色。
- SAM 憑借其提示驅動的通用分割能力和龐大的數據集,極大地推動了圖像分割的邊界,實現了對“任何物體”的分割。
這些視覺基礎模型為未來的計算機視覺應用奠定了堅實的基礎。它們不僅在傳統任務上取得了SOTA性能,更開啟了零樣本、少樣本學習以及更靈活、更智能的視覺交互模式。