基于flask的貓狗圖像預測案例

📚博客主頁:knighthood2001
?公眾號:認知up吧 (目前正在帶領大家一起提升認知,感興趣可以來圍觀一下)
🎃知識星球:【認知up吧|成長|副業】介紹
??如遇文章付費,可先看看我公眾號中是否發布免費文章??
🙏筆者水平有限,歡迎各位大佬指點,相互學習進步!

假設,你有模型,有訓練好的模型文件,有模型推理代碼,就可以把他放到flask上進行展示。

項目架構

在這里插入圖片描述

  • index.html是模板文件
  • app.py是項目運行的入口
  • best_model.pth是訓練好的模型參數
  • model.py是神經網絡模型,這里采用的是GoogleNet網絡。
  • model_reasoning.py是模型推理,通過這里面的代碼,我們可以在本地進行貓狗圖片的預測。

運行圖

在這里插入圖片描述

點擊選擇文件
在這里插入圖片描述
圖片下面就顯示預測結果了。
在這里插入圖片描述

項目完整代碼與講解

index.html

<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><title>圖像分類</title><style>body {font-family: Arial, sans-serif;margin: 20px;}#result {margin-top: 10px;}#preview-image {max-width: 400px;margin-top: 20px;}</style>
</head>
<body><h1>圖像分類</h1><form id="upload-form" action="/predict" method="post" enctype="multipart/form-data"><input type="file" name="file" accept="image/*" onchange="previewImage(event)"><input type="submit" value="預測"></form><img id="preview-image" src="" alt=""><br><div id="result"></div><script>document.getElementById('upload-form').addEventListener('submit', async (e) => {e.preventDefault();  // 阻止默認的表單提交行為const formData = new FormData(); // 創建一個新的FormData對象,用于封裝表單數據formData.append('file', document.querySelector('input[type=file]').files[0]);  // 添加表單數據// 使用fetch API發送POST請求到'/predict'路徑,并將formData作為請求體const response = await fetch('/predict', {method: 'POST',body: formData});// 獲取響應的JSON數據const result = await response.json();// 將預測結果顯示在頁面上ID為'result'的元素中document.getElementById('result').innerText = `預測結果: ${result.prediction}`;});function previewImage(event) {const file = event.target.files[0];  // 獲取上傳的文件對象const reader = new FileReader();  // 創建一個FileReader對象,用于讀取文件內容// 清空上一次的預測結果document.getElementById('result').innerText = '';// 當文件讀取完成后,將文件內容顯示在頁面上ID為'preview-image'的元素中reader.onload = function(event) {document.getElementById('preview-image').setAttribute('src', event.target.result);}// 如果用戶選擇了文件,則開始讀取文件內容if (file) {reader.readAsDataURL(file); // 將文件讀取為DataURL格式,這樣可以直接用作img元素的src屬性}}</script>
</body>
</html>

前端我練的不多,很多解釋已經在代碼中講了。

model.py

這是GoogleNet的網絡架構

import torch
from torch import nn
from torchsummary import summary
# 定義一個Inception模塊
class Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):  # 這些參數,所在的位置都會發送變化,所有需要這個參數super(Inception, self).__init__()self.ReLU = nn.ReLU()# 路線1,單1×1卷積層self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)# 路線2,1×1卷積層, 3×3的卷積self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)# 路線3,1×1卷積層, 5×5的卷積self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)# 路線4,3×3的最大池化, 1×1的卷積self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLU(self.p1_1(x))p2 = self.ReLU(self.p2_2(self.ReLU(self.p2_1(x))))p3 = self.ReLU(self.p3_2(self.ReLU(self.p3_1(x))))p4 = self.ReLU(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)class GoogLeNet(nn.Module):def __init__(self, Inception, in_channels, out_channels):super(GoogLeNet, self).__init__()self.b1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192), (32, 96), 64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (128, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(1024, out_channels))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.b1(x)x = self.b2(x)x = self.b3(x)x = self.b4(x)x = self.b5(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = GoogLeNet(Inception, 1, 10).to(device)print(summary(model, (1, 224, 224)))

model_reasoning.py

import torch
from torchvision import transforms
from model import GoogLeNet, Inception
from PIL import Imagedef test_model(model, test_file):# 設定測試所用到的設備,有GPU用GPU沒有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'model = model.to(device)classes = ['貓', '狗']print(classes)image = Image.open(test_file)# normalize = transforms.Normalize([0.162, 0.151, 0.138], [0.058, 0.052, 0.048])# # 定義數據集處理方法變量# test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])# 定義數據集處理方法變量test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])image = test_transform(image)# 添加批次維度,變成[1,3,224,224]image = image.unsqueeze(0)with torch.no_grad():model.eval()image = image.to(device)  # 圖片也要放到設備當中output = model(image)print(output.tolist())pre_lab = torch.argmax(output, dim=1)result = pre_lab.item()print("預測值:", classes[result])return classes[result]def test_special_model(best_model_file, test_file):# 加載模型model = GoogLeNet(Inception, in_channels=3, out_channels=2)model.load_state_dict(torch.load(best_model_file))# 模型的推理判斷return test_model(model, test_file)if __name__ == "__main__":# # 加載模型# model = GoogLeNet(Inception, in_channels=3, out_channels=2)# model.load_state_dict(torch.load('best_model.pth'))# # 模型的推理判斷# test_model(model, "test_data/images.jfif")test_special_model("best_model.pth", "static/1.jpg")

這段代碼與之前的模型推理代碼不同的是,我添加了test_special_model函數,方便后續app.py中可以直接調用這個函數進行模型推理。

app.py

import os
from flask import Flask, request, jsonify, render_templatefrom model_reasoning import test_special_model
from model_reasoning import test_model
app = Flask(__name__)# 定義路由
@app.route('/')
def index():return render_template('index.html')@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':# 獲取上傳的文件file = request.files['file']if file:# 調用模型進行預測# # 加載模型# model = GoogLeNet(Inception, in_channels=3, out_channels=2)# basedir = os.path.abspath(os.path.dirname(__file__))## model.load_state_dict(torch.load(basedir + '/best_model.pth'))# result = test_model(model, file)basedir = os.path.abspath(os.path.dirname(__file__))best_model_file = basedir + '/best_model.pth'result = test_special_model(best_model_file, file)return jsonify({'prediction': result})else:return jsonify({'error': 'No file found'})if __name__ == '__main__':app.run(debug=True)

如果沒有上文中的test_special_model函數,那么這里你就需要

   # 加載模型model = GoogLeNet(Inception, in_channels=3, out_channels=2)basedir = os.path.abspath(os.path.dirname(__file__))model.load_state_dict(torch.load(basedir + '/best_model.pth'))result = test_model(model, file)

并且還需要導入相應的庫。

best_model.pth

最重要的是,你需要訓練好的一個模型。

有需要的,可以聯系我,我直接把這個項目代碼發你。省得你還需要配置項目架構。

小插曲

我為什么會使用絕對路徑,因為我在使用相對路徑后,代碼提示找不到這個路徑。

    basedir = os.path.abspath(os.path.dirname(__file__))best_model_file = basedir + '/best_model.pth'

然后,我剛剛又試了一下,發現使用相對路徑,又可以運行成功了。

真是不可思議(這個小插曲花了我大半個小時)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/42036.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/42036.shtml
英文地址,請注明出處:http://en.pswp.cn/web/42036.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

二次元轉向SLG,B站游戲的破圈之困

文 | 螳螂觀察 作者 | 夏至 2023年是B站游戲的滑鐵盧&#xff0c;盡管這年B站的游戲營收還有40多億&#xff0c;但相比去年大幅下降了20%&#xff0c;整整少了10億&#xff0c;這是過去5年來的最大跌幅&#xff0c;也是陳睿接管B站游戲業務一年以來&#xff0c;在鼻子上碰的第…

鴻蒙語言基礎類庫:【@ohos.process (獲取進程相關的信息)】

獲取進程相關的信息 說明&#xff1a; 本模塊首批接口從API version 7開始支持。后續版本的新增接口&#xff0c;采用上角標單獨標記接口的起始版本。開發前請熟悉鴻蒙開發指導文檔&#xff1a;gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md點擊或者復制轉到。…

昇思13天

ResNet50遷移學習 ResNet50遷移學習總結 背景介紹 在實際應用場景中&#xff0c;由于訓練數據集不足&#xff0c;很少有人會從頭開始訓練整個網絡。普遍做法是使用在大數據集上預訓練得到的模型&#xff0c;然后將該模型的權重參數用于特定任務中。本章使用遷移學習方法對Im…

放棄華為OD,選擇最合適而不是最難得

時間不知不覺邁入了七月&#xff0c;五月嘗試去重新找一份工作&#xff0c;但釋放出來的崗位太少了&#xff0c;難得有進華為OD的機會&#xff0c;還是比較核心的部門&#xff0c;但思來想起&#xff0c;還是放棄了。 如果想去&#xff0c;是很有機會的&#xff0c;一路過關斬…

imx6ull/linux應用編程學習(13) CMAKE

什么是cmake&#xff1f; cmake 工具通過解析 CMakeLists.txt 自動幫我們生成 Makefile&#xff0c;可以實現跨平臺的編譯。cmake 就是用來產生 Makefile 的工具&#xff0c;解析 CMakeLists.txt 自動生成 Makefile&#xff1a; cmake 的使用方法 cmake 就是一個工具命令&am…

怎么將aac文件弄成mp3格式?把aac改成MP3格式的四種方法

怎么將aac文件弄成mp3格式&#xff1f;手頭有一些aac格式的音頻文件&#xff0c;但由于某些設備或軟件不支持這種格式&#xff0c;你希望將它們轉換成更為通用的MP3格式。而且音頻格式的轉換在現在已經是一個常見且必要的操作。aac是一種相對較新的音頻編碼格式&#xff0c;通常…

大模型增量預訓練新技巧-解決災難性遺忘

大模型增量預訓練新技巧-解決災難性遺忘 機器學習算法與自然語言處理 2024年03月21日 00:02 吉林 以下文章來源于NLP工作站 &#xff0c;作者劉聰NLP NLP工作站. AIGC前沿知識分享&落地經驗總結 轉載自 | NLP工作站 作者 | 劉聰NLP 目前不少開源模型在通用領域具有不錯…

G1 和 CMS

1、CMS CMS&#xff08;Concurrent Mark Sweep&#xff0c;并發標記清除&#xff0c;是為了解決早期垃圾收集器在執行垃圾回收時導致應用程序暫停時間過長的問題而設計的。 CMS的工作流程主要包括以下幾個階段&#xff1a; 初始標記&#xff08;Initial Mark&#xff09;&…

一體化運維監控平臺:賦能各行業用戶運維升級

在當今數字化轉型的大潮中&#xff0c;企業IT系統的復雜性和規模不斷攀升&#xff0c;對運維團隊提出了前所未有的挑戰。如何高效、精準地監控和管理IT基礎設施&#xff0c;確保業務連續性和穩定性&#xff0c;成為所有企業關注的焦點。美信&#xff0c;自2007年成立以來&#…

el-scrollbar實現自動滾動到底部(AI聊天)

目錄 項目背景 實現步驟 實現代碼 完整示例代碼 項目背景 chatGPT聊天消息展示滾動面板&#xff0c;每次用戶輸入提問內容或者ai進行流式回答時需要不斷的滾動到底部確保展示最新的消息。 實現步驟 采用element ui 的el-scrollbar作為聊天消息展示組件。 通過操作dom來實…

端、邊、云三級算力網絡

目錄 端、邊、云三級算力網絡 NPU Arm架構 OpenStack kubernetes k3s輕量級Kubernetes kubernetes和docker區別 DCI(Data Center Interconnect) SD/WAN TF 端、邊、云三級算力網絡 算力網絡從傳統云網融合的角度出發,結合 邊緣計算、網絡云化以及智能控制的優勢,通…

Qt開發 | Qt創建線程 | Qt并發-QtConcurrent

文章目錄 一、Qt創建線程的三種方法二、Qt并發&#xff1a;QtConcurrent介紹三、QtConcurrent run參數說明四、獲取QtConcurrent的返回值五、C其他線程技術介紹 一、Qt創建線程的三種方法 以下是Qt創建線程的三種方法&#xff1a; 方法一&#xff1a;派生于QThread 派生于QThre…

理解算法復雜度:空間復雜度詳解

引言 在計算機科學中&#xff0c;算法復雜度是衡量算法效率的重要指標。時間復雜度和空間復雜度是算法復雜度的兩個主要方面。在這篇博客中&#xff0c;我們將深入探討空間復雜度&#xff0c;了解其定義、常見類型以及如何進行分析。空間復雜度是衡量算法在執行過程中所需內存…

ceph mgr [errno 39] RBD image has snapshots (error deleting image from trash)

ceph mgr 報錯 debug 2024-07-08T09:25:56.512+0000 7f9c63bd2700 0 [rbd_support INFO root] execute_task: task={"sequence": 3, "id": "260b9fee-d567-4301-b7eb-b1fe1b037413", "message": "Removing image replicapool/8…

昇思25天學習打卡營第19天|Diffusion擴散模型

學AI還能贏獎品&#xff1f;每天30分鐘&#xff0c;25天打通AI任督二脈 (qq.com) Diffusion擴散模型 本文基于Hugging Face&#xff1a;The Annotated Diffusion Model一文翻譯遷移而來&#xff0c;同時參考了由淺入深了解Diffusion Model一文。 本教程在Jupyter Notebook上成…

python庫 - missingno

missingno 是一個用于可視化和分析數據集中缺失值的 Python 庫。它提供了一系列簡單而強大的工具&#xff0c;幫助用戶直觀地理解數據中的缺失模式&#xff0c;從而更好地進行數據清洗和預處理。missingno 庫特別適用于數據分析和數據科學項目&#xff0c;尤其是在處理缺失數據…

昇思MindSpore學習筆記5-02生成式--RNN實現情感分類

摘要&#xff1a; 記錄MindSpore AI框架使用RNN網絡對自然語言進行情感分類的過程、步驟和方法。 包括環境準備、下載數據集、數據集加載和預處理、構建模型、模型訓練、模型測試等。 一、概念 情感分類。 RNN網絡模型 實現效果&#xff1a; 輸入: This film is terrible 正…

放大鏡案例

放大鏡 <!DOCTYPE html> <html lang"zh-cn"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>商品放大鏡</title><link rel&qu…

如何使用allure生成測試報告

第一步下載安裝JDK1.8&#xff0c;參考鏈接JDK1.8下載、安裝和環境配置教程-CSDN博客 第二步配置allure環境&#xff0c;參考鏈接allure的安裝和使用(windows環境)_allure windows-CSDN博客 第三步&#xff1a; 第四步&#xff1a; pytest 查看目前運行的測試用例有無錯誤 …

如何使用 pytorch 創建一個神經網絡

我已發布在&#xff1a;如何使用 pytorch 創建一個神經網絡 SapientialM.Github.io 構建神經網絡 1 導入所需包 import os import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms2 檢查GPU是否可用 dev…