一、背景
??1.1 問題原因
??Qnn 模型在使用fp16的模型轉換不支持類型是boolean的cast 算子,因為 htp 后端支持量化數據類型或者fp16,不支持boolean 類型。
${QNN_SDK_ROOT_27}/bin/x86_64-linux-clang/qnn-model-lib-generator -c ./bge_small_fp16.cpp -b ./bge_small_fp16.bin -o output-so-small
也就是圖中的算子不支持。
嘗試了很多版本,后端,都不支持。沒辦法只能算子替換了。
1.2 替換算子
初步思路:
Sub↓Cast (to bool)↓Cast (to float32) (另外一個輸入,假設是 y)↓ ↓Mul Mul (1 - mask)↓ ↓Add↓Output
-
先做一個 Greater 比較,生成 0/1 tensor
-
再用這個 0/1 tensor 進行
(cond * x) + ((1-cond) * y)
操作, Where(cond, x, y) = cond * x + (1 - cond) * y 可以用Cast
+Mul
+Sub
+Add
基礎算子實現。 -
但是生成的還是有boolean 類型數據
不要 Greater
(即不要比較生成bool類型)
不要 BOOL
tensor (因為有些平臺對BOOL類型支持不好,比如QNN/DSP/NPU)
直接從 float tensor 生成 0/1 的 float tensor!
改進思路:
可以直接用 Clip + Sign 這種基礎算子來實現!
比如:
-
Sign(x)
:-
如果 x > 0,輸出 1
-
如果 x == 0,輸出 0
-
如果 x < 0,輸出 -1
-
-
Clip(Sign(x), 0.0, 1.0)
:-
把負數剪到 0
-
正數(1)保留為 1
-
這樣就完美地直接生成了一個 全是 0 或 1 的 FLOAT tensor! ? 沒有 BOOL 類型,? 沒有 Greater 節點,? 沒有 Cast,? 全是 float。
real_cond_input ---> Sign ---> Clip(0.0, 1.0) ---> mask (float 0/1 tensor)
二、算子代碼實現
1.1 替換算子
import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as npdef add_value_info(graph, name, dtype, shape):"""輔助函數:添加中間 tensor 的 shape 和 dtype"""vi = helper.make_tensor_value_info(name, dtype, shape)graph.value_info.append(vi)def add_constant(graph, base_name, value, dtype, shape):const_name = base_name + "_value"const_tensor = helper.make_tensor(name=const_name,data_type=dtype,dims=shape,vals=value)const_node = helper.make_node('Constant',inputs=[],outputs=[const_name],value=const_tensor)graph.node.append(const_node)add_value_info(graph, const_name, dtype, shape)return const_name
def replace_where_and_cast(model_path, output_path):"""替換 onnx 中的 Where 和 Cast 節點,保持功能等效"""# 讀取模型model = onnx.load(model_path)nodes = model.graph.nodeprint("old model node number" + str(len(model.graph.node)))new_nodes = []nodes_to_remove = []input_shape = [1,1, 512, 512]for node in model.graph.node:if node.op_type == "Where":# 記錄要移除的原始 Wherenodes_to_remove.append(node)# Where輸入:[condition, x, y]cond_input = node.input[0]print(cond_input)x_input = node.input[1]print(x_input)y_input = node.input[2]print(y_input)output_name = node.output[0]print(output_name)# 處理可能前面有 Cast 的情況real_cond_input = cond_inputfor sub_node in model.graph.node:if sub_node.output and sub_node.output[0] == cond_input and sub_node.op_type == "Cast":real_cond_input = sub_node.input[0]nodes_to_remove.append(sub_node)break# ========== 關鍵步驟 ==========# 1. Signsign_output = real_cond_input + "_sign"sign_node = helper.make_node('Sign',inputs=[real_cond_input],outputs=[sign_output],name ="sign_add_my")new_nodes.append(sign_node)add_value_info(model.graph, sign_output, TensorProto.FLOAT, input_shape)# 2. Clip(0,1)clip_output = real_cond_input + "_clip"clip_min_tensor_name = real_cond_input + "_min_value"clip_min_initializer = numpy_helper.from_array(np.zeros(1, dtype=np.float32),name=clip_min_tensor_name)clip_max_tensor_name = real_cond_input + "_max_value"clip_max_initializer = numpy_helper.from_array(np.ones(1, dtype=np.float32),name=clip_max_tensor_name)model.graph.initializer.append(clip_min_initializer)model.graph.initializer.append(clip_max_initializer)# min_val_const_node = add_constant(model.graph, "min_value", 0, TensorProto.FLOAT, input_shape)# max_val_const_node = add_constant(model.graph, "max_value", 1, TensorProto.FLOAT, input_shape)clip_node = helper.make_node('Clip',inputs=[sign_output, clip_min_tensor_name, clip_max_tensor_name],outputs=[clip_output],name="clip_add_my")new_nodes.append(clip_node)add_value_info(model.graph, clip_output, TensorProto.FLOAT, input_shape)# 3. 生成 (1 - mask)one_tensor_name = real_cond_input + "_one"one_initializer = numpy_helper.from_array(np.ones(input_shape, dtype=np.float32),name=one_tensor_name)model.graph.initializer.append(one_initializer)one_minus_mask_output = real_cond_input + "_one_minus_mask"sub_node = helper.make_node('Sub',inputs=[one_tensor_name, clip_output],outputs=[one_minus_mask_output],name="sub_my")new_nodes.append(sub_node)add_value_info(model.graph, one_minus_mask_output, TensorProto.FLOAT, input_shape)# 4. mask * xmask_mul_x_output = real_cond_input + "_mask_mul_x"mul1_node = helper.make_node('Mul',inputs=[clip_output, x_input],outputs=[mask_mul_x_output],name="mul_my")new_nodes.append(mul1_node)add_value_info(model.graph, mask_mul_x_output, TensorProto.FLOAT, input_shape)# 5. (1-mask) * yone_minus_mask_mul_y_output = real_cond_input + "_one_minus_mask_mul_y"mul2_node = helper.make_node('Mul',inputs=[one_minus_mask_output, y_input],outputs=[one_minus_mask_mul_y_output],name="mul_my2")new_nodes.append(mul2_node)add_value_info(model.graph, one_minus_mask_mul_y_output, TensorProto.FLOAT, input_shape)# 6. 加起來得到最終輸出add_node = helper.make_node('Add',inputs=[mask_mul_x_output, one_minus_mask_mul_y_output],outputs=[output_name],name="add_my")new_nodes.append(add_node)# output shape 已經有定義,不需要額外addelif node.op_type == 'Cast':# 如果是 Where 的 Cast,不保留if any(wn.input[0] == node.output[0] for wn in nodes if wn.op_type == 'Where'):print(f"Skipping Cast node: {node.name}")continueelse:new_nodes.append(node)else:new_nodes.append(node)# 移除舊節點for node in nodes_to_remove:model.graph.node.remove(node)# 更新新的節點列表model.graph.ClearField('node')model.graph.node.extend(new_nodes)print("new model node number" + str(len(model.graph.node)))# 保存新的模型onnx.save(model, output_path)if __name__ == "__main__":model_path = "./bge_small_model_simple.onnx"output_path = "./bge_replace_cast_where2.onnx"replace_where_and_cast(model_path, output_path)
?
2.2 運行原始模型和算子替換之后的模型
def run_bge_small_model_onnx():model = AutoModel.from_pretrained("BAAI/bge-small-zh-v1.5")tokenizers = AutoTokenizer.from_pretrained("BAAI/bge-small-zh-v1.5")input_data = "ZhongGuo, nihao, 日本再見, good cat!"device = "cuda" if torch.cuda.is_available() else "cpu"model.to(device)model.eval()input_tensor_data = tokenizers(input_data, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(device)with torch.no_grad():output = model(**input_tensor_data)print("oringal model putput")output_data = output.last_hidden_state.flatten().tolist()[:100]print(len(output.last_hidden_state.flatten().tolist()))print(output_data)print("run modify model")# 步驟 2:加載 ONNX 模型model_path = './bge_replace_cast_where2.onnx' # 替換為你的 ONNX 模型文件路徑session = ort.InferenceSession(model_path)# 步驟 3:準備輸入數據# 假設模型的輸入是一個形狀為 (1, 3, 224, 224) 的浮點張量input_name1 = session.get_inputs()[0].nameprint(input_name1)input_data1 = input_tensor_data["input_ids"].numpy()input_name2 = session.get_inputs()[1].nameinput_data2 = input_tensor_data["attention_mask"].numpy()print(input_name2)input_name3 = session.get_inputs()[2].nameinput_data3 = input_tensor_data["token_type_ids"].numpy()print(input_name3)# 步驟 4:運行模型并獲取輸出replace_model_output = session.run(None, {input_name1: input_data1, input_name2: input_data2, input_name3: input_data3})# 打印輸出結果print("replace_model_output shape:", replace_model_output[0].shape)print("replace_model_output data:", replace_model_output[0])replace_model_output_data = replace_model_output[:100]print(len(replace_model_output))print(replace_model_output_data)np.array(replace_model_output).tofile("last_output-onnx_bge_small_replace.raw")
2.3 原始模型和替換算子模型精度對齊
def compare_nchw_data(nchw_file, nchw_file2):data_nchw = read_bin_fp32(nchw_file, shape=[1, 512, 512])print("NCHW 原始數據形狀:", data_nchw.shape)print("NCHW 數據統計 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw.min(), data_nchw.max(), data_nchw.mean()))data_nchw2 = read_bin_fp32(nchw_file2, shape=[1, 512, 512])print("NHWC2 原始數據形狀:", data_nchw2.shape)print("NHWC2 數據統計 -> min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(data_nchw2.min(), data_nchw2.max(), data_nchw2.mean()))diff = data_nchw - data_nchw2print("\n==== 差異對比 ====")print("差值 min: {:.6f}, max: {:.6f}, mean: {:.6f}".format(diff.min(), diff.max(), diff.mean()))print(diff)# ==== 打印前100個數據 ====onnx_output_flat = data_nchw.flatten()onnx_output_flat2 = data_nchw2.flatten()print("\n--- 前100個元素 ---")for i in range(100):print(f"[{i}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 打印后100個數據 ====print("\n--- 后100個元素 ---")for i in range(-100, 0):idx = len(onnx_output_flat) + iprint(f"[{idx}] onnx-v={onnx_output_flat[i]:.6f} | qnn-v={onnx_output_flat2[i]:.6f} | diff={abs(onnx_output_flat[i] - onnx_output_flat2[i]):.6f}")# ==== 可選:統計誤差 ====max_diff = np.max(onnx_output_flat2 - onnx_output_flat)mean_diff = np.mean(onnx_output_flat2 - onnx_output_flat )min_diff = np.min(onnx_output_flat2 -onnx_output_flat)print(f"\n 總元素數: {onnx_output_flat.size}")print(f" 最大誤差: {max_diff}")print(f" 最小誤差: {min_diff}")print(f" 平均誤差: {mean_diff}")
2.4 對齊結果展示
?
?
結果對齊了,表示模型替換成功了。?