四. 以Annoy算法建樹的方式聚類清洗圖像數據集,一次建樹,無限次聚類搜索,提升聚類搜索效率。(附完整代碼)

文章內容結構:

一. 先介紹什么是Annoy算法。
二. 用Annoy算法建樹的完整代碼。
三. 用Annoy建樹后的樹特征匹配聚類歸類圖像。

一. 先介紹什么是Annoy算法

下面的文章鏈接將Annoy算法講解的很詳細,這里就不再做過多原理的分析了,想詳細了解的可以看看這篇文章內容。

https://zhuanlan.zhihu.com/p/148819536

總的來說:

(1)通過多次遞歸迭代,建立一個二叉樹,以二叉樹的方式,提升數據聚類和搜索速度,但會損失一些精度。

(2)建樹過程相對比較耗時,但建樹只需要一次,部署到線上或者其他設備上,能無數次聚類搜索。(類似于人臉識別的人臉底庫)

(注: 這里全部是個人經驗,能提升樣本標注和清洗效率,不是標準的數據處理方式,希望對您有幫助。)

--------

二. 用Annoy算法建樹的完整代碼

對底庫聚類建樹,生成Annoy樹特征文件。?

下面參數說明:

最佳聚類類別數量, 是根據《三.以聚類的方式清洗圖像數據集,找到最佳聚類類別數 (圖像特征提取+Kmeans聚類)》獲取得到
BEST_NUM_CLUSTERS = 2501圖像特征提取后的向量維度,是pt或者onnx模型輸出的類別數
FEATURE_DIM = 190推斷圖像尺寸,是根據訓練pt模型時,輸入的圖像尺寸大小
CLASSIFY_SIZE = 224  

以下是正式的代碼:

