論文學習16:Learning Transferable Visual Models From Natural Language Supervision

代碼來源

Learning Transferable Visual Models From Natural Language Supervisionhttps://arxiv.org/pdf/2103.00020

模塊作用

當前最先進的計算機視覺系統被訓練用于預測一組固定的、預先定義的目標類別。這種受限的監督方式限制了它們的通用性和可用性,因為要識別其他視覺概念時,仍然需要額外的標注數據。直接從關于圖像的原始文本中學習是一種更具潛力的替代方案,它可以利用更廣泛的監督信息。

模塊結構

  • 圖像編碼器
    • 支持多種架構,包括ResNet(如ResNet-50、ResNet-101、RN50x4、RN50x16、RN50x64)和Vision Transformer(ViT,如ViT-B/32、ViT-B/16、ViT-L/14)。
    • ResNet版本包括ResNet-D改進、抗混疊rect-2模糊池化和注意力池化(多頭QKV注意力)。
    • ViT版本在Transformer前增加層歸一化,訓練模型包括ViT-B/32、ViT-B/16和ViT-L/14,其中ViT-L/14@336px(在336像素分辨率下預訓練額外一輪)表現最佳。
  • 文本編碼器
    • 使用63M參數的12層Transformer,寬度512,8個注意力頭,處理小寫字節對編碼(BPE),詞匯表大小49,152,最大序列長度76,用[SOS]和[EOS]標記括住,使用屏蔽自注意力。

代碼

