TVM:使用 Auto-scheduling 來優化算子

TVM:使用 Auto-scheduling 來優化算子

在本教程中,我們將展示 TVM 的 Auto-scheduling 功能如何在無需編寫自定義模板的情況下找到最佳 schedule。

與基于模板的 AutoTVM 依賴手動模板定義搜索空間不同,auto-scheduler 不需要任何模板。 用戶只需編寫計算聲明,無需任何調度命令或模板。 auto-scheduler 可以自動生成一個大的搜索空間,并在該空間中找到一個好的 schedule。

我們在本教程中同樣使用矩陣乘法作為示例。

import osimport numpy as np
import tvm
from tvm import te, auto_scheduler

定義矩陣乘法

首先,我們定義一個帶有偏置的矩陣乘法。 請注意,這使用了 TVM 張量表達式語言中可用的標準操作。 主要區別在于在函數定義的開始使用了 auto_sceduler 裝飾器。 該函數應返回輸入/輸出張量列表。 從這些張量中,自動調度器可以獲得整個計算圖。

@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):A = te.placeholder((N, L), name="A", dtype=dtype)B = te.placeholder((L, M), name="B", dtype=dtype)C = te.placeholder((N, M), name="C", dtype=dtype)k = te.reduce_axis((0, L), name="k")matmul = te.compute((N, M),lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),name="matmul",attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B)out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")return [A, B, C, out]

創建搜索任務

定義函數后,我們現在可以創建供 auto_scheduler 搜索的任務。 我們指定此矩陣乘法的特定參數,在本例中為 1024x1024 大小的方陣的乘法。 然后我們創建一個搜索任務,其中 N=L=M=1024 ,數據類型為 ”float32”。

target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

注意:自定義 target 可以提高性能

為了讓 TVM 充分利用特定硬件平臺,您需要手動指定 CPU 功能。 例如: - 將下面的“llvm”替換為“llvm -mcpu=core-avx2”以啟用 AVX2 - 將下面的“llvm”替換為“llvm -mcpu=skylake-avx512”以啟用 AVX-512

此處輸出:

Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])

為 Auto-Scheduler 設置參數

接下來,我們為自動調度程序設置參數。

  • num_measure_trials 是我們在搜索過程中可以使用的測量試驗次數。 為了快速演示,我們在本教程中僅進行了 10 次試驗。 在實踐中,1000 是一個很好的搜索收斂值。 您可以根據您的時間預算進行更多試驗。

  • 此外,我們使用 RecordToFile 將測量記錄記錄到文件 matmul.json 中。 測量記錄可用于最佳查詢歷史記錄、恢復搜索以及稍后進行更多分析。

  • 有關更多參數,請參閱 auto_scheduler.TuningOptions

log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(num_measure_trials=10,measure_callbacks=[auto_scheduler.RecordToFile(log_file)],verbose=2,
)

運行搜索

現在我們準備好所有輸入。 很簡單,不是嗎? 我們可以開始搜索并讓自動調度程序發揮它的魔力。 經過一些測量試驗后,我們可以從日志文件中加載最佳計劃并應用它。

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

檢查優化過的 Schedule

我們可以在 auto-scheduling 后降低(lower)schedule 以查看 IR。 auto-schduling 程序正確執行優化,包括多級平鋪、布局轉換、并行化、矢量化、展開和算子融合。

print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

此處輸出:

Lowered TIR:
primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}buffers = {out: Buffer(out_2: Pointer(float32), float32, [1024, 1024], []),A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out} {allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [1048576]), storage_scope = global {for (ax0.ax1.fused.ax2.fused: int32, 0, 128) "parallel" {for (ax4: int32, 0, 256) {for (ax6: int32, 0, 4) {for (ax7: int32, 0, 8) {auto_scheduler_layout_transform[((((ax0.ax1.fused.ax2.fused*8192) + (ax4*32)) + (ax6*8)) + ax7)] = (float32*)B_2[((((ax4*4096) + (ax6*1024)) + (ax0.ax1.fused.ax2.fused*8)) + ax7)]}}}}for (i.outer.outer.j.outer.outer.fused: int32, 0, 16384) "parallel" {allocate(matmul: Pointer(global float32x8), float32x8, [4]), storage_scope = global;for (i.outer.inner: int32, 0, 2) {matmul[ramp(0, 1, 8)] = broadcast(0f32, 8)matmul[ramp(8, 1, 8)] = broadcast(0f32, 8)matmul[ramp(16, 1, 8)] = broadcast(0f32, 8)matmul[ramp(24, 1, 8)] = broadcast(0f32, 8)for (k.outer: int32, 0, 256) {for (k.inner: int32, 0, 4) {matmul[ramp(0, 1, 8)] = ((float32x8*)matmul[ramp(0, 1, 8)] + (broadcast((float32*)A_2[((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))matmul[ramp(8, 1, 8)] = ((float32x8*)matmul[ramp(8, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 1024)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))matmul[ramp(16, 1, 8)] = ((float32x8*)matmul[ramp(16, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 2048)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))matmul[ramp(24, 1, 8)] = ((float32x8*)matmul[ramp(24, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 3072)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))}}for (i.inner: int32, 0, 4) {out_2[ramp(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8)), 1, 8)] = ((float32x8*)matmul[ramp((i.inner*8), 1, 8)] + (float32x8*)C_2[ramp(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8)), 1, 8)])}}}}
}

檢查正確性并評估性能

我們構建二進制文件并檢查其正確性和性能。

func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_npdev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print("Execution time of this operator: %.3f ms"% (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)

此處輸出:

Execution time of this operator: 45.418 ms

使用記錄文件

在搜索過程中,所有的測量記錄都被記錄到記錄文件“matmul.json”中。 測量記錄可用于重新應用搜索結果、恢復搜索和執行其他分析。

這是一個示例,我們從文件加載最佳 schedule,并打印等效的 Python schedule API。 這可用于調試和學習 auto-scheduling 程序的行為。

print("Equivalent python schedule:")
print(task.print_best(log_file))

此處輸出:

Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=4)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=1)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=2)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=1)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=4)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=4)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=2)
out_j_o_i, out_j_i = s[out].split(out_j, factor=8)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=1)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 8)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
s[out].vectorize(out_j_i)

一個更復雜的例子是恢復搜索。 在這種情況下,我們需要自己創建搜索策略和成本模型,并通過日志文件恢復搜索策略和成本模型的狀態。 在下面的示例中,我們恢復狀態并再進行 5 次試驗。

def resume_search(task, log_file):print("Resume search:")cost_model = auto_scheduler.XGBModel()cost_model.update_from_file(log_file)search_policy = auto_scheduler.SketchPolicy(task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)])tune_option = auto_scheduler.TuningOptions(num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)])task.tune(tune_option, search_policy=search_policy)resume_search(task, log_file)

此處輸出:

Resume search:
/usr/local/lib/python3.6/dist-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated.  See: https://xgboost.readthedocs.io/en/latest/python/callbacks.htmlwarnings.warn(f'Old style callback is deprecated.  See: {link}', UserWarning)

總結

在本教程中,我們展示了如何使用 TVM Auto-Scheduler 自動優化矩陣乘法,而無需指定搜索模板。 它結束了一系列從張量表達式 (TE) 語言開始的示例,這些示例演示了 TVM 如何優化計算操作。

Ref:

https://tvm.apache.org/docs/tutorial/auto_scheduler_matmul_x86.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532663.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532663.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532663.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C語言—sort函數比較大小的快捷使用--algorithm頭文件下

sort函數 一般情況下要將一組數從的大到小排序或從小到大排序&#xff0c;要定義一個新的函數排序。 而我們也可以直接使用在函數下的sort函數&#xff0c;只需加上頭文件&#xff1a; #include<algorithm> using namespace std;sort格式&#xff1a;sort(首元素地址&…

