CLIP在Github上的使用教程

CLIP的github鏈接:https://github.com/openai/CLIP

CLIP

Blog,Paper,Model Card,Colab
CLIP(對比語言-圖像預訓練)是一個在各種(圖像、文本)對上進行訓練的神經網絡。可以用自然語言指示它在給定圖像的情況下預測最相關的文本片段,而無需直接對任務進行優化,這與 GPT-2 和 3 的零鏡頭功能類似。我們發現,CLIP 無需使用任何 128 萬個原始標注示例,就能在 ImageNet "零拍攝 "上達到原始 ResNet50 的性能,克服了計算機視覺領域的幾大挑戰。

Usage用法

首先,安裝 PyTorch 1.7.1(或更高版本)和 torchvision,以及少量其他依賴項,然后將此 repo 作為 Python 軟件包安裝。在 CUDA GPU 機器上,完成以下步驟即可:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git

將上面的 cudatoolkit=11.0 替換為機器上相應的 CUDA 版本,如果在沒有 GPU 的機器上安裝,則替換為 cpuonly

import torch
import clip
from PIL import Imagedevice = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)with torch.no_grad():image_features = model.encode_image(image)text_features = model.encode_text(text)logits_per_image, logits_per_text = model(image, text)probs = logits_per_image.softmax(dim=-1).cpu().numpy()print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

API

CLIP 模塊提供以下方法:

clip.available_models()

返回可用 CLIP 模型的名稱。例如下面就是我執行的結果。
在這里插入圖片描述

clip.load(name, device=..., jit=False)

返回模型和模型所需的 TorchVision 變換(由 clip.available_models() 返回的模型名稱指定)。它將根據需要下載模型。name參數也可以是本地檢查點的路徑。
可以選擇指定運行模型的設備,默認情況下,如果有第一個 CUDA 設備,則使用該設備,否則使用 CPU。當 jitFalse 時,將加載模型的非 JIT 版本。

clip.tokenize(text: Union[str, List[str]], context_length=77)

返回包含給定文本輸入的標記化序列的 LongTensor。這可用作模型的輸入。

clip.load() 返回的模型支持以下方法:

model.encode_image(image: Tensor)

給定一批圖像,返回 CLIP 模型視覺部分編碼的圖像特征。

model.encode_text(text: Tensor)

給定一批文本標記,返回 CLIP 模型語言部分編碼的文本特征。

model(image: Tensor, text: Tensor)

給定一批圖像和一批文本標記,返回兩個張量,其中包含與每張圖像和每個文本輸入相對應的 logit 分數。這些值是相應圖像和文本特征之間的余弦相似度乘以 100。

More Examples更多實例

Zero-Shot預測

下面的代碼使用 CLIP 執行零點預測,如論文附錄 B 所示。該示例從 CIFAR-100 數據集中獲取一張圖片,并預測數據集中 100 個文本標簽中最有可能出現的標簽。

import os
import clip
import torch
from torchvision.datasets import CIFAR100# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)# Calculate features
with torch.no_grad():image_features = model.encode_image(image_input)text_features = model.encode_text(text_inputs)# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

輸出結果如下(具體數字可能因計算設備而略有不同):

Top predictions:snake: 65.31%turtle: 12.29%sweet_pepper: 3.83%lizard: 1.88%crocodile: 1.75%

請注意,本示例使用的 encode_image()encode_text() 方法可返回給定輸入的編碼特征。

Linear-probe evaluation線性探針評估

下面的示例使用 scikit-learn 對圖像特征進行邏輯回歸。

import os
import clip
import torchimport numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)def get_features(dataset):all_features = []all_labels = []with torch.no_grad():for images, labels in tqdm(DataLoader(dataset, batch_size=100)):features = model.encode_image(images.to(device))all_features.append(features)all_labels.append(labels)return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

請注意,C 值應通過使用驗證分割進行超參數掃描來確定。

See Also

OpenCLIP:包括更大的、獨立訓練的 CLIP 模型,最高可達 ViT-G/14
Hugging Face implementation of CLIP:更易于與高頻生態系統集成

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/214235.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/214235.shtml
英文地址,請注明出處:http://en.pswp.cn/news/214235.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

