Pytorch02:深度學習基礎示例——貓狗識別

?一、第三方庫介紹

庫/模塊功能
torch提供張量操作、自動求導、優化算法、神經網絡模塊等基礎設施。
torchvision計算機視覺工具集,提供預訓練模型、數據集、圖像轉換等功能。
datasets (torchvision)用于加載常見數據集(如 ImageNet、CIFAR-10、MNIST)。
transforms (torchvision)提供圖像數據的預處理、數據增強操作(如大小調整、裁剪、轉換為張量、歸一化等)。
nn (torch)用于定義和構建神經網絡,包含各類網絡層、損失函數等。
optim (torch)提供優化算法(如 Adam、SGD、RMSprop)用于更新神經網絡權重。
DataLoader (torch.utils.data)用于批量加載數據,支持多線程加載數據,按批次讀取數據。
Image (PIL)用于圖像處理(加載、裁剪、旋轉、縮放、保存圖像等)。
ResNet18_Weights (torchvision.models)提供 ResNet18 模型的預訓練權重,可用于遷移學習。

二、訓練數據集介紹

三、原理簡介

????????該代碼使用PyTorch訓練一個基于ResNet-18的貓狗分類模型。通過加載并處理數據、訓練模型、調整最后輸出層、使用Adam優化器進行反向傳播,并在每個訓練周期輸出損失與準確率。訓練完畢后,保存模型用于后續預測。

四、代碼思路簡介

  • 加載數據?→ 使用?datasets.ImageFolder?加載貓狗數據集,并應用圖像轉換。
  • 構建模型?→ 使用預訓練的 ResNet-18 模型,修改輸出層以適應2類分類。
  • 定義損失和優化器?→ 使用交叉熵損失函數和 Adam 優化器。
  • 訓練模型?→ 遍歷數據集,前向傳播、計算損失、反向傳播、更新模型參數。
  • 保存模型?→ 訓練完成后,保存模型權重。
  • 預測圖片?→ 加載已訓練模型,輸入圖片進行預測并輸出分類結果。

五、代碼

場景:使用pytorch識別貓狗
貓的圖片路徑:F:\pycharm\AIDEMO\data\cat
狗的圖片路徑:F:\pycharm\AIDEMO\data\dog
需要判斷的圖片:F:\pycharm\AIDEMO\01.jpeg

import torch
import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from PIL import Image
from torchvision.models import ResNet18_Weights# 定義transform類(視覺轉換類,將圖片格式轉化為張量格式)
transform = transforms.Compose([transforms.Resize((128, 128)),  # 將圖片縮放到統一大小transforms.ToTensor(),  # 轉換為Tensor格式transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化處理
])def train_model(data_dir, num_epochs=10, batch_size=32, save_path='cat_dog_model.pth'):"""訓練模型并保存。:param data_dir: 數據路徑,包含cat和dog文件夾:param num_epochs: 訓練周期,默認為10:param batch_size: 批次大小,默認為32:param save_path: 模型保存路徑,默認為'cat_dog_model.pth'"""# 1. 加載訓練數據train_data = datasets.ImageFolder(root=data_dir,  # 數據路徑transform=transform)print(train_data.class_to_idx)  # 輸出文件夾編號,例如這里輸出{'cat': 0, 'dog': 1},表達0代表貓貓,1代表狗狗train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)# 2. 使用預訓練的ResNet18模型model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.fc = nn.Linear(model.fc.in_features, 2)  # 修改輸出層以適應2類分類(貓、狗)# 3. 定義損失函數和優化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 4. 開始訓練模型for epoch in range(num_epochs):model.train()  # 設置模型為訓練模式running_loss = 0.0  # 初始化損失值correct = 0  # 模型預測準確數total = 0  # 模型預測總數for images, labels in train_loader:optimizer.zero_grad()  # 清除之前的梯度outputs = model(images)  # 前向傳播,得出預測結果loss = criterion(outputs, labels)  # 計算損失loss.backward()  # 反向傳播optimizer.step()  # 更新參數running_loss += loss.item()# 計算準確率_, predicted = torch.max(outputs, 1)  # 獲取預測結果total += labels.size(0)  # 累計總樣本數correct += (predicted == labels).sum().item()  # 累計預測正確的樣本數# 輸出訓練周期的損失和準確率accuracy = 100 * correct / total  # 計算準確率print(f'周期 [{epoch + 1}/{num_epochs}], 損失: {running_loss / len(train_loader):.4f}, 準確率: {accuracy:.2f}%')if accuracy == 100:  # 準確率達到100%就停止訓練,避免過度擬合break# 保存訓練模型torch.save(model.state_dict(), save_path)def predict_image(model_path, img_path):"""加載訓練好的模型并進行圖片預測。:param model_path: 訓練好的模型路徑:param img_path: 需要預測的圖片路徑:return: 預測結果(貓或狗)"""# 加載模型model = torchvision.models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)model.fc = nn.Linear(model.fc.in_features, 2)model.load_state_dict(torch.load(model_path))model.eval()  # 設置模型為評估模式# 預測指定圖片img = Image.open(img_path)img = transform(img).unsqueeze(0)  # 將圖片處理成張量輸出并增加batch維度# 模型預測with torch.no_grad():  # 不需要梯度計算,只是進行模型預測outputs = model(img)_, predicted = torch.max(outputs, 1)# 輸出預測結果return "這是貓的圖片" if predicted.item() == 0 else "這是狗的圖片"if __name__ == "__main__":# 01 訓練出模型(若已訓練出準確度較高模型,可注釋下面兩句話,直接用訓練完畢的模型預測)data_dir = 'F:/pycharm/AIDEMO/data'  # 數據路徑train_model(data_dir, num_epochs=10, batch_size=32)# 02 預測指定圖片img_path = 'F:/pycharm/AIDEMO/data/01.jpeg'  # 圖片路徑result = predict_image('cat_dog_model.pth', img_path)print(result)

