Softmax回歸

一、Softmax回歸關鍵思想

1、回歸問題和分類問題的區別

? ? ? ?Softmax回歸雖然叫“回歸”,但是它本質是一個分類問題。回歸是估計一個連續值,而分類是預測一個離散類別。

2、Softmax回歸模型

???????Softmax回歸跟線性回歸一樣將輸入特征與權重做線性疊加。與線性回歸的一個主要不同在于,Softmax回歸的輸出值個數等于標簽里的類別數。比如一共有4種特征和3種輸出動物類別(貓、狗、豬),則權重包含12個標量(帶下標的$w$),偏差包含3個標量(帶下標的$b$),且對每個輸入計算$ O_1,O_2,O_3 $這三個輸出:

$ \begin{aligned} o_1 &= x_1 w_{11} + x_2 w_{12} + x_3 w_{13} + x_4 w_{14} + b_1,\\ o_2 &= x_1 w_{21} + x_2 w_{22} + x_3 w_{23} + x_4 w_{24} + b_2,\\ o_3 &= x_1 w_{31} + x_2 w_{32} + x_3 w_{33} + x_4 w_{34} + b_3. \end{aligned} $

最后,再對這些輸出值進行Softmax函數運算

???????softmax回歸同線性回歸一樣,也是一個單層神經網絡。由于每個輸出$ O_1,O_2,O_3 $的計算都要依賴于所有的輸入$ X_1,X_2,X_3,X_4 $,所以softmax回歸的輸出層也是一個全連接層。

3、Softmax函數

???????Softmax用于多分類過程中,它將多個神經元的輸出(比如$ O_1,O_2,O_3 $)映射到(0,1)區間內,可以看成概率來理解,從而來進行多分類!它通過下式將輸出值變換成值為正且和為1的概率分布:

$\widehat{y_1},\widehat{y_2},\widehat{y_3} = \mathrm{softmax}(o_1,o_2,o_3)$

其中:

$ \widehat{y}_j=\frac{\exp \left( o_1 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $,?$ \widehat{y}_j=\frac{\exp \left( o_2 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $,?$ \widehat{y}_j=\frac{\exp \left( o_3 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $

???????容易看出?$ \widehat{y_1}+\widehat{y_2}+\widehat{y_3}=1 $?且?$ \widehat{y_1}+\widehat{y_2}+\widehat{y_3}=1 $,因此?$ \widehat{y_1},\widehat{y_2},\widehat{y_3} $?是一個合法的概率分布。此外,我們注意到:

$ arg\max\text{\ }o_i=arg\max\text{\ }\widehat{y_i} $

?因此softmax運算不改變預測類別輸出。

? ? ? ?下圖可以更好的理解Softmax函數,其實就是取自然常數e的指數相加后算比例,由于自然常數的指數($ e^x $)在$ \left( -\infty ,+\infty \right) $單調遞增,因此softmax運算不改變預測類別輸出。

4、交叉熵損失函數

? ? ? ?假設我們希望根據圖片動物的輪廓、顏色等特征,來預測動物的類別,有三種可預測類別:貓、狗、豬。假設我們當前有兩個模型(參數不同),這兩個模型都是通過sigmoid/softmax的方式得到對于每個預測結果的概率值:

模型1:

模型1
預測真實是否正確
0.30.30.4001正確
0.30.40.3010正確
0.10.20.7100錯誤

???????模型評價:模型1對于樣本1和樣本2以非常微弱的優勢判斷正確,對于樣本3的判斷則徹底錯誤。

模型2:

模型2
預測真實是否正確
0.10.20.7001正確
0.10.70.2010正確
0.30.40.3100錯誤

???????模型評價:模型2對于樣本1和樣本2判斷非常準確,對于樣本3判斷錯誤,但是相對來說沒有錯得太離譜。

???????好了,有了模型之后,我們需要通過定義損失函數來判斷模型在樣本上的表現了,那么我們可以定義哪些損失函數呢?我們可以先嘗試使用以下幾種損失函數,然后討論哪種效果更好。

(1)Classification Error(分類錯誤率)

???????最為直接的損失函數定義為:

$ classification\ error=\frac{count\ of\ error\ items}{count\ of\ all\ items} $

模型1:$ classification\ error=\frac{1}{3} $

模型2:$ classification\ error=\frac{2}{3} $

???????我們知道,模型1模型2雖然都是預測錯了1個,但是相對來說模型2表現得更好,損失函數值照理來說應該更小,但是,很遺憾的是,classification error?并不能判斷出來,所以這種損失函數雖然好理解,但表現不太好。

(2)Mean Squared Error(均方誤差MSE)

???????均方誤差損失也是一種比較常見的損失函數,其定義為:

$ MSE=\frac{1}{n}\sum_i^n{\left( \widehat{y_i}-y_i \right) ^2} $

模型1:

對所有樣本的loss求平均:

模型2:

對所有樣本的loss求平均:

???????我們發現,MSE能夠判斷出來模型2優于模型1,那為什么不采樣這種損失函數呢?主要原因是在分類問題中,使用sigmoid/softmx得到概率,配合MSE損失函數時,采用梯度下降法進行學習時,會出現模型一開始訓練時,學習速率非常慢的情況(損失函數 | Mean-Squared Loss - 知乎)。

???????有了上面的直觀分析,我們可以清楚的看到,對于分類問題的損失函數來說,分類錯誤率和均方誤差損失都不是很好的損失函數,下面我們來看一下交叉熵損失函數的表現情況。

(3)Cross Entropy Loss Function(交叉熵損失函數)

其中:

$M$:類別的數量

$ y_{ic} $:符號函數(0或1),如果樣本 i 的真實類別等于 c 取 1,否則取 0

$ p_{ic} $:觀測樣本 i 屬于類別 c 的預測概率

$N$:樣本的數量

現在我們利用這個表達式計算上面例子中的損失函數值:

模型1

對所有樣本的loss求平均:

模型2:

對所有樣本的loss求平均:

???????可以發現,交叉熵損失函數可以捕捉到模型1和模型2預測效果的差異,因此對于Softmax回歸問題我們常用交叉熵損失函數。

? ? ? 下面兩圖可以很清晰的反應整個Softmax回歸算法的流程:

二、圖像分類數據集

???????MNIST數據集是圖像分類中廣泛使用的數據集之一,但作為基準數據集過于簡單。我們將使用類似但更復雜的Fashion-MNIST數據集。

???????在這里我們定義一些函數用于數據的讀取與顯示,這些函數已經在Python包d2l中定義好了,但為了便于大家理解,這里沒有直接調用d2l中的函數。

1、讀取數據集

???????我們可以通過框架中的內置函數將Fashion-MNIST數據集下載并讀取到內存中。

# 通過ToTensor實例將圖像數據從PIL類型變換成32位浮點數格式,
# 并除以255使得所有像素的數值均在0~1之間
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

???????Fashion-MNIST由10個類別的圖像組成,每個類別由訓練數據集(train dataset)中的6000張圖像和測試數據集(test dataset)中的1000張圖像組成。因此,訓練集和測試集分別包含60000和10000張圖像。測試數據集不會用于訓練,只用于評估模型性能。

print(len(mnist_train), len(mnist_test))
60000 10000

???????每個輸入圖像的高度和寬度均為28像素。數據集由灰度圖像組成,其通道數為1。為了簡潔起見,本書將高度$h$像素、寬度$w$像素圖像的形狀記為$h \times w$($h$,$w$)。接下來我們可以打印一下mnist_train的類型和mnist_train的第一個元素。

print(type(mnist_train))
print(type(mnist_train[0]))
print(mnist_train[0])
print(mnist_train[0][0].shape)

???????可以看出mnist_train的類型為<class 'torchvision.datasets.mnist.FashionMNIST'>。mnist_train的第一個元素的類型是<class 'tuple'>,是一個元組,元組第一個元素是轉化為tensor后的灰度值,第二個元素是圖像所屬類別index,這里是9。因為是灰度圖,因此channel數量為1,圖片長和寬都是28,因此形狀是(1,28,28)。

???????Fashion-MNIST中包含的10個類別,分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)

???????以下函數用于在數字標簽索引及其文本名稱之間進行轉換。

def get_fashion_mnist_labels(labels):   # labels:mnist_train和mnist_test里面圖像的類別index(數字)"""返回Fashion-MNIST數據集的文本標簽"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]    # 根據index返回文本標簽列表('t-shirt', 'trouser'...)

???????我們現在可以創建一個函數來可視化這些樣本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""繪制圖像列表""""""imgs: tensor向量num_rows: 畫圖時的行數num_cols: 畫圖時的列數titles: 每張圖片的標題scales: 因為要將num_rows*num_cols張圖片畫到一張圖上,并且還要添加一些文字,因此需要對大圖進行一定的縮放才能保證每張小圖之間的間隙"""figsize = (num_cols * scale, num_rows * scale)# figsize = (num_cols, num_rows)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 圖片張量ax.imshow(img.numpy())else:# PIL圖片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

???????以下是訓練數據集中前18個樣本的圖像及其相應的標簽。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

2、讀取小批量數據

???????為了使我們在讀取訓練集和測試集時更容易,我們使用內置的數據迭代器,而不是從零開始創建。在每次迭代中,數據加載器每次都會讀取一小批量數據,大小為`batch_size`。通過內置數據迭代器,我們可以隨機打亂所有樣本,從而無偏見地讀取小批量。

batch_size = 256def get_dataloader_workers():  #@save"""使用4個進程來讀取數據"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

3、整合所有組件

???????現在我們定義`load_data_fashion_mnist`函數,用于獲取和讀取Fashion-MNIST數據集。這個函數返回訓練集和驗證集的數據迭代器。此外,這個函數還接受一個可選參數`resize`,用來將圖像大小調整為另一種形狀。

def load_data_fashion_mnist(batch_size, resize=None):"""下載Fashion-MNIST數據集,然后將其加載到內存中"""trans = [transforms.ToTensor()]    # 此時的trans是一個列表if resize:trans.insert(0, transforms.Resize(resize))    # 如果提供了resize參數,則在轉換鏈中插入Resize操作trans = transforms.Compose(trans)    # 將一系列的圖像轉換操作組合成一個轉換鏈。# trans是一個由多個圖像轉換操作組成的列表。它按照列表中的順序依次應用這些轉換操作。# 這樣可以將多個轉換操作組合在一起,以便在加載數據時一次性應用它們。mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

???????下面,我們通過指定`resize`參數來測試`load_data_fashion_mnist`函數的圖像大小調整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

三、softmax回歸的從零開始實現

...

參考文獻

[1]??損失函數|交叉熵損失函數

[2]??深度學習模型系列一——多分類模型——Softmax 回歸-CSDN博客

[3]??Softmax 回歸_嗶哩嗶哩_bilibili

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/214479.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/214479.shtml
英文地址,請注明出處:http://en.pswp.cn/news/214479.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linux安裝Nginx并部署Vue項目

今天部署了一個Vue項目到阿里云的云服務器上&#xff0c;現記錄該過程。 1. 修改Vue項目配置 我們去項目中發送axios請求的文件里更改一下后端的接口路由&#xff1a; 2. 執行命令打包 npm run build ### 或者 yarn build 打包成功之后&#xff0c;我們會看到一個dist包&a…

[MySQL]SQL優化之索引的使用規則

&#x1f308;鍵盤敲爛&#xff0c;年薪30萬&#x1f308; 目錄 一、索引失效 &#x1f4d5;最左前綴法則 &#x1f4d5;范圍查詢> &#x1f4d5;索引列運算&#xff0c;索引失效 &#x1f4d5;前模糊匹配 &#x1f4d5;or連接的條件 &#x1f4d5;字符串類型不加 …

110. 平衡二叉樹(Java)

給定一個二叉樹&#xff0c;判斷它是否是高度平衡的二叉樹。 本題中&#xff0c;一棵高度平衡二叉樹定義為&#xff1a; 一個二叉樹每個節點 的左右兩個子樹的高度差的絕對值不超過 1 。 示例 1&#xff1a; 輸入&#xff1a;root [3,9,20,null,null,15,7] 輸出&#xff1a;t…

如何通過SPI控制Peregrine的數控衰減器

概要 Peregrine的數控衰減器PE4312是6位射頻數字步進衰減器(DSA,Digital Step Attenuator)工作頻率覆蓋1MHz~4GHz,插入損耗2dB左右,衰減步進0.5dB,最大衰減量為31.5dB,高達59dBm的IIP3提供了良好的動態性能,切換時間0.5微秒,供電電源2.3V~5.5V,邏輯控制兼容1.8V,20…

?如何使用https://www.krea.ai/來實現文生圖,圖生圖,

網址&#xff1a;https://www.krea.ai/apps/image/realtime Krea.ai 是一個強大的人工智能藝術生成器&#xff0c;可用于創建各種創意內容。它可以用來生成文本描述的圖像、將圖像轉換為其他圖像&#xff0c;甚至寫博客文章。 文本描述生成圖像 要使用 Krea.ai 生成文本描述…

設計模式——建造者模式(Java示例)

引言 生成器是一種創建型設計模式&#xff0c; 使你能夠分步驟創建復雜對象。 與其他創建型模式不同&#xff0c; 生成器不要求產品擁有通用接口。 這使得用相同的創建過程生成不同的產品成為可能。 復雜度&#xff1a; 中等 流行度&#xff1a; 流行 使用示例&#xff1a…

【conda】利用Conda創建虛擬環境,Pytorch各版本安裝教程(Ubuntu)

TOC conda 系列&#xff1a; 1. conda指令教程 2. 利用Conda創建虛擬環境&#xff0c;安裝Pytorch各版本教程(Ubuntu) 1. 利用Conda創建虛擬環境 nolonolo:~/sun/SplaTAM$ conda create -n splatam python3.10查看結果&#xff1a; (splatam) nolonolo:~/sun/SplaTAM$ cond…

Java 中的 Deque 接口及其用途

文章目錄 Deque 介紹Deque 使用雙端隊列普通隊列棧 總結 在 Java 中&#xff0c;Deque 接口是一個雙端隊列&#xff08;double-ended queue&#xff09;的數據結構&#xff0c;它支持在兩端插入和移除元素。Deque 是 “Double Ended Queue” 的縮寫&#xff0c;而且它可以同時充…

Linux系統編程(一):基本概念

參考引用 Unix和Linux操作系統有什么區別&#xff1f;一文帶你徹底搞懂posix Linux系統編程&#xff08;文章鏈接匯總&#xff09; 1. Unix 和 Linux 1.1 Unix Unix 操作系統誕生于 1969 年&#xff0c;貝爾實驗室發布了一個用 C 語言編寫的名為「Unix」的操作系統&#xff0…

【基于LSTM的電商評論情感分析:Flask與Sklearn的完美結合】

基于LSTM的電商評論情感分析&#xff1a;Flask與Sklearn的完美結合 引言數據集與爬取數據處理與可視化情感分析模型構建Flask應用搭建詞云展示創新點結論 引言 在當今數字化時代&#xff0c;電商平臺上涌現出大量的用戶評論數據。了解和分析這些評論對于企業改進產品、服務以及…

?expect命令運用于bash?

目錄 ?expect命令運用于bash? expect使用原理 expet使用場景 常用的expect命令選項 Expect腳本的結尾 常用的expect命令選參數 Expect執行方式 單一分支語法 多分支模式語法第一種 多分支模式語法第二種 在shell 中嵌套expect Shell Here Document&#xff08;內…

基于Java實驗室管理系統

基于Java實驗室管理系統 功能需求 1、實驗室設備管理&#xff1a;系統需要提供實驗室設備管理功能&#xff0c;包括設備的查詢、預訂、使用記錄、維護和報廢等。 2、實驗項目管理&#xff1a;系統需要提供實驗項目管理功能&#xff0c;包括項目的創建、審批、執行和驗收等&a…

以太坊:前世今生與未來

一、引言 以太坊&#xff0c;這個在區塊鏈領域大放異彩的名字&#xff0c;似乎已經成為了去中心化應用&#xff08;DApps&#xff09;的代名詞。從初期的萌芽到如今的繁榮發展&#xff0c;以太坊經歷了一段曲折而精彩的旅程。讓我們一起回顧一下以太坊的前世今生&#xff0c;以…

樹實驗代碼

哈夫曼樹 #include <stdio.h> #include <stdlib.h> #define MAXLEN 100typedef struct {int weight;int lchild, rchild, parent; } HTNode;typedef HTNode HT[MAXLEN]; int n;void CreatHFMT(HT T); void InitHFMT(HT T); void InputWeight(HT T); void SelectMi…

【算法專題】分治 - 快速排序

分治 - 快速排序 分治 - 快速排序1. 顏色分類2. 排序數組(快速排序)3. 數組中的第K個最大元素4. 庫存管理Ⅲ5. 排序數組(歸并排序)6. 交易逆序對的總數7. 計算右側小于當前元素的個數8. 翻轉對 分治 - 快速排序 1. 顏色分類 做題鏈接 -> Leetcode -75.顏色分類 題目&…

【華為數據之道學習筆記】3-5 規則數據治理

在業務規則管理方面&#xff0c;華為經常面對“各種業務場景業務規則不同&#xff0c;記不住&#xff0c;找不到”“大量規則在政策、流程等文件中承載&#xff0c;難以遵守”“各國規則均不同&#xff0c;IT能否一國一策、快速上線”等問題。 規則數據是結構化描述業務規則變量…

【Qt開發流程】之UI風格、預覽及QPalette使用

概述 一個優秀的應用程序不僅要有實用的功能&#xff0c;還要有一個漂亮美膩的外觀&#xff0c;這樣才能使應用程序更加友善、操作性良好&#xff0c;更加符合人體工程學。作為一個跨平臺的UI開發框架&#xff0c;Qt提供了強大而且靈活的界面外觀設計機制&#xff0c;能夠幫助…

利用Rclone將阿里云對象存儲遷移至雨云對象存儲的教程,對象存儲數據遷移教程

使用Rclone將阿里云對象存儲(OSS)的文件全部遷移至雨云對象存儲(ROS)的教程&#xff0c;其他的對象存儲也可以參照本教程。 Rclone簡介 Rclone 是一個用于和同步云平臺同步文件和目錄命令行工具。采用 Go 語言開發。 它允許在文件系統和云存儲服務之間或在多個云存儲服務之間…

STM32-EXTI外部中斷

目錄 一、中斷系統 二、STM32中斷 三、NVIC&#xff08;嵌套中斷向量控制器&#xff09;基本結構 四、NVIC優先級分組 五、EXTI外部中斷 5.1 外部中斷基本知識 5.2 外部中斷&#xff08;EXTI&#xff09;基本結構 ?編輯 5.2.1開發步驟&#xff1a; 5.3 AFIO復用IO口…

ADAudit Plus:強大的網絡安全衛士

隨著數字化時代的不斷發展&#xff0c;企業面臨著越來越復雜和多樣化的網絡安全威脅。在這個信息爆炸的時代&#xff0c;保護組織的敏感信息和確保網絡安全已經成為企業發展不可或缺的一環。為了更好地管理和監控網絡安全&#xff0c;ADAudit Plus應運而生&#xff0c;成為網絡…