從零開始搭建CLIP模型實現基于文本的圖像檢索

目錄

  • CLIP原理簡介
  • 代碼實現
  • 參考鏈接

CLIP原理簡介

論文鏈接,源碼鏈接

CLIP模型由OpenAI在2021年提出,利用雙Decoder(Dual Encoder)的架構來學習圖像和文本之間的對應關系,是多模態大模型的開創之作,為后續許多高效的多模態模型的提出打下基礎。CLIP是一個預訓練模型(Pre-trained Model),在學習到圖像–文本特征之間的關聯后可以遷移到各種下游任務中,如圖像分類,文本引導圖像分割和目標檢測,圖像文本檢索等。由于模型學習到的是文本語義和圖像語義之間的關聯,使得其zero-shot能力非常強大,根據論文中的描述,CLIP在很多數據集上zero-shot的結果甚至超越了許多訓練好的模型的效果。CLIP的訓練范式如下:

![在這里插入圖片描述](https://i-blog.csdnimg.cn/direct/1d112d364a60434bba8dd07d42d2a1c6.png

CLIP的結構非常簡單,數據集包含大量的圖像文本對,圖像經過圖像編碼器得到圖像特征,文本經過文本編碼器得到文本特征,將圖像特征和文本特征按照數據集中的對應關系進行配對,不配對的特征給予懲罰,從上圖中可以看出,我們希望矩陣中藍色的值趨近于1,其余值趨近于0,采用對比學習的方式對模型進行訓練,算法的偽代碼如下:

在這里插入圖片描述
從損失函數中可以看出,分別對特征對比矩陣的行和列進行交叉熵損失函數計算,并取平均得到最終的loss。圖像編碼器一般有兩種選擇:ResNet和ViT;文本編碼器采用Transformer Encoder,均是各自領域中優秀的特征提取網絡。
CLIP的推理范式如下:

在這里插入圖片描述
在推理階段,圖像編碼器中輸入圖像獲取圖像特征,文本編碼器中輸入文本獲取文本特征,將圖像特征向量和文本特征向量的轉置相乘得到每張圖像對每個文本的特征相似度,相似度最高的文本即描述了該圖像中物體所屬的類別。

代碼實現

Flickr8k數據集下載,提取碼:fbfz
DistilBert模型文件下載

我的運行環境:
?CUDA 11.8
?pytorch 2.2.2
?transformers 4.44.0?# 用于從HuggingFace上加載預訓練模型


數據集預覽:
圖片示例

圖片示例

在這里插入圖片描述

文本示例

由于作者的顯卡算力有限,選取Flickr8k數據集進行模型訓練,其中包含8k個圖像文本對,其中一張圖像對應5條文本。圖像編碼器采用ResNet50,直接從timm庫中導入;文本編碼器采用DistilBert,即輕量化的Bert模型,從HuggingFace上下載。閑話少說,小二,上菜!

### 模型參數配置 ###
import argparse
from dataclasses import dataclassparser = argparse.ArgumentParser(description="CLIP from zero")
parser.add_argument("--image_dir", default="user/Flickr8k/Images", help='path to image folder')  # 存放圖像的文件路徑
parser.add_argument("--caption_dir", default="user/Flickr8k", help='path to caption folder')  # 存放文本的文件路徑
parser.add_argument("--weight_dir", default='user/checkpoints', help='path to save output weight')  # 存放訓練權重的文件路徑
args = parser.parse_args()@dataclass
class CLIPConfig:image_path: str = args.image_dir  # 圖像存放路徑image_size: int = 224  # resize后的圖像尺寸,便于構建Dataloadercaption_path: str = args.caption_dir  # 文本存放路徑batch_size: int = 8  # 一個批次中的數據數量epochs: int = 3  # 訓練世代image_encoder_model: str = "resnet50"  # 圖像編碼器的名稱image_embedding_dim: int = 2048  # 圖像特征的維度text_encoder_model: str = "distilbert-base-uncased"  # 文本編碼器的名稱text_embedding_dim: int = 768  # 文本特征的維度text_tokenizer: str = text_encoder_model  # 文本分詞器模型的名稱max_length: int = 200  # 文本編碼器可輸入的最長文本長度pretrained: bool = False  # 是否加載預訓練好的編碼器trainable: bool = True  # 在訓練過程中是否更新編碼器的參數temperature: float = 1.0  # 計算loss時的正則化系數proj_dim: int = 256  # 圖像特征和文本特征統一后的維度dropout_rate: float = 0.1  # dropout系數,避免過擬合### 載入數據集并初始化 ###
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import pandas as pd
import cv2class CLIPDataset(Dataset):def __init__(self, config, image_path, caption_path, transforms=True):"""圖片文件名和標題的長度必須相同如果一個圖片對應多個標題,該圖片文件名需要重復多次"""self.image_path = image_path  # 圖像路徑self.caption_path = caption_path  # 文本路徑self.dataframe = pd.read_csv(f"{self.caption_path}/captions.csv")  # 讀取文本self.tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)  # 載入分詞器self.image_filenames = self.dataframe["image"].values  # 獲取圖像文件名self.captions = list(self.dataframe["caption"].values)   # 獲取圖像對應的描述文本self.encoded_captions = self.tokenizer(self.captions, padding=True, truncation=True, max_length=config.max_length)  # 文本分詞self.transforms = transforms  # 對輸入圖像進行預處理def __getitem__(self, idx):  # 獲取數據集中第idx個數據,其中包含圖片名稱和對應的標題(可能不止一個)item = {key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()}image = cv2.imread(f"{self.image_path}/{self.image_filenames[idx]}")  # 獲取原始圖像image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transforms:image = self.get_transforms(mode="train")(image=image)["image"]  # 對圖像進行預處理item["image"] = torch.tensor(image).permute(2, 0, 1).float()  # 將圖片轉換為tensor格式,并調整為RGB順序item["caption"] = self.captions[idx]  # 獲取標題return itemdef __len__(self):return len(self.captions)  # 獲取文本長度def get_transforms(self, mode="train"):if mode == "train":return A.Compose([A.Resize(config.image_size, config.image_size, always_apply=True),  # 對圖像進行resizeA.Normalize(max_pixel_value=255.0, always_apply=True)  # 對像素值進行歸一化])### 圖像編碼器 ###
import torch.nn as nn
import timmclass ImageEncoder(nn.Module):"""圖像編碼器,采用ResNet50"""def __init__(self, config):super().__init__()self.model = timm.create_model(config.image_encoder_model, pretrained=config.pretrained, num_classes=0, global_pool="avg")  # 創建ResNet50for p in self.model.parameters():p.requires_grad = config.trainable  # 設置參數可訓練def forward(self, x):image_encoded = self.model(x)  # 獲得圖像特征編碼,形狀為[batch_size, image_embedding_dim]return image_encoded### 文本編碼器 ###
class TextEncoder(nn.Module):"""文本編碼器,采用DistilBERT"""def __init__(self, config):super().__init__()if config.pretrained:self.model = DistilBertModel.from_pretrained(config.text_encoder_model)  # 導入下載好的模型文件else:self.model = DistilBertModel(DistilBertConfig())for p in self.model.parameters():p.requires_grad = config.trainable  # 設置參數可訓練self.target_token_idx = 0# 提取出和圖像對應的文本特征向量def forward(self, input_ids, attention_mask):output = self.model(input_ids=input_ids, attention_mask=attention_mask)text_encoded = output.last_hidden_state[:, self.target_token_idx, :]  # [batch_size, text_embedding_dim]return text_encoded### 投影層 (MLP) ###
class ProjectionHead(nn.Module):"""將圖像編碼和文本編碼映射到相同維度"""def __init__(self, config, input_embedding_dim):super().__init__()self.proj = nn.Linear(input_embedding_dim, config.proj_dim)self.act_fn = nn.GELU()self.fc = nn.Linear(config.proj_dim, config.proj_dim)self.dropout = nn.Dropout(config.dropout_rate)self.layer_norm = nn.LayerNorm(config.proj_dim)def forward(self, x):x_proj = self.proj(x)x = self.act_fn(x_proj)x = self.fc(x)x = self.dropout(x)x = x + x_projx = self.layer_norm(x)return x### 定義損失函數 ###
def cross_entropy(logits, labels, reduction='none'):log_softmax = nn.LogSoftmax(dim=-1)loss = (-labels * log_softmax(logits)).sum(dim=1)if reduction == 'mean':return loss.mean()else:return loss.sum()### 模型主體 ###
import torch.nn.functional as Fclass CLIP(nn.Module):def __init__(self, config):super().__init__()self.image_encoder = ImageEncoder(config)  # 實例化圖像編碼器self.text_encoder = TextEncoder(config)  # 實例化文本編碼器self.image_proj = ProjectionHead(config, config.image_embedding_dim)  # 圖像特征投影self.text_proj = ProjectionHead(config, config.text_embedding_dim)  # 文本特征投影self.temperature = config.temperaturedef forward(self, batch):image_features = self.image_encoder(batch["image"])  # 圖像編碼# 文本編碼,tokenizer處理后的文本序列自帶input_ids和attention_masktext_features = self.text_encoder(batch["input_ids"], batch["attention_mask"])image_embeddings = self.image_proj(image_features)  # 圖像特征投影text_embeddings = self.text_proj(text_features)  # 文本特征投影logits = (text_embeddings @ image_embeddings.T) / self.temperature  # tensor形狀為[batch_size, batch_size]images_similarity = image_embeddings @ image_embeddings.T  # tensor形狀為[batch_size, batch_size]text_similarity = text_embeddings @ text_embeddings.T  # tensor形狀為[batch_size, batch_size]# 軟標簽,不配對的位置設置為較小的數,而非0labels = F.softmax((images_similarity + text_similarity) / 2 * self.temperature, dim=-1)  loss_T = cross_entropy(logits, labels)  # 計算文本損失loss_I = cross_entropy(logits.T, labels.T)  # 計算圖像損失total_loss = (loss_T + loss_I) / 2  # 對比學習平均損失return total_loss, logits### 訓練函數 ###
def train(model, optimizer, scheduler, train_loader, device):model.train()  # 模型設置為訓練模式train_loss = 0train_loader = tqdm(train_loader, total=len(train_loader))  # 顯示訓練進度條cnt = 0for batch in train_loader:# print(batch.keys())cnt += 1batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}  # 將dataloader中一個batch的數據轉換為字典形式loss, _ = model(batch)optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step(metrics=loss.item())  # 根據上次訓練的損失更新學習率train_loss += loss.item()# 訓練100個batch顯示一次lossif cnt % 100 == 0:print(f' ==> Epoch: {epoch + 1}, Batch: {cnt}, Loss: {loss.item():.4f}')return train_loss / len(train_loader)  # 平均訓練損失### 測試函數 ###
def eval(model, val_loader, device):model.eval()  # 模型設置為測試模式val_loss = 0val_loader = tqdm(val_loader, total=len(val_loader))with torch.no_grad():for batch in val_loader:batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}loss, _ = model(batch)val_loss += loss.item()return val_loss / len(val_loader)  # 平均測試損失if __name__ == '__main__':config = CLIPConfig()  # 實例化配置信息model = CLIP(config)  # 實例化CLIP模型device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 查看模型的總參數量total_params = sum(p.numel() for p in model.parameters())print(f"Total parameters: {total_params / 1e6} M")optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)dataset = CLIPDataset(config, args.image_dir, args.caption_dir)  # 讀取并預處理數據train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])  # 80%為訓練數據,20%為測試數據dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)# 開始訓練best_loss = float("inf")for epoch in range(config.epochs):print(f"Epoch: {epoch + 1}")train_loss_avg = train(model, optimizer, scheduler, train_loader, device)val_loss_avg = eval(model, val_loader, device)if val_loss_avg < best_loss:best_loss = val_loss_avgtorch.save(model.state_dict(), f'{args.weight_dir}' + f'/CLIP_{epoch}.pth')print("Best model saved!")# 圖像文本檢索推理并可視化# dataframe = pd.read_csv(f"{config.caption_path}/captions.csv")# tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)# model.load_state_dict(torch.load(f'{args.weight_dir}' + f'/CLIP_1.pth', map_location=device))# model.eval()# # image_embeddings = []# with torch.no_grad():#     for batch in tqdm(dataloader):#         image_features = model.image_encoder(batch["image"].to(device))  # 獲取圖像特征#         cur_image_embeddings = model.image_proj(image_features)  # [batch_size, proj_dim]  # 圖像特征投影#         image_embeddings.append(cur_image_embeddings)  # 將一個batch的圖像特征保存# # image_embeddings = torch.cat(image_embeddings, dim=0)  # [image_number, proj_dim]# input_query = "two dogs sitting on the grass"  # 輸入文本# image_filenames = dataframe["image"].values  # 待檢索的圖片# # encoded_query = tokenizer([input_query])  # 對輸入文本進行分詞# batch = {key: torch.tensor(values).to(device) for key, values in encoded_query.items()}# # with torch.no_grad():#     text_features = model.text_encoder(batch["input_ids"], batch["attention_mask"])  # 獲取文本特征#     text_embeddings = model.text_proj(text_features)  # 文本特征投影,與圖像特征維度一致# # image_embeddings_n = F.normalize(image_embeddings, dim=-1)  # [image_number, proj_dim]# text_embeddings_n = F.normalize(text_embeddings, dim=-1)  # [1, proj_dim]# dot_similarity = text_embeddings_n @ image_embeddings_n.T  # 輸入文本的特征和數據集中每張圖像特征之間的相似度# # values, indices = torch.topk(dot_similarity.squeeze(0), k=45)  # 獲取前45個相似度最高的圖像# matches = [image_filenames[idx] for idx in indices[::5]]  # 獲取對應的圖像文件名(9張圖像)# # f, axes = plt.subplots(3, 3, figsize=(10, 10))# f.suptitle(f"Retrieving text: {input_query}")  # 設置主標題# for match, ax in zip(matches, axes.flatten()):  # 顯示檢索出的圖像#     image = cv2.imread(f"{args.image_dir}/{match}")#     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#     ax.imshow(image)#     ax.axis("off")# # plt.show()

理想結果:

在這里插入圖片描述

參考鏈接

https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2/

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/77551.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/77551.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/77551.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

熊海cms代碼審計

目錄 sql注入 1. admin/files/login.php 2. admin/files/columnlist.php 3. admin/files/editcolumn.php 4. admin/files/editlink.php 5. admin/files/editsoft.php 6. admin/files/editwz.php 7. admin/files/linklist.php 8. files/software.php 9. files…

[Java微服務組件]注冊中心P3-Nacos中的設計模式1-觀察者模式

在P1-簡單注冊中心實現和P2-Nacos解析中&#xff0c;我們分別實現了簡單的注冊中心并總結了Nacos的一些設計。 本篇繼續看Nacos源碼&#xff0c;了解一下Nacos中的設計模式。 目錄 Nacos 觀察者模式 Observer Pattern觀察者模式總結 Nacos 觀察者模式 Observer Pattern 模式定…

電腦 訪問 github提示 找不到網頁,處理方案

1、找到 本機的 host文件 例如 windows 的 一般在 C:\Windows\System32\drivers\etc\hosts 用管理員身份打開 hosts 文件 如果文件中沒有 github的配置&#xff0c;需要自己手動添加上去&#xff1b; 如果有&#xff0c;則需要 檢查 github.com 與 github.global.ssl.fastly.…

Linux系統中的網絡管理

1.RHEL9版本中&#xff0c;使用nm進行網絡配置&#xff0c;ifcfg不再是網絡配置文件的主存儲&#xff0c;樣式仍然可用&#xff0c;但它不再是NetworkManger存儲新網絡配置文件的默認位置&#xff0c;RHEL以key-file格式在etc/NetworkManger/system-connections/中存儲新的網絡…

AI技術深度解析:從移動芯片到AIoT的全面突破

作為全球無線通信技術和半導體解決方案的重要參與者,高通始終將技術創新作為核心驅動力,在移動通信、物聯網(IoT)、汽車電子、AI計算等領域占據關鍵地位。本文將從其核心產品線、技術突破、應用場景及未來布局四個維度,客觀解析高通的技術積累與行業角色。 一、核心產品線…

使用CS Roofline Toolkit測量帶寬

使用CS Roofline Toolkit測量帶寬 工程下載&#xff1a;使用CS Roofline Toolkit測量帶寬-案例工程文件&#xff0c;也可以按照下面的說明使用git clone下載 目錄 使用CS Roofline Toolkit測量帶寬0、Roofline模型理解1、CS Roofline Toolkit下載1.1、設置代理1.2、git clone下…

EAGLE代碼研讀+模型復現

要對代碼下手了&#xff0c;加油(? ?_?)? 作者在他們自己的設備上展現了推理的評估結果&#xff0c;受第三方評估認證&#xff0c;EAGLE為目前最快的投機方法&#xff08;雖然加速度是評估投機解碼方法的主要指標&#xff0c;但其他點也值得關注。比如PLD和Lookahead無需額…

基于SFC的windows修復程序,修復絕大部分系統損壞

效果:可以自動修復大部分由系統文件損壞而導致的錯誤 例如:系統應用無法打開 系統窗口(例如開始菜單)無法使用 電腦藍屏或者卡死.....文章 01技術背景 Windows自帶了一個SFC命令行應用程序,可以檢查大部分的系統文件錯誤,以及復這些文件 其中自動檢查所有系統文件&#x…

liunx日志問題

一、日志定向 Linux 系統的日志配置文件&#xff08;如/etc/syslog.conf或/etc/rsyslog.conf &#xff09;中&#xff0c;用于定義系統日志的記錄規則&#xff0c;決定哪些類型的日志消息會被記錄到特定的日志文件中。 *.info;mail.none;authpriv.none;cron.none /va…

2.凸包優化求解

1.減而治之(Decrease and Conquer) 插入排序 典型的減而治之算法就是插入排序方法 插入排序法: 在未排序中選擇一個元素&#xff0c;插入到已經排序號的序列中 將凸包也采用減而治之的方法 2.In-Convex-Polygon Test 怎么判斷引入的極點存在于多邊形里面還是外面&#xff1…

系統思考:危機中的轉型機遇

“危機不僅是挑戰&#xff0c;更是轉型的機會” 每當大事發生&#xff0c;很多企業老板常常被眼前的困境壓得喘不過氣&#xff0c;焦慮與壓力讓人難以思考長遠。特別是在危機面前&#xff0c;大家忙于應對眼前的風險&#xff0c;卻忽略了背后隱藏的機遇。而危機&#xff0c;恰…

大模型Rag - 如何評估Rag

一.RAG流程與評估標準補充 RAG&#xff08;Retrieval-Augmented Generation&#xff09;是一種結合檢索與生成的問答架構。為了確保系統效果&#xff0c;需要從以下三個角度對其評估&#xff1a; 回顧RAG流程 用戶提出問題 → 系統檢索相關上下文 → 基于上下文由大語言模型…

Linux RT RT RT

RT的最終目的是盡可能多的讓原來系統不可搶占的部分變成可搶占&#xff0c;讓高優先級的程序先跑。這里的rt引入了一個deadline的說法&#xff0c;此時的實時性是保證在最大一個時間間隔內&#xff0c;程序被執行。比如每100ms算法做一次決策。 所以此時面臨著幾座大山…

演員柳琦正式加入創星演員出道計劃,開創演藝事業新天地

4月18日&#xff0c;演員柳琦正式加入“創星演員出道計劃”&#xff0c;不僅得到參演都市愛情喜劇《和我結婚吧》角色的機會&#xff0c;還獲得文旅精品網劇《醉夢靈州》的出演機會&#xff0c;自此開啟全新影視之路。對表演藝術極具天賦的柳琦&#xff0c;相信未來可以憑借自身…

16.Chromium指紋瀏覽器開發教程之WebGPU指紋定制

WebGPU指紋概述 WebGPU是下一代的Web圖形和計算API&#xff0c;旨在提供高性能的圖形渲染和計算能力。它是WebGL的后繼者&#xff0c;旨在利用現代GPU的強大功能&#xff0c;使得Web應用能夠實現接近原生應用的圖形和計算性能。而且它是一個低級別的API&#xff0c;可以直接與…

HTTP:九.WEB機器人

概念 Web機器人是能夠在無需人類干預的情況下自動進行一系列Web事務處理的軟件程序。人們根據這些機器人探查web站點的方式,形象的給它們取了一個飽含特色的名字,比如“爬蟲”、“蜘蛛”、“蠕蟲”以及“機器人”等!爬蟲概述 網絡爬蟲(英語:web crawler),也叫網絡蜘蛛(…

Vue3+TS中svg圖標的使用

安裝依賴 pnpm i vite-plugin-svg-icons -D配置引入 vite.config.ts ... import { createSvgIconsPlugin } from vite-plugin-svg-icons import path from node:pathconst svgIconsPlugin createSvgIconsPlugin({iconDirs: [path.resolve(process.cwd(), src/assets/icons)]…

【java實現+4種變體完整例子】排序算法中【堆排序】的詳細解析,包含基礎實現、常見變體的完整代碼示例,以及各變體的對比表格

以下是堆排序的詳細解析&#xff0c;包含基礎實現、常見變體的完整代碼示例&#xff0c;以及各變體的對比表格&#xff1a; 一、堆排序基礎實現 原理 基于二叉堆結構&#xff08;最大堆&#xff09;&#xff0c;通過以下步驟實現排序&#xff1a; 構建最大堆&#xff1a;將…

論文閱讀筆記:Generative Modeling by Estimating Gradients of the Data Distribution

1、參考來源 論文《Generative Modeling by Estimating Gradients of the Data Distribution》 來源&#xff1a;NeurIPS 2019 論文鏈接&#xff1a;https://arxiv.org/abs/1907.05600 參考鏈接&#xff1a; 【AI知識分享】真正搞懂擴散模型Score Matching一定要理解的三大核心…

Kubernetes相關的名詞解釋CNI插件(1)

&#xff08;一&#xff09;什么是CNI插件&#xff1f; 在 Kubernetes 中&#xff0c;CNI 插件&#xff08;Container Network Interface Plugin&#xff09; 是一種用于配置容器網絡接口的標準工具&#xff0c;負責為 Pod 分配網絡資源&#xff08;如 IP 地址&#xff09;并建…