class CLIP(nn.Module):def __init__(self,embed_dim: int,# visionimage_resolution: int,vision_layers: Union[Tuple[int, int, int, int], int],vision_width: int,vision_patch_size: int,# textcontext_length: int,vocab_size: int,transformer_width: int,transformer_heads: int,transformer_layers: int):super().__init__()self.context_length = context_lengthif isinstance(vision_layers, (tuple, list)):vision_heads = vision_width * 32 // 64self.visual = ModifiedResNet(layers=vision_layers,output_dim=embed_dim,heads=vision_heads,input_resolution=image_resolution,width=vision_width)else:vision_heads = vision_width // 64self.visual = VisionTransformer(input_resolution=image_resolution,patch_size=vision_patch_size,width=vision_width,layers=vision_layers,heads=vision_heads,output_dim=embed_dim)self.transformer = Transformer(width=transformer_width,layers=transformer_layers,heads=transformer_heads,attn_mask=self.build_attention_mask())self.vocab_size = vocab_sizeself.token_embedding = nn.Embedding(vocab_size, transformer_width)self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))self.ln_final = LayerNorm(transformer_width)self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))self.initialize_parameters()def initialize_parameters(self):nn.init.normal_(self.token_embedding.weight, std=0.02)nn.init.normal_(self.positional_embedding, std=0.01)if isinstance(self.visual, ModifiedResNet):if self.visual.attnpool is not None:std = self.visual.attnpool.c_proj.in_features ** -0.5nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)attn_std = self.transformer.width ** -0.5fc_std = (2 * self.transformer.width) ** -0.5for block in self.transformer.resblocks:nn.init.normal_(block.attn.in_proj_weight, std=attn_std)nn.init.normal_(block.attn.out_proj.weight, std=proj_std)nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)if self.text_projection is not None:nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)def build_attention_mask(self):# lazily create causal attention mask, with full attention between the vision tokens# pytorch uses additive attention mask; fill with -infmask = torch.empty(self.context_length, self.context_length)mask.fill_(float("-inf"))mask.triu_(1)  # zero out the lower diagonalreturn mask@propertydef dtype(self):return self.visual.conv1.weight.dtypedef encode_image(self, image):return self.visual(image.type(self.dtype))def encode_text(self, text):x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]x = x + self.positional_embedding.type(self.dtype)x = x.permute(1, 0, 2)  # NLD -> LNDx = self.transformer(x)x = x.permute(1, 0, 2)  # LND -> NLDx = self.ln_final(x).type(self.dtype)# x.shape = [batch_size, n_ctx, transformer.width]# take features from the eot embedding (eot_token is the highest number in each sequence)x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projectionreturn xdef forward(self, image, text):image_features = self.encode_image(image)text_features = self.encode_text(text)# normalized featuresimage_features = image_features / image_features.norm(dim=1, keepdim=True)text_features = text_features / text_features.norm(dim=1, keepdim=True)# cosine similarity as logitslogit_scale = self.logit_scale.exp()logits_per_image = logit_scale * image_features @ text_features.t()logits_per_text = logits_per_image.t()# shape = [global_batch_size, global_batch_size]return logits_per_image, logits_per_text

總結

本文研究了在自然語言處理(NLP)領域取得成功的、與具體任務無關的大規模網絡預訓練方法,是否可以遷移到另一個領域。研究表明,采用這一方法后,在計算機視覺領域會出現類似的行為,我們也探討了這一研究方向的社會影響。為了優化訓練目標,CLIP 模型在預訓練過程中學習執行多種不同的任務。這種任務學習可以通過自然語言提示(prompting)加以利用,從而實現對許多現有數據集的零樣本(zero-shot)遷移。在足夠大的規模下,這種方法的性能可以與特定任務的監督學習模型相競爭,盡管仍有很大的改進空間。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/74722.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/74722.shtml
英文地址,請注明出處:http://en.pswp.cn/web/74722.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[MySQL初階]MySQL(9)事務機制

標題:[MySQL初階]MySQL(9)事物機制 水墨不寫bug 文章目錄 一、認識事務1、多線程訪問數據庫出現的問題2、對CURD的限制是通過事務機制實現的3、事務的四個屬性4、哪些引擎支持事務 二、事務的提交與autocommit設置三、事務的隔離性和隔離級別…

spring-cloud-alibaba-nacos-config使用說明

一、核心功能與定位 Spring Cloud Alibaba Nacos Config 是 Spring Cloud Alibaba 生態中的核心組件之一,專為微服務架構提供動態配置管理能力。它通過整合 Nacos 的配置中心功能,替代傳統的 Spring Cloud Config,提供更高效的配置集中化管理…

SonarQube數據庫配置

SonarQube部署完成后,在瀏覽器地址欄輸入http://IP:9000可以進入登錄頁面,以本機運行為例,地址為http://127.0.0.1:9000/,默認登錄名:admin,登錄密碼也是admin。登錄后會要求設置密碼: 按要求設…

醫藥檔案區塊鏈系統

1. 醫生用戶模塊?? ??目標用戶??:醫護人員 ??核心功能??: ??檢索檔案??:通過關鍵詞或篩選條件快速定位患者健康檔案。??請求授權??:向個人用戶發起檔案訪問權限申請,需經對方確認。??查看檔案?…

CSS3學習教程,從入門到精通, 化妝品網站 HTML5 + CSS3 完整項目(26)

化妝品網站 HTML5 CSS3 完整項目 下面是一個完整的化妝品網站項目,包含主頁、登錄頁面和注冊頁面。我將按照您的要求提供詳細的代碼和注釋。 1. 網站規劃與需求分析 需求分析 展示化妝品產品信息提供用戶注冊和登錄功能響應式設計,適配不同設備美觀…

ROS2 多機時間同步(Chrony配置簡明指南)

適用場景: 主機運行 ROS2 Humble(發布 /scan 等),板子運行 ROS2 Foxy(發布 /tf 等),兩邊通過 ROS_DOMAIN_ID 跨平臺通訊。需要保證系統時間對齊,避免 TF 插值失敗、建圖抖動等問題。…

Nginx配置偽靜態,URL重寫

Nginx配置偽靜態,URL重寫 [ Nginx ] 在Nginx低版本中,是不支持PATHINFO的,但是可以通過在Nginx.conf中配置轉發規則實現: location / { // …..省略部分代碼if (!-e $request_filename) {rewrite ^(.*)$ /index.php?s/$1 l…

電路筆記(元器件):ADC LTC系列模數轉換器的輸出范圍+滿量程和偏移調整

LTC1740(LTC1740官方文檔)是Analog Devices(原Linear Technology)公司生產的一款高性能、低功耗的14位模數轉換器(ADC)。它通常用于需要高精度和快速采樣率的應用中,如通信系統、數據采集設備等。同類產品 LTC1746:一款14位、40Ms…

續-算法-數學知識

3、歐拉函數 1、定義: 1~n 中與 n 互質的數的個數 例如:6 的有 1 2 3 4 5 6 其中,與 n 互質 的 數的個數為 2個分別是:1、5 2、計算: $ N p_1^{a1} p_2^{a2} p_3^{a3} … p_k^{ak} $(例如&#x…

C/C++測試框架googletest使用示例

文章目錄 文檔編譯安裝示例參考文章 文檔 https://github.com/google/googletest https://google.github.io/googletest/ 編譯安裝 googletest是cmake項目,可以用cmake指令編譯 cmake -B build && cmake --build build將編譯產物lib和include 兩個文件夾…

LintCode第974題-求矩陣各節點的最短路徑(以0為標準)

描述 給定一個由0和1組成的矩陣,求每個單元格最近的0的距離。 兩個相鄰細胞之間的距離是1。 給定矩陣的元素數不超過10,000。 在給定的矩陣中至少有一個0。 單元格在四個方向上相鄰:上,下,左和右。 樣例 例1: 輸入: [[0,0,0],[0,0,0],[0…

Redis核心機制-緩存、分布式鎖

目錄 緩存 緩存更新策略 定期生成 實時生成 緩存問題 緩存預熱(Cache preheating) 緩存穿透(Cache penetration) 緩存雪崩(Cache avalanche) 緩存擊穿(Cache breakdown) 分…

CF每日5題(1300-1500)

最近急速補練藍橋杯中,疏于cf練習。 感覺自己過題還是太慢了。 今日水題,我水水水水。 1- 1979C lcm 水 1400 第 i i i局贏了,1個硬幣頂 k [ i ] k[i] k[i]個貢獻,所以每局分硬幣 x i 1 k [ i ] x_i{1\over k[i]} xi?k[i]1?個…

從代碼學習深度學習 - LSTM PyTorch版

文章目錄 前言一、數據加載與預處理1.1 代碼實現1.2 功能解析二、LSTM介紹2.1 LSTM原理2.2 模型定義代碼解析三、訓練與預測3.1 訓練邏輯代碼解析3.2 可視化工具功能解析功能結果總結前言 深度學習中的循環神經網絡(RNN)及其變種長短期記憶網絡(LSTM)在處理序列數據(如文…

easy-poi 一對多導出

1. 需求: 某一列上下兩行單元格A,B值一樣且這兩個單元格, 前面所有列對應單元格值一樣的話, 就對A,B 兩個單元格進行縱向合并單元格 1. 核心思路: 先對數據集的國家,省份,城市...... id 身份證進行排序…

AI比人腦更強,因為被植入思維模型【42】思維投影思維模型

giszz的理解:本質和外在。我們的行為舉止,都是我們的內心的表現。從外邊可以看內心,從內心可以判斷外在。曾國藩有7個識人的方法,大部分的人在他的面前如同沒穿衣服一樣。對于我們自身的啟迪,我認為有四點&…

Spring Boot 打印日志

1.通過slf4j包中的logger對象打印日志 Spring Boot內置了日志框架slf4j,在程序中調用slf4j來輸出日志 通過創建logger對象打印日志,Logger 對象是屬于 org.slf4j 包下的不要導錯包。 2.日志級別 日志級別從高到低依次為: FATAL:致命信息,表…

【IOS webview】源代碼映射錯誤,頁面卡住不動

報錯場景 safari頁面報源代碼映射錯誤,頁面卡住不動。 機型:IOS13 技術棧:react 其他IOS也會報錯,但不影響頁面顯示。 debug webpack配置不要GENERATE_SOURCEMAP。 解決方法: GENERATE_SOURCEMAPfalse react-app…

ES中經緯度查詢geo_point

0. ES版本 6.x版本 1. 創建索引 PUT /location {"settings": {"number_of_shards": 1,"number_of_replicas": 0},"mappings": {"location": {"properties": {"id": {"type": "keywor…

OpenCV界面編程

《OpenCV計算機視覺開發實踐:基于Python(人工智能技術叢書)》(朱文偉,李建英)【摘要 書評 試讀】- 京東圖書 OpenCV的Python開發環境搭建(Windows)-CSDN博客 OpenCV也支持有限的界面編程,主要是針對窗口、控件和鼠標…