stable diffusion 量化加速點

文章目錄

    • 一、導出為dynamic shape
      • 1)函數講解(函數導出、輸出檢查)
      • 2)代碼展示
    • 二、導出為static shape
      • 1)函數講解(略)
      • 2)代碼展示
    • 三、序列化為FP32測速
      • 1)測速
      • 2)代碼
    • 四、序列化為FP16測速
      • 1)測速
      • 2)代碼同上
    • 五、發現并解決解決CLIP FP16溢出,并測速
      • 1)如何找到溢出的算子
      • 2)CLIP溢出算子解決方案
      • 3)其他FP16算子溢出的解決方案
    • 六、cuda-graph代碼優化并測速
    • 七、圖片迭代次數優化PD、合并GroupNorm算子制作plugin,UNet和ControlNet拼batch測試
      • 1)迭代次數優化
      • 2)合并GroupNorm算子
      • 3)UNet和ControlNet拼batch
    • 八、根據smooth-quant算法優化INT8量化,對比測速PD
      • 1)smooth-quant算法原理
      • 2)smooth-quant算法代碼
      • 3)測速PD損失

一、導出為dynamic shape

1)函數講解(函數導出、輸出檢查)

①torch.onnx.export

    torch.onnx.export(clip_model,(tokens),onnx_path,verbose=True,opset_version=18,do_constant_folding=True,input_names=input_names,output_names=output_names,dynamic_axes=dynamic_axes,)
(1)export_params:默認為true,表示導出的 ONNX 模型文件會包含模型的所有參數(如權重、偏置等)。而當設置為 False 時,導出的 ONNX 模型文件僅包含模型的計算圖結構,不包含模型的參數。這意味著導出的 ONNX 文件會小很多,因為它沒有存儲大量的參數數據
(2)verbose:為true表示,將會輸出大量打印日志信息
(3)do_constant_folding:一般為true,是一個布爾類型的參數,其作用是控制在導出 ONNX 模型時是否進行常量折疊優化從而提高推理性能。為TRUE開啟常量折疊優化。在導出 ONNX 模型時,會對圖中所有僅包含常量輸入的操作進行預先計算,并用計算結果替換這些操作,以此簡化計算圖,減少模型的計算量和復雜度。
(4)input_names和output_names:輸入、輸出參數
(5)dynamic_axes:是一個字典,其鍵為輸入或輸出張量的名稱,值也是一個字典,用于指定該張量中哪些維度是動態的。內層字典的鍵是維度索引(從 0 開始),值是一個字符串,用于標識這個動態維度,通常在 ONNX 運行時會使用這個標識來指定具體的維度大小
(6)opset_version:指定optset的版本輸入參數舉例:dynamic_axes = {"x": {0: "batch_size"},"hint": {0: "batch_size"},"timesteps": {0: "batch_size"},"context": {0: "batch_size", 1: "sequence_length"},"output": {0: "batch_size", 1: "hint_height", 2: "hint_width"}}dynamic_axes = {"input_ids": {1: "S"}, "last_hidden_state": {1: "S"}}dynamic_axes = {"x": {0: "latent"},}

②誤差檢查

#onnx_path onnx文件目錄
#input_dicts  輸入參數
#torch_outputs  模型輸出結果
def onnxruntime_check(onnx_path, input_dicts, torch_outputs):onnx_model = onnx.load(onnx_path)# onnx.checker.check_model(onnx_model)sess = rt.InferenceSession(onnx_path)# outputs = self.get_output_names()# latent input# data = np.zeros((4, 77), dtype=np.int32)result = sess.run(None, input_dicts)cnt = 0for i in range(0, len(torch_outputs)):ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)cnt = cnt +1if ret is False:#print(f"onnxruntime_check {i} ret:{ret}  result[i]:{result[i]}  torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")print("Error onnxruntime_check")# import pdb; pdb.set_trace()#print("cnt:", cnt)