import os
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
import shutil
from sklearn.cluster import KMeans
from sklearn.preprocessing import Normalizer
from  tqdm import tqdm
import math
import matplotlib.pyplot as plt# 圖像預處理函數
def preprocess_image(image_path):roi_frame= cv2.imread(image_path)width = roi_frame.shape[1]height = roi_frame.shape[0]if (width != CLASSIFY_SIZE) or (height != CLASSIFY_SIZE) :if width > height:# 將圖像逆時針旋轉90度roi_frame = cv2.rotate(roi_frame, cv2.ROTATE_90_COUNTERCLOCKWISE)new_height = CLASSIFY_SIZEnew_width = int(roi_frame.shape[1] * (CLASSIFY_SIZE / roi_frame.shape[0]))roi_frame = cv2.resize(roi_frame, (new_width, new_height))# 計算上下左右漂移量y_offset = (CLASSIFY_SIZE - roi_frame.shape[0]) // 2x_offset = (CLASSIFY_SIZE - roi_frame.shape[1]) // 2gray_image = np.full((CLASSIFY_SIZE, CLASSIFY_SIZE, 3), 128, dtype=np.uint8)# 將調整大小后的目標圖像放置到灰度圖上gray_image[y_offset:y_offset + roi_frame.shape[0], x_offset:x_offset + roi_frame.shape[1]] = roi_frame# # 顯示結果# cv2.imshow("gray_image", gray_image)# cv2.waitKey(1)# 將圖像轉為 rgbgray_image =  cv2.cvtColor(gray_image, cv2.COLOR_BGR2RGB)else:gray_image = cv2.cvtColor(roi_frame, cv2.COLOR_BGR2RGB)img_np = np.array(gray_image).transpose(2, 0, 1).astype(np.float32)# 假設模型需要[0,1]歸一化img_np = img_np / 255.0# 均值 方差mean = np.array([0.485, 0.456, 0.406],dtype=np.float32).reshape(3, 1, 1)std = np.array([0.229, 0.224, 0.225],dtype=np.float32).reshape(3, 1, 1)img_np= (img_np - mean)/stdreturn np.expand_dims(img_np, axis=0)# 卸載 onnxruntime
# 安裝  pip install onnxruntime-gpu
def get_onnx_providers():# 檢查是否安裝了GPU版本的ONNX Runtimeall_provider = ort.get_available_providers()if "CUDAExecutionProvider" in all_provider:providers = [("CUDAExecutionProvider", {"device_id": 0,"arena_extend_strategy": "kNextPowerOfTwo","gpu_mem_limit": 6 * 1024 * 1024 * 1024,  # 限制GPU內存使用為2GB"cudnn_conv_algo_search": "EXHAUSTIVE","do_copy_in_default_stream": True,}),"CPUExecutionProvider"]print("檢測到NVIDIA GPU,使用CUDA加速")return providerselse:print("未檢測到NVIDIA GPU,使用CPU")return ["CPUExecutionProvider"]if __name__ =="__main__":root_path =  "/home/xxx/Download"# ONNX模型路徑MODEL_PATH = os.path.join(root_path, "08以圖搜圖_找相似度/98_weights/classify_modified_model_224.onnx")# 圖像文件夾路徑IMAGE_DIR = os.path.join(root_path, "08以圖搜圖_找相似度/99_test_datasets/8_bcd已驗收/8")# 分類結果輸出路徑OUTPUT_DIR = os.path.join(root_path, "08以圖搜圖_找相似度/99_test_datasets/8_bcd已驗收/8_kmeans_besk_k_classify")# 保存ann建樹文件路徑ANNOY_PATH = "08以圖搜圖_找相似度/01kmeans和DBscan/kmeans/annoy_cls.ann"# 最佳聚類類別數量(用kmeans和inner找到的)BEST_NUM_CLUSTERS = 2501# 圖像特征提取后的向量維度FEATURE_DIM = 190  # 根據自己的模型輸出維度修改# 推斷圖像尺寸CLASSIFY_SIZE = 224# 手動劃分分類數量# NUM_CLUSTERS = 3000# 創建輸出文件夾os.makedirs(OUTPUT_DIR, exist_ok=True)print("ONNX Runtime版本:", ort.__version__)print("可用執行器:", ort.get_available_providers())#   可用執行器: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider']# 加載ONNX模型(動態獲取輸入/輸出名稱)ort_session = ort.InferenceSession(MODEL_PATH,providers=get_onnx_providers())# 確保輸出名稱正確input_name = ort_session.get_inputs()[0].nameoutput_name = ort_session.get_outputs()[0].namefrom annoy import AnnoyIndext = AnnoyIndex(FEATURE_DIM, metric="angular")  # FEATURE_DIM是圖像特征提取后的向量維度# 提取特征向量features = []image_paths = []print("====開始對所有圖像推理, 提取特征====")for index, filename in tqdm(enumerate(os.listdir(IMAGE_DIR))):if filename.lower().endswith((".png", ".jpg", ".jpeg")):path = os.path.join(IMAGE_DIR, filename)try:# 前處理input_tensor = preprocess_image(path)# 推斷feature = ort_session.run([output_name], {input_name: input_tensor})[0]# 確保特征展平為1D,  190維度features.append(feature.reshape(-1))image_paths.append(path)# 增加到Annoy樹t.add_item(index, feature.reshape(-1))except Exception as e:print(f"Error processing {filename}: {str(e)}")t.build(BEST_NUM_CLUSTERS)    # 根據kmeans聚類找到最佳的聚類類別數量t.save(ANNOY_PATH)print("+++++提取特征結束+++++")print("+++++Annoy建樹結束+++++++++")

生成建樹annoy_cls.ann文件。

三. 用Annoy建樹后的樹特征匹配聚類歸類圖像

使用流程:

(1)加載ann建樹文件

(2)提取單張A圖像特征

(3)單張A圖像特征與ann建樹文件的特征進行比對,找到ann建樹文件里面的與A圖像特征相似的TOP_K的底庫圖像,拷貝走或者移動走。