六、輸出結果展示

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/92150.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/92150.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/92150.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

spring簡單項目實戰

項目路徑 modelspackage com.qcby.demo1;import com.qcby.service.UserService; import com.qcby.service.UserServiceImpl;public class Dfactory {public UserService createUs(){System.out.println("實例化工廠的方式...");return new UserServiceImpl();} }pack…

ServBay for Windows 1.4.0 發布:新增MySQL、PostgreSQL等數據庫自定義配置

各位 Windows 平臺的開發者們, ServBay 始終致力于為您打造一個強大、高效且靈活的本地開發環境。距離上次更新僅過去短短一周,經過我們技術團隊的快速開發,我們正式推出了 ServBay for Windows 1.4.0 版本。 專業開發者不僅需要一個能用的環…

python網絡爬蟲小項目(爬取評論)超級簡單

python網絡爬蟲小項目(爬取評論)超級簡單 學習python網絡爬蟲的完整路徑: (第一章) python網絡爬蟲(第一章/共三章:網絡爬蟲庫、robots.txt規則(防止犯法)、查看獲取網頁源代碼)-…

本周大模型新動向:獎勵引導、多模態代理、鏈式思考推理

點擊藍字關注我們AI TIME歡迎每一位AI愛好者的加入!01Iterative Distillation for Reward-Guided Fine-Tuning of Diffusion Models in Biomolecular Design本文提出了一種用于生物分子設計中獎勵引導生成的擴散模型微調框架。擴散模型在建模復雜、高維數據分布方面…

JAVA+AI教程-第三天

我將由簡入繁,由零基礎到詳細跟大家一起學習java---------------------------------------------------------------------01、程序流程控制:今日課程介紹02、程序流程控制:if分支結構if分支有三種形式,執行順序就是先執行if&…

自定義命令行解釋器shell

目錄 一、模塊框架圖 二、實現目標 三、實現原理 四、全局變量 五、環境變量函數 六、初始化環境變量表函數 七、輸出命令行提示符模塊 八、提取命令輸入模塊 九、填充命令行參數表模塊 十、檢測并處理內建命令模塊 十一、執行命令模塊 十二、源碼 一、模塊框架圖…

uniapp使用uni-ui怎么修改默認的css樣式比如多選框及樣式覆蓋小程序/安卓/ios兼容問題

修改 uni-ui 多選框 (uni-data-checkbox) 的默認樣式 在 uniapp 中使用 uni-ui 的 uni-data-checkbox 組件時,可以通過以下幾種方式修改其默認樣式: 方法一:使用深度選擇器格式一:在頁面的 style 部分使用深度選擇器 >>>…

《Linux 環境下 Nginx 多站點綜合實踐:域名解析、訪問控制與 HTTPS 加密部署》?

綜合練習:請給openlab搭建web網站,網站需求: 1.基于域名www.openlab.com可以訪問網站內容為 welcome to openlab!!, 2.給該公司創建三個子界面分別顯示學生信息,教學資料和繳費網站,基于www.openlab.com/student 網站訪…

