「日拱一碼」039 機器學習-訓練時間VS精確度

目錄

時間-精度權衡曲線(不同模型復雜度)

訓練與驗證損失對比

帕累托前沿分析(3D)?


在機器學習實踐中,理解模型收斂所需時間及其與精度的關系至關重要。下面介紹如何分析模型收斂時間與精度之間的權衡,并找到最佳性價比點。

時間-精度權衡曲線(不同模型復雜度)

# 1. 時間-精度曲線(不同模型復雜度)import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import time
import seaborn as sns# 設置樣式
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
sns.set_palette("viridis")X, y = make_regression(n_samples=10000, n_features=50, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 模擬不同訓練策略的結果
def simulate_training(model, max_iter=100):"""模擬訓練過程并記錄時間和精度"""train_losses = []val_losses = []times = []# 初始模型model.warm_start = Truemodel.n_estimators = 1model.fit(X_train, y_train)start_time = time.time()for i in range(1, max_iter + 1):model.n_estimators = imodel.fit(X_train, y_train)# 記錄時間elapsed = time.time() - start_timetimes.append(elapsed)# 記錄訓練損失train_pred = model.predict(X_train)train_loss = mean_squared_error(y_train, train_pred)train_losses.append(train_loss)# 記錄驗證損失val_pred = model.predict(X_test)val_loss = mean_squared_error(y_test, val_pred)val_losses.append(val_loss)return np.array(times), np.array(train_losses), np.array(val_losses)model_simple = GradientBoostingRegressor(learning_rate=0.1, max_depth=3, random_state=42)
model_medium = GradientBoostingRegressor(learning_rate=0.05, max_depth=5, random_state=42)
model_complex = GradientBoostingRegressor(learning_rate=0.01, max_depth=8, random_state=42)# 模擬訓練過程
times_simple, train_loss_simple, val_loss_simple = simulate_training(model_simple, 100)
times_medium, train_loss_medium, val_loss_medium = simulate_training(model_medium, 100)
times_complex, train_loss_complex, val_loss_complex = simulate_training(model_complex, 100)plt.figure(figsize=(12, 8))
plt.plot(times_simple, val_loss_simple, 'o-', label='簡單模型 (depth=3, lr=0.1)', alpha=0.7)
plt.plot(times_medium, val_loss_medium, 's-', label='中等模型 (depth=5, lr=0.05)', alpha=0.7)
plt.plot(times_complex, val_loss_complex, 'd-', label='復雜模型 (depth=8, lr=0.01)', alpha=0.7)# 標記最佳平衡點(最大曲率點)
def find_elbow_point(times, losses):"""找到曲線的拐點(最大曲率點)"""# 計算一階導數(梯度)grad = np.gradient(losses, times)# 計算二階導數grad2 = np.gradient(grad, times)# 計算曲率curvature = np.abs(grad2) / (1 + grad ** 2) ** 1.5# 找到最大曲率點(排除前10%的點)exclude_n = max(5, int(len(curvature) * 0.1))max_idx = np.argmax(curvature[exclude_n:-5]) + exclude_nreturn times[max_idx], losses[max_idx]# 標記各模型的最佳點
elbow_time_simple, elbow_loss_simple = find_elbow_point(times_simple, val_loss_simple)
elbow_time_medium, elbow_loss_medium = find_elbow_point(times_medium, val_loss_medium)
elbow_time_complex, elbow_loss_complex = find_elbow_point(times_complex, val_loss_complex)plt.scatter(elbow_time_simple, elbow_loss_simple, s=150, c='red', marker='*', zorder=10)
plt.scatter(elbow_time_medium, elbow_loss_medium, s=150, c='red', marker='*', zorder=10)
plt.scatter(elbow_time_complex, elbow_loss_complex, s=150, c='red', marker='*', zorder=10)plt.annotate('最佳平衡點',(elbow_time_medium, elbow_loss_medium),xytext=(elbow_time_medium + 0.5, elbow_loss_medium + 5),arrowprops=dict(facecolor='red', shrink=0.05, width=2),fontsize=12, color='red')plt.xlabel('訓練時間 (秒)', fontsize=12)
plt.ylabel('驗證集 MSE', fontsize=12)
plt.title('不同復雜度模型的時間-精度權衡', fontsize=14)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

訓練與驗證損失對比

# 2. 訓練與驗證損失對比
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import time
import seaborn as sns# 設置樣式
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
sns.set_palette("viridis")X, y = make_regression(n_samples=10000, n_features=50, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 模擬不同訓練策略的結果
def simulate_training(model, max_iter=100):"""模擬訓練過程并記錄時間和精度"""train_losses = []val_losses = []times = []# 初始模型model.warm_start = Truemodel.n_estimators = 1model.fit(X_train, y_train)start_time = time.time()for i in range(1, max_iter + 1):model.n_estimators = imodel.fit(X_train, y_train)# 記錄時間elapsed = time.time() - start_timetimes.append(elapsed)# 記錄訓練損失train_pred = model.predict(X_train)train_loss = mean_squared_error(y_train, train_pred)train_losses.append(train_loss)# 記錄驗證損失val_pred = model.predict(X_test)val_loss = mean_squared_error(y_test, val_pred)val_losses.append(val_loss)return np.array(times), np.array(train_losses), np.array(val_losses)model_simple = GradientBoostingRegressor(learning_rate=0.1, max_depth=3, random_state=42)
model_medium = GradientBoostingRegressor(learning_rate=0.05, max_depth=5, random_state=42)
model_complex = GradientBoostingRegressor(learning_rate=0.01, max_depth=8, random_state=42)# 模擬訓練過程
times_simple, train_loss_simple, val_loss_simple = simulate_training(model_simple, 100)
times_medium, train_loss_medium, val_loss_medium = simulate_training(model_medium, 100)
times_complex, train_loss_complex, val_loss_complex = simulate_training(model_complex, 100)# 標記最佳平衡點(最大曲率點)
def find_elbow_point(times, losses):"""找到曲線的拐點(最大曲率點)"""# 計算一階導數(梯度)grad = np.gradient(losses, times)# 計算二階導數grad2 = np.gradient(grad, times)# 計算曲率curvature = np.abs(grad2) / (1 + grad ** 2) ** 1.5# 找到最大曲率點(排除前10%的點)exclude_n = max(5, int(len(curvature) * 0.1))max_idx = np.argmax(curvature[exclude_n:-5]) + exclude_nreturn times[max_idx], losses[max_idx]elbow_time_simple, elbow_loss_simple = find_elbow_point(times_simple, val_loss_simple)
elbow_time_medium, elbow_loss_medium = find_elbow_point(times_medium, val_loss_medium)
elbow_time_complex, elbow_loss_complex = find_elbow_point(times_complex, val_loss_complex)plt.figure(figsize=(12, 8))
plt.plot(times_medium, train_loss_medium, 'b-', label='訓練損失')
plt.plot(times_medium, val_loss_medium, 'r-', label='驗證損失')# 標記過擬合開始點
diff = val_loss_medium - train_loss_medium
overfit_idx = np.where(diff > np.percentile(diff, 90))[0][0]
overfit_time = times_medium[overfit_idx]
overfit_val_loss = val_loss_medium[overfit_idx]plt.axvline(x=overfit_time, color='purple', linestyle='--', alpha=0.7)
plt.scatter(overfit_time, overfit_val_loss, s=150, c='purple', marker='x')plt.annotate('過擬合開始點',(overfit_time, overfit_val_loss),xytext=(overfit_time+0.5, overfit_val_loss+5),arrowprops=dict(facecolor='purple', shrink=0.05, width=2),fontsize=12, color='purple')# 標記最佳點
plt.scatter(elbow_time_medium, elbow_loss_medium, s=150, c='red', marker='*', zorder=10)
plt.annotate('最佳平衡點',(elbow_time_medium, elbow_loss_medium),xytext=(elbow_time_medium+0.5, elbow_loss_medium+2),arrowprops=dict(facecolor='red', shrink=0.05, width=2),fontsize=12, color='red')plt.xlabel('訓練時間 (秒)', fontsize=12)
plt.ylabel('MSE', fontsize=12)
plt.title('訓練與驗證損失隨時間變化', fontsize=14)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

帕累托前沿分析(3D)?

# 3. 帕累托前沿分析 (3D)
# 模擬不同超參數配置的結果import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import time
import pandas as pd
import seaborn as sns# 設置樣式
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
sns.set_palette("viridis")X, y = make_regression(n_samples=10000, n_features=50, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
np.random.seed(42)
results = []
for _ in range(10):depth = np.random.randint(3, 10)lr = 10 ** np.random.uniform(-2, -0.5)n_estimators = np.random.randint(20, 200)model = GradientBoostingRegressor(max_depth=depth,learning_rate=lr,n_estimators=n_estimators,random_state=42)start_time = time.time()model.fit(X_train, y_train)train_time = time.time() - start_timey_pred = model.predict(X_test)mse = mean_squared_error(y_test, y_pred)results.append({'depth': depth,'lr': lr,'n_estimators': n_estimators,'time': train_time,'mse': mse,'cost_score': 0.7 * mse + 0.3 * np.log1p(train_time)  # 自定義成本函數})# 轉換為DataFrame
df = pd.DataFrame(results)# 找到帕累托前沿點
def is_pareto_efficient(costs):"""找到帕累托最優解"""is_efficient = np.ones(costs.shape[0], dtype=bool)for i, c in enumerate(costs):if is_efficient[i]:# 所有成本都小于等于當前點的點is_efficient[is_efficient] = np.any(costs[is_efficient] < c, axis=1)is_efficient[i] = True  # 保持當前點return is_efficientcosts = df[['mse', 'time']].values
pareto_mask = is_pareto_efficient(costs)
pareto_points = df[pareto_mask]# 3D可視化
fig = plt.figure(figsize=(14, 10))
ax = fig.add_subplot(111, projection='3d')# 繪制所有點
scatter = ax.scatter(df['depth'],df['lr'],df['n_estimators'],c=df['cost_score'],cmap='viridis',s=50,alpha=0.7,label='所有配置'
)# 繪制帕累托前沿點
pareto_scatter = ax.scatter(pareto_points['depth'],pareto_points['lr'],pareto_points['n_estimators'],c='red',s=100,edgecolors='black',label='帕累托前沿'
)# 標記最佳性價比點
best_idx = df['cost_score'].idxmin()
best_point = df.loc[best_idx]
ax.scatter(best_point['depth'],best_point['lr'],best_point['n_estimators'],c='gold',s=200,edgecolors='black',marker='*',label='最佳性價比'
)ax.set_xlabel('樹深度', fontsize=12)
ax.set_ylabel('學習率', fontsize=12)
ax.set_zlabel('樹數量', fontsize=12)
ax.set_title('超參數空間的帕累托前沿分析', fontsize=14)
ax.legend()# 添加顏色條
cbar = fig.colorbar(scatter, pad=0.1)
cbar.set_label('成本分數 (MSE + 時間成本)', fontsize=12)plt.tight_layout()
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/91249.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/91249.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/91249.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

面試刷題平臺項目總結

項目簡介&#xff1a; 面試刷題平臺是一款基于 Spring Boot Redis MySQL Elasticsearch 的 面試刷題平臺&#xff0c;運用 Druid HotKey Sa-Token Sentinel 提高了系統的性能和安全性。 第一階段&#xff0c;開發基礎的刷題平臺&#xff0c;帶大家熟悉項目開發流程&#xff…

負載均衡、算法/策略

負載均衡一、負載均衡層級對比特性四層負載均衡 (L4)七層負載均衡 (L7)工作層級傳輸層 (TCP/UDP)應用層 (HTTP/HTTPS等)決策依據源/目標IP端口URL路徑、Header、Cookie、內容等轉發方式IP地址/端口替換重建連接并深度解析報文性能更高吞吐量&#xff0c;更低延遲需內容解析&…

StackingClassifier參數詳解與示例

StackingClassifier參數詳解與示例 StackingClassifier是一種集成學習方法&#xff0c;通過組合多個基分類器的預測結果作為元分類器的輸入特征&#xff0c;從而提高整體模型性能。以下是關鍵參數的詳細說明和示例&#xff1a; 1. classifiers&#xff08;基分類器&#xff09;…

嵌入式中間件-uorb解析

uORB系統詳細解析 1. 系統概述 1.1 設計理念 uORB&#xff08;Micro Object Request Broker&#xff09;是一個專為嵌入式實時系統設計的發布-訂閱式進程間通信框架。該系統借鑒了ROS中topic的概念&#xff0c;為無人機飛控系統提供了高效、可靠的數據傳輸機制。 1.2 核心特征 …

HTTP.Client 庫對比與選擇

HTTP.Client 庫對比與選擇在 Python 中&#xff0c;除了標準庫 http.client&#xff0c;還有許多功能更強大、使用更便捷的 HTTP 庫。以下是一些常用的庫及其特點&#xff1a;1. Requests&#xff08;最流行&#xff09;特點&#xff1a;高層 API&#xff0c;簡單易用&#xff…

RabbitMQ面試精講 Day 5:Virtual Host與權限控制

【RabbitMQ面試精講 Day 5】Virtual Host與權限控制 開篇 歡迎來到"RabbitMQ面試精講"系列的第5天&#xff01;今天我們將深入探討RabbitMQ中Virtual Host與權限控制的核心機制&#xff0c;這是構建安全、隔離的消息系統必須掌握的重要知識。在面試中&#xff0c;面…

【前端實戰】純HTML+CSS+JS實現蠟筆小新無盡冒險:從零打造網頁版超級瑪麗

摘要&#xff1a;本文將詳細介紹一款完全由HTMLCSSJS實現的網頁版橫版闖關游戲——"蠟筆小新無盡冒險"。游戲采用純前端技術實現&#xff0c;無需任何外部依賴&#xff0c;完美復刻了經典超級瑪麗的核心玩法&#xff0c;并創新性地融入了蠟筆小新角色元素。通過本文&…

[工具類] 網絡請求HttpUtils

引言在現代應用程序開發中&#xff0c;網絡請求是必不可少的功能之一。無論是訪問第三方API、微服務之間的通信&#xff0c;還是請求遠程數據&#xff0c;都需要通過HTTP協議實現。在Java中&#xff0c;java.net.HttpURLConnection、Apache的HttpClient庫以及OkHttp等庫提供了豐…

基于Spring Boot的裝飾工程管理系統(源碼+論文)

一、 開發環境與技術 本章節對開發裝飾工程管理系統------項目立項子系統需要搭建的開發環境&#xff0c;以及裝飾工程管理系統------項目立項子系統開發中使用的編程技術等進行闡述。 1 開發環境 工具/環境描述操作系統Windows 10/11 或 Linux&#xff08;如 Ubuntu&#x…

【WebGPU學習雜記】數學基礎拾遺(2)變換矩陣中的齊次坐標推導與幾何理解

今天打算開始 3D 數學基礎的復習&#xff0c;本文假設你了解以下概念&#xff1a;一次多項式、矩陣、向量&#xff0c;基于以上拓展的概念 歸一化、2&#xff5e;3階矩陣的幾何意義。幾何意義結論 齊次坐標是對三維的人工的特定的升維&#xff0c;它是一個工具而已。圖形學中常…

JS前端壓縮算法——WWDHCAPOF-算法導論論文——東方仙盟算法

代碼function customCompressString(input) {// 第一步&#xff1a;將字符串轉換為ANSI碼數組并乘以位置序號let resultArray Array.from(input).map((char, index) > {const ansiCode char.charCodeAt(0);return ansiCode * (index 東方仙盟); // 位置序號從1開始});// …

linux命令less的實際應用

less 是 Linux/Unix 中交互式文件查看神器&#xff0c;相比 more 和 cat&#xff0c;它支持自由導航、搜索、高亮等強大功能&#xff0c;尤其適合處理大文件或實時日志。以下是深度應用指南&#xff1a;?一、核心優勢?less large_file.log # 秒開GB級文件&#xff08…

DAY31 整數矩陣及其運算

DAY31 整數矩陣及其運算 本次代碼通過IntMatrix類封裝了二維整數矩陣的核心操作&#xff0c;思路如下&#xff1a;數據封裝→基礎操作&#xff08;修改和獲取元素、獲取維度&#xff0c;toString返回字符串表示&#xff0c;getData返回內部數組引用&#xff09;→矩陣運算&…

飛槳深度學習環境搭建

一、安裝 PyCharm PyCharm 官網下載頁面 記得全部勾選。 二、安裝 miniconda miniconda 官網下載頁面 根據你的操作系統選擇。 記得勾選前三個。 三、安裝 CUDA 首先 nvidia-smi 查看支持最高的 CUDA 版本。 然后去 nvidia 官網下載 CUDA&#xff0c;選擇適合你的版本。 …

MySQL 8.0 OCP 1Z0-908 題目解析(37)

題目146 Choose two. Which two are true about binary logs used in asynchronous replication? □ A) The master connects to the slave and initiates log transfer. □ B) They contain events that describe all queries run on the master. □ C) They contain events …

vue element 封裝表單

背景&#xff1a; 在前端系統開發中&#xff0c;系統頁面涉及到的表單組件比較多&#xff0c;所以進行了簡單的封裝。封裝的包括一些Form表單組件&#xff0c;如下&#xff1a;input輸入框、select下拉框、等 實現效果&#xff1a; 理論知識&#xff1a; 表單組件官方鏈接&…

flutter-完美解決鍵盤彈出遮擋輸入框的問題

文章目錄1. 前言2. 借助 Scaffold 的特性自動調整3. 使用 MediaQuery 精準控制抬升高度3.1. 底部抽屜內輸入框的方案4. 注意事項5. 總結1. 前言 在 Flutter 的開發過程中&#xff0c;經常會碰到某一個頁面有個 TextField 輸入組件&#xff0c;點擊的時候鍵盤會彈起來&#xff…

機器學習筆記(四)——聚類算法KNN、Kmeans、Dbscan

寫在前面&#xff1a;寫本系列(自用)的目的是回顧已經學過的知識、記錄新學習的知識或是記錄心得理解&#xff0c;方便自己以后快速復習&#xff0c;減少遺忘。概念部分大部分來自于機器學習菜鳥教程&#xff0c;公式部分也會參考機器學習書籍、阿里云天池。機器學習如果只啃概…

【C#】事務(進程 ID 64)與另一個進程被死鎖在鎖資源上,并且已被選作死鎖犧牲品。請重新運行該事務。不能在具有唯一索引“XXX_Index”的對象“dbo.Test”中插入重復鍵的行。

&#x1f339;歡迎來到《小5講堂》&#x1f339; &#x1f339;這是《C#》系列文章&#xff0c;每篇文章將以博主理解的角度展開講解。&#x1f339; &#x1f339;溫馨提示&#xff1a;博主能力有限&#xff0c;理解水平有限&#xff0c;若有不對之處望指正&#xff01;&#…

LeetCode Hot 100 搜索二維矩陣

給你一個滿足下述兩條屬性的 m x n 整數矩陣&#xff1a;每行中的整數從左到右按非嚴格遞增順序排列。每行的第一個整數大于前一行的最后一個整數。給你一個整數 target &#xff0c;如果 target 在矩陣中&#xff0c;返回 true &#xff1b;否則&#xff0c;返回 false 。示例…