為pytorch前向和反向的Tensor生成描述性統計

為pytorch前向和反向的Tensor生成描述性統計

  • 代碼

在調試Megatron-DeepSpeed的精度時,我們希望對比每一層前向和反向傳播的輸入輸出誤差。然而,由于數據量過大,直接保存所有數據不太現實。因此,我們生成了輸入輸出tensor的描述性統計信息,并等間隔抽樣N個數據點,以比較這些點的相對誤差,從而查找精度異常的位置。為了準確定位,我們通過類名和對象ID生成唯一的對象名稱(形式為[類名-創建的第幾個])以及前向和反向傳播的次數。通過保存上述信息,我們可以詳細記錄并回溯當時的實際輸入輸出數據。

代碼

cat > linear_test.py <<-'EOF'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from datetime import datetime# 設置設備
device = "cpu"if torch.cuda.is_available():device = "cuda:4"def is_tensor(val):# 判斷是否為tensor或Parameterreturn isinstance(val, (torch.Tensor, nn.Parameter))def describe_tensor(tensor):# 返回tensor的描述,包括形狀和部分數據統計信息shape = list(tensor.shape)tensor_data = tensor.cpu().float().detach().numpy().ravel()num_points = min(16, len(tensor_data))indices = np.linspace(0, len(tensor_data) - 1, num_points, dtype=int)stats = [np.max(tensor_data), np.min(tensor_data), np.mean(tensor_data), np.std(tensor_data)]sample_data = tensor_data[indices]stats_str = ",".join(f"{x:.5f}" for x in stats)sample_str = ",".join(f"{x:.5f}" for x in sample_data)return f"{shape}-{stats_str},{sample_str}"def generate_random_data(shape):# 生成符合指定形狀的隨機數據max_val, min_val, mean, std = 0.04025, -0.04651, 0.0, 0.00134data = np.random.normal(mean, std, shape)data = (data - data.min()) / (data.max() - data.min()) * (max_val - min_val) + min_valreturn dataindex_counter = 0def log_tensor_data(name, tensor):# 打印tensor的日志數據global index_counterindex_counter += 1timestamp = datetime.now().strftime("%H%M%S%f")if is_tensor(tensor):print(f"{timestamp},{index_counter},{name},0,{describe_tensor(tensor)}")elif isinstance(tensor, (tuple, list)):for idx, t in enumerate(tensor):if is_tensor(t):print(f"{timestamp},{index_counter},{name},{idx},{describe_tensor(t)}")def log_gradient(model):# 打印模型參數梯度信息for name, param in model.named_parameters():if param.grad is not None:log_tensor_data(f"grad-{name}", param.grad)# 對象和類名緩存
object_cache = {}
class_name_count = {}def get_unique_name(class_name, obj_id):# 生成唯一的對象名稱if class_name not in class_name_count:class_name_count[class_name] = 0uid = f"{class_name}_{obj_id}"if uid not in object_cache:class_name_count[class_name] += 1object_cache[uid] = {"idx": class_name_count[class_name]}return f'{class_name}-{object_cache[uid]["idx"]}'def initialize_module_attributes(module):# 初始化模塊屬性if not hasattr(module, 'uuid'):module.uuid = get_unique_name(module.__class__.__name__, id(module))if not hasattr(module, 'backward_step'):module.backward_step = 0if not hasattr(module, 'forward_step'):module.forward_step = 0def forward_decorator():# 包裝forward函數的修飾器def decorator(func):def wrapped(*args, **kwargs):module = args[0]initialize_module_attributes(module)module.forward_step += 1log_tensor_data(f"forward-{module.uuid}-{module.forward_step}-input", args)output = func(*args, **kwargs)log_tensor_data(f"forward-{module.uuid}-{module.forward_step}-output", output)return outputreturn wrappedreturn decoratordef pre_backward_hook(module, grad_input):# 反向傳播前的鉤子函數initialize_module_attributes(module)module.backward_step += 1log_tensor_data(f"backward-{module.uuid}-{module.backward_step}-input", grad_input)def post_backward_hook(module, grad_input, grad_output):# 反向傳播后的鉤子函數initialize_module_attributes(module)log_tensor_data(f"backward-{module.uuid}-{module.backward_step}-output", grad_output)def register_backward_hooks(module):# 注冊反向傳播鉤子module.register_full_backward_pre_hook(pre_backward_hook)module.register_full_backward_hook(post_backward_hook)class CustomLinear(nn.Module):def __init__(self, shape):super(CustomLinear, self).__init__()weight_data = torch.from_numpy(generate_random_data(shape)).half().to(device)self.weight = nn.Parameter(weight_data)self.register_parameter('bias', None)register_backward_hooks(self)@forward_decorator()def forward(self, input_):return F.linear(input_, self.weight, self.bias)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layer1 = CustomLinear((5504, 4096))self.layer2 = CustomLinear((4096, 5504))@forward_decorator()def forward(self, input_):out = self.layer1(input_)out = self.layer2(out)return out
# 設置隨機種子
np.random.seed(1)
torch.manual_seed(2)# 創建和訓練模型
model = MyModel().half().to(device)
model.train()input_data = torch.from_numpy(generate_random_data((1024, 12, 4096))).half().to(device)
target_data = torch.from_numpy(generate_random_data((1024, 12, 4096))).half().to(device)for _ in range(2):outputs = model(input_data)outputs.backward(target_data)  # 使用全一的梯度來反向傳播log_gradient(model)
EOF
python3 linear_test.py

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/12970.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/12970.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/12970.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

