第100+12步 ChatGPT學習:R實現KNN分類

基于R 4.2.2版本演示

一、寫在前面

有不少大佬問做機器學習分類能不能用R語言,不想學Python咯。

答曰:可!用GPT或者Kimi轉一下就得了唄。

加上最近也沒啥內容寫了,就幫各位搬運一下吧。

二、R代碼實現KNN分類

(1)導入數據

我習慣用RStudio自帶的導入功能:

(2)建立KNN模型

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)# Fit KNN model on the training set
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl, preProcess = "scale")# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]# Convert true values to factor for ROC analysis
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "blue") +geom_area(alpha = 0.2, fill = "blue") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Training ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +geom_line(color = "red") +geom_area(alpha = 0.2, fill = "red") +geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +ggtitle("Validation ROC Curve") +xlab("False Positive Rate") +ylab("True Positive Rate") +annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {conf_mat_df <- as.data.frame(as.table(conf_mat))colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +geom_tile(color = "white") +geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +scale_fill_gradient(low = "white", high = "steelblue") +labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +theme_minimal() +theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))print(p)
}
# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

在R語言中,caret包提供了一個通用的接口來訓練KNN模型。使用caret的train函數來訓練KNN模型時,可以調整多種參數來優化模型的性能:

基本參數:

①formula: 指定模型的公式,如Y ~ .,表示使用數據框中的所有其他變量來預測Y。

②data:?提供包含訓練數據的數據框。

③method:?對于KNN模型,這個參數應設置為"knn"。

④preProcess: 預處理步驟,常用的包括標準化("scale")和中心化("center"),對于KNN這一步非常重要因為KNN依賴于變量的距離度量。

⑤trControl: 一個trainControl對象,定義了模型訓練的各種控制策略,如交叉驗證的類型和重復次數。

trainControl 函數的參數:

①method: 訓練的方法,如交叉驗證("cv"),重復交叉驗證("repeatedcv"),留一交叉驗證("LOOCV")等。

②number: 對于"cv"和"repeatedcv",這個參數定義了折數。

③repeats: 當使用"repeatedcv"時,定義重復的次數。

④search: 參數搜索方法,默認為"grid"。也可以設置為"random"進行隨機搜索。

⑤savePredictions: 是否保存預測結果,通常用于后續分析。

模型性能調整參數:

使用KNN時,最關鍵的參數之一是鄰居的數量(K值)。這可以通過train函數的以下參數來調整:

①tuneLength: 這個參數決定了在參數搜索中考慮多少個不同的K值。

②tuneGrid: 這是一個數據框,可以自定義K值的具體范圍,例如expand.grid(k = c(1, 5, 10))

結果輸出(默認參數):

三、KNN調參方法

如前所述,KNN的關鍵參數就是K值,所以可以對其進行一個暴力測試,比如取值1到10:

# 定義交叉驗證的控制方法,啟用網格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定義K值的網格搜索范圍
tuneGrid <- expand.grid(k = 1:10)
# 在訓練集上擬合KNN模型,指定網格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型結果,找出最優的K值
print(model)

解讀:

定義交叉驗證的控制方法:使用trainControl函數設定交叉驗證的詳細參數。

定義K值的網格:使用tuneGrid參數在train函數中指定K值的范圍。

擬合模型:使用train函數訓練模型,同時應用預處理步驟(比如標準化數據),以確保每個特征在距離計算中具有等同的權重。

結果輸出:

注意:用了caret包的train函數,并且通過網格搜索指定了一系列的參數(如K值的范圍),那么這個函數會自動選擇表現最好的參數配置來訓練最終的模型。train函數的輸出即是基于你提供的訓練數據和參數搜索范圍內表現最優的模型。因此,當你調用predict函數進行預測時,使用的就是這個最優化的模型。所以,下面的代碼不變。

結果吧,跟之前的完全一樣:

因為caret包對于KNN模型默認進行一系列的K值嘗試,通常這個范圍是1到最多的鄰居數,但具體的最大K值依賴于caret的內部設置。在大多數情況下,它會嘗試如1, 5, 7, 9等常用的K值。所以,我們默認參數的時候,其實軟件自動給我們尋找最優K值了。可以用這個代碼輸出最有K值:

# Print the best K value used by the model
best_k <- model$bestTune$k
cat("The best K value found is:", best_k, "\n")