鴻蒙HarmonyOS(ArkTS)語法 聲明變量及注意事項

好 今天我們來看一個基礎的harmonyOS語法 變量聲明 這里 我們還是用 ArkTS項目 我們聲明變量的語法并不是ArkTS的 而是 javaScript 和 TypeScript的 可以看一下下面一張圖 js是最初弱類型語言 于是TS作為js的副類 是一種更嚴謹的數據限定語法 而ArkTS 是TS的改良版 其實我們…

算法通關村第十八關 | 白銀 | 回溯熱門問題

1.組合總和問題 原題&#xff1a;力扣39. 元素可以重復拿取&#xff0c;且題目的測試用例保證了組合數少于 150 個。 class CombinationSum {List<List<Integer>> res new ArrayList<>();List<Integer> path new ArrayList<>();public List…

一篇文章教你快速弄懂 web自動化測試中的三種等待方式

前言 現在的網頁很多都是動態加載的&#xff0c;如果頁面的內容發生了改變&#xff0c;就需要時間來渲染。在咱們做web自動化測試的時候&#xff0c;由于代碼是自動執行的&#xff0c;代碼在執行的時候&#xff0c;有可能上一步操作而加載的元素還沒加載出來&#xff0c;就會報…

配置本地端口鏡像示例(1:1)

本地端口鏡像簡介 本地端口鏡像是指觀察端口與監控設備直接相連&#xff0c;觀察端口直接將鏡像端口復制來的報文轉發到與其相連的監控設備進行故障定位和業務監測。 配置注意事項 觀察端口專門用于鏡像報文的轉發&#xff0c;因此不要在上面配置其他業務&#xff0c;防止鏡像…

建筑學VR虛擬仿真情景實訓教學

首先&#xff0c;建筑學VR虛擬仿真情景實訓教學為建筑學專業的學生提供了一個身臨其境的學習環境。通過使用VR仿真技術&#xff0c;學生可以在虛擬環境中觀察和理解建筑結構、材料、設計以及施工等方面的知識。這種教學方法不僅能幫助學生更直觀地理解復雜的建筑理論&#xff0…

記錄 | ubuntu源碼編譯安裝/更新boost版本

一、卸載當前的版本 1、查看當前安裝的boost版本 dpkg -S /usr/include/boost/version.hpp通過上面的命令&#xff0c;你就可以發現boost的版本了&#xff0c;查看結果可能如下&#xff1a; libboost1.54-dev: /usr/include/boost/version.hpp 2、刪除當前安裝的boost sudo …

記錄 | 使用samba將ubuntu文件夾映射到windows實現共享文件夾

一、ubuntu配置 1. 安裝 samba samba 是在 Linux 和 UNIX 系統上實現 SMB 協議的一個免費軟件&#xff0c;由服務器及客戶端程序構成。SMB&#xff08;Server Messages Block&#xff0c;信息服務塊&#xff09;是一種在局域網上共享文件和打印機的一種通信協議。 sudo apt-…

Excel COUNT類函數使用

目錄 一. COUNT二. COUNTA三. COUNTBLANK四. COUNTIF五. COUNTIFS 一. COUNT ?用于計算指定范圍內包含數字的單元格數量。 基本語法 COUNT(value1, [value2], ...)?統計A2到A7所有數字單元格的數量 ?統計A2到A7&#xff0c;B2到B7的所有數字單元格的數量 二. COUNTA ?計…

大數據分析與應用實驗任務十一

大數據分析與應用實驗任務十一 實驗目的 通過實驗掌握spark Streaming相關對象的創建方法&#xff1b; 熟悉spark Streaming對文件流、套接字流和RDD隊列流的數據接收處理方法&#xff1b; 熟悉spark Streaming的轉換操作&#xff0c;包括無狀態和有狀態轉換。 熟悉spark S…

Linux 驅動開發需要掌握哪些編程語言和技術?

Linux 驅動開發需要掌握哪些編程語言和技術&#xff1f; 在開始前我有一些資料&#xff0c;是我根據自己從業十年經驗&#xff0c;熬夜搞了幾個通宵&#xff0c;精心整理了一份「Linux從專業入門到高級教程工具包」&#xff0c;點個關注&#xff0c;全部無償共享給大家&#xf…

