pytorch邏輯回歸實現垃圾郵件檢測

完整代碼:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np# 增強的數據集:更多的垃圾郵件與正常郵件樣本
X = ["Congratulations! You've won a $1000 gift card. Claim it now!","Dear friend, I hope you are doing well. Let's catch up soon.","Urgent: Your bank account has been compromised. Please contact support immediately.","Hello, just wanted to confirm our meeting at 2 PM today.","You have a new message from your friend. Click here to read.","Get a free iPhone now! Limited offer, click here.","Last chance to claim your prize, you won $500!","Meeting scheduled for tomorrow. Please confirm.","Hello! You are invited to an exclusive event!","Click here to get free lottery tickets. Hurry up!","Reminder: Your subscription will expire soon, renew now.","Don't forget to submit your report by end of day today."
]
y = [1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0]  # 1 為垃圾郵件,0 為正常郵件# 使用 TfidfVectorizer 進行文本向量化
vectorizer = TfidfVectorizer(stop_words='english')  # 去除停用詞
X_vec = vectorizer.fit_transform(X).toarray()# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X_vec, y, test_size=0.33, random_state=42)# 定義邏輯回歸模型
class LogisticRegressionModel(nn.Module):def __init__(self, input_dim):super(LogisticRegressionModel, self).__init__()self.fc = nn.Linear(input_dim, 1)  # 線性層,輸入維度是特征的數量,輸出是1def forward(self, x):return torch.sigmoid(self.fc(x))  # 使用sigmoid激活函數輸出0到1之間的概率# 定義訓練過程
def train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001):criterion = nn.BCELoss()  # 二分類交叉熵損失optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 使用Adam優化器X_train_tensor = torch.tensor(X_train, dtype=torch.float32)y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train_tensor)loss = criterion(outputs, y_train_tensor)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 測試模型
def evaluate_model(model, X_test, y_test):model.eval()X_test_tensor = torch.tensor(X_test, dtype=torch.float32)y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)with torch.no_grad():outputs = model(X_test_tensor)predictions = (outputs >= 0.5).float()  # 閾值設為0.5accuracy = accuracy_score(y_test, predictions.numpy())print(f'Accuracy: {accuracy * 100:.2f}%')# 訓練并評估模型
input_dim = X_train.shape[1]  # 輸入特征的數量
model = LogisticRegressionModel(input_dim)
train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001)
evaluate_model(model, X_test, y_test)# 預測新郵件
def predict(model, new_email):model.eval()new_email_vec = vectorizer.transform([new_email]).toarray()new_email_tensor = torch.tensor(new_email_vec, dtype=torch.float32)with torch.no_grad():prediction = model(new_email_tensor)return "Spam" if prediction >= 0.5 else "Not Spam"# 檢測新郵件
email_1 = "Congratulations! You have a limited time offer for a free cruise."
email_2 = "Hi, let's discuss the project updates tomorrow."print(f"Email 1: {predict(model, email_1)}")  # 可能輸出:Spam
print(f"Email 2: {predict(model, email_2)}")  # 可能輸出:Not Spam
1. 數據預處理
  • 準備數據集:包含垃圾郵件(Spam)和正常郵件(Not Spam)。
  • 文本向量化:使用 TfidfVectorizer 將文本轉換為數值特征,使模型能夠處理。
  • 去除停用詞:排除無意義的常見詞(如 "the", "is", "and"),提高模型性能。
2. 訓練集與測試集劃分
  • 將數據集拆分為訓練集和測試集,以 67% 訓練,33% 測試,保證模型有足夠數據訓練,同時可以評估其泛化能力。
3. 邏輯回歸模型
  • 搭建 PyTorch 邏輯回歸模型
    • 采用 nn.Linear() 構建一個單層神經網絡(輸入為文本特征,輸出為 1 個數值)。
    • 使用 sigmoid 作為激活函數,將輸出轉換為 0-1 之間的概率值。
4. 訓練模型
  • 定義損失函數:使用二元交叉熵損失 (BCELoss),適用于二分類問題。
  • 優化器:采用 Adam 優化器,以 0.001 學習率進行參數優化。
  • 訓練流程
    1. 計算前向傳播的輸出。
    2. 計算損失值,衡量預測結果與真實標簽的差距。
    3. 進行反向傳播,更新權重參數。
    4. 迭代多輪(如 200 輪),不斷優化模型。
