RAM模型從數據準備到pretrain、finetune與推理全過程詳細說明

提示:RAM++模型:環境安裝、數據準備與說明、模型推理、模型finetune、模型pretrain等

文章目錄

  • 前言
  • 一、環境安裝
  • 二、數據準備與解讀
    • 1.數據下載
    • 2.數據標簽內容解讀
    • 3.標簽map內容解讀
  • 三、finetune訓練
    • 1.微調訓練命令
    • 2.load載入參數問題
    • 3.權重載入
    • 4.數據加載問題
    • 5.設備不匹配報錯
    • 6.運行結果
  • 四、pretrain預訓練
    • 1.預訓練命令
    • 2.swin_large_patch4_window12_384_22k.pth權重
      • a.下載
      • b.權重加載修改
    • 3.ram_plus_tag_embedding_class_4585_des_51.pth權重
      • a.下載
      • b.權重加載修改
    • 4.變量設備匹配問題
    • 5. 預訓練成功顯示
  • 五、數據加載源碼簡單解讀
  • 六、推理


前言

隨著SAM模型分割一切大火之后,又有RAM模型識別一切,RAM模型由來可有三篇模型構成,TAG2TEXT為首篇將tag引入VL模型中,由tagging、generation、alignment分支構成,隨后才是RAM模型,主要借助CLIP模型輔助與annotation處理trick,由tagging、generation分支構成,最后才是RAM++模型,該模型引入semantic concepts到圖像tagging訓練框架,RAM++模型能夠利用圖像-標簽-文本三者之間的關系,整合image-text alignment 和 image-tagging 到一個統一的交互框架里。作者也將三個模型整合成一套代碼,本文將介紹RAM++模型,主要內容包含環境安裝、數據準備與說明、模型推理、模型finetune、模型pretrain等內容,并逐過程解讀,也幫讀者踩完所有坑,只要按照我我步驟將會實現RAM流暢運行。


TAG2TEXT論文鏈接:點擊這里
RAM論文鏈接:點擊這里
RAM++論文鏈接:點擊這里
github官網鏈接:點擊這里

一、環境安裝

說實話,環境安裝按照官網來,沒有報什么錯,可直接推理運行,但是訓練可能會缺一些東西,后續將介紹,環境安裝如下:

Install recognize-anything as a package:

pip install git+https://github.com/xinyu1205/recognize-anything.git

Or, for development, you may build from source

git clone https://github.com/xinyu1205/recognize-anything.git
cd recognize-anything
pip install -e .

二、數據準備與解讀

1.數據下載

圖像數據需要根據相應內容去下載,而數據標簽下載可以去github代碼官網鏈接,點擊下面紅框即可。當然你也可轉到下面網頁鏈接。
數據標簽下載:https://huggingface.co/datasets/xinyu1205/recognize-anything-dataset-14m/tree/main
在這里插入圖片描述
當進入標簽頁面如下:
在這里插入圖片描述

2.數據標簽內容解讀

當你下載了標簽后,你能發現標簽實際是列表,列表中每個數據又是一個字典,包含image_path、caption、union_label_id、parse_label_id字典,以vg_ram.json標簽舉列,我們取第一個元素,如下圖所示:
在這里插入圖片描述
我們進一步展開該數據,你會重點發現parse_label_id是一個二維列表,每一行是對對應caption描述取的tag而union_label_id是一維列表,parse_label_id中tag都能在union_label_id找到,反之不行。如下圖:
在這里插入圖片描述

3.標簽map內容解讀

我們在上面可看到parse_label_id與union_label_id是數字,那么這些數字如何得到,必然有一個映射表,該表是ram_tag_list_4585_llm_tag_descriptions.json文件中,該文件也是一個列表,列表中每個元素是一個字典,該字典key就是tag,value是一個列表,是對key的描述,我查看value的列表有50個描述。其中該文件列表位置(索引)就代表key(tag),這也是parse_label_id與union_label_id的數字。如下:

在這里插入圖片描述

當然,RAM模型數據可以一個元素的一張圖有多個描述如下左圖,也可以多個元素表示同一張圖,進行多個描述,如下:
在這里插入圖片描述

三、finetune訓練

1.微調訓練命令

可看出訓練使用finetune.py文件,參數配置是finetune.yaml文件,模型類型選擇是ram_plus文件,如下:

python -m torch.distributed.run --nproc_per_node=8 finetune.py \   --model-type ram_plus \   --config ram/configs/finetune.yaml  \   --checkpoint outputs/ram_plus/checkpoint_04.pth \   --output-dir outputs/ram_plus_ft

