目錄
- CLIP原理簡介
- 代碼實現
- 參考鏈接
CLIP原理簡介
論文鏈接,源碼鏈接
CLIP模型由OpenAI在2021年提出,利用雙Decoder(Dual Encoder)的架構來學習圖像和文本之間的對應關系,是多模態大模型的開創之作,為后續許多高效的多模態模型的提出打下基礎。CLIP是一個預訓練模型(Pre-trained Model),在學習到圖像–文本特征之間的關聯后可以遷移到各種下游任務中,如圖像分類,文本引導圖像分割和目標檢測,圖像文本檢索等。由于模型學習到的是文本語義和圖像語義之間的關聯,使得其zero-shot能力非常強大,根據論文中的描述,CLIP在很多數據集上zero-shot的結果甚至超越了許多訓練好的模型的效果。CLIP的訓練范式如下:
CLIP的結構非常簡單,數據集包含大量的圖像文本對,圖像經過圖像編碼器得到圖像特征,文本經過文本編碼器得到文本特征,將圖像特征和文本特征按照數據集中的對應關系進行配對,不配對的特征給予懲罰,從上圖中可以看出,我們希望矩陣中藍色的值趨近于1,其余值趨近于0,采用對比學習的方式對模型進行訓練,算法的偽代碼如下:
從損失函數中可以看出,分別對特征對比矩陣的行和列進行交叉熵損失函數計算,并取平均得到最終的loss。圖像編碼器一般有兩種選擇:ResNet和ViT;文本編碼器采用Transformer Encoder,均是各自領域中優秀的特征提取網絡。
CLIP的推理范式如下:
在推理階段,圖像編碼器中輸入圖像獲取圖像特征,文本編碼器中輸入文本獲取文本特征,將圖像特征向量和文本特征向量的轉置相乘得到每張圖像對每個文本的特征相似度,相似度最高的文本即描述了該圖像中物體所屬的類別。
代碼實現
Flickr8k數據集下載,提取碼:fbfz
DistilBert模型文件下載
我的運行環境:
?CUDA 11.8
?pytorch 2.2.2
?transformers 4.44.0?# 用于從HuggingFace上加載預訓練模型
數據集預覽:
由于作者的顯卡算力有限,選取Flickr8k數據集進行模型訓練,其中包含8k個圖像文本對,其中一張圖像對應5條文本。圖像編碼器采用ResNet50,直接從timm庫中導入;文本編碼器采用DistilBert,即輕量化的Bert模型,從HuggingFace上下載。閑話少說,小二,上菜!
### 模型參數配置 ###
import argparse
from dataclasses import dataclassparser = argparse.ArgumentParser(description="CLIP from zero")
parser.add_argument("--image_dir", default="user/Flickr8k/Images", help='path to image folder') # 存放圖像的文件路徑
parser.add_argument("--caption_dir", default="user/Flickr8k", help='path to caption folder') # 存放文本的文件路徑
parser.add_argument("--weight_dir", default='user/checkpoints', help='path to save output weight') # 存放訓練權重的文件路徑
args = parser.parse_args()@dataclass
class CLIPConfig:image_path: str = args.image_dir # 圖像存放路徑image_size: int = 224 # resize后的圖像尺寸,便于構建Dataloadercaption_path: str = args.caption_dir # 文本存放路徑batch_size: int = 8 # 一個批次中的數據數量epochs: int = 3 # 訓練世代image_encoder_model: str = "resnet50" # 圖像編碼器的名稱image_embedding_dim: int = 2048 # 圖像特征的維度text_encoder_model: str = "distilbert-base-uncased" # 文本編碼器的名稱text_embedding_dim: int = 768 # 文本特征的維度text_tokenizer: str = text_encoder_model # 文本分詞器模型的名稱max_length: int = 200 # 文本編碼器可輸入的最長文本長度pretrained: bool = False # 是否加載預訓練好的編碼器trainable: bool = True # 在訓練過程中是否更新編碼器的參數temperature: float = 1.0 # 計算loss時的正則化系數proj_dim: int = 256 # 圖像特征和文本特征統一后的維度dropout_rate: float = 0.1 # dropout系數,避免過擬合### 載入數據集并初始化 ###
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import pandas as pd
import cv2class CLIPDataset(Dataset):def __init__(self, config, image_path, caption_path, transforms=True):"""圖片文件名和標題的長度必須相同如果一個圖片對應多個標題,該圖片文件名需要重復多次"""self.image_path = image_path # 圖像路徑self.caption_path = caption_path # 文本路徑self.dataframe = pd.read_csv(f"{self.caption_path}/captions.csv") # 讀取文本self.tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer) # 載入分詞器self.image_filenames = self.dataframe["image"].values # 獲取圖像文件名self.captions = list(self.dataframe["caption"].values) # 獲取圖像對應的描述文本self.encoded_captions = self.tokenizer(self.captions, padding=True, truncation=True, max_length=config.max_length) # 文本分詞self.transforms = transforms # 對輸入圖像進行預處理def __getitem__(self, idx): # 獲取數據集中第idx個數據,其中包含圖片名稱和對應的標題(可能不止一個)item = {key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()}image = cv2.imread(f"{self.image_path}/{self.image_filenames[idx]}") # 獲取原始圖像image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transforms:image = self.get_transforms(mode="train")(image=image)["image"] # 對圖像進行預處理item["image"] = torch.tensor(image).permute(2, 0, 1).float() # 將圖片轉換為tensor格式,并調整為RGB順序item["caption"] = self.captions[idx] # 獲取標題return itemdef __len__(self):return len(self.captions) # 獲取文本長度def get_transforms(self, mode="train"):if mode == "train":return A.Compose([A.Resize(config.image_size, config.image_size, always_apply=True), # 對圖像進行resizeA.Normalize(max_pixel_value=255.0, always_apply=True) # 對像素值進行歸一化])### 圖像編碼器 ###
import torch.nn as nn
import timmclass ImageEncoder(nn.Module):"""圖像編碼器,采用ResNet50"""def __init__(self, config):super().__init__()self.model = timm.create_model(config.image_encoder_model, pretrained=config.pretrained, num_classes=0, global_pool="avg") # 創建ResNet50for p in self.model.parameters():p.requires_grad = config.trainable # 設置參數可訓練def forward(self, x):image_encoded = self.model(x) # 獲得圖像特征編碼,形狀為[batch_size, image_embedding_dim]return image_encoded### 文本編碼器 ###
class TextEncoder(nn.Module):"""文本編碼器,采用DistilBERT"""def __init__(self, config):super().__init__()if config.pretrained:self.model = DistilBertModel.from_pretrained(config.text_encoder_model) # 導入下載好的模型文件else:self.model = DistilBertModel(DistilBertConfig())for p in self.model.parameters():p.requires_grad = config.trainable # 設置參數可訓練self.target_token_idx = 0# 提取出和圖像對應的文本特征向量def forward(self, input_ids, attention_mask):output = self.model(input_ids=input_ids, attention_mask=attention_mask)text_encoded = output.last_hidden_state[:, self.target_token_idx, :] # [batch_size, text_embedding_dim]return text_encoded### 投影層 (MLP) ###
class ProjectionHead(nn.Module):"""將圖像編碼和文本編碼映射到相同維度"""def __init__(self, config, input_embedding_dim):super().__init__()self.proj = nn.Linear(input_embedding_dim, config.proj_dim)self.act_fn = nn.GELU()self.fc = nn.Linear(config.proj_dim, config.proj_dim)self.dropout = nn.Dropout(config.dropout_rate)self.layer_norm = nn.LayerNorm(config.proj_dim)def forward(self, x):x_proj = self.proj(x)x = self.act_fn(x_proj)x = self.fc(x)x = self.dropout(x)x = x + x_projx = self.layer_norm(x)return x### 定義損失函數 ###
def cross_entropy(logits, labels, reduction='none'):log_softmax = nn.LogSoftmax(dim=-1)loss = (-labels * log_softmax(logits)).sum(dim=1)if reduction == 'mean':return loss.mean()else:return loss.sum()### 模型主體 ###
import torch.nn.functional as Fclass CLIP(nn.Module):def __init__(self, config):super().__init__()self.image_encoder = ImageEncoder(config) # 實例化圖像編碼器self.text_encoder = TextEncoder(config) # 實例化文本編碼器self.image_proj = ProjectionHead(config, config.image_embedding_dim) # 圖像特征投影self.text_proj = ProjectionHead(config, config.text_embedding_dim) # 文本特征投影self.temperature = config.temperaturedef forward(self, batch):image_features = self.image_encoder(batch["image"]) # 圖像編碼# 文本編碼,tokenizer處理后的文本序列自帶input_ids和attention_masktext_features = self.text_encoder(batch["input_ids"], batch["attention_mask"])image_embeddings = self.image_proj(image_features) # 圖像特征投影text_embeddings = self.text_proj(text_features) # 文本特征投影logits = (text_embeddings @ image_embeddings.T) / self.temperature # tensor形狀為[batch_size, batch_size]images_similarity = image_embeddings @ image_embeddings.T # tensor形狀為[batch_size, batch_size]text_similarity = text_embeddings @ text_embeddings.T # tensor形狀為[batch_size, batch_size]# 軟標簽,不配對的位置設置為較小的數,而非0labels = F.softmax((images_similarity + text_similarity) / 2 * self.temperature, dim=-1) loss_T = cross_entropy(logits, labels) # 計算文本損失loss_I = cross_entropy(logits.T, labels.T) # 計算圖像損失total_loss = (loss_T + loss_I) / 2 # 對比學習平均損失return total_loss, logits### 訓練函數 ###
def train(model, optimizer, scheduler, train_loader, device):model.train() # 模型設置為訓練模式train_loss = 0train_loader = tqdm(train_loader, total=len(train_loader)) # 顯示訓練進度條cnt = 0for batch in train_loader:# print(batch.keys())cnt += 1batch = {k: v.to(device) for k, v in batch.items() if k != "caption"} # 將dataloader中一個batch的數據轉換為字典形式loss, _ = model(batch)optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step(metrics=loss.item()) # 根據上次訓練的損失更新學習率train_loss += loss.item()# 訓練100個batch顯示一次lossif cnt % 100 == 0:print(f' ==> Epoch: {epoch + 1}, Batch: {cnt}, Loss: {loss.item():.4f}')return train_loss / len(train_loader) # 平均訓練損失### 測試函數 ###
def eval(model, val_loader, device):model.eval() # 模型設置為測試模式val_loss = 0val_loader = tqdm(val_loader, total=len(val_loader))with torch.no_grad():for batch in val_loader:batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}loss, _ = model(batch)val_loss += loss.item()return val_loss / len(val_loader) # 平均測試損失if __name__ == '__main__':config = CLIPConfig() # 實例化配置信息model = CLIP(config) # 實例化CLIP模型device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 查看模型的總參數量total_params = sum(p.numel() for p in model.parameters())print(f"Total parameters: {total_params / 1e6} M")optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)dataset = CLIPDataset(config, args.image_dir, args.caption_dir) # 讀取并預處理數據train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2]) # 80%為訓練數據,20%為測試數據dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)# 開始訓練best_loss = float("inf")for epoch in range(config.epochs):print(f"Epoch: {epoch + 1}")train_loss_avg = train(model, optimizer, scheduler, train_loader, device)val_loss_avg = eval(model, val_loader, device)if val_loss_avg < best_loss:best_loss = val_loss_avgtorch.save(model.state_dict(), f'{args.weight_dir}' + f'/CLIP_{epoch}.pth')print("Best model saved!")# 圖像文本檢索推理并可視化# dataframe = pd.read_csv(f"{config.caption_path}/captions.csv")# tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)# model.load_state_dict(torch.load(f'{args.weight_dir}' + f'/CLIP_1.pth', map_location=device))# model.eval()# # image_embeddings = []# with torch.no_grad():# for batch in tqdm(dataloader):# image_features = model.image_encoder(batch["image"].to(device)) # 獲取圖像特征# cur_image_embeddings = model.image_proj(image_features) # [batch_size, proj_dim] # 圖像特征投影# image_embeddings.append(cur_image_embeddings) # 將一個batch的圖像特征保存# # image_embeddings = torch.cat(image_embeddings, dim=0) # [image_number, proj_dim]# input_query = "two dogs sitting on the grass" # 輸入文本# image_filenames = dataframe["image"].values # 待檢索的圖片# # encoded_query = tokenizer([input_query]) # 對輸入文本進行分詞# batch = {key: torch.tensor(values).to(device) for key, values in encoded_query.items()}# # with torch.no_grad():# text_features = model.text_encoder(batch["input_ids"], batch["attention_mask"]) # 獲取文本特征# text_embeddings = model.text_proj(text_features) # 文本特征投影,與圖像特征維度一致# # image_embeddings_n = F.normalize(image_embeddings, dim=-1) # [image_number, proj_dim]# text_embeddings_n = F.normalize(text_embeddings, dim=-1) # [1, proj_dim]# dot_similarity = text_embeddings_n @ image_embeddings_n.T # 輸入文本的特征和數據集中每張圖像特征之間的相似度# # values, indices = torch.topk(dot_similarity.squeeze(0), k=45) # 獲取前45個相似度最高的圖像# matches = [image_filenames[idx] for idx in indices[::5]] # 獲取對應的圖像文件名(9張圖像)# # f, axes = plt.subplots(3, 3, figsize=(10, 10))# f.suptitle(f"Retrieving text: {input_query}") # 設置主標題# for match, ax in zip(matches, axes.flatten()): # 顯示檢索出的圖像# image = cv2.imread(f"{args.image_dir}/{match}")# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# ax.imshow(image)# ax.axis("off")# # plt.show()
理想結果:
參考鏈接
https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2/