2)代碼展示

  • 代碼
import numpy as np
from pytorch_fid import fid_score
from pytorch_fid.inception import InceptionV3
import cv2
import datetime
from share import *
import configimport cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
import osfrom pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from onnx import shape_inference
import onnx_graphsurgeon as gs
import onnx
import onnxruntime as rtdef optimize(onnx_path, opt_onnx_path):from onnxsim import simplifymodel = onnx.load(onnx_path)graph = gs.import_onnx(model)print(f"{onnx_path} simplify start !")# self.info("init", graph)model_simp, check = simplify(model)# self.info("opt", gs.import_onnx(model_simp))onnx.save(model_simp, opt_onnx_path, save_as_external_data=True)assert check, "Simplified ONNX model could not be validated"print(f"{onnx_path} simplify done !")def onnxruntime_check(onnx_path, input_dicts, torch_outputs):onnx_model = onnx.load(onnx_path)# onnx.checker.check_model(onnx_model)sess = rt.InferenceSession(onnx_path)# outputs = self.get_output_names()# latent input# data = np.zeros((4, 77), dtype=np.int32)result = sess.run(None, input_dicts)cnt = 0for i in range(0, len(torch_outputs)):ret = np.allclose(result[i], torch_outputs[i].detach().numpy(), rtol=1e-03, atol=1e-05, equal_nan=False)cnt = cnt +1if ret is False:#print(f"onnxruntime_check {i} ret:{ret}  result[i]:{result[i]}  torch_outputs[i]:{torch_outputs[i].detach().numpy()} ")print("Error onnxruntime_check")# import pdb; pdb.set_trace()#print("cnt:", cnt)class hackathon():def initialize(self):self.apply_canny = CannyDetector()self.model = create_model('./models/cldm_v15.yaml').cpu()self.model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cpu'))# self.model.load_state_dict(load_state_dict('/home/player/ControlNet/models/control_sd15_canny.pth', location='cuda'))self.model = self.model.cpu()self.model.eval()self.ddim_sampler = DDIMSampler(self.model)hk = hackathon()
hk.initialize()def export_clip_model():clip_model = hk.model.cond_stage_modelimport typesdef forward(self, tokens):outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")if self.layer == "last":z = outputs.last_hidden_stateelif self.layer == "pooled":z = outputs.pooler_output[:, None, :]else:z = outputs.hidden_states[self.layer_idx]return zclip_model.forward = types.MethodType(forward, clip_model)onnx_path = "./onnx/CLIP.onnx"tokens = torch.zeros(1, 77, dtype=torch.int32)input_names = ["input_ids"]output_names = ["last_hidden_state"]dynamic_axes = {"input_ids": {1: "S"}, "last_hidden_state": {1: "S"}}torch.onnx.export(clip_model,(tokens),onnx_path,verbose=True,opset_version=18,do_constant_folding=True,input_names=input_names,output_names=output_names,dynamic_axes=dynamic_axes,)print("======================= CLIP model export onnx done!")# verify onnx modeloutput = clip_model(tokens)input_dicts = {"input_ids": tokens.numpy()}onnxruntime_check(onnx_path, input_dicts, [output])print("======================= CLIP onnx model verify done!")# opt_onnx_path = "./onnx/CLIP.opt.onnx"# optimize(onnx_path, opt_onnx_path)def export_control_net_model():control_net_model = hk.model.control_modelonnx_path = "./onnx/control_net_model.onnx"def get_shape(B=1,S=64):return [(B, 4, 32, 48),(B, 3, 256, 384),tuple([B])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/75349.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/75349.shtml
英文地址,請注明出處:http://en.pswp.cn/web/75349.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

7-openwrt-one通過web頁面配置訪客網絡、無線中繼等功能