我是直接運行finetune.py文件,使用遠程鏈接方式運行的!

2.load載入參數問題

當執行微調命令時,我遇到yaml載入問題,如下圖:
在這里插入圖片描述
當然這個是個小問題,與環境相關,可能你們不會遇到,若遇到可嘗試我的解決方法:
導入包ruamel.yaml,更改原有代碼

config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

為以下代碼即可:

import ruamel.yaml  yaml = ruamel.yaml.YAML(typ='rt') config = yaml.load(open(args.config, 'r'))

注:該問題pretrain可能也會遇到。

3.權重載入

第二個問題,模型權重 模型權重載入需要修改,根據你的需求可修改權重路徑,如下圖:  我使用ram++,將model_clip, _ = clip.load("/home/notebook/data/group/huangxinyu/clip/ViT-B-16.pt")中的地址替換即可。

權重下載地址如下:

_MODELS = {     
"RN50":"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",     
"RN101":"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",     
"RN50x4":"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",   
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",   
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",   
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",    
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",     
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",   
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 
}  

4.數據加載問題

需在finetune.yaml文件中設定image_path_root: “” 參數,使得該參數與下圖image_path合并為圖像絕對路徑,我設定如下:

 image_path_root: "/home/Project/recognize-anything/datasets/train"  

圖像路徑如下圖所示:
在這里插入圖片描述

5.設備不匹配報錯

運行預訓練命令依然會報錯,如下:
在這里插入圖片描述
該問題也是小問題,就是變量設備不匹配問題,在finetune.py文件,為image_tag變量指定設備,添加一句代碼:

image_tag = image_tag.to(device,non_blocking=True) 

修改后整體代碼如下:

for i, (image, image_224, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, 	  print_freq, header)):          optimizer.zero_grad()          batch_text_embed = build_text_embed(model_clip,caption)                  image = image.to(device,non_blocking=True)         image_224 = image_224.to(device,non_blocking=True)                  image_tag = image_tag.to(device,non_blocking=True)                  clip_image_feature = model_clip.encode_image(image_224)          loss_tag, loss_dis, loss_alignment = model(image, caption, image_tag, clip_image_feature, batch_text_embed)           loss = loss_tag + loss_dis + loss_alignment  

6.運行結果

之后運行結果如下: 在這里插入圖片描述

我們進一步可發現使用一張3090顯卡,batch為20即可滿負載,如下:
在這里插入圖片描述
注:以上微調內容某些在訓練時候遇到,按其修改即可!

四、pretrain預訓練

1.預訓練命令

可看出預訓練使用pretrain.py文件,參數配置是pretrain.yaml文件,模型類型選擇是ram_plus文件,如下:

python -m torch.distributed.run --nproc_per_node=8 pretrain.py \--model-type ram_plus \--config ram/configs/pretrain.yaml  \--output-dir outputs/ram_plus

2.swin_large_patch4_window12_384_22k.pth權重

a.下載

但你直接使用該命令時候,會報如下錯誤:
在這里插入圖片描述

以上報錯是因為缺失相應權重swin_large_patch4_window12_384_22k.pth,我們只需通過下面鏈接點擊這里,獲得如下圖權重下載即可,如下:
在這里插入圖片描述

b.權重加載修改

對應權重下載實際是pretrain.yaml參數設置的vit: 'swin_l'image_size: 224共同決定,我們將其定位為config_swinl_224.json文件,如下圖:
在這里插入圖片描述
上面我們已知權重路徑更改位置,我們將其下載權重絕對路徑替換即可,如下代碼示列:

{
{"ckpt": "絕對路徑位置/swin_large_patch4_window12_384_22k.pth","vision_width": 1536,"image_res": 224,"window_size": 7,"embed_dim": 192,"depths": [ 2, 2, 18, 2 ],"num_heads": [ 6, 12, 24, 48 ]}}

3.ram_plus_tag_embedding_class_4585_des_51.pth權重

a.下載

當你再次使用該命令時候,會報如下錯誤:
在這里插入圖片描述

不要慌張,依然是權重問題,我們只需鏈接:點擊這里,可在huggingface下載我們想要的權重文件。

b.權重加載修改

對應權重下載后有2種方法可實現權重正確加載,第一將下載權重放到指定路,第二將在源碼ram_plus.py改成絕對路徑,如下圖:
在這里插入圖片描述

4.變量設備匹配問題