有哪些好用的3dMax大神插件?

有哪些好用的3dMax大神插件&#xff1f; Mesh Insert 3DMAX網格插入插件Mesh Insert&#xff0c;在選擇的面上安門窗、打螺絲、挖洞、插入眼耳口鼻及其它網格模型等可以分分鐘搞定&#xff01;它通過將面選擇替換為庫中的資源來加快建模過程。非常適合硬網格和有機建模&#xf…

Go 一個類型轉換工具包strconv包

Go 語言的 strconv 包提供了用于基本數據類型之間轉換的函數&#xff0c;包括字符串到其他基本類型的轉換&#xff0c;以及其他基本類型到字符串的轉換。 字符串轉換為基本數據類型 strconv.Atoi&#xff1a;將字符串轉換為 intstrconv.ParseBool&#xff1a;將字符串轉換為 b…

iOS ------ 多線程基礎

一&#xff0c;進程和線程 1&#xff0c;進程 定義&#xff1a; 進程是指在系統中正在運行的一個應用程序每個進程之間是獨立的&#xff0c;每個進程均運行在其專有的且受保護的內存進程是系統進行資源分配和調度的一個獨立單位 補充&#xff1a;iOS系統是相對封閉的系統&a…

SQL中的LAG函數與LEAD函數用法

LAG&#xff1a;函數用于獲取結果集中當前行之前的某一行的值 LAG (scalar_expression [,offset] [,default]) OVER ([partition_by_clause ] order_by_clause ) -----漢字解釋 LAG (字段 [,偏移量默認為1] [,如果沒有值時候默認值]) OVER ( [ partition_by 字段 ] order_by 字…

服務網格 SolarMesh v1.13 重磅發布

SolarMesh是行云創新推出的流量治理平臺&#xff0c;它基于Istio&#xff0c;為部署在K8s集群上的應用提供全面的流量治理能力。 在之前的版本中&#xff0c;SolarMesh提供的能力有&#xff1a;流量視圖&#xff0c;流量控制策略批量配置&#xff0c;API級別的流量數據采集和展…

【上海大學計算機組成原理實驗報告】五、機器語言程序實驗

一、實驗目的 理解計算機執行程序的實際過程。 學習編制機器語言簡單程序的方法。 二、實驗原理 根據實驗指導書的相關內容&#xff0c;指令的形式化表示是指采用一種規范化的符號系統&#xff0c;以更清晰、精確地描述和表示指令的邏輯功能和操作步驟。 匯編是一種編程語言…

MM模塊學習二 (供應商,物料后臺相關配置)

公司代碼配置 新建條目&#xff08;只是建了一個名字出來&#xff0c;后面很多表都是沒有得&#xff09; 接下來定義公司代碼&#xff1a; 公司代碼復制完成&#xff08;后續修改交給財務顧問去做&#xff09; 復制工廠&#xff1a; 復制工廠完成&#xff1a; 修改復制過去的工…

Linux服務器lvm磁盤管理fdisk和df磁盤大小不同修改

服務器端由于硬盤是通過VCenter原來100G磁盤復制的虛擬機,復制完成后,原來100G的磁盤通過選擇 磁盤重新復制出150G的磁盤,開機后發現還是原來的100G的磁盤,通過fdisk -l 查看有個sdb是150G, 但是已經劃轉的lvm盤只有100G, 通過df查看也是原來的100G: pvs查看pv里也是10…

用c++實現快速排序、最大子段和問題

6.2.2 快速排序 【問題】快速排序(quick sort)的分治策略如下&#xff08;圖6-5)。 (1)劃分&#xff1a;&#xff08;選定一個記錄作為軸值&#xff0c;以軸值為基準將整個序列劃分為兩個子序列&#xff0c;軸值的位置在劃分的過程中確定&#xff0c;并且左側子序列的所有記錄…

26 分鐘驚訝世界,GPT-4o 引領未來人機交互

