解決qnn htp 后端不支持boolean 數據類型的方法。

一、背景

??1.1 問題原因

??Qnn 模型在使用fp16的模型轉換不支持類型是boolean的cast 算子,因為 htp 后端支持量化數據類型或者fp16,不支持boolean 類型。

${QNN_SDK_ROOT_27}/bin/x86_64-linux-clang/qnn-model-lib-generator -c ./bge_small_fp16.cpp -b ./bge_small_fp16.bin -o output-so-small

也就是圖中的算子不支持。

嘗試了很多版本,后端,都不支持。沒辦法只能算子替換了。

1.2 替換算子

初步思路:

       Sub↓Cast (to bool)↓Cast (to float32)    (另外一個輸入,假設是 y)↓                  ↓Mul              Mul (1 - mask)↓                  ↓Add↓Output
  1. 先做一個 Greater 比較,生成 0/1 tensor

  2. 再用這個 0/1 tensor 進行 (cond * x) + ((1-cond) * y) 操作, Where(cond, x, y) = cond * x + (1 - cond) * y 可以用 Cast + Mul + Sub + Add 基礎算子實現。

  3. 但是生成的還是有boolean 類型數據

不要 Greater (即不要比較生成bool類型)

不要 BOOL tensor (因為有些平臺對BOOL類型支持不好,比如QNN/DSP/NPU)

直接從 float tensor 生成 0/1 的 float tensor!

改進思路:

可以直接用 Clip + Sign 這種基礎算子來實現!

比如:

  • Sign(x)

    • 如果 x > 0,輸出 1

    • 如果 x == 0,輸出 0

    • 如果 x < 0,輸出 -1

  • Clip(Sign(x), 0.0, 1.0)

    • 把負數剪到 0

    • 正數(1)保留為 1

這樣就完美地直接生成了一個 全是 0 或 1FLOAT tensor! ? 沒有 BOOL 類型,? 沒有 Greater 節點,? 沒有 Cast,? 全是 float。

real_cond_input ---> Sign ---> Clip(0.0, 1.0) ---> mask (float 0/1 tensor)

二、算子代碼實現

1.1 替換算子