當你很開心再次使用預訓練命令時,會報如下錯誤(該錯誤在finetune也會出現):
在這里插入圖片描述
該問題也是小問題,就是變量設備不匹配問題,從上圖報錯地方可追述到ram_plus.py文件除了問題,實際決定該問題是在pretrain.py文件調用那里,主要是image_tag是一個傳入參數未能給定device,我們在pretrain.py下面代碼給定即可,我也建議在pretrain.py修改,而不要動報錯地方修改,你只需添加image_tag = image_tag.to(device, non_blocking=True)指定設備,修改如下:

    for i, (image, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):if epoch==0:warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])optimizer.zero_grad()batch_text_embed = build_text_embed(model_clip,caption)image = image.to(device,non_blocking=True)image_tag = image_tag.to(device, non_blocking=True) #

5. 預訓練成功顯示

如出現下圖表示預訓練成功,如下:
在這里插入圖片描述
我們進一步可發現使用一張3090顯卡,batch為20即可滿負載,如下:
在這里插入圖片描述

五、數據加載源碼簡單解讀

標簽源碼如下,可看到圖像做了2次加工一次該模型本身使用image,一次為圖像特征提取swin模型使用image_224,而caption為一句話(若為多句隨機選擇一句),該句話直接通過clip的文本編碼獲得特征,image_tag 是union_label_id, parse_tag是parse_label_id,具體如下代碼:

def __getitem__(self, index):    ann = self.ann[index]   image_path_use = os.path.join(self.root, ann['image_path'])image = Image.open(image_path_use).convert('RGB')   image = self.transform(image)image_224 = Image.open(image_path_use).convert('RGB')  image_224 = self.transform_224(image_224)# image_tag 是union_label_idnum = ann['union_label_id']image_tag = np.zeros([self.class_num])image_tag[num] = 1image_tag = torch.tensor(image_tag, dtype = torch.long)caption_index = np.random.randint(0, len(ann['caption']))  # 有的數據集有多個描述caption = pre_caption(ann['caption'][caption_index],30)# parse_tag是parse_label_idnum = ann['parse_label_id'][caption_index]parse_tag = np.zeros([self.class_num])parse_tag[num] = 1parse_tag = torch.tensor(parse_tag, dtype = torch.long)return image, image_224, caption, image_tag, parse_tag

六、推理

推理可直接使用命令,指定權重我在pretrain已給出鏈接,可自行下載:

python batch_inference.py \   --model-type ram_plus \   --checkpoint pretrained/ram_plus_swin_large_14m.pth \   --dataset openimages_common_214 \   --output-dir outputs/ram_plus

當然,你也可以使用我的代碼,我是將一個文件夾循環推理,并將推理結果打印于圖上便于查看,如下:

'''* The Recognize Anything Plus Model (RAM++)* Written by Xinyu Huang
'''
import argparse
import osimport numpy as np
import randomimport torchfrom PIL import Image
from ram.models import ram_plus
from ram import inference_ram as inference
from ram import get_transformparser = argparse.ArgumentParser(description='Tag2Text inferece for tagging and captioning')
parser.add_argument('--image',help='path to dataset',default='images/demo/demo1.jpg')
parser.add_argument('--pretrained',help='path to pretrained model',default='路徑位置/ram_plus_swin_large_14m.pth')
parser.add_argument('--image-size',default=384,type=int,metavar='N',help='input image size (default: 448)')import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFontdef cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20):if (isinstance(img, np.ndarray)):  # 判斷是否OpenCV圖片類型img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))# 創建一個可以在給定圖像上繪圖的對象draw = ImageDraw.Draw(img)# 字體的格式fontStyle = ImageFont.truetype("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", textSize, encoding="utf-8") # 繪制文本draw.text((left, top), text, textColor, font=fontStyle)# 轉換回OpenCV格式return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)def build_dir(out_dir):if not os.path.exists(out_dir):os.makedirs(out_dir,exist_ok=True)return out_dirif __name__ == "__main__":args = parser.parse_args()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')transform = get_transform(image_size=args.image_size)#######load modelmodel = ram_plus(pretrained=args.pretrained,image_size=args.image_size,vit='swin_l')model.eval()model = model.to(device)total = sum(p.numel() for p in model.parameters())  # 統計個數print("模型參數總量: %.2f million\t" % (total / 1e6), " 以float32模型內存占用:%.2f M" % (total * 4 / 1e6))# 下面是推理file_root='/推理文件路徑/sam_test' # 這個是多個文件夾路徑save_file_path=build_dir('runs')for file_name in os.listdir(file_root):save_path=os.path.join(save_file_path,file_name)img_root=os.path.join(file_root,file_name)for img_name in os.listdir(img_root):img_path=os.path.join(img_root,img_name)image = transform(Image.open(img_path)).unsqueeze(0).to(device)res = inference(image, model)# print("Image Tags: ", res[0])# print("圖像標簽: ", res[1])img = cv2.imread(img_path)N=int(len(res[1])/2)r1 = res[1][:N]r2 = res[1][N:]# r3 = res[1][2*N:]img = cv2ImgAddText(img, r1, 40, 50, textColor=(255, 0, 0), textSize=20)img = cv2ImgAddText(img, r2, 40, 200, textColor=(255, 0, 0), textSize=20)# img = cv2ImgAddText(img, r2, 40, 300, textColor=(255, 0, 0), textSize=40)build_dir(save_path)cv2.imwrite(os.path.join(save_path, img_name), img)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/168176.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/168176.shtml
英文地址,請注明出處:http://en.pswp.cn/news/168176.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用new Vue()的時候發生了什么?