K值就是9,跟我們自行調參的一致。

那我們猛點,把K的范圍設置的寬一些:

# 定義交叉驗證的控制方法,啟用網格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定義K值的網格搜索范圍
tuneGrid <- expand.grid(k = 1:20)
# 在訓練集上擬合KNN模型,指定網格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型結果,找出最優的K值
print(model)

結果:

K=19,性能指標如下,似乎大同小異:

四、最后

數據嘛:

鏈接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm

提取碼:x8xm

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/35046.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/35046.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/35046.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Docker】Consul 和API

目錄 一、Consul 1. 拉取鏡像 2. 啟動第一個consul服務&#xff1a;consul1 3. 查看consul service1 的ip地址 4. 啟動第二個consul服務&#xff1a;consul2&#xff0c; 并加入consul1&#xff08;使用join命令&#xff09; 5. 啟動第三個consul服務&#xff1a;consul3&…

攻擊者開始使用 XLL 文件進行攻擊

近期&#xff0c;研究人員發現使用惡意 Microsoft Excel 加載項&#xff08;XLL&#xff09;文件發起攻擊的行動有所增加&#xff0c;這項技術的 MITRE ATT&CK 技術項編號為 T1137.006。 這些加載項都是為了使用戶能夠利用高性能函數&#xff0c;為 Excel 工作表提供 API …

【SQL Server數據庫】關系模式與關系代數

目錄 一、請用關系代數完成下列查詢 1. 求 供應工程J1 零件P1的供應商號碼SNO&#xff1b; 2. 求 供應工程J1 零件&#xff08;P&#xff09;為紅色 的供應商號碼SNO&#xff1b; 3. 求 沒有使用 天津供應商&#xff08;P&#xff09;生產的紅色零件&#xff08;S&#xff0…

【雜記-淺談OSPF協議之RouterDeadInterval死區間隔】

OSPF協議之RouterDeadInterval死區間隔 一、RouterDeadInterval概述二、設置RouterDeadInterval三、RouterDeadInterval的重要性 一、RouterDeadInterval概述 RouterDeadInterval&#xff0c;即路由器死區間隔&#xff0c;它涉及到路由器如何在廣播網絡上發現和維護鄰居關系。…

pycharm中的使用技巧

1、更改主題&#xff1a;找到設置&#xff0c;然后更改主題 點擊選擇自己喜歡的主題&#xff0c;然后就可以更改主題了 2、設置字體的快捷鍵 找到設置&#xff0c;如下&#xff1a; 找到increase&#xff0c;如下&#xff1a; 右鍵選擇&#xff0c;增加字體快捷鍵 按住ctrl滑輪…

Excel 查找后隱去右邊列

Excel 有幾列數字 ABC11002042002202100102326027010841199100512100100 當給定參數時&#xff0c;請從每行找到該參數&#xff0c;隱去右邊的列。如果某行不含該參數&#xff0c;則隱去整行。當參數是 100 時&#xff0c;結果如下&#xff1a; ABC710082021009119910010121…

shell之免交互

免交互 交互&#xff1a;發出指令控制指令的運行&#xff0c;程序再接收到指令的效果做出對應的反應。 免交互&#xff1a;間接的&#xff0c;通過第三方的方式把指令傳送給程序&#xff0c;不用直接的下達指令 Hhere Document 免交互 這是命令行格式&#xff0c;也可以寫在腳本…

QTableWidget的使用

使用QTableWidget&#xff0c;初始化數據、設置列頭及格式&#xff0c;設置行數&#xff0c;設置每個單元格的編輯&#xff0c;間隔行底色變換、行選擇 &#xff0c;模式&#xff0c;單元格選擇模式、插入行 、追加行、刪除行&#xff0c;單元格加圖標&#xff0c;單元格顯示ch…

Android Gradle開發與應用

Android Gradle 開發是指在 Android 應用開發中使用 Gradle 作為構建工具的過程。Gradle 是一個基于 Groovy 的自動化構建工具&#xff0c;它允許開發者定義靈活的構建邏輯&#xff0c;并且能夠很好地與 Android Studio 集成。以下是一些關于 Android Gradle 開發與應用的基本概…

替換特殊符號