import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as npdef add_value_info(graph, name, dtype, shape):"""輔助函數:添加中間 tensor 的 shape 和 dtype"""vi = helper.make_tensor_value_info(name, dtype, shape)graph.value_info.append(vi)def add_constant(graph, base_name, value, dtype, shape):const_name = base_name + "_value"const_tensor = helper.make_tensor(name=const_name,data_type=dtype,dims=shape,vals=value)const_node = helper.make_node('Constant',inputs=[],outputs=[const_name],value=const_tensor)graph.node.append(const_node)add_value_info(graph, const_name, dtype, shape)return const_name
def replace_where_and_cast(model_path, output_path):"""替換 onnx 中的 Where 和 Cast 節點,保持功能等效"""# 讀取模型model = onnx.load(model_path)nodes = model.graph.nodeprint("old model node number" + str(len(model.graph.node)))new_nodes = []nodes_to_remove = []input_shape = [1,1, 512, 512]for node in model.graph.node:if node.op_type == "Where":# 記錄要移除的原始 Wherenodes_to_remove.append(node)# Where輸入:[condition, x, y]cond_input = node.input[0]print(cond_input)x_input = node.input[1]print(x_input)y_input = node.input[2]print(y_input)output_name = node.output[0]print(output_name)# 處理可能前面有 Cast 的情況real_cond_input = cond_inputfor sub_node in model.graph.node:if sub_node.output and sub_node.output[0] == cond_input and sub_node.op_type == "Cast":real_cond_input = sub_node.input[0]nodes_to_remove.append(sub_node)break# ========== 關鍵步驟 ==========# 1. Signsign_output = real_cond_input + "_sign"sign_node = helper.make_node('Sign',inputs=[real_cond_input],outputs=[sign_output],name ="sign_add_my")new_nodes.append(sign_node)add_value_info(model.graph, sign_output, TensorProto.FLOAT, input_shape)# 2. Clip(0,1)clip_output = real_cond_input + "_clip"clip_min_tensor_name = real_cond_input + "_min_value"clip_min_initializer = numpy_helper.from_array(np.zeros(1, dtype=np.float32),name=clip_min_tensor_name)clip_max_tensor_name = real_cond_input + "_max_value"clip_max_initializer = numpy_helper.from_array(np.ones(1, dtype=np.float32),name=clip_max_tensor_name)model.graph.initializer.append(clip_min_initializer)model.graph.initializer.append(clip_max_initializer)# min_val_const_node = add_constant(model.graph, "min_value", 0, TensorProto.FLOAT, input_shape)# max_val_const_node = add_constant(model.graph, "max_value", 1, TensorProto.FLOAT, input_shape)clip_node = helper.make_node('Clip',inputs=[sign_output, clip_min_tensor_name, clip_max_tensor_name],outputs=[clip_output],name="clip_add_my")new_nodes.append(clip_node)add_value_info(model.graph, clip_output, TensorProto.FLOAT, input_shape)# 3. 生成 (1 - mask)one_tensor_name = real_cond_input + "_one"one_initializer = numpy_helper.from_array(np.ones(input_shape, dtype=np.float32),name=one_tensor_name)model.graph.initializer.append(one_initializer)one_minus_mask_output = real_cond_input + "_one_minus_mask"sub_node = helper.make_node('Sub',inputs=[one_tensor_name, clip_output],outputs=[one_minus_mask_output],name="sub_my")new_nodes.append(sub_node)add_value_info(model.graph, one_minus_mask_output, TensorProto.FLOAT, input_shape)# 4. mask * xmask_mul_x_output = real_cond_input + "_mask_mul_x"mul1_node = helper.make_node('Mul',inputs=[clip_output, x_input],outputs=[mask_mul_x_output],name="mul_my")new_nodes.append(mul1_node)add_value_info(model.graph, mask_mul_x_output, TensorProto.FLOAT, input_shape)# 5. (1-mask) * yone_minus_mask_mul_y_output = real_cond_input + "_one_minus_mask_mul_y"mul2_node = helper.make_node('Mul',inputs=[one_minus_mask_output, y_input],outputs=[one_minus_mask_mul_y_output],name="mul_my2")new_nodes.append(mul2_node)add_value_info(model.graph, one_minus_mask_mul_y_output, TensorProto.FLOAT, input_shape)# 6. 加起來得到最終輸出add_node = helper.make_node('Add',inputs=[mask_mul_x_output, one_minus_mask_mul_y_output],outputs=[output_name],name="add_my")new_nodes.append(add_node)# output shape 已經有定義,不需要額外addelif node.op_type == 'Cast':# 如果是 Where 的 Cast,不保留if any(wn.input[0] == node.output[0] for wn in nodes if wn.op_type == 'Where'):print(f"Skipping Cast node: {node.name}")continueelse:new_nodes.append(node)else:new_nodes.append(node)# 移除舊節點for node in nodes_to_remove:model.graph.node.remove(node)# 更新新的節點列表model.graph.ClearField('node')model.graph.node.extend(new_nodes)print("new model node number" + str(len(model.graph.node)))# 保存新的模型onnx.save(model, output_path)if __name__ == "__main__":model_path = "./bge_small_model_simple.onnx"output_path = "./bge_replace_cast_where2.onnx"replace_where_and_cast(model_path, output_path)

?

2.2 運行原始模型和算子替換之后的模型

