李沐動手學習深度學習——4.2練習

1. 在所有其他參數保持不變的情況下,更改超參數num_hiddens的值,并查看此超參數的變化對結果有何影響。確定此超參數的最佳值。

通過改變隱藏層的數量,導致就是函數擬合復雜度下降,隱藏層過多可能導致過擬合,而過少導致欠擬合。
我們將層數改為128可得:
在這里插入圖片描述

2. 嘗試添加更多的隱藏層,并查看它對結果有何影響。

過擬合,導致測試機精確度下降。

3. 改變學習速率會如何影響結果?保持模型架構和其他超參數(包括輪數)不變,學習率設置為多少會帶來最好的結果?

過高的學習率導致,梯度跨度過大,使得降低不到對應的駐點。
過低的學習率導致訓練緩慢,需要增加epoch。
在訓練輪數不變的情況下,我們可以通過for 設置不同的學習率找出最合適的學習率。一般來說設置為0.01或者0.1足以

4. 通過對所有超參數(學習率、輪數、隱藏層數、每層的隱藏單元數)進行聯合優化,可以得到的最佳結果是什么?

跑了一次學習率lr=0.01的情況:
在這里插入圖片描述

需要大量的訓練,但是目前我訓練結果是學習率lr=0.1、輪數是num_epochs=10,隱藏層數為1,隱藏層數單元num_hiddens=128。

5. 描述為什么涉及多個超參數更具挑戰性。

因為組合的情況更多,當層數越多時,訓練時間也更多,這玩意就是煉丹了,看你自己的GPU還有時間、運氣。

6. 如果想要構建多個超參數的搜索方法,請想出一個聰明的策略。

套用for 循環暴力破解,時間上肯定慢的要死,我們可以先固定其他變量,挑選一個變量尋找最優解,以此類推對所有的超參數這樣使用,但是這種做法肯定不是最優的,只是能夠較好的找出比較好的超參數。

由于學校窮逼所以沒有閑置GPU服務器,所有的模型只能在colab上進行運行,其中遇到了d2l的版本對應問題,所以對于d2l.train_ch3跑不起來,只能使用自寫進行替代如下:

import torch.nn
from d2l import torch as d2l
from IPython import displayclass Accumulator:"""在n個變量上累加"""def __init__(self, n):self.data = [0.0] * n       # 創建一個長度為 n 的列表,初始化所有元素為0.0。def add(self, *args):           # 累加self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):                # 重置累加器的狀態,將所有元素重置為0.0self.data = [0.0] * len(self.data)def __getitem__(self, idx):     # 獲取所有數據return self.data[idx]def accuracy(y_hat, y):"""計算正確的數量:param y_hat::param y::return:"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)            # 在每行中找到最大值的索引,以確定每個樣本的預測類別cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy(net, data_iter):"""計算指定數據集的精度:param net::param data_iter::return:"""if isinstance(net, torch.nn.Module):net.eval()                  # 通常會關閉一些在訓練時啟用的行為metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]class Animator:"""在動畫中繪制數據"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量的繪制多條線if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函數捕獲參數self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):"""向圖表中添加多個數據點:param x::param y::return:"""if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)def train_epoch_ch3(net, train_iter, loss, updater):"""訓練模型一輪:param net:是要訓練的神經網絡模型:param train_iter:是訓練數據的數據迭代器,用于遍歷訓練數據集:param loss:是用于計算損失的損失函數:param updater:是用于更新模型參數的優化器:return:"""if isinstance(net, torch.nn.Module):  # 用于檢查一個對象是否屬于指定的類(或類的子類)或數據類型。net.train()# 訓練損失總和, 訓練準確總和, 樣本數metric = Accumulator(3)for X, y in train_iter:  # 計算梯度并更新參數y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):  # 用于檢查一個對象是否屬于指定的類(或類的子類)或數據類型。# 使用pytorch內置的優化器和損失函數updater.zero_grad()l.mean().backward()  # 方法用于計算損失的平均值updater.step()else:# 使用定制(自定義)的優化器和損失函數l.sum().backward()updater(X.shape())metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回訓練損失和訓練精度return metric[0] / metric[2], metric[1] / metric[2]def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):"""訓練模型():param net::param train_iter::param test_iter::param loss::param num_epochs::param updater::return:"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):trans_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, trans_metrics + (test_acc,))train_loss, train_acc = trans_metricsprint(trans_metrics)def predict_ch3(net, test_iter, n=6):"""進行預測:param net::param test_iter::param n::return:"""global X, yfor X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + "\n" + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])d2l.plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/714605.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/714605.shtml
英文地址,請注明出處:http://en.pswp.cn/news/714605.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Git多人合作的推送流程