5. 評估模型
  • 將測試數據輸入模型,預測結果并與真實標簽進行對比。
  • 計算準確率,評估模型在未見過的數據上的表現。
6. 預測新郵件
  • 將新郵件轉換為數值特征(與訓練時相同的方法)。
  • 使用訓練好的模型進行預測
  • 閾值判斷:如果輸出概率 ≥ 0.5,則判斷為垃圾郵件,否則為正常郵件。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/67466.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/67466.shtml
英文地址,請注明出處:http://en.pswp.cn/web/67466.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【 CVE-2025-21298】 通過ghidriff查看完整補丁差異

ole32_dec24.dll-ole32.dll 差異 目錄 視覺圖表差異元數據 Ghidra 差異引擎 命令行二進制元數據差異程序選項

洛谷P3383 【模板】線性篩素數

題目鏈接:P3383 【模板】線性篩素數 - 洛谷 | 計算機科學教育新生態 題目難度:普及一 題目分析:本題是模板題,用到了線性篩法,其中原理是保證范圍內的每個合數都被刪掉(在 bool 數組里面標記為非素數…

STM32標準庫移植RT-Thread nano

STM32標準庫移植RT-Thread Nano 嗶哩嗶哩教程鏈接:STM32F1標準庫移植RT_Thread Nano 移植前的準備 stm32標準庫的裸機代碼(最好帶有點燈和串口)RT-Thread Nano Pack自己的開發板 移植前的說明 本人是在讀學生,正在學習階段&a…

JVM--類加載器

概念 類加載器:只參與加載過程中的字節碼獲取并加載到內存中的部分;java虛擬機提供給應用程序去實現獲取類和接口字節碼數據的一種技術,也就是說java虛擬機是允許程序員寫代碼去獲取字節碼信息 類加載是加載的第一步,主要有以下三…

ECMAScript 6語法

1.ES6簡介 ECMAScript 6(簡稱ES6)是于2015年6月正式發布的JavaScript語言的標準,正式名為ECMAScript 2015(ES2015)。它的目標是使得JavaScript語言可以用來編寫復雜的大型應用程序,成為企業級開發語言 。 …

聯想Y7000+RTX4060+i7+Ubuntu22.04運行DeepSeek開源多模態大模型Janus-Pro-1B+本地部署

直接上手搓了: conda create -n myenv python3.10 -ygit clone https://github.com/deepseek-ai/Janus.gitcd Januspip install -e .pip install webencodings beautifulsoup4 tinycss2pip install -e .[gradio]pip install pexpect>4.3python demo/app_januspr…

Tez 0.10.1安裝

個人博客地址:Tez 0.10.1安裝 | 一張假鈔的真實世界 具體安裝步驟參照官網安裝手冊即可。此處只對官網手冊進行補充。 從官網下載apache-tez-0.10.1-bin.tar.gz進行安裝未成功,出現下面的異常。最終按照官網源代碼編譯的方式安裝測試成功。 環境 Had…

FastAPI + GraphQL + SQLAlchemy 實現博客系統

本文將詳細介紹如何使用 FastAPI、GraphQL(Strawberry)和 SQLAlchemy 實現一個帶有認證功能的博客系統。 技術棧 FastAPI:高性能的 Python Web 框架Strawberry:Python GraphQL 庫SQLAlchemy:Python ORM 框架JWT&…

微服務入門(go)

微服務入門(go) 和單體服務對比:里面的服務僅僅用于某個特定的業務 一、領域驅動設計(DDD) 基本概念 領域和子域 領域:有范圍的界限(邊界) 子域:劃分的小范圍 核心域…

深入解析 Linux 內核內存管理核心:mm/memory.c

在 Linux 內核的眾多組件中,內存管理模塊是系統性能和穩定性的關鍵。mm/memory.c 文件作為內存管理的核心實現,承載著頁面故障處理、頁面表管理、內存區域映射與取消映射等重要功能。本文將深入探討 mm/memory.c 的設計思想、關鍵機制以及其在內核中的作用,幫助讀者更好地理…

安卓通過網絡獲取位置的方法

一 方法介紹 1. 基本權限設置 首先需要在 AndroidManifest.xml 中添加必要權限&#xff1a; xml <uses-permission android:name"android.permission.INTERNET" /> <uses-permission android:name"android.permission.ACCESS_NETWORK_STATE" /&g…

【B站保姆級視頻教程:Jetson配置YOLOv11環境(二)SSH連接的三種方式】

B站同步視頻教程&#xff1a;https://www.bilibili.com/video/BV1m5wUeyEQD/ 在Jetson設備上配置YOLOv11環境時&#xff0c;SSH連接是實現遠程高效開發與管理的關鍵一環。不同的網絡環境和硬件配置可能會影響SSH連接的方式&#xff0c;本文將結合相關視頻內容&#xff0c;詳細…

視頻拼接,拼接時長版本

目錄 視頻較長&#xff0c;分辨率較大&#xff0c;這個效果很好&#xff0c;不耗用內存 ffmpeg imageio&#xff0c;適合視頻較短 視頻較長&#xff0c;分辨率較大&#xff0c;這個效果很好&#xff0c;不耗用內存 ffmpeg import subprocess import glob import os from nats…

Vue.js 什么是 Composition API?

Vue.js 什么是 Composition API&#xff1f; 今天我們來聊聊 Vue 3 引入的一個重要特性&#xff1a;組合式 API&#xff08;Composition API&#xff09;。如果你曾在開發復雜的 Vue 組件時感到代碼難以維護&#xff0c;那么組合式 API 可能正是你需要的工具。 什么是組合式 …

Selenium配合Cookies實現網頁免登錄

文章目錄 前言1 方案一&#xff1a;使用Chrome用戶數據目錄2 方案二&#xff1a;手動獲取并保存Cookies&#xff0c;后續使用保存的Cookies3 注意事項 前言 在進行使用Selenium進行爬蟲、網頁自動化操作時&#xff0c;登錄往往是一個必須解決的問題&#xff0c;但是Selenium每次…

計算機畢業設計Python+知識圖譜大模型AI醫療問答系統 健康膳食推薦系統 食譜推薦系統 醫療大數據 機器學習 深度學習 人工智能 爬蟲 大數據畢業設計

溫馨提示&#xff1a;文末有 CSDN 平臺官方提供的學長聯系方式的名片&#xff01; 溫馨提示&#xff1a;文末有 CSDN 平臺官方提供的學長聯系方式的名片&#xff01; 溫馨提示&#xff1a;文末有 CSDN 平臺官方提供的學長聯系方式的名片&#xff01; 作者簡介&#xff1a;Java領…

關于el-table翻頁后序號列遞增的組件封裝

需求說明&#xff1a; 項目中經常會用到的一個場景&#xff0c;表格第一列顯示序號&#xff08;1、2、3...&#xff09;&#xff0c;但是在翻頁后要遞增顯示序號&#xff0c;例如10、11、12&#xff08;假設一頁顯示10條數據&#xff09;&#xff0c;針對這種情況&#xff0c;封…

Elasticsearch的索引生命周期管理

目錄 說明零、參考一、ILM的基本概念二、ILM的實踐步驟Elasticsearch ILM策略中的“最小年齡”是如何計算的&#xff1f;如何監控和調整Elasticsearch ILM策略的性能&#xff1f; 1. **監控性能**使用/_cat/thread_pool API基本請求格式請求特定線程池的信息響應內容 2. **調整…

AI大模型開發原理篇-3:詞向量和詞嵌入

簡介 詞向量是用于表示單詞意義的向量&#xff0c; 并且還可以被認為是單詞的特征向量或表示。 將單詞映射到實向量的技術稱為詞嵌入。在實際應用中&#xff0c;詞向量和詞嵌入這兩個重要的NLP術語通常可以互換使用。它們都表示將詞匯表中的單詞映射到固定大小的連續向量空間中…

[內網安全] 內網滲透 - 學習手冊

這是一篇專欄的目錄文檔&#xff0c;方便讀者系統性的學習&#xff0c;筆者后續會持續更新文檔內容。 如果沒有特殊情況的話&#xff0c;大概是一天兩篇的速度。&#xff08;實驗多或者節假日&#xff0c;可能會放緩&#xff09; 筆者也是一邊學習一邊記錄筆記&#xff0c;如果…