前言 原文鏈接&#xff1a;OpenAI最新模型——GPT-4o&#xff0c;實時語音視頻交互&#xff0c;未來人機交互近在眼前 - Kaiho小站 北京時間 5 月 14 日凌晨&#xff0c;OpenAI 發布新一代模型——GPT-4o&#xff0c;僅在 ChatGPT 面世 17 個月后&#xff0c;OpenAI 再次通過…

qt的udp通訊

QString mylocalip; const QList interfaces QNetworkInterface::allInterfaces(); foreach(QNetworkInterface ip, interfaces) { if (ip.humanReadableName() QStringLiteral(“以太網”)) { //if (ip.type() QNetworkInterface::Ethernet) { const QList iplist ip.addr…

【EasyX】快速入門——靜態圖形篇

1.基本說明 EasyX 是針對 C 的圖形庫&#xff0c;可以幫助 C/C 初學者快速上手圖形和游戲編程。 比如&#xff0c;可以基于 EasyX 圖形庫很快的用幾何圖形畫一個房子&#xff0c;或者一輛移動的小車&#xff0c;可以編寫俄羅斯方塊、貪吃蛇、黑白棋等小游戲&#xff0c;可以練…

Go 注釋生成 api文檔

在 Go 語言中&#xff0c;通常會使用 godoc 工具來從注釋中生成 API 文檔。godoc 是 Go 官方提供的文檔生成工具&#xff0c;它可以解析 Go 源代碼中的注釋&#xff0c;并生成在線的、可交互的文檔。 為了使用 godoc 生成 API 文檔&#xff0c;你需要遵循一些特定的注釋格式。…

使用VMware或VirtualBox安裝eNSP Pro并使用CRT連接設備

文章目錄 使用Oracle Virtual Box安裝eNSP Pro創建虛擬機配置網卡配置帶外管理網絡 使用VMware Workstation安裝eNSP Pro轉換文件格式及虛擬磁盤模式配置網卡創建虛擬機配置使用CRT連接管理設備 前一段時間是開放了eNSP Pro的賬號權限&#xff0c;但是在寫博客時&#xff0c;權…

2024OD機試卷-字符串分割(二) (java\python\c++)

題目:字符串分割(二) 題目描述 給定一個非空字符串S,其被N個‘-’分隔成N+1的子串,給定正整數K,要求除第一個子串外,其余的子串每K個字符組成新的子串,并用‘-’分隔。 對于新組成的每一個子串,如果它含有的小寫字母比大寫字母多,則將這個子串的所有 大寫字母轉換為小…

27.哀家要長腦子了!

目錄 1.316. 去除重復字母 - 力扣&#xff08;LeetCode&#xff09; 2. 1209. 刪除字符串中的所有相鄰重復項 II - 力扣&#xff08;LeetCode 哎喲 煩死了 剛剛不小心退出又沒保存 又要寫一遍 煩死了 最近刷題不得勁啊 感覺這腦子沒長一點 1.316. 去除重復字母 - 力扣&am…

(實測驗證)【移遠EC800M-CN 】GNSS功能打開和關閉關閉步驟驗證

引言 本文章使用自研“超小體積TTL轉4GGPS集成模塊”進行實測驗證&#xff1b; 一、打開GNSS功能 步驟一、通過 ATQGPSCFG 配置 GNSS 參數 &#xff08;1&#xff09;該命令用于查詢和配置 GNSS 不同的設置&#xff0c;包括 NMEA 語句輸出端口、NMEA 語句的輸出類型等。 1.1…

NSSCTF | [SWPUCTF 2021 新生賽]easyupload2.0

先傳一個普通的一句話木馬試一試 GIF89a <?php eval($_POST[shell]);?> 可以看到回顯&#xff0c;不允許上傳php文件。 使用Burpsuite抓包只修改ContentType后發現也不能繞過&#xff0c;說明服務器使用了黑名單后綴限制&#xff0c;那么我們可以使用其他的后綴代替ph…

RPA的實施過程通常包括哪些步驟?

RPA&#xff08;Robotic Process Automation&#xff09;的實施過程通常涉及一系列詳細的步驟&#xff0c;旨在確保自動化項目的成功部署和運行。以下是RPA實施過程的一般步驟&#xff1a; ### 1. 需求分析與目標設定 實施RPA的第一步是進行需求分析&#xff0c;明確企業希望通…

電路板維修【四】

【開關電源輸出電壓偏低不穩&#xff0c;用示波器立馬鎖定故障范圍】&#xff1a;https://www.bilibili.com/video/BV1pf421D73K?vd_source3cc3c07b09206097d0d8b0aefdf07958 可以用示波器查看MOS的輸出波形來查看其是否損壞&#xff1a; 電源芯片的供電電壓來回跳變&#xf…