散列的使用

散列 散列簡單來說&#xff1a;給N個正整數和M個負整數&#xff0c;問這M個數中的每個數是否在N中出現過。 比如&#xff1a;N&#xff1a;{1,2,3,4}&#xff0c;M{2,5,7}&#xff0c;其中M的2在N中出現過 對這個問題最直觀的思路是&#xff1a;對M中每個欲查的值x&#xff0…

關于C++中的unordered_map和unordered_set不能直接以pair作為鍵名的問題

關于C中的unordered_map和unordered_set不能直接以pair作為鍵名的問題 在 C STL 中&#xff0c;不同于有序的 std::map 和 std::set 是基于紅黑樹實現的&#xff0c;std::unordered_map 和 std::unordered_set 是基于哈希實現的&#xff0c;在不要求容器內的鍵有序&#xff0c…

AI編譯器與傳統編譯器的聯系與區別

AI編譯器與傳統編譯器的區別與聯系 總結整理自知乎問題 針對神經網絡的編譯器和傳統編譯器的區別和聯系是什么&#xff1f;。 文中提到的答主的知乎主頁&#xff1a;金雪鋒、楊軍、藍色、SunnyCase、貝殼與知了、工藤福爾摩 筆者本人理解 為了不用直接手寫機器碼&#xff0…

python學習1:注釋\變量類型\轉換函數\轉義字符\運算符

python基礎學習 與大多數語言不同&#xff0c;python最具特色的就是使用縮進來表示代碼塊&#xff0c;不需要使用大括號 {} 。縮進的空格數是可變的&#xff0c;但是同一個代碼塊的語句必須包含相同的縮進空格數。 &#xff08;一個tab4個空格&#xff09; Python語言中常見的…

Python、C++ lambda 表達式

Python、C lambda 表達式 lambda函數簡介 匿名函數lambda&#xff1a;是指一類無需定義標識符&#xff08;函數名&#xff09;的函數或子程序。所謂匿名函數&#xff0c;通俗地說就是沒有名字的函數&#xff0c;lambda函數沒有名字&#xff0c;是一種簡單的、在同一行中定義函…

python 學習2 /輸入/ 輸出 /列表 /字典

python基礎學習第二天 輸入輸出 xinput("輸入內容") print(x)input輸出&#xff1a; eval :去掉字符串外圍的引號&#xff0c;按照python的語法執行內容 aeval(12) print(a)eval輸出樣式&#xff1a; 列表 建立&#xff0c;添加&#xff0c;插入&#xff0c;刪去…

Linux、Mac 命令行快捷鍵

Linux、Mac 命令行快捷鍵 Linux 命令行編輯快捷鍵&#xff0c;參考了好多個&#xff0c;應該算是比較全的了&#xff0c;Linux 和 Mac 的都有&#xff0c;筆者本人比較常用的也已經紅色標出來了&#xff0c;如有錯誤或遺漏&#xff0c;歡迎留言指出。 光標移動及編輯&#xff…

Python 命令行傳參

Python 命令行傳參 說到 python 命令行傳參&#xff0c;可能大部分人的第一反應就是用 argparse。的確&#xff0c;argparse 在我們需要指定多個預設的參數&#xff08;如深度學習中指定模型的超參數等&#xff09;時&#xff0c;是非常有用的。但是如果有時我們只需要一個參數…

快速排序 C++

快速排序 C 本文圖示借鑒自清華大學鄧俊輝老師數據結構課程。 快速排序的思想 快速排序是分治思想的典型應用。該排序算法可以原地實現&#xff0c;即空間復雜度為 O(1)O(1)O(1)&#xff0c;而時間復雜度為 O(nlogn)O(nlogn)O(nlogn) 。 算法將待排序的序列 SSS 分為兩個子…

Linux命令行下感嘆號的幾個用法

