使用pytorch創建/訓練/推理OCR模型

一、任務描述

????????從手寫數字圖像中自動識別出對應的數字(0-9)” 的問題,屬于單標簽圖像分類任務(每張圖像僅對應一個類別,即 0-9 中的一個數字)

? ? ? ? 1、任務的核心定義:輸入與輸出

  • 輸入:28×28 像素的灰度圖像(像素值范圍 0-255,0 代表黑色背景,255 代表白色前景),圖像內容是人類手寫的 0-9 中的某一個數字,例如:一張 28×28 的圖像,像素分布呈現 “3” 的形狀,就是模型的輸入。
  • 輸出:一個 “類別標簽”,即從 10 個可能的類別(0、1、2、…、9)中選擇一個,作為輸入圖像對應的數字,例如:輸入 “3” 的圖像,模型輸出 “類別 3”,即完成一次正確識別。
  • 目標:讓模型在 “未見的手寫數字圖像” 上,盡可能準確地輸出正確類別(通常用 “準確率” 衡量,即正確識別的圖像數 / 總圖像數)

? ? ? ? 2、任務的核心挑戰

  • 不同人書寫習慣差異極大:有人寫的 “4” 帶彎鉤,有人寫的 “7” 帶橫線,有人字體粗大,有人字體纖細;甚至同一個人不同時間寫的同一數字,筆畫粗細、傾斜角度也會不同。例如:同樣是 “5”,可能是 “直筆 5”“圓筆 5”,也可能是傾斜 10° 或 20° 的 “5”—— 模型需要忽略這些 “風格差異”,抓住 “數字的本質特征”(如 “5 有一個上半圓 + 一個豎線”)。
  • 圖像噪聲與干擾:手寫數字圖像可能存在噪聲,比如紙張上的污漬、書寫時的斷筆、掃描時的光線不均,這些都會影響像素分布。例如:一張 “0” 的圖像,邊緣有一小塊污漬,模型需要判斷 “這是噪聲” 而不是 “0 的一部分”,避免誤判為 “6” 或 “8”。

二、模型訓練

? ? ? ?1、MNIST數據集

????????MNIST(Modified National Institute of Standards and Technology database)是由美國國家標準與技術研究院(NIST)整理的手寫數字數據集,后經修改(調整圖像大小、居中對齊)成為機器學習領域的 “基準數據集”,MNIST手寫數字識別的核心是 “讓計算機從標準化的手寫數字灰度圖中,自動識別出對應的 0-9 數字”,它看似基礎,卻濃縮了圖像分類的核心挑戰(風格多樣性、噪聲魯棒性、特征自動提取),同時是實際 OCR 場景的技術基礎和機器學習入門的經典案例。

  • 數據量適中:包含 70000 張圖像,其中 60000 張用于訓練(讓模型學習特征),10000 張用于測試(驗證模型泛化能力);
  • 圖像規格統一:所有圖像都是 28×28 灰度圖,無需復雜的預處理(如尺寸縮放、顏色通道處理),降低入門門檻;
  • 標注準確:每張圖像都有明確的 “正確數字標簽”(人工標注),無需額外標注成本。

? ? ? ? 2、代碼

  • 數據準備:使用torchvision.datasets加載 MNIST 數據集,對數據進行轉換(轉為 Tensor 并標準化),使用DataLoader創建可迭代的數據加載器;
  • 模型定義:定義了一個簡單的兩層神經網絡SimpleNN,第一層將 28x28 的圖像展平后映射到 128 維,第二層將 128 維特征映射到 10 個類別(對應數字 0-9);
  • 訓練設置:使用交叉熵損失函數(CrossEntropyLoss),使用 Adam 優化器,設置批量大小為64,訓練輪次為5;
  • 訓練過程:循環多個訓練輪次(epoch),每個輪次中迭代所有批次數據,執行前向傳播、計算損失、反向傳播和參數更新。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 設置隨機種子,確保結果可復現