前幾個章節一直在介紹編譯、分區之類的,都還沒正常開始使用這個路由器的wifi。默認wifi是沒有啟動的,前面還是通過手動修改uci配置啟動的,這個章節介紹下官方web頁面的使用。特別是訪客網絡、無線中繼 1、開啟wifi,配置wifi基本信息 我們使用有線連接路由器,通過192.168.…

AcWing 6099. 座位

原題目鏈接 問題描述 有 n 頭奶牛(n ≥ 5),編號為 1 ~ n,按照某種順序圍著一張圓桌坐成一圈。 奶牛之間存在如下的朋友關系: 如果兩頭奶牛相鄰,則它們是朋友;如果兩頭奶牛之間只隔著一頭奶…

44、Spring Boot 詳細講義(一)

Spring Boot 詳細講義 目錄 Spring Boot 簡介Spring Boot 快速入門Spring Boot 核心功能Spring Boot 技術棧與集成Spring Boot 高級主題Spring Boot 項目實戰Spring Boot 最佳實踐總結 一、Spring Boot 簡介 1. Spring Boot 概念和核心特點 1.1、什么是 Spring Boot&#…

配置mac mini M4 的一些軟件

最近更換了 mac mini M4 ,想要重新下載配置軟件 ,記錄一下。 Homebrew是什么? homebrew是一款Mac OS平臺下的軟件包管理工具,擁有安裝、卸載、更新、查看、搜索等功能。通過簡單的指令可以實現包管理,而不用關心各種…

網絡空間安全(54)CSRF

一、定義與原理 CSRF(Cross-Site Request Forgery),全稱為跨站請求偽造,也被稱為One Click Attack或Session Riding,縮寫為CSRF或XSRF。它是一種網絡安全漏洞,攻擊者通過偽造用戶的請求,利用用戶…

分布式文件存儲系統FastDFS

文章目錄 1 分布式文件存儲1_分布式文件存儲的由來2_常見的分布式存儲框架 2 FastDFS介紹3 FastDFS安裝1_拉取鏡像文件2_構建Tracker服務3_構建Storage服務4_測試圖片上傳 4 客戶端操作1_Fastdfs-java-client2_文件上傳3_文件下載4_獲取文件信息5_問題 5 SpringBoot整合 1 分布…

安裝了VM Tools,仍無法復制拖動-解決方案

今天在安裝ubuntu時遇到了困擾許久的問題,安裝了VM Tools,仍無法拖動主機文件到虛擬機,主要有兩種原因并對應解決辦法。 1.相關虛擬機設置選項卡中-客戶機隔離-兩個功能沒有勾選 解決方案:勾選重啟虛擬機即可 2.(這個…

Jmeter分布式測試啟動

代理客戶端配置 打開jmeter.properties文件,取消注釋并設置端口(如server_port1099), 并添加server.rmi.ssl.disabletrue禁用SSL加密。 (Linux系統)修改jmeter-server文件中的RMI_HOST_DEF為代理機實際IP。…

火語言RPA--Oracle-導入數據表格

【組件功能】:導入特定的表格數據到包含同樣字段的數據表 將表格對象數據通過數據庫操作對象導入到指定數據庫。 配置預覽 配置說明 源表格 表格來源有“來自表格對象”和“來自表達式”2種,表達式支持DataTable類型變量。 對象 對應來自表格對象&…

Java的Selenium的特殊元素操作與定位之驗證碼

1.使用OCR技術識別驗證 步驟: 截取整個網頁的截圖。 定位驗證碼圖片元素。 根據驗證碼圖片的位置和大小,從截圖中裁剪出驗證碼圖片。 使用OCR工具(如Tesseract)識別驗證碼圖片中的文本。 2.手動處理驗證碼 步驟:…

OpenStack Yoga版安裝筆記(十七)安全組筆記

一、安全組與iptables的關系 OpenStack的安全組(Security Group)默認是通過Linux的iptables實現的。以下是其主要實現原理和機制: 安全組與iptables的關系 OpenStack的安全組規則通過iptables的規則鏈實現。每條安全組規則會被轉換為相應的i…

starrocks split函數和trino split函數差異性

在trino419和starrocks3.2.8中分別執行下面這兩條sql,出來的結果是不一樣的 select split(,,,)[1] as t1 select coalesce(split(,,&#

Spring Data JPA中的List底層:深入解析ArrayList的奧秘!!!

&#x1f31f; Spring Data JPA中的List底層&#xff1a;深入解析ArrayList的奧秘 &#x1f4a1; 你是否好奇過&#xff0c;為什么Spring Data JPA的查詢方法返回的List<T>總是默認為ArrayList&#xff1f;本文將通過技術原理解析、驗證實驗和性能優化指南&#xff0c;為…

騰訊云智測試開發面經

1、投遞時間線 2.20投遞簡歷,3.11第一輪面試,3.30第二輪面試,4.4第三輪面試,4.10第四輪面試,4.11offer意向書 2、第一輪面試 第一輪面試技術面,面試官是導師,面試時長40多分鐘 1)自我介紹 2)數組和列表的區別 3)了解哪些數據庫 4)進程和線程的區別 5)了解哪…

