文章目錄
- 嘗試1:強行設置dropout層train mode為False
- 嘗試2:找到onnx模型中的dropout, train mode設置為False
- 嘗試3:直接刪除dropout層,連接其輸入輸出
- 結語
最近訓練模型使用了tinyvit,性能挺強的:

但是導出onnx時,會提示dropout層的train mode被設置為True了。
UserWarning: ONNX export mode is set to TrainingMode.EVAL, but operator 'dropout' is set to train=True. Exporting with train=True.
這個警告如果只是使用onnxruntime去推理的話,可以不用處理,但是如果使用openvino則會在轉換模型時失敗。因為導出的onnx中出現了Dropout層,一般的推理框架是不支持推理的時候用dropout的。
嘗試1:強行設置dropout層train mode為False
for m in torch_model.modules():if isinstance(m, torch.nn.Dropout):m.training = False
問題依舊
嘗試2:找到onnx模型中的dropout, train mode設置為False
做這個嘗試的本意是先設置為False, 再用onnx-simplify去優化一把,理論上會把dropout層去掉。
# 遍歷模型的所有Dropout節點, 找到所有的training mode節點名稱
training_mode_inputs=[]
for node in model.graph.node:if node.op_type == 'Dropout':# 獲取Dropout節點的training_mode輸入(假設是最后一個輸入)training_mode_input = node.input[-1]# 檢查這個輸入是否指向之前找到的值為True的常量節點training_mode_inputs.append(training_mode_input)# 遍歷所有初始化器
for initializer in model.graph.initializer:# 檢查初始化器是否是我們要找的training_mode輸入if initializer.name in training_mode_inputs:# 假設這個初始化器是一個布爾值,我們將其修改為False# 注意:ONNX中的布爾值是以int64類型存儲的,0表示False,1表示True# initializer.data_type = onnx.TensorProto.INT64initializer.int64_data[:] = [0] # 修改為False
from onnx import helper
new_initializers = []for initializer in model.graph.initializer:if initializer.name in training_mode_inputs:# 創建一個新的TensorProto對象,值為Falsenew_initializer = helper.make_tensor(name=initializer.name, # 保持原來的名稱data_type=onnx.TensorProto.BOOL,dims=initializer.dims, # 保持原來的維度vals=[0] # 設置值為False(在ONNX中用0表示))new_initializers.append(new_initializer)else:new_initializers.append(initializer)# 替換原來的初始化器列表
# Clear existing initializers
model.graph.ClearField('initializer')
# Add the new initializers
model.graph.initializer.extend(new_initializers)
理想很豐滿,現實很骨感···并沒有發生什么變化
嘗試3:直接刪除dropout層,連接其輸入輸出
dropout層在推理的時候也沒什么用,直接刪除,然后連接上原dropout的輸入輸出層就好了
import onnx
from onnx import helper# 加載模型
onnx_model = onnx.load(model_path)
graph = onnx_model.graph# 找到 Dropout 層
nodes_to_remove = [node for node in graph.node if node.op_type == 'Dropout']# 刪除 Dropout 層并重新連接
for node in nodes_to_remove:input_name = node.input[0]output_name = node.output[0]# 找到所有使用 Dropout 輸出作為輸入的節點for next_node in graph.node:for i, input_name in enumerate(next_node.input):if input_name == node.output[0]:next_node.input[i] = node.input[0]# 從圖中移除 Dropout 節點graph.node.remove(node)# 保存修改后的模型
# check if the model is valid
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'tinyvit_11m_sim_replace.onnx')
成功了,模型的dropout層都被刪除了。
結語
雖然嘗試了好幾種方式···不過這些具體的代碼我基本都是問的copilot,不得不說代碼助手減輕了好多工作。