網絡基礎1-11綜合實驗(eNSP):vlan/DHCP/Web/HTTP/動態PAT/靜態NAT

注:在華為模擬器(eNSP)上做的實驗其中,在內網實驗:Vlan/DHCP/VWeb/HTTP,在外網實驗:動態PAT/靜態NAT一、拓撲結構1. 核心設備與連接設備接口連接對象VLAN/IP角色LSW2/LSW3Ethernet 0/0/1-2PC1/P…

Mac上安裝Claude Code的步驟

以下是基于現有信息的簡明安裝指南,適用于macOS系統。請按照以下步驟操作: 前提條件 操作系統:macOS 10.15或更高版本。Node.js和npm:Claude Code基于Node.js,需安裝Node.js 18和npm。請檢查是否已安裝: …

MybatisPlus-15.擴展功能-邏輯刪除

一.邏輯刪除配置邏輯刪除的字段時,logic-delete-field字段配置的是邏輯刪除的實體字段名。字段類型可以是boolean和integer。在java中默認是boolean類型。邏輯已刪除值默認為1,而邏輯未刪除值默認為0。當是1時代表已刪除(1在數據庫表中為true&#xff0c…

IDEA 同時修改某個區域內所有相同變量名

在 IntelliJ IDEA 中,同時修改某個區域內所有 相同變量名 的快捷鍵是: ? Shift F6(重命名變量) 但這個快捷鍵默認是 全局重命名,如果你想 僅修改某個方法或代碼塊內的變量名,可以這樣做:&…

Telink BLE 低功耗學習

低功耗管理(Low Power Management)也可以稱為功耗管理(Power Management),本?檔中會簡稱為PM。Telink低功耗解惑我查閱多連接SDK開發手冊時,低功耗管理章節看了兩三遍也沒太明白,有以下幾個問題…

設備管理系統(MMS)如何在工廠MOM功能設計和系統落地

一、核心系統功能模塊設備管理系統圍繞設備全生命周期管理設計,涵蓋基礎數據管理、設備運維全流程管控及統計分析功能,具體如下:基礎數據管理設備與備件臺賬:包含設備臺賬(設備編號、識別碼、型號、生產日期等&#xf…

低空經濟展 | 牧羽天航空攜飛行重卡AT1300亮相2025深圳eVTOL展

為深入推動低空經濟產業高質量發展,構建全球eVTOL(電動垂直起降飛行器)產業交流合作高端平臺,2025深圳eVTOL展定于2025年9月23日至25日在深圳坪山燕子湖國際會展中心隆重舉辦。本屆展會以“低空經濟?eVTOL?航空應急救援?商載大…

CS231n-2017 Lecture4神經網絡筆記

神經網絡:我們之前的線性分類器可以接受輸入,進而給出評分,這是一種線性變換,再此基礎上,我們對這種線性變換結果進行非線性變換,并輸入到下一層線性分類器中,這個過程就像是人類大腦神經的運作…

暑期算法訓練.5

目錄 20. 力扣 34.在排序數組中查找元素的第一個位置和最后一個位置 20.1 題目解析: 20.2 算法思路: 20.3 代碼演示: ?編輯 20.4 總結反思: 21.力扣 69.x的平方根 21.1 題目解析: 21.2 算法思路:…

【HDLBits習題詳解 2】Circuit - Sequential Logic(5)Finite State Machines 更新中...

1. Fsm1(Simple FSM 1 - asynchronous reset)狀態機可分為兩類:(1)Mealy狀態機:輸出由當前狀態和輸入共同決定。輸入變化可能立即改變輸出。(2)Moore狀態機:輸出僅由當前…

多級緩存(億級流量緩存)

傳統緩存方案問題 多級緩存方案 流程 1.客戶端瀏覽器緩存頁面靜態資源; 2. 客戶端請求到Nginx反向代理;[一級緩存_瀏覽器緩存] 3.Nginx反向代理將請求分發到Nginx集群(OpenResty); 4.先重Nginx集群OpenResty中獲取Nginx本地緩存數據;[二級緩存_Nginx本地緩存] 5.若Nginx本地緩存…

淺談Rust語言特性

如大家所了解的,Rust是一種由Mozilla開發的系統編程語言,專注于內存安全、并發性和高性能,旨在替代C/C等傳統系統編程語言。Rust 有著非常優秀的特性,例如:可重用模塊 內存安全和保證(安全的操作與不安全的…