Linux命令行下 " ! " 的幾個用法 ! 在大多數編程語言中表示取反的意思&#xff0c;但是在命令行中&#xff0c;他還有一些其他的神奇用法。熟練掌握這些用法&#xff0c;可以大大提高我們日常命令行操作的效率。 1 執行歷史命令 !! ! 在命令行中可以用來執行歷史…

三地址碼簡介

三地址碼簡介 三地址碼&#xff08;Three Address Code&#xff09;是一種最常用的中間語言&#xff0c;編譯器可以通過它來改進代碼轉換效率。每個三地址碼指令&#xff0c;都可以被分解為一個四元組&#xff08;4-tuple&#xff09;的形式&#xff1a;&#xff08;運算符&am…

llvm與gcc

llvm與gcc llvm 是一個編譯器&#xff0c;也是一個編譯器架構&#xff0c;是一系列編譯工具&#xff0c;也是一個編譯器工具鏈&#xff0c;開源 C11 實現。 gcc 相對于 clang 的優勢&#xff1a; gcc 支持更過語言前端&#xff0c;如 Java, Ada, FORTRAN, Go等gcc 支持更多地 …

攻防世界web新手區解題 view_source / robots / backup

1**. view_source** 題目描述&#xff1a;X老師讓小寧同學查看一個網頁的源代碼&#xff0c;但小寧同學發現鼠標右鍵好像不管用了。 f12查看源碼即可發現flag 2. robots 題目描述&#xff1a;X老師上課講了Robots協議&#xff0c;小寧同學卻上課打了瞌睡&#xff0c;趕緊來教教…

python參數傳遞*args和**kwargs

python參數傳遞*args和**kwargs 和* 實際上真正的Python參數傳遞語法是 * 和 ** 。*args 和 **kwargs 只是一種約定俗成的編程實踐。我們也可以寫成 *vars 和 **kvars 。就如同其他常規變量的命名一樣&#xff0c; args 和 kwargs 只是一種習慣的名稱。 *args 和 **kwargs 一…

聽GPT 講Rust源代碼--src/tools(25)

File: rust/src/tools/clippy/clippy_lints/src/methods/suspicious_command_arg_space.rs 在Rust源代碼中&#xff0c;suspicious_command_arg_space.rs文件位于clippy_lints工具包的methods目錄下&#xff0c;用于實現Clippy lint SUSPICIOUS_COMMAND_ARG_SPACE。 Clippy是Ru…

Java一次編譯,到處運行是如何實現的

Java一次編譯&#xff0c;到處運行是如何實現的 轉自&#xff1a;https://cloud.tencent.com/developer/article/1415194 &#xff08;排版微調&#xff09; JAVA編譯運行總覽 Java是一種高級語言&#xff0c;要讓計算機執行你撰寫的Java程序&#xff0c;也得通過編譯程序的…

JIT(動態編譯)和AOT(靜態編譯)編譯技術比較

JIT&#xff08;動態編譯&#xff09;和AOT&#xff08;靜態編譯&#xff09;編譯技術比較 轉自&#xff1a;https://www.cnblogs.com/tinytiny/p/3200448.html Java 應用程序的性能經常成為開發社區中的討論熱點。因為該語言的設計初衷是使用解釋的方式支持應用程序的可移植…

python解釋器

python解釋器 計算機編程語言 本部分參考自&#xff1a;https://zhuanlan.zhihu.com/p/141212114 從計算機編程語言說起&#xff0c;它主要分為三類&#xff1a;機器語言、匯編語言、高級語言。 機器語言是一種計算機可以直接識別并執行的二進制指令集。由于其可以直接交給…

編譯型語言與解釋型語言

編譯型語言與解釋型語言 首先要說明&#xff0c;編譯型語言與解釋型語言這種分類方法是不科學的&#xff0c;或者說已經過時了&#xff0c;但是這種稱呼大抵還是能夠讓人明白我們將要討論的是什么東西。 文中所列參考是筆者認為比較有幫助的一些擴展閱讀內容。 首先貼一個很形…