torch.manual_seed(42)# 1. 數據準備
# 定義數據變換
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為Tensortransforms.Normalize((0.1307,), (0.3081,))  # 標準化,MNIST數據集的均值和標準差
])# 加載MNIST數據集
train_dataset = datasets.MNIST(root='./data',  # 數據保存路徑train=True,  # 訓練集download=True,  # 如果數據不存在則下載transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,  # 測試集download=True,transform=transform
)# 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 2. 定義模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 輸入層到隱藏層self.fc1 = nn.Linear(28 * 28, 128)  # MNIST圖像大小為28x28# 隱藏層到輸出層self.fc2 = nn.Linear(128, 10)  # 10個類別(0-9)def forward(self, x):# 將圖像展平為一維向量x = x.view(-1, 28 * 28)# 隱藏層,使用ReLU激活函數x = torch.relu(self.fc1(x))# 輸出層,不使用激活函數(因為后面會用CrossEntropyLoss)x = self.fc2(x)return x# 3. 初始化模型、損失函數和優化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()  # 交叉熵損失,適用于分類問題
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器# 4. 訓練模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()  # 設置為訓練模式train_losses = []for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):# 清零梯度optimizer.zero_grad()# 前向傳播outputs = model(data)loss = criterion(outputs, target)# 反向傳播和優化loss.backward()optimizer.step()running_loss += loss.item()# 每100個批次打印一次信息if batch_idx % 100 == 99:print(f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')running_loss = 0.0train_losses.append(running_loss / len(train_loader))return train_losses# 6. 運行訓練和測試
if __name__ == '__main__':# 訓練模型print("開始訓練模型...")train_losses = train(model, train_loader, criterion, optimizer, epochs=5)print("模型訓練完成...")# 保存模型torch.save(model.state_dict(), 'mnist_model.pth')print("模型已保存為 mnist_model.pth")

三、模型使用測試

import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms  # 修正transforms的導入方式# 定義與訓練時相同的模型結構
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 加載模型
def load_model(model_path='mnist_model.pth'):model = SimpleNN()# 加載模型時添加參數以避免潛在的Python 3兼容性問題model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))model.eval()  # 設置為評估模式return model# 圖像預處理(與訓練時保持一致)
def preprocess_image(image_path):# 打開圖像并轉換為灰度圖img = Image.open(image_path).convert('L')  # 'L'表示灰度模式# 調整大小為28x28img = img.resize((28, 28))# 轉換為numpy數組并歸一化img_array = np.array(img) / 255.0# 定義圖像轉換(使用torchvision的transforms)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 注意:這里需要先將numpy數組轉換為PIL圖像再應用transformimg_pil = Image.fromarray((img_array * 255).astype(np.uint8))img_tensor = transform(img_pil).unsqueeze(0)  # 增加批次維度return img_tensor# 預測函數
def predict_digit(model, image_path):# 預處理圖像img_tensor = preprocess_image(image_path)# 預測with torch.no_grad():  # 不計算梯度outputs = model(img_tensor)_, predicted = torch.max(outputs.data, 1)return predicted.item()  # 返回預測的數字# 示例使用
if __name__ == '__main__':# 加載模型model = load_model('mnist_model.pth')# 預測示例圖像test_image_path = 'test_digit.png'  # 用戶需要提供的測試圖像路徑try:predicted_digit = predict_digit(model, test_image_path)print(f"預測的數字是: {predicted_digit}")except Exception as e:print(f"預測出錯: {str(e)}")

使用gpu0(第一塊gpu)進行訓練/推理:
????????torch.cuda.set_device(0) ???
????????model = model.cuda(0)
使用cpu記性訓練/推理:
????????model = model.cpu()


怎么用pytorch訓練一個模型-手寫數字識別
手把手教你如何跑通一個手寫中文漢字識別模型-OCR識別【pytorch】
手把手教你用PyTorch從零訓練自己的大模型(非常詳細)零基礎入門到精通,收藏這一篇就夠了
揭秘大模型的訓練方法:使用PyTorch進行超大規模深度學習模型訓練
全套解決方案:基于pytorch、transformers的中文NLP訓練框架,支持大模型訓練和文本生成,快速上手,海量訓練數據!
用 pytorch 從零開始創建大語言模型(三):編碼注意力機制