import os
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
import shutil
from sklearn.cluster import KMeans
from sklearn.preprocessing import Normalizer
from  tqdm import tqdm
import math
import matplotlib.pyplot as plt# 圖像預處理函數
def preprocess_image(image_path):roi_frame= cv2.imread(image_path)width = roi_frame.shape[1]height = roi_frame.shape[0]if (width != CLASSIFY_SIZE) or (height != CLASSIFY_SIZE) :if width > height:# 將圖像逆時針旋轉90度roi_frame = cv2.rotate(roi_frame, cv2.ROTATE_90_COUNTERCLOCKWISE)new_height = CLASSIFY_SIZEnew_width = int(roi_frame.shape[1] * (CLASSIFY_SIZE / roi_frame.shape[0]))roi_frame = cv2.resize(roi_frame, (new_width, new_height))# 計算上下左右漂移量y_offset = (CLASSIFY_SIZE - roi_frame.shape[0]) // 2x_offset = (CLASSIFY_SIZE - roi_frame.shape[1]) // 2gray_image = np.full((CLASSIFY_SIZE, CLASSIFY_SIZE, 3), 128, dtype=np.uint8)# 將調整大小后的目標圖像放置到灰度圖上gray_image[y_offset:y_offset + roi_frame.shape[0], x_offset:x_offset + roi_frame.shape[1]] = roi_frame# # 顯示結果# cv2.imshow("gray_image", gray_image)# cv2.waitKey(1)# 將圖像轉為 rgbgray_image =  cv2.cvtColor(gray_image, cv2.COLOR_BGR2RGB)else:gray_image = cv2.cvtColor(roi_frame, cv2.COLOR_BGR2RGB)img_np = np.array(gray_image).transpose(2, 0, 1).astype(np.float32)# 假設模型需要[0,1]歸一化img_np = img_np / 255.0# 均值 方差mean = np.array([0.485, 0.456, 0.406],dtype=np.float32).reshape(3, 1, 1)std = np.array([0.229, 0.224, 0.225],dtype=np.float32).reshape(3, 1, 1)img_np= (img_np - mean)/stdreturn np.expand_dims(img_np, axis=0)# todo
# 卸載 onnxruntime
# 安裝  pip install onnxruntime-gpu
def get_onnx_providers():# 檢查是否安裝了GPU版本的ONNX Runtimeall_provider = ort.get_available_providers()if "CUDAExecutionProvider" in all_provider:providers = [("CUDAExecutionProvider", {"device_id": 0,"arena_extend_strategy": "kNextPowerOfTwo","gpu_mem_limit": 6 * 1024 * 1024 * 1024,  # 限制GPU內存使用為2GB"cudnn_conv_algo_search": "EXHAUSTIVE","do_copy_in_default_stream": True,}),"CPUExecutionProvider"]print("檢測到NVIDIA GPU,使用CUDA加速")return providerselse:print("未檢測到NVIDIA GPU,使用CPU")return ["CPUExecutionProvider"]if __name__ =="__main__":root_path =  "/home/xxx/Download"# ONNX模型路徑MODEL_PATH = os.path.join(root_path, "08以圖搜圖_找相似度/98_weights/classify_modified_model_224.onnx")# 圖像文件夾路徑IMAGE_DIR = os.path.join(root_path, "08以圖搜圖_找相似度/99_test_datasets/8_bcd已驗收/8")# 分類結果輸出路徑OUTPUT_DIR = os.path.join(root_path, "08以圖搜圖_找相似度/99_test_datasets/8_bcd已驗收/8_kmeans_besk_k_classify")# 保存annoy建樹路徑ANNOY_PATH = os.path.join(root_path, "08以圖搜圖_找相似度/01kmeans和DBscan/kmeans/annoy_cls.ann")# 最佳聚類類別數量BEST_NUM_CLUSTERS = 2501# 圖像特征提取后的向量維度FEATURE_DIM = 190# 推斷圖像尺寸CLASSIFY_SIZE = 224# 取top10TOP_K = 10# 手動劃分分類數量# NUM_CLUSTERS = 3000# 創建輸出文件夾os.makedirs(OUTPUT_DIR, exist_ok=True)print("ONNX Runtime版本:", ort.__version__)print("可用執行器:", ort.get_available_providers())#   可用執行器: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider']# 加載ONNX模型(動態獲取輸入/輸出名稱)ort_session = ort.InferenceSession(MODEL_PATH,providers=get_onnx_providers())# 確保輸出名稱正確input_name = ort_session.get_inputs()[0].nameoutput_name = ort_session.get_outputs()[0].namefrom annoy import AnnoyIndexAnnoy_ = AnnoyIndex(FEATURE_DIM, metric="angular")  # FEATURE_DIM是圖像特征提取后的向量維度Annoy_.load(ANNOY_PATH) # 提取特征向量features = []image_paths = []# 獲取所有圖像路徑for _, filename in tqdm(enumerate(os.listdir(IMAGE_DIR))):if filename.lower().endswith((".png", ".jpg", ".jpeg")):path = os.path.join(IMAGE_DIR, filename)image_paths.append(path)print("====開始對所有圖像推理, 提取特征, 根據創建的樹進行聚類====")for _, filename in tqdm(enumerate(os.listdir(IMAGE_DIR))):if filename.lower().endswith((".png", ".jpg", ".jpeg")):path = os.path.join(IMAGE_DIR, filename)try:# 前處理input_tensor = preprocess_image(path)# 推斷feature = ort_session.run([output_name], {input_name: input_tensor})[0]# 確保特征展平為1D,  190維度features.append(feature.reshape(-1))# image_paths.append(path)# 取top10的相似圖像similar_img_indices, similar_img_distances=Annoy_.get_nns_by_vector(feature.reshape(-1), TOP_K, include_distances=True)print("similar_img_index:", similar_img_indices)print("similar_img_distance:", similar_img_distances)shutil.copy(path, os.path.join(OUTPUT_DIR,"11"))#  移動相似圖像到輸出目錄for idx in similar_img_indices:similar_image_path = image_paths[idx]# shutil.move(similar_image_path, OUTPUT_DIR)shutil.copy(similar_image_path, OUTPUT_DIR)except Exception as e:print(f"Error processing {filename}: {str(e)}")print("+++++提取特征結束+++++")print("+++++根據Annoy數特征聚類歸類圖像結束+++++++++")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/902858.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/902858.shtml
英文地址,請注明出處:http://en.pswp.cn/news/902858.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