1. mycat入門

1、mycat介紹 Mycat 是一個開源的分布式數據庫系統&#xff0c;但是由于真正的數據庫需要存儲引擎&#xff0c;而 Mycat 并沒有存 儲引擎&#xff0c;所以并不是完全意義的分布式數據庫系統。MyCat是目前最流行的基于Java語言編寫的數據庫中間件&#xff0c;也可以理解為是數據…

鴻蒙HarmonyOS4.0 入門與實戰

一、開發準備: 熟悉鴻蒙官網安裝DevEco Studio熟悉鴻蒙官網 HarmonyOS應用開發官網 - 華為HarmonyOS打造全場景新服務 應用設計相關資源: 開發相關資源: 例如開發工具 DevEco Studio 的下載 應用發布: 開發文檔:

3易懂AI深度學習算法:長短期記憶網絡(Long Short-Term Memory, LSTM)生成對抗網絡 優化算法進化算法

繼續寫&#xff1a;https://blog.csdn.net/chenhao0568/article/details/134920391?spm1001.2014.3001.5502 1.https://blog.csdn.net/chenhao0568/article/details/134931993?spm1001.2014.3001.5502 2.https://blog.csdn.net/chenhao0568/article/details/134932800?spm10…

LeetCode 1631. 最小體力消耗路徑:廣度優先搜索BFS

【LetMeFly】1631.最小體力消耗路徑&#xff1a;廣度優先搜索BFS 力扣題目鏈接&#xff1a;https://leetcode.cn/problems/path-with-minimum-effort/ 你準備參加一場遠足活動。給你一個二維 rows x columns 的地圖 heights &#xff0c;其中 heights[row][col] 表示格子 (ro…

視頻如何提取文字?這四個方法一鍵提取視頻文案

視頻如何提取文字&#xff1f;你用過哪些視頻提取工具&#xff1f;視頻轉文字工具&#xff0c;又稱為語音識別軟件&#xff0c;是一款能夠將視頻中的語音或對話轉化為文字的實用工具。它運用了尖端的聲音識別和語言理解技術&#xff0c;能精準地捕捉視頻中的音頻&#xff0c;并…

弧形導軌的工作原理

弧形導軌是一種能夠將物體沿著弧形軌道運動的裝置&#xff0c;它由個弧形軌道和沿著軌道運動的物體組成&#xff0c;弧形導軌的工作原理是利用軌道的形狀和物體的運動方式來實現運動&#xff0c;當物體處于軌道上時&#xff0c;它會受到軌道的引導&#xff0c;從而沿著軌道的弧…

Nginx正則表達式

目錄 1.nginx常用的正則表達式 2.location location 大致可以分為三類 location 常用的匹配規則 location 優先級 location 示例說明 優先級總結 3.rewrite rewrite功能 rewrite跳轉實現 rewrite執行順序 語法格式 rewrite示例 實例1&#xff1a; 實例2&#xf…

生活小記錄

上個月項目總算上線了&#xff0c;節奏也慢慢調整正常。發現自己好久沒有記錄生活點滴了&#xff0c;正好寫寫。其實&#xff0c;最近這段日子發生的事情還是挺多的。 流感 媳婦11.24得流感&#xff0c;這件事情特別好笑&#xff0c;大晚上她和我妹妹想喝酒試試&#xff0c;結…

【Python必做100題】之第六題(求圓的周長)

圓的周長公式&#xff1a;C 2 * pi * r 代碼如下&#xff1a; pi 3.14 r float(input("請輸入圓的半徑&#xff1a;")) c 2 * pi *r print(f"圓的周長為{c}") 運行截圖&#xff1a; 總結 1、圓周長的公式&#xff1a;C 2 * pi * r 2、輸出結果注意…

webrtc 工具類

直接上代碼&#xff1b;webrtc 工具類 package com.example.mqttdome;import android.app.Activity; import android.content.Context; import android.content.Intent; import android.media.projection.MediaProjection; import android.media.projection.MediaProjectionMa…