content content.replaceAll("[\\x00-\\x09\\x11\\x12\\x14-\\x1F\\x7F]", ""); 打印特殊符號&#xff1a; String s new String( Character.toChars(0)); System.out.println((char)0); 2024-06-20 17:21:26.155 ERROR 5584 --- [6884333_inbound] c.…

好記性不如爛筆頭(三)——文件保存后打開呈現亂碼問題

現象 請隨博主進行下列操作&#xff0c;神奇的事情會發生—— 1、新建記事本&#xff0c;里面輸入“同”字&#xff0c;保存為ANSI格式 2、再次打開會發現&#xff0c;“同”已經變成了亂碼 3、類似的字還有很多&#xff0c;例如“同學”的“學”。而有些字則不會出現這種情況…

3_電機的發展及學習方法

一、電機組成及發展 1、什么是勵磁&#xff1f; 在電磁學中&#xff0c;勵磁是通過電流產生磁場的過程。 發電機或電動機由在磁場中旋轉的轉子組成。磁場可以由 永磁體或勵磁線圈產生。對于帶有勵磁線圈的機器&#xff0c;電流必須在線圈中流動才能產生&#xff08;激發&#x…

香港服務器托管對外貿行業必要性和優勢

在當今全球化的經濟環境下&#xff0c;外貿企業面臨著前所未有的機遇與挑戰。其中&#xff0c;服務器托管的選擇對于外貿企業的運營效率和市場拓展具有舉足輕重的作用。香港服務器&#xff0c;憑借其獨特的地理位置、優質的網絡環境和卓越的服務性能&#xff0c;一直是外貿企業…

“Hello, World” 的歷史

“Hello, World!” —— 初學者進入編程世界的第一步 由布萊恩柯林漢 撰寫的“Hello, world”程序 (1978年) 布萊恩W.克尼漢&#xff08;Brian W. Kernighan&#xff09;—— Unix 和 C 語言背后的巨人 布萊恩W.克尼漢 布萊恩W.克尼漢在 1942 年出生在加拿大多倫多&#xff…

OS中斷機制-嵌套和競爭

對于FreeRTOS最好不去用中斷嵌套,中斷嵌套會增加堆棧空間的使用,因為每個中斷服務程序都需要保存和恢復寄存器狀態,這可能會耗盡有限的堆棧空間,從而導致系統故障。以及中斷嵌套時,不同的中斷服務程序可能會競爭訪問共享資源,從而增加死鎖的風險。這可能會導致系統出現故…

Verilog進行結構描述(structural modeling)(一):基本概念

目錄 1.結構描述(structural modeling)的內容&#xff1a;2.實例 微信公眾號獲取更多FPGA相關源碼&#xff1a; 1.結構描述(structural modeling)的內容&#xff1a; 用門來描述器件的功能基于基本元件和底層模塊例化語句最接近實際的硬件結構主要使用元件的定義、使用聲明以…

Flink——最流批的大數據框架(流批一體)

Apache Flink基礎教程 資料來源&#xff1a;Apache Flink Tutorial (tutorialspoint.com) Apache Flink是Apache Hadoop的開源本地分析數據庫。它由Cloudera、MapR、Oracle和Amazon等供應商提供。本教程中提供的示例是使用Cloudera Apache Flink開發的。 本教程是為那些想要學…

fork 是一個創建新進程的系統調用

在計算機科學中&#xff0c;fork 是一個創建新進程的系統調用。具體來說&#xff0c;fork 調用會創建一個與當前進程幾乎完全相同的副本&#xff0c;包括父進程的內存布局、環境變量、打開的文件描述符等。這個新的進程被稱為子進程&#xff0c;而原始進程被稱為父進程。 以下…

光伏開發有沒有難點?如何解決?

隨著全球對可再生能源的日益重視&#xff0c;光伏技術作為其中的佼佼者&#xff0c;已成為實現能源轉型的關鍵手段。然而&#xff0c;光伏開發并非一帆風順&#xff0c;其過程中也面臨著諸多難點和挑戰。本文將對這些難點進行探討&#xff0c;并提出相應的解決策略。 一、光伏開…

12 學習總結:操作符

目錄 一、操作符的分類 二、二進制和進制轉換 &#xff08;一&#xff09;概念 &#xff08;二&#xff09;二進制 &#xff08;三&#xff09;進制轉換 1、2進制與10進制的互換 &#xff08;1&#xff09;2進制轉化10進制 &#xff08;2&#xff09;10進制轉化2進制 2…