論文:Learning Transferable Visual Models From Natural Language Supervision
地址:Learning Transferable Visual Models From Natural Language Supervision
一、關于CLIP
基于圖文匹配的特征學習:該論文證明了預測哪個標題與哪個圖像相匹配的簡單預訓練任務是一種有效且可擴展的方法,可以在從互聯網收集的4億對(圖像,文本)數據集上從頭開始學習SOTA圖像表示。在預訓練之后,使用自然語言來參考學習到的視覺概念(或描述新的概念),從而實現模型向下游任務的zero-shot遷移學習。?
怎么做?傳統模型聯合訓練一個圖像特征提取器和一個線性分類器來預測一些標簽,CLIP聯合訓練一個圖像編碼器和一個文本編碼器來預測一批(圖像、文本)訓練樣本的正確配對。 在測試時,學習的文本編碼器通過嵌入目標數據集的類的名稱或描述來合成zero-shot線性分類器。
-
圖像編碼器 (Image Encoder):它的任務是把任何一張圖片變成一串模型能理解的特征向量。這串向量濃縮了圖片的核心信息。
-
文本編碼器 (Text Encoder):它的任務是把任何一段文字描述也變成一串“特征向量”。這串向量濃縮了文字的核心含義。
-
理論上,如果圖片和文字是匹配的,那么它們被編碼器轉換成的“特征向量”就應該非常相似。
那么,在文本與圖像的特征矩陣當中,計算相似度,并且讓對角線上的相似度高,非對角線上的相似度低就是我們的訓練目標。于是,CLIP模型直接從自然語言的描述中學習,利用了網上海量的現成數據。同時,它理解學會了文字的深層含義,可以提取到圖像的更抽象含義。
二、官方代碼測試
首先,安裝PyTorch 1.7.1(或更高版本)和torchvision,以及一些小的附加依賴項,然后將此repo作為Python包進行安裝。在配備CUDA GPU的機器上,執行以下操作即可:
conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
在安裝到沒有GPU的機器上時,請將上述的cudatoolkit=11.0替換為您機器上適當的CUDA版本或cpuonly。
接下來進行測試:邏輯是通過把圖片的特征向量和所有類別的文本向量的相似度進行計算得出最高的5個。
import os
import clip
import torch
from torchvision.datasets import CIFAR100
from PIL import Image
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)# Prepare the inputs
# image, class_id = cifar100[3637]
image = Image.open('rabbit.jpg') # 指定一張圖片,比如說兔子
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)# Calculate features
with torch.no_grad():image_features = model.encode_image(image_input)text_features = model.encode_text(text_inputs)# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
我輸入了一張兔子的圖片:
輸出結果是這樣的:
Top predictions:rabbit: 99.61%lawn_mower: 0.16%mouse: 0.04%kangaroo: 0.03%squirrel: 0.03%
三、實際搭建以及訓練
接下來我想簡單搭建一個CLIP模型。CLIP無非使用兩個編碼器,官方采用resnet以及transformer。這里我直接使用transformers庫的CLIPModel進行搭建。注意這里默認是直接下載Hunggingface的預訓練模型,如果網絡問題可以下載到本地再讀取。
class CLIPForCIFAR(nn.Module):"""Thin wrapper around Hugging Face CLIPModel to expose forward and projection features."""# 可以直接加載或者本地模型def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):super().__init__()self.model = CLIPModel.from_pretrained('clip-vit-base-patch32')self.processor = CLIPProcessor.from_pretrained('clip-vit-base-patch32')def forward(self, batch: CLIPBatch):outputs = self.model(input_ids=batch.input_ids,attention_mask=batch.attention_mask,pixel_values=batch.pixel_values,return_dict=True,)return outputs # contains logits_per_image (B, B), logits_per_text (B, B)@torch.no_grad()def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:out = self.model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)return F.normalize(out, dim=-1)@torch.no_grad()def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:out = self.model.get_image_features(pixel_values=pixel_values)return F.normalize(out, dim=-1)
接下來創建dataset類讀取數據集,如果要拿PyTorch TorchVision庫提供的現成數據集接口舉例,則要使用以下的代碼進行讀取:
class CIFAR100CLIPDataset(torch.utils.data.Dataset):"""將每個圖像與其類標簽設計的單個提示配對."""def __init__(self, root: str, split: str, processor: CLIPProcessor, templates: List[str] = None):assert split in {"train", "test"}self.ds = tvdatasets.CIFAR100(root=root, train=(split == "train"), download=True)self.processor = processorself.templates = templates or ["a photo of a {label}."]self.classes = self.ds.classes # list of 100 class names# CLIP默認要求ViT-B/32為224x224image_size = processor.image_processor.crop_size["height"] # 或者取 "width",兩者一樣self.img_transform = T.Compose([T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),T.CenterCrop(image_size),T.ToTensor(),T.Normalize(mean=processor.image_processor.image_mean,std=processor.image_processor.image_std),])def __len__(self):return len(self.ds)def __getitem__(self, idx: int):img, label = self.ds[idx]# CLIP數據要求img = self.img_transform(img)# 隨機選擇一個提示模板來增加訓練時間label_text = random.choice(self.templates).format(label=self.classes[label])enc = self.processor.tokenizer(label_text,padding="max_length",truncation=True,return_tensors="pt",)item = {"pixel_values": img,"input_ids": enc["input_ids"].squeeze(0),"attention_mask": enc["attention_mask"].squeeze(0),"target": label,}return item
定義數據集加載器:
def build_dataloaders(root: str,processor: CLIPProcessor,batch_size: int = 256,num_workers: int = 4,
) -> Tuple[DataLoader, DataLoader, List[str]]:train_set =CIFAR100CLIPDataset(root=root, split='train',processor=processor, templates=CIFAR100_TEMPLATES)test_set = CIFAR100CLIPDataset(root=root,split='test', processor=processor, templates=CIFAR100_TEMPLATES)def collate_fn(batch): # 將樣本列表組合成批次pixel_values = torch.stack([b["pixel_values"] for b in batch])input_ids = torch.stack([b["input_ids"] for b in batch])attention_mask = torch.stack([b["attention_mask"] for b in batch])# labels are simply 0..B-1 (diagonal matching)bsz = pixel_values.size(0)labels = torch.arange(bsz)return CLIPBatch(pixel_values, input_ids, attention_mask, labels)train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=True,drop_last=True,collate_fn=collate_fn,)test_loader = DataLoader(test_set,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=True,drop_last=False,collate_fn=collate_fn,)return train_loader, test_loader, train_set.classes
如果想對自己的數據集進行微調,則需要寫另外的數據集加載版本,這個數據集類適用于任意圖片數據集ImageFolder 格式:
dataset_root/train/class1/img1.jpgimg2.jpg...class2/...val/class1/...class2/...
from torchvision.datasets import ImageFolderclass ImageFolderCLIPDataset(torch.utils.data.Dataset):def __init__(self, root, processor, templates=None):self.ds = ImageFolder(root)self.processor = processorself.templates = templates or ["a photo of a {label}."]self.classes = self.ds.classesself.img_transform = T.Compose([T.Resize(processor.image_processor.crop_size["height"], interpolation=T.InterpolationMode.BICUBIC),T.CenterCrop(processor.image_processor.crop_size["height"]),T.ToTensor(),T.Normalize(mean=processor.image_processor.image_mean,std=processor.image_processor.image_std),])def __getitem__(self, idx):img, label = self.ds[idx]img = self.img_transform(img)label_text = random.choice(self.templates).format(label=self.classes[label])enc = self.processor.tokenizer(label_text, padding="max_length", truncation=True, return_tensors="pt")return {"pixel_values": img,"input_ids": enc["input_ids"].squeeze(0),"attention_mask": enc["attention_mask"].squeeze(0),"target": label,}def __len__(self):return len(self.ds)
***完整代碼***
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Train (optional fine-tune) and evaluate a CLIP model for zero-shot classification on CIFAR-100.Requirements:- torch, torchvision- transformers >= 4.41- accelerate (optional)Example usage:python clip_cifar100_zeroshot.py --epochs 5 --batch-size 256 --lr 5e-6 \--model openai/clip-vit-base-patch32 --data-root ./data --ampEvaluate only (no training):python clip_cifar100_zeroshot.py --eval-only --model openai/clip-vit-base-patch32Save / load:python clip_cifar100_zeroshot.py --epochs 2 --save-path ./clip_cifar100.ptpython clip_cifar100_zeroshot.py --eval-only --load-path ./clip_cifar100.ptThis script includes:- Dataset and dataloader for CIFAR-100- Prompt engineering templates- CLIP model/processor setup- Contrastive training loop (image-text)- Zero-shot evaluation using class-name prompts- (Optional) linear-probing head for supervised classification (off by default)
"""import argparse
import math
import os
import random
from dataclasses import dataclass
from typing import List, Tuple, Optionalimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets as tvdatasets
from torchvision import transforms as Tfrom transformers import (CLIPModel,CLIPProcessor,CLIPTokenizer,CLIPTextModelWithProjection,CLIPVisionModelWithProjection,
)# -------------------------
# Utilities
# -------------------------
# 隨機種子
def set_seed(seed: int = 42):random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)def exists(x):return x is not None# -------------------------
# 提示詞工程。用于將類別轉為文本
# -------------------------
CIFAR100_TEMPLATES = ["a photo of a {label}.","a blurry photo of a {label}.","a photo of the {label}.","a close-up photo of a {label}.","a bright photo of a {label}.","a cropped photo of a {label}.","a photo of a small {label}.","a photo of a big {label}.","a low contrast photo of a {label}.","a high contrast photo of a {label}.",
]# -------------------------
# Data: CIFAR-100
# -------------------------
@dataclass
class CLIPBatch:pixel_values: torch.Tensor # (B, C, H, W)input_ids: torch.Tensor # (B, L)attention_mask: torch.Tensor # (B, L)labels: torch.Tensor # (B,) image-text matching along the batch diagonalclass CIFAR100CLIPDataset(torch.utils.data.Dataset):"""將每個圖像與其類標簽設計的單個提示配對."""def __init__(self, root: str, split: str, processor: CLIPProcessor, templates: List[str] = None):assert split in {"train", "test"}self.ds = tvdatasets.CIFAR100(root=root, train=(split == "train"), download=True)self.processor = processorself.templates = templates or ["a photo of a {label}."]self.classes = self.ds.classes # list of 100 class names# CLIP默認要求ViT-B/32為224x224image_size = processor.image_processor.crop_size["height"] # 或者取 "width",兩者一樣self.img_transform = T.Compose([T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),T.CenterCrop(image_size),T.ToTensor(),T.Normalize(mean=processor.image_processor.image_mean,std=processor.image_processor.image_std),])def __len__(self):return len(self.ds)def __getitem__(self, idx: int):img, label = self.ds[idx]# CLIP數據要求img = self.img_transform(img)# 隨機選擇一個提示模板來增加訓練時間label_text = random.choice(self.templates).format(label=self.classes[label])enc = self.processor.tokenizer(label_text,padding="max_length",truncation=True,return_tensors="pt",)item = {"pixel_values": img,"input_ids": enc["input_ids"].squeeze(0),"attention_mask": enc["attention_mask"].squeeze(0),"target": label,}return item
# from torchvision.datasets import ImageFolder
#
# class ImageFolderCLIPDataset(torch.utils.data.Dataset):
# def __init__(self, root, processor, templates=None):
# self.ds = ImageFolder(root)
# self.processor = processor
# self.templates = templates or ["a photo of a {label}."]
# self.classes = self.ds.classes
# self.img_transform = T.Compose([
# T.Resize(processor.image_processor.crop_size["height"], interpolation=T.InterpolationMode.BICUBIC),
# T.CenterCrop(processor.image_processor.crop_size["height"]),
# T.ToTensor(),
# T.Normalize(mean=processor.image_processor.image_mean,
# std=processor.image_processor.image_std),
# ])
#
# def __getitem__(self, idx):
# img, label = self.ds[idx]
# img = self.img_transform(img)
# label_text = random.choice(self.templates).format(label=self.classes[label])
# enc = self.processor.tokenizer(label_text, padding="max_length", truncation=True, return_tensors="pt")
# return {
# "pixel_values": img,
# "input_ids": enc["input_ids"].squeeze(0),
# "attention_mask": enc["attention_mask"].squeeze(0),
# "target": label,
# }
#
# def __len__(self):
# return len(self.ds)# 數據集加載
def build_dataloaders(root: str,processor: CLIPProcessor,batch_size: int = 256,num_workers: int = 4,
) -> Tuple[DataLoader, DataLoader, List[str]]:train_set =CIFAR100CLIPDataset(root=root, split='train',processor=processor, templates=CIFAR100_TEMPLATES)test_set = CIFAR100CLIPDataset(root=root,split='test', processor=processor, templates=CIFAR100_TEMPLATES)def collate_fn(batch): # 將樣本列表組合成批次pixel_values = torch.stack([b["pixel_values"] for b in batch])input_ids = torch.stack([b["input_ids"] for b in batch])attention_mask = torch.stack([b["attention_mask"] for b in batch])# labels are simply 0..B-1 (diagonal matching)bsz = pixel_values.size(0)labels = torch.arange(bsz)return CLIPBatch(pixel_values, input_ids, attention_mask, labels)train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=True,drop_last=True,collate_fn=collate_fn,)test_loader = DataLoader(test_set,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=True,drop_last=False,collate_fn=collate_fn,)return train_loader, test_loader, train_set.classes# -------------------------
# Model 包裝
# -------------------------
class CLIPForCIFAR(nn.Module):"""Thin wrapper around Hugging Face CLIPModel to expose forward and projection features."""# 可以直接加載或者本地模型def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):super().__init__()self.model = CLIPModel.from_pretrained('clip-vit-base-patch32')self.processor = CLIPProcessor.from_pretrained('clip-vit-base-patch32')def forward(self, batch: CLIPBatch):outputs = self.model(input_ids=batch.input_ids,attention_mask=batch.attention_mask,pixel_values=batch.pixel_values,return_dict=True,)return outputs # contains logits_per_image (B, B), logits_per_text (B, B)@torch.no_grad()def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:out = self.model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)return F.normalize(out, dim=-1)@torch.no_grad()def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:out = self.model.get_image_features(pixel_values=pixel_values)return F.normalize(out, dim=-1)# -------------------------
# Loss (InfoNCE over CLIP logits)
# -------------------------def clip_contrastive_loss(logits_per_image: torch.Tensor, logits_per_text: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:"""標準對比損失函數"""loss_i = F.cross_entropy(logits_per_image, labels)loss_t = F.cross_entropy(logits_per_text, labels)return (loss_i + loss_t) / 2# -------------------------
# Training loop
# -------------------------def train(model: CLIPForCIFAR,train_loader: DataLoader,device: torch.device,epochs: int = 5,lr: float = 5e-6,weight_decay: float = 0.2,amp: bool = False,freeze_vision: bool = False,freeze_text: bool = False,grad_accum_steps: int = 1,save_path: Optional[str] = None,
):model.train()# Optionally freeze encoders (useful for quick runs)if freeze_vision:for p in model.model.vision_model.parameters():p.requires_grad = Falseif freeze_text:for p in model.model.text_model.parameters():p.requires_grad = False# Only optimize trainable paramsparams = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)scaler = torch.cuda.amp.GradScaler(enabled=amp)global_step = 0for epoch in range(epochs):running_loss = 0.0for step, batch in enumerate(train_loader):batch = CLIPBatch(pixel_values=batch.pixel_values.to(device, non_blocking=True),input_ids=batch.input_ids.to(device, non_blocking=True),attention_mask=batch.attention_mask.to(device, non_blocking=True),labels=batch.labels.to(device, non_blocking=True),)with torch.cuda.amp.autocast(enabled=amp):outputs = model(batch)loss = clip_contrastive_loss(outputs.logits_per_image, outputs.logits_per_text, batch.labels)loss = loss / grad_accum_stepsscaler.scale(loss).backward()if (step + 1) % grad_accum_steps == 0:scaler.step(optimizer)scaler.update()optimizer.zero_grad(set_to_none=True)global_step += 1running_loss += loss.item() * grad_accum_stepsif (step + 1) % 50 == 0:avg = running_loss / 50print(f"Epoch {epoch+1} | Step {step+1}/{len(train_loader)} | loss {avg:.4f}")running_loss = 0.0if exists(save_path):ckpt = {"model_state": model.state_dict(),"epoch": epoch + 1,}torch.save(ckpt, save_path)print(f"[Saved] {save_path} at epoch {epoch+1}")# -------------------------
# Zero-shot evaluation
# -------------------------
@torch.no_grad()
def build_text_classifier(model: CLIPForCIFAR, classnames: List[str], templates: List[str], device: torch.device):"""每個類別生成多個提示,編碼并平均得到類別特征Returns: text_features (C, D), where each row is the normalized class embedding."""tokenizer = model.processor.tokenizerall_class_embeds = []for cls in classnames:# Encode multiple prompts per class and averagetexts = [template.format(label=cls) for template in templates]enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")enc = {k: v.to(device) for k, v in enc.items()}class_feats = model.encode_text(enc["input_ids"], enc["attention_mask"]) # (T, D)class_feats = class_feats.mean(dim=0)class_feats = F.normalize(class_feats, dim=-1)all_class_embeds.append(class_feats)text_features = torch.stack(all_class_embeds, dim=0) # (C, D)return text_features@torch.no_grad()
def zero_shot_eval(model: CLIPForCIFAR, loader: DataLoader, classnames: List[str], device: torch.device) -> float:model.eval()text_features = build_text_classifier(model, classnames, CIFAR100_TEMPLATES, device)correct = 0total = 0for batch in loader:pixel_values = batch.pixel_values.to(device)targets = batch.labels.to(device)pass# 我們需要一個信息量更大的collate_fn,它也返回真實的類索引進行評估
@dataclass
class CLIPBatchFull(CLIPBatch):targets: torch.Tensor # (B,) true CIFAR-100 labelsdef build_dataloaders_full(root: str,processor: CLIPProcessor,batch_size: int = 256,num_workers: int = 4,
) -> Tuple[DataLoader, DataLoader, List[str]]:train_set = CIFAR100CLIPDataset(root=root,split='train', processor=processor, templates=CIFAR100_TEMPLATES)test_set =CIFAR100CLIPDataset(root=root, split='test', processor=processor, templates=CIFAR100_TEMPLATES)def collate_fn(batch):pixel_values = torch.stack([b["pixel_values"] for b in batch])input_ids = torch.stack([b["input_ids"] for b in batch])attention_mask = torch.stack([b["attention_mask"] for b in batch])diagonal = torch.arange(pixel_values.size(0))targets = torch.tensor([b["target"] for b in batch], dtype=torch.long)return CLIPBatchFull(pixel_values, input_ids, attention_mask, diagonal, targets)train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=True,drop_last=True,collate_fn=collate_fn,)test_loader = DataLoader(test_set,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=True,drop_last=False,collate_fn=collate_fn,)return train_loader, test_loader, train_set.classes@torch.no_grad()
def zero_shot_eval(model: CLIPForCIFAR, loader: DataLoader, classnames: List[str], device: torch.device) -> float:model.eval()text_features = build_text_classifier(model, classnames, CIFAR100_TEMPLATES, device) # (C, D)correct = 0total = 0for batch in loader:pixel_values = batch.pixel_values.to(device)targets = batch.targets.to(device)image_features = model.encode_image(pixel_values) # (B, D)# similarity (B, C)logits = image_features @ text_features.t()preds = logits.argmax(dim=-1)correct += (preds == targets).sum().item()total += targets.size(0)acc = correct / totalreturn accfrom PIL import Image@torch.no_grad()
def predict_single_image(model, image_path, classnames, device):model.eval()# 構建 zero-shot 分類器text_features = build_text_classifier(model, classnames, CIFAR100_TEMPLATES, device)# 加載圖片image = Image.open(image_path).convert("RGB")inputs = model.processor(images=image, return_tensors="pt")pixel_values = inputs["pixel_values"].to(device)# 提取圖像特征image_features = model.encode_image(pixel_values)# 確保特征歸一化(encode_image 已經做了歸一化,但為了清晰可以再顯式做一次)image_features = image_features / image_features.norm(dim=-1, keepdim=True)text_features = text_features / text_features.norm(dim=-1, keepdim=True)# 計算相似度并轉換為百分比similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)values, indices = similarity[0].topk(5)# 準備結果results = []for value, index in zip(values, indices):class_name = classnames[index]percent_prob = value.item() * 100 # 已經是百分比格式results.append((class_name, percent_prob))return results# -------------------------
# Main
# -------------------------def main():parser = argparse.ArgumentParser(description="CLIP zero-shot on CIFAR-100")parser.add_argument("--model", type=str, default="openai/clip-vit-base-patch32", help="CLIP model name")parser.add_argument("--data-root", type=str, default="./data", help="Directory for CIFAR-100")parser.add_argument("--batch-size", type=int, default=64)parser.add_argument("--epochs", type=int, default=0, help="Training epochs (0 = skip training)")parser.add_argument("--lr", type=float, default=5e-6)parser.add_argument("--weight-decay", type=float, default=0.2)parser.add_argument("--num-workers", type=int, default=0)parser.add_argument("--seed", type=int, default=42)parser.add_argument("--amp", action="store_true", help="Use mixed precision")parser.add_argument("--freeze-vision", action="store_true")parser.add_argument("--freeze-text", action="store_true")parser.add_argument("--grad-accum-steps", type=int, default=1)parser.add_argument("--save-path", type=str, default=None)parser.add_argument("--load-path", type=str, default=None)parser.add_argument("--eval-only", action="store_true")args = parser.parse_args()set_seed(args.seed)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Build model and processormodel = CLIPForCIFAR(model_name=args.model)model.to(device)train_root = os.path.join(args.data_root, "train")# Build datatrain_loader, test_loader, classnames = build_dataloaders_full(root=args.data_root,processor=model.processor,batch_size=args.batch_size,num_workers=args.num_workers,)# Optional loadingif exists(args.load_path) and os.path.isfile(args.load_path):ckpt = torch.load(args.load_path, map_location="cpu")model.load_state_dict(ckpt["model_state"], strict=False)print(f"[Loaded] {args.load_path}")# Train (optional)if not args.eval_only and args.epochs > 0:train(model=model,train_loader=train_loader,device=device,epochs=args.epochs,lr=args.lr,weight_decay=args.weight_decay,amp=args.amp,freeze_vision=args.freeze_vision,freeze_text=args.freeze_text,grad_accum_steps=args.grad_accum_steps,save_path=args.save_path,)# Zero-shot evaluation# acc = zero_shot_eval(model, test_loader, classnames, device)# print(f"Zero-shot Top-1 Accuracy on CIFAR-100: {acc * 100:.2f}%")# ======================================================================image_path = r"90.jpg" # 預測圖片的讀取路徑results = predict_single_image(model, image_path, classnames, device)print("\n預測結果 Top-5:\n")for cls, prob_percent in results:print(f"{cls:>16s}: {prob_percent:.2f}%")#======================================================================if __name__ == "__main__":main()
如果想要直接使用官方模型clip-vit-base-patch32,則直接將epoch設置成0,跳過訓練并指定image_path為指定路徑圖片,我使用的還是上面那張圖,效果如下:
預測結果 Top-5:rabbit: 99.85%lawn_mower: 0.04%mouse: 0.03%kangaroo: 0.02%squirrel: 0.01%
預測成功!
四、自己訓練集微調
需要將dataset改為ImageFolderCLIPDataset形式,并且做出以下修改:
train_root = os.path.join(args.data_root, "train")# Build datatrain_loader, test_loader, classnames = build_dataloaders_full(root=args.data_root, # <--將這里改為root = train_rootprocessor=model.processor,batch_size=args.batch_size,num_workers=args.num_workers,)
然后進行epoch和模型路徑保存的修改即可!
以上即為全部內容!CLIP模型最厲害的是實現了zero-shot,將固定的分類集合轉化為完全依靠對自然語言的理解的開放式分類集合。