YOLOv5源碼逐行超詳細注釋與解讀(1)——項目目錄結構解析

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/98023.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/98023.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/98023.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

新啟航開啟深孔測量新紀元:激光頻率梳技術攻克光學遮擋,達 130mm 深度 2μm 精度

摘要:本文聚焦于深孔測量領域,介紹了一種創新的激光頻率梳技術。該技術成功攻克傳統測量中的光學遮擋難題,在深孔測量深度達 130mm 時,可實現 2μm 的高精度測量,為深孔測量開啟了新的發展篇章。關鍵詞:激光…

GEO優化推薦:AI搜索新紀元下的品牌內容權威構建

引言:AI搜索引擎崛起與GEO策略的戰略重心轉移2025年,以ChatGPT、百度文心一言、DeepSeek為代表的AI搜索引擎已深入成為公眾信息獲取的核心渠道。這標志著品牌營銷策略的重心,正從傳統的搜索引擎優化(SEO)加速向生成式引…

uniapp的上拉加載H5和小程序

小程序配置{"path": "list/course-list","style": {"navigationBarTitleText": "課程列表","enablePullDownRefresh": true,"onReachBottomDistance": 150}}上拉拉觸底鉤子onReachBottom() {var that …

【和春筍一起學C++】(四十)抽象數據類型

抽象數據類型(abstract data type, ADT)以通用的方式描述數據類型。C中類的概念非常適合于ADT方法。例如,C程序通過堆棧來管理自動變量,堆棧可由對它執行的操作來描述。可創建空堆棧;可將數據項添加到堆頂(…

大文件斷點續傳解決方案:基于Vue 2與Spring Boot的完整實現

大文件斷點續傳解決方案:基于Vue 2與Spring Boot的完整實現 在現代Web應用中,大文件上傳是一個常見但具有挑戰性的需求。傳統的文件上傳方式在面對網絡不穩定、大文件傳輸時往往表現不佳。本文將詳細介紹如何實現一個支持斷點續傳的大文件上傳功能,結合Vue 2前端和Spring Bo…

LeNet-5:手寫數字識別經典CNN

配套講解視頻,點擊下方名片獲取20 世紀 90 年代,計算機已經能識別文本,但圖片識別很困難。比如銀行支票的手寫數字識別,傳統方法需要人工設計規則,費時費力且精度不高。 于是,Yann LeCun 及其團隊提出了 Le…

如何在 C# 中將文本轉換為 Word 以及將 Word 轉換為文本

在現代軟件開發中,處理文檔內容是一個非常常見的需求。無論是生成報告、存儲日志,還是處理用戶輸入,開發者都可能需要在純文本與 Word 文檔之間進行轉換。有時需要將文本轉換為 Word,以便生成結構化的 .docx 文件,使內…

Open SWE:重構代碼協作的智能范式——從規劃到PR的全流程自動化革命

在軟件開發的演進史上,工具鏈的每一次革新都深刻重塑著開發者的工作方式。LangChain AI推出的Open SWE,作為首個開源的異步編程代理,正在重新定義代碼協作的邊界——它不再僅僅是代碼生成工具,而是構建了從代碼庫分析、方案規劃、代碼實現到拉取請求創建的端到端自動化工作…

【ARDUINO】通過ESP8266控制電機【待測試】

需求 通過Wi-Fi控制Arduino驅動的3V直流電機。這個方案使用外部6V或9V電源,ESP8266作為Wi-Fi模塊,Arduino作為主控制器,L298N作為電機驅動器。 手機/電腦 (Wi-Fi客戶端) | | (Wi-Fi) | ESP8266 (Wi-Fi模塊, AT指令模式) | | (串口通信) | A…

cuda編程筆記(18)-- 使用im2col + GEMM 實現卷積

我們之前介紹了cudnn調用api直接實現卷積,本文我們探究手動實現。對于直接使用for循環在cpu上的實現方法,就不過多介紹,只要了解卷積的原理,就很容易實現。im2col 的核心思想im2col image to column把輸入 feature map 的每個卷積…

Loopback for Mac:一鍵打造虛擬音頻矩陣,實現跨應用音頻自由流轉

虛擬音頻設備創建 模擬物理設備:Loopback允許用戶在Mac上創建虛擬音頻設備,這些設備可被系統及其他應用程序識別為真實硬件,實現音頻的虛擬化傳輸。多源聚合:支持將麥克風、應用程序(如Skype、Zoom、GarageBand、Logic…

深入解析Django重定向機制

概述 核心是一個基類 HttpResponseRedirectBase,以及兩個具體的子類 HttpResponseRedirect(302 臨時重定向)和 HttpResponsePermanentRedirect(301 永久重定向)。它們都是 HttpResponse 的子類,專門用于告訴…

【Java實戰?】從IO到NIO:Java高并發編程的飛躍

目錄一、NIO 與 IO 的深度剖析1.1 IO 的局限性1.2 NIO 核心特性1.3 NIO 核心組件1.4 NIO 適用場景二、NIO 核心組件實戰2.1 Buffer 緩沖區2.2 Channel 通道2.3 Selector 選擇器2.4 NIO 文件操作案例三、NIO2.0 實戰3.1 Path 類3.2 Files 類3.3 Files 類高級操作3.4 NIO2.0 實戰…

OpenCV 實戰:圖像模板匹配與旋轉處理實現教程

目錄 一、功能概述:代碼能做什么? 二、環境準備:先搭好運行基礎 1. 安裝 Python 2. 安裝 OpenCV 庫 3. 準備圖像文件 三、代碼逐段解析:從基礎到核心 1. 導入 OpenCV 庫 2. 讀取圖像文件 3. 模板圖像旋轉:處理…

一、cadence的安裝及入門教學(反相器的設計與仿真)

一、Cadence的安裝 1、安裝VMware虛擬機 2、安裝帶有cadence軟件的Linux系統 注:網盤鏈接 分享鏈接:https://disk.ningsuan.com.cn/#s/8XaVdtRQ 訪問密碼:11111 所有文件壓縮包及文檔密碼: Cadence_ic 3、安裝tsmc18工藝庫…

用ai寫了個UE5插件

文章目錄實際需求1.頭文件2.源文件3.用法小結實際需求 這個需求來源于之前的一個項目,當時用了一個第三方插件,里邊有一些繪制線段的代碼,c層用的是drawdebugline,當時看底層,覺得應該沒問題,不應該在rele…

機器學習從入門到精通 - 強化學習初探:Q-Learning到Deep Q-Network實戰

機器學習從入門到精通 - 強化學習初探:從 Q-Learning 到 Deep Q-Network 實戰 一、開場白:推開強化學習這扇門 不知道你有沒有過這種感覺 —— 盯著一個復雜的系統,既想讓它達到某個目標,又苦于無法用傳統規則去精確描述每一步該怎…

【OpenHarmony文件管理子系統】文件訪問接口解析

OpenHarmony文件訪問接口(filemanagement_file_api) 概述 OpenHarmony文件訪問接口(filemanagement_file_api)是開源鴻蒙操作系統中的核心文件系統接口,為應用程序提供了完整的文件IO操作能力。該項目基于Node-API&…

云手機運行是否消耗自身流量?

云手機運行是否消耗自身流量,取決于具體的使用場景和設置:若用戶在連接云手機時,使用的是家中Wi-Fi、辦公室局域網等非移動數據網絡,那么在云手機運行過程中,基本不會消耗用戶自身的移動數據流量,在家中連接…

JavaSe之多線程

一、多線程基本了解 1、多線程基本知識 1.進程:進入到內存中執行的應用程序 2.線程:內存和CPU之間開通的通道->進程中的一個執行單元 3.線程作用:負責當前進程中程序的運行.一個進程中至少有一個線程,一個進程還可以有多個線程,這樣的應用程序就稱之為多線程程序 4.簡單理解…