前言 Vue.js是一個流行的JavaScript前端框架,用于構建單頁面應用(SPA)和用戶界面。當我們使用new Vue()來創建一個Vue實例時,Vue會執行一系列的初始化過程,將數據變成響應式,編譯模板,掛載實例…

RabbitMQ之發送者(生產者)可靠性

文章目錄 前言一、生產者重試機制二、生產者確認機制實現生產者確認(1)定義ReturnCallback(2)定義ConfirmCallback 總結 前言 生產者重試機制、生產者確認機制。 一、生產者重試機制 問題:生產者發送消息時&#xff0…

分布式事務總結

文章目錄 一、分布式事務基礎什么是事務?本地事物分布式事務分布式事務的場景 二、分布式事務解決方案全局事務可靠消息服務TCC 事務 三、Seata 分布式事務解決方案3.1 Seata-At模式3.2 秒殺項目集成 Seata啟動 Seata-Server項目集成seata配置AT模式代碼實現 3.3 Se…

openstack(2)

目錄 塊存儲服務 安裝并配置控制節點 安裝并配置一個存儲節點 驗證操作 封裝鏡像 上傳鏡像 塊存儲服務 安裝并配置控制節點 創建數據庫 [rootcontroller ~]# mysql -u root -pshg12345 MariaDB [(none)]> CREATE DATABASE cinder; MariaDB [(none)]> GRANT ALL PR…

1、Docker概述與安裝

相關資源網站: ● docker官網:http://www.docker.com ● Docker Hub倉庫官網: https://hub.docker.com/ 注意,如果只是想看Docker的安裝,可以直接往下拉跳轉到Docker架構與安裝章節下的Docker具體安裝步驟,一步步帶你安…

82基于matlab GUI的圖像處理

