文章目錄
- 一、導出為dynamic shape
- 1)函數講解(函數導出、輸出檢查)
- 2)代碼展示
- 二、導出為static shape
- 1)函數講解(略)
- 2)代碼展示
- 三、序列化為FP32測速
- 1)測速
- 2)代碼
- 四、序列化為FP16測速
- 1)測速
- 2)代碼同上
- 五、發現并解決解決CLIP FP16溢出,并測速
- 1)如何找到溢出的算子
- 2)CLIP溢出算子解決方案
- 3)其他FP16算子溢出的解決方案
- 六、cuda-graph代碼優化并測速
- 七、圖片迭代次數優化PD、合并GroupNorm算子制作plugin,UNet和ControlNet拼batch測試
- 1)迭代次數優化
- 2)合并GroupNorm算子
- 3)UNet和ControlNet拼batch
- 八、根據smooth-quant算法優化INT8量化,對比測速PD
- 1)smooth-quant算法原理
- 2)smooth-quant算法代碼
- 3)測速PD損失
一、導出為dynamic shape
1)函數講解(函數導出、輸出檢查)
①torch.onnx.export
torch.onnx.export(clip_model,(tokens),onnx_path,verbose=True,opset_version=18,do_constant_folding=True,input_names=input_names,output_names=output_names,dynamic_axes=dynamic_axes,)
(1)export_params:默認為true,表示導出的 ONNX 模型文件會包含模型的所有參數(如權重、偏置等)。而當設置為 False 時,導出的 ONNX 模型文件僅包含模型的計算圖結構,不包含模型的參數。這意味著導出的 ONNX 文件會小很多,因為它沒有存儲大量的參數數據
(2)verbose:為true表示,將會輸出大量打印日志信息
(3)do_constant_folding:一般為true,是一個布爾類型的參數,其作用是控制在導出 ONNX 模型時是否進行常量折疊優化從而提高推理性能。為TRUE開啟常量折疊優化。在導出 ONNX 模型時,會對圖中所有僅包含常量輸入的操作進行預先計算,并用計算結果替換這些操作,以此簡化計算圖,減少模型的計算量和復雜度。
(4)input_names和output_names:輸入、輸出參數
(5)dynamic_axes:是一個字典,其鍵為輸入或輸出張量的名稱,值也是一個字典,用于指定該張量中哪些維度是動態的。內層字典的鍵是維度索引(從 0 開始),值是一個字符串,用于標識這個動態維度,通常在 ONNX 運行時會使用這個標識來指定具體的維度大小
(6)opset_version:指定optset的版本輸入參數舉例:dynamic_axes = {"x": {0: "batch_size"},"hint": {0: "batch_size"},"timesteps": {0: "batch_size"},"context": {0: "batch_size", 1: "sequence_length"},"output": {0: "batch_size", 1: "hint_height", 2: "hint_width"}}dynamic_axes = {"input_ids": {1: "S"}, "last_hidden_state": {1: "S"}}dynamic_axes = {"x": {0: "latent"},}
②誤差檢查
#onnx_path onnx文件目錄
#input_dicts 輸入參數
#torch_outputs 模型輸出結果
def onnxruntime_check(onnx_path, input_dicts, torch_outputs):onnx_model = onnx.load(onnx_path)# onnx.checker.check_model(onnx_model)sess = rt.InferenceSession(onnx_path)# outputs = self.get_output_names()# latent input# data = np.zeros((4, 77), dtype=np.int32)result = sess.run(None, input_dicts)cnt = 0for i in range(0, len(torch_outputs)):ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)cnt = cnt +1if ret is False:#print(f"onnxruntime_check {i} ret:{ret} result[i]:{result[i]} torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")print("Error onnxruntime_check")# import pdb; pdb.set_trace()#print("cnt:", cnt)
2)代碼展示
- 代碼
import numpy as np
from pytorch_fid import fid_score
from pytorch_fid.inception import InceptionV3
import cv2
import datetime
from share import *
import configimport cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
import osfrom pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from onnx import shape_inference
import onnx_graphsurgeon as gs
import onnx
import onnxruntime as rtdef optimize(onnx_path, opt_onnx_path):from onnxsim import simplifymodel = onnx.load(onnx_path)graph = gs.import_onnx(model)print(f"{onnx_path} simplify start !")# self.info("init", graph)model_simp, check = simplify(model)# self.info("opt", gs.import_onnx(model_simp))onnx.save(model_simp, opt_onnx_path, save_as_external_data=True)assert check, "Simplified ONNX model could not be validated"print(f"{onnx_path} simplify done !")def onnxruntime_check(onnx_path, input_dicts, torch_outputs):onnx_model = onnx.load(onnx_path)# onnx.checker.check_model(onnx_model)sess = rt.InferenceSession(onnx_path)# outputs = self.get_output_names()# latent input# data = np.zeros((4, 77), dtype=np.int32)result = sess.run(None, input_dicts)cnt = 0for i in range(0, len(torch_outputs)):ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)cnt = cnt +1if ret is False:#print(f"onnxruntime_check {i} ret:{ret} result[i]:{result[i]} torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")print("Error onnxruntime_check")# import pdb; pdb.set_trace()#print("cnt:", cnt)class hackathon():def initialize(self):self.apply_canny = CannyDetector()self.model = create_model('./models/cldm_v15.yaml').cpu()self.model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cpu'))# self.model.load_state_dict(load_state_dict('/home/player/ControlNet/models/control_sd15_canny.pth', location='cuda'))self.model = self.model.cpu()self.model.eval()self.ddim_sampler = DDIMSampler(self.model)hk = hackathon()
hk.initialize()def export_clip_model():clip_model = hk.model.cond_stage_modelimport typesdef forward(self, tokens):outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")if self.layer == "last":z = outputs.last_hidden_stateelif self.layer == "pooled":z = outputs.pooler_output[:, None, :]else:z = outputs.hidden_states[self.layer_idx]return zclip_model.forward = types.MethodType(forward, clip_model)onnx_path = "./onnx/CLIP.onnx"tokens = torch.zeros(1, 77, dtype=torch.int32)input_names = ["input_ids"]output_names = ["last_hidden_state"]dynamic_axes = {"input_ids": {1: "S"}, "last_hidden_state": {1: "S"}}torch.onnx.export(clip_model,(tokens),onnx_path,verbose=True,opset_version=18,do_constant_folding=True,input_names=input_names,output_names=output_names,dynamic_axes=dynamic_axes,)print("======================= CLIP model export onnx done!")# verify onnx modeloutput = clip_model(tokens)input_dicts = {"input_ids": tokens.numpy()}onnxruntime_check(onnx_path, input_dicts, [output])print("======================= CLIP onnx model verify done!")# opt_onnx_path = "./onnx/CLIP.opt.onnx"# optimize(onnx_path, opt_onnx_path)def export_control_net_model():control_net_model = hk.model.control_modelonnx_path = "./onnx/control_net_model.onnx"def get_shape(B=1,S=64):return [(B, 4, 32, 48),(B, 3, 256, 384),tuple([B])