什么是電容?

什么是電容? 電荷與電壓的比值就是電容量C。電容單位為法拉(F)。1法拉電容器在電壓為1V時儲存的電荷量為1庫倫(C)。圖1.1中的球體表面電壓與儲存的電荷Q關聯。電壓V等于。Q/V等于。如果球體位于電介質媒介中,電壓V降低倍,Q/V等于。在電介質媒…

Linux服務器上mysql8.0+數據庫優化

1.配置文件路徑 /etc/my.cnf # CentOS/RHEL /etc/mysql/my.cnf # Debian/Ubuntu /etc/mysql/mysql.conf.d/mysqld.cnf # Ubuntu/Debian檢查當前配置文件 sudo grep -v "^#" /etc/mysql/mysql.conf.d/mysqld.cnf | grep -v "^$&q…

MQTT學習資源

MQTT入門:強烈推薦

第十二章 Python語言-大數據分析PySpark(終)

目錄 一. PySpark前言介紹 二.基礎準備 三.數據輸入 四.數據計算 1.數據計算-map方法 2.數據計算-flatMap算子 3.數據計算-reduceByKey方法 4.數據計算-filter方法 5.數據計算-distinct方法 6.數據計算-sortBy方法 五.數據輸出 1.輸出Python對象 (1&am…

【XR手柄交互】Unity 中使用 InputActions 實現手柄控制詳解(基于 OpenXR + Unity新輸入系統(Input Actions))

摘要: 本文主要介紹如何使用 Input Actions(Unity 新輸入系統) OpenXR 來實現 VR手柄控制(監聽ABXY按鈕、搖桿、抓握等操作)。 🎮 Unity 中使用 InputActions 實現手柄控制詳解(基于 OpenXR 新…

java實現網格交易回測

以下是一個基于Java實現的簡單網格交易回測程序框架,以證券ETF(512880)為例。代碼包含歷史數據加載、網格策略邏輯和基礎統計指標: import java.io.BufferedReader; import java.io.FileReader; import java.text.ParseException…

探秘 3D 展廳之卓越優勢,解鎖沉浸式體驗新境界

(一)打破時空枷鎖,全球觸達? 3D 展廳的首要優勢便是打破了時空限制。在傳統展廳中,觀眾需要親臨現場,且必須在展廳開放的特定時間內參觀。而 3D 展廳依托互聯網,讓觀眾無論身處世界哪個角落,只…

第十二屆藍橋杯 2021 C/C++組 直線

目錄 題目: 題目描述: 題目鏈接: 思路: 核心思路: 兩點確定一條直線: 思路詳解: 代碼: 第一種方式代碼詳解: 第二種方式代碼詳解: 題目:…

微信小程序藍牙連接打印機打印單據完整Demo【藍牙小票打印】

文章目錄 一、準備工作1. 硬件準備2. 開發環境 二、小程序配置1. 修改app.json 三、完整代碼實現1. pages/index/index.wxml2. pages/index/index.wxss3. pages/index/index.js 四、ESC/POS指令說明五、測試流程六、常見問題解決七、進一步優化建議 下面我將提供一個完整的微信…

ubuntu opencv 安裝

1.ubuntu opencv 安裝 在Ubuntu系統中安裝OpenCV,可以通過多種方式進行,以下是一種常用的安裝方法,包括從源代碼編譯安裝。請注意,安裝步驟可能會因OpenCV的版本和Ubuntu系統的具體版本而略有不同。 一、安裝準備 更新系統&…

【C++】class靜態常量

Usage: static const T 1 background static const成員屬于類,而不是類的實例,所以它們的初始化需要在類外進行(或者在C17之后可以用inline初始化)。 使用中可能遇到的情況: 在頭文件中聲明一個static const成員,然后在多個cpp…

Java 安全:如何防止 DDoS 攻擊?

一、DDoS 攻擊簡介 DDoS(分布式拒絕服務)攻擊是一種常見的網絡攻擊手段,攻擊者通過控制大量的僵尸主機向目標服務器發送海量請求,致使服務器資源耗盡,無法正常響應合法用戶請求。在 Java 應用開發中,了解 …

統計文件中單詞出現的次數并累計

# 統計單詞出現次數 fileopen("E:\Dasktape/python_test.txt","r",encoding"UTF-8") f1file.read() # 讀取文件 countf1.count("is") # 統計文件中is 單詞出現的次數 print(f"此文件中單詞is出現了{count}次")# 2.判斷單詞出…

C語言實現貪心算法

一、貪心算法核心思想 特征:在每一步選擇中都采取當前狀態下最優(局部最優)的選擇,從而希望導致全局最優解 適用場景:需要滿足貪心選擇性質和最優子結構性質 二、經典貪心算法示例 1. 活動選擇問題 目標&#xff1a…

《一文讀懂Transformers庫:開啟自然語言處理新世界的大門》

《一文讀懂Transformers庫:開啟自然語言處理新世界的大門》 GitHub - huggingface/transformers: ?? Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. HF-Mirror Hello! Transformers快速入門 pip install transformers -i https:/…

Vue里面elementUi-aside 和el-main不垂直排列

先說解決方法 main.js少導包 import element-ui/lib/theme-chalk/index.css; //加入此行即可 問題復現 排查了一個小時終于找出來問題了,建議導包去看官方的文檔,作者就是因為看了別人的導包流程導致的問題 導包官網地址Element UI導包快速入門

MYSQL 常用字符串函數 和 時間函數詳解

一、字符串函數 1、?CONCAT(str1, str2, …) 拼接多個字符串。 SELECT CONCAT(Hello, , World); -- 輸出 Hello World2、SUBSTRING(str, start, length)?? 或 ?SUBSTR() 截取字符串。 SELECT SUBSTRING(MySQL, 3, 2); -- 輸出 SQ3、LENGTH(str)?? 與 ?CHAR_LENGTH…

Python-Agent調用多個Server-FastAPI版本

Python-Agent調用多個Server-FastAPI版本 Agent調用多個McpServer進行工具調用 1-核心知識點 fastAPI的快速使用agent調用多個server 2-思路整理 1)先把每個子服務搭建起來2)再暴露一個Agent 3-參考網址 VSCode配置Python開發環境:https:/…

Drools+自定義規則庫

文章目錄 前言一、創建規則庫二、SpringBootDrools程序1.Maven依賴2.application.yml3.Mapper.xml4.Drools配置類5.Service6.Contoller7.測試接口 前言 公司的技術方案想搭建Drools自定義規則庫配合大模型進行數據的校驗。本篇用來記錄使用SpringBoot配合Drools開發Demo程序。…

潮了 低配電腦6G顯存生成60秒AI視頻 本地部署/一鍵包/云算力部署/批量生成

最近發現了一個讓人眼前一亮的工具——FramePack,它能用一塊普通的6GB顯存筆記本GPU,生成60秒電影級的高清視頻畫面,效果堪稱炸裂!那么我們就把他本地部署起來玩一玩、下載離線一鍵整合包,或者是用云算力快速上手。接下…