基于matlab GUI的圖像處理,功能包括圖像一般處理(灰度圖像、二值圖);圖像幾何變換(旋轉可輸入旋轉角度、平移、鏡像)、圖像邊緣檢測(拉普拉斯算子、sobel算子、wallis算子、roberts算子&#xf…

【Rust日報】2023-11-22 Floneum -- 基于 Rust 的一款用于 AI 工作流程的圖形編輯器

Floneum -- 基于 Rust 的一款用于 AI 工作流程的圖形編輯器 Floneum 是一款用于 AI 工作流程的圖形編輯器,專注于社區制作的插件、本地 AI 和安全性。 Floneum 有哪些特性: 可視化界面:您無需任何編程知識即可使用Floneum。可視化圖形編輯器可…

oled的使用 動態的變量 51

源碼均在IIC手寫程序中 外部中斷實現變量加一 #include "reg52.h" #include "main.h" #include <intrins.h> #include "OLED.h" #include "bmp.h" #include "Delay.h" sbit LED1 P1^0; sbit LED2 P1^1; sbit LED3…

【LeetCode每日一題】525. 連續數組

題目&#xff1a; 給定一個二進制數組 nums , 找到含有相同數量的 0 和 1 的最長連續子數組&#xff0c;并返回該子數組的長度。 媽的 連題目都沒有讀懂&#xff01;本來看成是找到兩個連續子數組&#xff0c;兩個連續子數組的 0 1 個數分別相同&#xff0c;我說怎么看著如此…

Python報錯:AttributeError(類屬性、實例屬性)

Python報錯&#xff1a;AttributeError&#xff08;類屬性、實例屬性&#xff09; Python報錯&#xff1a;AttributeError 這個錯誤就是說python找不到對應的對象的屬性&#xff0c;百度后才發現竟然是初始化類的時候函數名寫錯了 __init__應該有2條下劃線&#xff0c;如果只有…

構建未來:云計算 生成式 AI 誕生科技新局面

目錄 引言生成式 AI&#xff1a;開發者新伙伴云計算與生成式 AI 的無縫融合亞馬遜云與生成式 AI 結合的展望/總結我用亞馬遜云科技生成式 AI 產品打造了什么&#xff0c;解決了什么問題未來科技發展趨勢&#xff1a;開發者的機遇與挑戰結合實踐看未來結語開源項目 引言 2023年…

SpectralGPT: Spectral Foundation Model 論文翻譯1

遙感領域的通用大模型 2023.11.13在CVPR發表 原文地址&#xff1a;[2311.07113] SpectralGPT: Spectral Foundation Model (arxiv.org) 摘要 ? 基礎模型最近引起了人們的極大關注&#xff0c;因為它有可能以一種自我監督的方式徹底改變視覺表征學習領域。雖然大多數基礎模型…

VSCode 連接遠程服務器問題及解決辦法

端口號不一樣&#xff0c;需要在配置文件中添加Port Host 27.223.26.46HostName 27.223.*.*User userForwardAgent yesPort 14111輸入密碼后可以連接 在vscode界面&#xff0c;終端&#xff0c;生成公鑰&私鑰 ssh-keygen可以看到有id_rsa和id_rsa.pub兩個文件生成&#…

curl 命令的一些基本用法,

curl 是一個用于在命令行中進行網絡請求的工具。以下是一些 curl 命令的常見用法&#xff1a; 從 URL 下載文件并保存為本地文件&#xff1a; curl -O URL例如&#xff1a; curl -O https://example.com/file.zip這將會將 file.zip 下載到當前目錄。 將文件下載到指定位置&…

Nginx如何配置負載均衡

nginx的負載均衡有4種模式&#xff1a; 1)、輪詢&#xff08;默認&#xff09; 每個請求按時間順序逐一分配到不同的后端服務器&#xff0c;如果后端服務器down掉&#xff0c;能自動剔除。 2)、weight 指定輪詢幾率&#xff0c;weight和訪問比率成正比&#xff0c;用于后端服務…

C#,《小白學程序》第五課:隊列(Queue)其一,排隊的技術與算法

日常生活中常見的排隊&#xff0c;軟件怎么體現呢&#xff1f; 排隊的基本原則是&#xff1a;先到先得&#xff0c;先到先吃&#xff0c;先進先出 1 文本格式 /// <summary> /// 《小白學程序》第五課&#xff1a;隊列&#xff08;Queue&#xff09; /// 日常生活中常見…

antDesignPro a-table樣式二次封裝

antDesignPro是跟element-ui類似的一個樣式框架&#xff0c;其本身就是一個完整的后臺系統&#xff0c;風格樣式都很統一。我使用的是antd pro vue&#xff0c;版本是1.7.8。公司要求使用這個框架&#xff0c;但是UI又有自己的一套設計。這就導致我需要對部分組件進行一定的個性…

nodejs微信小程序+python+PHP-青云商場管理系統的設計與實現-安卓-計算機畢業設計

目 錄 摘 要 I ABSTRACT II 目 錄 II 第1章 緒論 1 1.1背景及意義 1 1.2 國內外研究概況 1 1.3 研究的內容 1 第2章 相關技術 3 2.1 nodejs簡介 4 2.2 express框架介紹 6 2.4 MySQL數據庫 4 第3章 系統分析 5 3.1 需求分析 5 3.2 系統可行性分析 5 3.2.1技術可行性&#xff1a;…

mysql 性能參數調優詳解

1 優化連接池 連接池運行機制 MySQL連接器中的連接池&#xff0c;用以提高數據庫密集型應用程序的性能和可擴展性&#xff0c;默認啟用。MySQL連接器負責管理連接池中的多個連接&#xff0c;自動創建、打開、關閉和破壞連接&#xff0c;多個連接的創建&#xff0c;可滿足多客戶…

C++算法 —— 貪心(4)

文章目錄 1、分發餅干2、最優除法3、跳躍游戲Ⅱ4、跳躍游戲Ⅰ5、加油站6、單調遞增的數字7、壞了的計算器 1、分發餅干 455. 分發餅干 其實看完這個題會發現&#xff0c;如果給定的兩個數組不排序的話會非常難受&#xff0c;所以無論怎樣&#xff0c;先排序。接下來需要比較兩…