多人合作時,使用Git進行代碼推動(push)需要一定的協調和規范,以確保代碼庫的整體健康。以下是一個常見的多人合作時的Git代碼推動流程: 同步主分支: 在推送之前,確保你的本地主分支(…

【Java】四大函數式接口

消費型接口Consumer 消費型接口接收一個輸入,沒有返回值 在stream流計算中 forEach() 接收一個消費型接口Consumer用于 遍歷元素 /*** 消費型接口* 接收一個輸入,沒有返回值*/ public class demo01 {public static void main(String[] args) {//TODO 消…

【MySQL】表的內連和外連(重點)

表的連接分為內連和外連。 一、內連接 內連接實際上就是利用 where 子句對兩種表形成的笛卡兒積進行篩選,前面學習的查詢都是內連接,也是在開發過程中使用的最多的連接查詢。 select 字段 from 表1 inner join 表2 on 連接條件 and 其他條件; 注意&…

【數倉】Hadoop集群配置常用參數說明

Hadoop集群中,需要配置的文件主要包括四個 配置核心Hadoop參數: 編輯core-site.xml文件,設置Hadoop集群的基本參數,如文件系統、Hadoop臨時目錄等。 配置HDFS參數: 編輯hdfs-site.xml文件,設置HDFS的相關參…

策略開發:EMA如何計算

EMA的計算原理 EMA 是MA(平滑移動平均線)的另一種形式。全名“加權指數移動平均線”。 2/13就是12日移動平均線的平滑因子,他的意思是指:給予新價格 2/13的權重,給予過去的EMA 11/13的權重。 在計算的時候第一天的M…

Linux使用基礎命令

1.常用系統工作命令 (1).用echo命令查看SHELL變量的值 qiangziqiangzi-virtual-machine:~$ echo $SHELL /bin/bash(2).查看本機主機名 qiangziqiangzi-virtual-machine:~$ echo $HOSTNAME qiangzi-virtual-machine (3).date命令用于顯示/設置系統的時間或日期 qiangziqian…

Linux多線程服務端編程:使用muduo C++網絡庫 學習筆記 附錄B 從《C++ Primer(第4版)》入手學習C++

這是作者為《C Primer(第4版)(評注版)》寫的序言,文中“本書”指的是這本書評注版。 B.1 為什么要學習C 2009年本書作者Stanley Lippman先生應邀來華參加上海祝成科技舉辦的C技術大會,他表示人們現在還用…

MySQL存儲過程和Function

一、存儲過程 MySQL中提供存儲過程和存儲函數機制,將其統稱為存儲程序。 SQL語句要先編譯,然后執行,存儲程序是一組為了完成特定功能的SQL語句,編譯后存到數據庫中。 用戶通過指定存儲程序的名字并給定參數來調用才會執行。 存…

擴展學習|大數據分析的現狀和分類

文獻來源:[1] Mohamed A , Najafabadi M K , Wah Y B ,et al.The state of the art and taxonomy of big data analytics: view from new big data framework[J].Artificial Intelligence Review: An International Science and Engineering Journal, 2020(2):53. 下…

藍橋杯(3.2)

1209. 帶分數 import java.io.*;public class Main {static BufferedReader br new BufferedReader(new InputStreamReader(System.in));static PrintWriter pw new PrintWriter(new OutputStreamWriter(System.out));static final int N 10;static int n, cnt;static int[…

LabVIEW流量控制系統

LabVIEW流量控制系統 為響應水下航行體操縱舵翼環量控制技術的試驗研究需求,通過LabVIEW開發了一套小量程流量控制系統。該系統能夠滿足特定流量控制范圍及精度要求,展現了其在實驗研究中的經濟性、可靠性和實用性,具有良好的推廣價值。 項…

tritonserver學習之八:redis_caches實踐

tritonserver學習之一:triton使用流程 tritonserver學習之二:tritonserver編譯 tritonserver學習之三:tritonserver運行流程 tritonserver學習之四:命令行解析 tritonserver學習之五:backend實現機制 tritonserv…

【C++初階】內存管理

目錄 一.C語言中的動態內存管理方式 二.C中的內存管理方式 1.new/delete操作內置類型 2.new和delete操作自定義類型 3.淺識拋異常 (內存申請失敗) 4.new和delete操作自定義類型 三.new和delete的實現原理 1.內置類型 2.自定義類型 一.C語…

C++學習筆記:二叉搜索樹

二叉搜索樹 什么是二叉搜索樹?搜索二叉樹的操作查找插入刪除 二叉搜索樹的應用二叉搜索樹的代碼實現K模型:KV模型 二叉搜索樹的性能怎么樣? 什么是二叉搜索樹? 二叉搜索樹又稱二叉排序樹,它或者是一棵空樹,或者是具有以下性質的二叉樹: 若它的左子樹…

Linux安裝Nginx詳細步驟

1、創建兩臺虛擬機,分別為主機和從機,區別兩臺虛擬機的IP地址 2、將Nginx素材內容上傳到/usr/local目錄(pcre,zlib,openssl,nginx) 附件 3、安裝pcre庫   3.1 cd到/usr/local目錄 3.2 tar -zxvf pcre-8.36.tar.gz 解壓 3.3 cd…

MATLAB圖像噪聲添加與濾波

在 MATLAB 中添加圖像噪聲和進行濾波通常使用以下函數: 添加噪聲:可以使用imnoise函數向圖像添加各種類型的噪聲,如高斯噪聲、椒鹽噪聲等。 濾波:可以使用各種濾波器對圖像進行濾波處理,例如中值濾波、高斯濾波等。 …

前端學習、HTML

html是由一些標簽構成的,標簽之間可以嵌套,每個標簽都有開始標簽和結束標簽,也有部分標簽只有開始標簽,沒有結束標簽。html的標簽也可以成為元素。(樹形結構) html文件的最頂層標簽就是html。 head用來放…

**藍橋OJ 178全球變暖 DFS

藍橋OJ 178全球變暖 思路: 將每一座島嶼用一個顏色scc代替, 用dx[]和dy[]判斷他的上下左右是否需要標記顏色,如果已經標記過顏色或者是海洋就跳過.后面的淹沒,實際上就是哪個塊上下左右有陸地,那么就不會被淹沒,我用一個tag標記,如果上下左右一旦有海洋,tag就變為false.如果tag…

用冒泡排序模擬C語言中的內置快排函數qsort!

目錄 ?編輯 1.回調函數的介紹 2. 回調函數實現轉移表 3. 冒泡排序的實現 4. qsort的介紹和使用 5. qsort的模擬實現 6. 完結散花 悟已往之不諫,知來者猶可追 創作不易,寶子們!如果這篇文章對你們有幫助的話,別忘了給個免…

機器學習:模型評估和模型保存

一、模型評估 from sklearn.metrics import accuracy_score, confusion_matrix, classification_report# 使用測試集進行預測 y_pred model.predict(X_test)# 計算準確率 accuracy accuracy_score(y_test, y_pred) print(f"Accuracy: {accuracy*100:.2f}%")# 打印…