【深度學習】【目標檢測】【Ultralytics-YOLO系列】YOLOV3源碼整體結構解析

【深度學習】【目標檢測】【Ultralytics-YOLO系列】YOLOV3源碼整體結構解析 文章目錄 【深度學習】【目標檢測】【Ultralytics-YOLO系列】YOLOV3源碼整體結構解析前言代碼結構整體data文件結構模型訓練超參數配置文件解析數據集配置文件解析 models文件結構utils文件結構runs文…

Python常用排序算法

1. 冒泡排序 冒泡排序是一種簡單的排序算法&#xff0c;它重復地遍歷要排序的列表&#xff0c;比較相鄰的元素&#xff0c;如果他們的順序錯誤就交換他們。 def bubble_sort(arr):# 遍歷所有數組元素for i in range(len(arr)):# 最后i個元素是已經排序好的for j in range(0, …

解鎖塔能科技,開啟工廠綠色轉型與可持續發展雙引擎

在全球積極推進可持續發展的大背景下&#xff0c;能源的高效利用與節能減排&#xff0c;已成為各行各業邁向高質量發展進程中無法回避的核心任務。工廠作為能源消耗大戶與污染排放重點源頭&#xff0c;其綠色轉型迫在眉睫&#xff0c;這不僅關乎企業自身的長遠發展&#xff0c;…

Spring Boot 線程池配置詳解

Spring Boot 線程池配置詳解 一、核心配置參數及作用 基礎參數核心線程數 (corePoolSize)? 作用?:線程池中始終保持存活的線程數量,即使空閑也不回收?。 建議?:根據任務類型設定(如 I/O 密集型任務可設為 CPU 核心數 2)?。 最大線程數 (maxPoolSize)? 作用?:…

入侵檢測系統(IDS)和入侵防御系統(IPS)有啥區別?

入侵檢測系統&#xff08;IDS&#xff09;和入侵防御系統&#xff08;IPS&#xff09;是網絡安全中的兩種關鍵技術&#xff0c;它們的核心區別在于 檢測后的響應方式 和 部署位置。以下是詳細對比&#xff1a; 1. 核心功能 - IDS&#xff08;入侵檢測系統&#xff09; - 僅監…

【MySQL 數據庫】數據表的操作

&#x1f525;博客主頁&#x1f525;&#xff1a;【 坊鈺_CSDN博客 】 歡迎各位點贊&#x1f44d;評論?收藏? 目錄 1. 表的查看 1.1 語法 2. 表的創建 2.1 語法 2.2 練習 3. 查看表結構 3.1 語法 3.2 示例 4. 表的修改 4.1 語法 4.2 示例操作 4.2.1 向表中添加字段…