def run_bge_small_model_onnx():model = AutoModel.from_pretrained("BAAI/bge-small-zh-v1.5")tokenizers = AutoTokenizer.from_pretrained("BAAI/bge-small-zh-v1.5")input_data = "ZhongGuo, nihao, 日本再見, good cat!"device = "cuda" if torch.cuda.is_available() else "cpu"model.to(device)model.eval()input_tensor_data = tokenizers(input_data, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(device)with torch.no_grad():output = model(**input_tensor_data)print("oringal model putput")output_data = output.last_hidden_state.flatten().tolist()[:100]print(len(output.last_hidden_state.flatten().tolist()))print(output_data)print("run modify model")# 步驟 2:加載 ONNX 模型model_path = './bge_replace_cast_where2.onnx'  # 替換為你的 ONNX 模型文件路徑session = ort.InferenceSession(model_path)# 步驟 3:準備輸入數據# 假設模型的輸入是一個形狀為 (1, 3, 224, 224) 的浮點張量input_name1 = session.get_inputs()[0].nameprint(input_name1)input_data1 = input_tensor_data["input_ids"].numpy()input_name2 = session.get_inputs()[1].nameinput_data2 = input_tensor_data["attention_mask"].numpy()print(input_name2)input_name3 = session.get_inputs()[2].nameinput_data3 = input_tensor_data["token_type_ids"].numpy()print(input_name3)# 步驟 4:運行模型并獲取輸出replace_model_output = session.run(None, {input_name1: input_data1, input_name2: input_data2, input_name3: input_data3})# 打印輸出結果print("replace_model_output shape:", replace_model_output[0].shape)print("replace_model_output data:", replace_model_output[0])replace_model_output_data = replace_model_output[:100]print(len(replace_model_output))print(replace_model_output_data)np.array(replace_model_output).tofile("last_output-onnx_bge_small_replace.raw")

2.3 原始模型和替換算子模型精度對齊


def compare_nchw_data(nchw_file, nchw_file2):data_nchw = read_bin_fp32(nchw_file, shape=[1, 512, 512])print("NCHW 原始數據形狀:", data_nchw.shape)print("NCHW 數據統計 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw.min(), data_nchw.max(), data_nchw.mean()))data_nchw2 = read_bin_fp32(nchw_file2, shape=[1, 512, 512])print("NHWC2 原始數據形狀:", data_nchw2.shape)print("NHWC2 數據統計 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw2.min(), data_nchw2.max(), data_nchw2.mean()))diff = data_nchw - data_nchw2print("\n==== 差異對比 ====")print("差值 min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(diff.min(), diff.max(), diff.mean()))print(diff)# ==== 打印前100個數據 ====onnx_output_flat = data_nchw.flatten()onnx_output_flat2 = data_nchw2.flatten()print("\n--- 前100個元素 ---")for i in range(100):print(f"[{i}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 打印后100個數據 ====print("\n--- 后100個元素 ---")for i in range(-100, 0):idx = len(onnx_output_flat) + iprint(f"[{idx}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 可選:統計誤差 ====max_diff = np.max(onnx_output_flat2 - onnx_output_flat)mean_diff = np.mean(onnx_output_flat2 - onnx_output_flat )min_diff = np.min(onnx_output_flat2 -onnx_output_flat)print(f"\n 總元素數: {onnx_output_flat.size}")print(f" 最大誤差: {max_diff}")print(f" 最小誤差: {min_diff}")print(f" 平均誤差: {mean_diff}")

2.4 對齊結果展示

?

?

結果對齊了,表示模型替換成功了。?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/80696.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/80696.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/80696.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用Three.js搭建自己的3Dweb模型(從0到1無廢話版本)

教學視頻參考&#xff1a;B站——Three.js教學 教學鏈接&#xff1a;Three.js中文網 老陳打碼 | 麒躍科技 一.什么是Three.js&#xff1f; Three.js? 是一個基于 JavaScript 的 ?3D 圖形庫&#xff0c;用于在網頁瀏覽器中創建和渲染交互式 3D 內容。它基于 WebGL&#xff0…

PostgreSQL WAL 冪等性詳解

1. WAL簡介 WAL&#xff08;Write-Ahead Logging&#xff09;是PostgreSQL的核心機制之一。其基本理念是&#xff1a;在修改數據庫數據頁之前&#xff0c;必須先將這次修改操作寫入到WAL日志中。 這確保了即使發生崩潰&#xff0c;數據庫也可以根據WAL日志進行恢復。 恢復的核…

git提交規范記錄,常見的提交類型及模板、示例

Git提交規范是一種約定俗成的提交信息編寫標準&#xff0c;旨在使代碼倉庫的提交歷史更加清晰、可讀和有組織。以下是常見的Git提交類型及其對應的提交模板&#xff1a; 提交信息的基本結構 一個標準的Git提交信息通常包含以下三個主要部分&#xff1a; Header?&#xff1a;描…

FastAPI系列06:FastAPI響應(Response)

FastAPI響應&#xff08;Response&#xff09; 1、Response入門2、Response基本操作設置響應體&#xff08;返回數據&#xff09;設置狀態碼設置響應頭設置 Cookies 3、響應模型 response_model4、響應類型 response_classResponse派生類自定義response_class 在“FastAPI系列0…

每日一題(小白)模擬娛樂篇33

首先&#xff0c;理解題意是十分重要的&#xff0c;我們是要求最短路徑&#xff0c;這道題可以用dfs&#xff0c;但是題目給出的數據是有規律的&#xff0c;我們可以嘗試模擬的過程使用簡單的方法做出來。每隔w數字就會向下轉向&#xff0c;就比如題目上示例的w6&#xff0c;無…

哈希封裝unordered_map和unordered_set的模擬實現

文章目錄 &#xff08;一&#xff09;認識unordered_map和unordered_set&#xff08;二&#xff09;模擬實現unordered_map和unordered_set2.1 實現出復用哈希表的框架2.2 迭代器iterator的實現思路分析2.3 unordered_map支持[] &#xff08;三&#xff09;結束語 &#xff08;…

Java學習-Java基礎

1.重寫與重載的區別 重寫發生在父子類之間,重載發生在同類之間構造方法不能重寫,只能重載重寫的方法返回值,參數列表,方法名必須相同重載的方法名相同,參數列表必須不同重寫的方法的訪問權限不能比父類方法的訪問權限更低 2.接口和抽象類的區別 接口是interface,抽象類是abs…

BG開發者日志0427:故事的起點

1、4月26日晚上&#xff0c;BG項目的gameplay部分開發完畢&#xff0c;后續是細節以及試玩版優化。 開發重心轉移到story部分&#xff0c;目前剛開始&#xff0c; 確切地說以前是長期擱置狀態&#xff0c;因為過去的四個月中gameplay部分優先開發。 --- 2、BG這個項目的起點…

頭歌實訓之游標觸發器

&#x1f31f; 各位看官好&#xff0c;我是maomi_9526&#xff01; &#x1f30d; 種一棵樹最好是十年前&#xff0c;其次是現在&#xff01; &#x1f680; 今天來學習C語言的相關知識。 &#x1f44d; 如果覺得這篇文章有幫助&#xff0c;歡迎您一鍵三連&#xff0c;分享給更…

【深度學習】多頭注意力機制的實現|pytorch

博主簡介&#xff1a;努力學習的22級計算機科學與技術本科生一枚&#x1f338;博主主頁&#xff1a; Yaoyao2024往期回顧&#xff1a;【深度學習】注意力機制| 基于“上下文”進行編碼,用更聰明的矩陣乘法替代笨重的全連接每日一言&#x1f33c;: 路漫漫其修遠兮&#xff0c;吾…

java16

1.API續集 可以導入別人寫好的clone的jar包 注意&#xff1a;方法要有調用者&#xff0c;如果調用者是null就會報錯 2.如何導入別人寫好的jar包 復制jar包然后粘貼在lib里面&#xff0c;然后右鍵點擊jar包再點擊下面的add 3.關于打印java中的引用數據類型

PostgreSQL的擴展 credcheck

PostgreSQL的擴展 credcheck credcheck 是 PostgreSQL 的一個安全擴展&#xff0c;專門用于強制實施密碼策略和憑證檢查&#xff0c;特別適合需要符合安全合規要求的數據庫環境。 一、擴展概述 1. 主要功能 強制密碼復雜度要求防止使用常見弱密碼密碼過期策略實施密碼重復使…

MyBatis中的@Param注解-如何傳入多個不同類型的參數

mybatis中參數識別規則 默認情況下,MyBatis 會按照參數位置自動分配名稱:param1, param2, param3, ...或者 arg0, arg1。 // Mapper 接口方法 User getUserByIdAndName(Integer id, String name); 以上接口在XML中只能通過param1或者arg0這樣的方式來引用,可讀性差。 &l…

DIFY教程第一集:安裝Dify配置環境

一、Dify的介紹 https://dify.ai/ Dify 是一款創新的智能生活助手應用&#xff0c;旨在為您提供便捷、高效的服務。通過人工智能技術&#xff0c; Dify 可以實現語音 助手、智能家居控制、日程管理等功能&#xff0c;助您輕松應對生活瑣事&#xff0c;享受智慧生活。簡約的…

5、Rag基礎:RAG 專題

RAG 簡介 什么是檢索增強生成? 檢索增強生成(RAG)是指對大型語言模型輸出進行優化,使其能夠在生成響應之前引用訓練數據來源之外的權威知識庫。大型語言模型(LLM)用海量數據進行訓練,使用數十億個參數為回答問題、翻譯語言和完成句子等任務生成原始輸出。在 LLM 本就強…

GAMES202-高質量實時渲染(homework1)

目錄 Homework1shadow MapPCF(Percentage Closer Filter)PCSS(Percentage Closer Soft Shadow) GitHub主頁&#xff1a;https://github.com/sdpyy1 作業實現:https://github.com/sdpyy1/CppLearn/tree/main/games202 Homework1 shadow Map 首先需要完成MVP矩陣的構造&#xf…

JDK(Ubuntu 18.04.6 LTS)安裝筆記

一、前言 本文與【MySQL 8&#xff08;Ubuntu 18.04.6 LTS&#xff09;安裝筆記】同批次&#xff1a;先搭建數據庫&#xff0c;再安裝JDK&#xff0c;后面肯定就是部署Web應用&#xff1a;典型的單機部署。“麻雀雖小五臟俱全”&#xff0c;善始善終&#xff0c;還是記下來吧。…

軟件測試之接口測試常見面試題

一、什么是(軟件)接口測試? 接口測試&#xff1a;是測試系統組件間接口的一種測試方法 接口測試的重點&#xff1a;檢查數據的交換&#xff0c;數據傳遞的正確性&#xff0c;以及接口間的邏輯依賴關系 接口測試的意義&#xff1a;在較早期開展&#xff0c;在軟件開發的同時…

Lua 第11部分 小插曲:出現頻率最高的單詞

在本章中&#xff0c;我們要開發一個讀取并輸出一段文本中出現頻率最高的單詞的程序。像之前的小插曲一樣&#xff0c;本章的程序也十分簡單但是也使用了諸如迭代器和匿名函數這樣的高級特性。 該程序的主要數據結構是一個記錄文本中出現的每一個單詞及其出現次數之間關系的表。…

軟件項目進度管理活動詳解

目錄 1. 活動定義&#xff08;Activity Definition&#xff09; 2. 活動排序&#xff08;Activity Sequencing&#xff09; 3. 活動資源估算&#xff08;Activity Resource Estimating&#xff09; 4. 活動歷時估算&#xff08;Activity Duration Estimating&#xff09; …