《Transformer如何進行圖像分類:從新手到入門》

引言

如果你對人工智能(AI)或深度學習(Deep Learning)感興趣,可能聽說過“Transformer”這個詞。它最初在自然語言處理(NLP)領域大放異彩,比如在翻譯、聊天機器人和文本生成中表現出色。但你知道嗎?Transformer不僅能處理文字,還能用來分類圖像!這聽起來是不是有點神奇?別擔心,這篇博客將帶你從零開始,了解Transformer的基本概念、它如何被應用到圖像分類,以及通過一個簡單的例子讓你直觀理解它的運作原理。無論你是AI新手還是好奇的技術愛好者,這篇文章都會盡量用通俗的語言為你解鎖Transformer的奧秘。

第一部分:Transformer是什么?

Transformer是一種深度學習模型,最早由Vaswani等人在2017年的論文《Attention is All You Need》中提出。它的核心思想是“注意力機制”(Attention Mechanism),這是一種讓模型學會“關注”輸入中重要部分的能力。傳統的模型,比如卷積神經網絡(CNN)和循環神經網絡(RNN),在處理圖像或序列數據時有局限性,而Transformer通過注意力機制突破了這些限制。

1.1 為什么叫“Transformer”?

“Transformer”這個名字聽起來很酷,但它其實反映了模型的功能:它能將輸入數據“轉換”(Transform)成更有意義的表示形式。比如,把一句話翻譯成另一種語言,或者把一張圖片“翻譯”成一個分類標簽(比如“貓”或“狗”)。它的核心在于通過計算輸入數據之間的關系,生成更有用的輸出。

1.2 Transformer的基本結構

Transformer由兩個主要部分組成:編碼器(Encoder)和解碼器(Decoder)。不過,在圖像分類任務中,我們通常只用到編碼器部分。讓我們簡單看看它的組成:

  • 輸入嵌入(Input Embedding):把輸入數據(比如單詞或圖像塊)轉換成數字向量。
  • 注意力機制(Attention):讓模型關注輸入中最重要的部分。
  • 前饋神經網絡(Feed-Forward Network):對數據進一步處理。
  • 層歸一化和殘差連接(Layer Normalization & Residual Connection):幫助模型穩定訓練,避免“梯度消失”等問題。

這些組件堆疊在一起,形成多層結構,每一層都讓模型對數據的理解更深一層。

1.3 注意力機制:Transformer的“超能力”

注意力機制是Transformer的核心。想象你在讀一本書,當你看到“貓”這個詞時,你會自動想到整句話的上下文,比如“貓在睡覺”還是“貓在跑”。注意力機制讓模型也能做到這一點:它會計算輸入中每個部分對其他部分的“重要性”,然后根據這些關系調整輸出。

具體來說,Transformer使用的是“自注意力”(Self-Attention)。它會為輸入的每個部分(比如圖像的一個小塊)生成三個向量:

  • 查詢(Query):我想知道什么?
  • 鍵(Key):我有哪些信息?
  • 值(Value):這些信息有多重要?

通過計算查詢和鍵之間的相似度,模型決定每個值的權重,然后把它們加權組合起來。這種方式讓Transformer能捕捉全局關系,而不是像CNN那樣只關注局部區域。

第二部分:從NLP到圖像分類:Vision Transformer (ViT)

Transformer最初是為NLP設計的,那它是怎么“跨界”到圖像分類的呢?這要歸功于2020年提出的Vision Transformer(簡稱ViT)。讓我們看看它是如何工作的。

2.1 圖像怎么變成Transformer的輸入?

圖像和文字完全不同,對吧?圖像是一堆像素,而文字是一串單詞。要讓Transformer處理圖像,第一步就是把圖像“翻譯”成它能理解的形式。ViT的做法是:

  1. 切分圖像:把一張圖片(比如224x224像素)切成固定大小的小塊(比如16x16像素),就像把一張大拼圖拆成小碎片。
  2. 展平并嵌入:把每個小塊展平成一個向量(就像把拼圖碎片攤平),然后通過一個線性層把它們變成嵌入向量(Embedding)。
  3. 加上位置信息:因為Transformer不像CNN有固定的空間感知能力,我們需要手動告訴它每個小塊在圖像中的位置。這通過“位置編碼”(Positional Encoding)實現。

經過這些步驟,一張圖像就變成了一個序列(Sequence),就像NLP中的一句話,只不過這里的“單詞”是圖像塊。

2.2 Transformer處理圖像的過程

一旦圖像被轉換成序列,Transformer的編碼器就開始工作:

  • 自注意力:計算每個圖像塊和其他圖像塊之間的關系。比如,在一張貓的圖片中,耳朵和眼睛的圖像塊可能會被關聯起來。
  • 多層堆疊:通過多層編碼器,模型逐漸提取更高層次的特征。
    分類頭:在最后一層,添加一個簡單的分類層(比如全連接層),輸出圖像的類別(比如“貓”或“狗”)。

2.3 ViT的優勢和挑戰

相比傳統的CNN,ViT有幾個優點:

  • 全局視野:它能一次性看到整張圖像的關系,而不像CNN只關注局部。
  • 靈活性:同一個模型可以輕松處理不同大小的輸入。

但它也有挑戰:

  • 計算量大:自注意力機制需要大量計算,尤其當圖像塊很多時。
  • 數據需求高:ViT需要大量標注數據才能訓練得好。

第三部分:一個簡單的例子:用ViT分類貓和狗

為了讓新手更容易理解,我們通過一個具體的例子來說明Transformer如何進行圖像分類。假設我們要訓練一個模型,區分CIFAR-10數據集中的“貓”和“狗”圖片(CIFAR-10是PyTorch內置的一個小型圖像數據集,包含10類32x32像素的圖像)。下面我們逐步拆解過程,并新增代碼實現。

3.1 數據準備

CIFAR-10中的每張圖片是32x32像素,RGB格式。我們將它切成4x4的小塊(為了簡化示例),總共有64個塊(32 ÷ 4 = 8,8x8 = 64)。每個小塊有48個數值(4x4x3,因為RGB有3個通道)。

3.2 嵌入過程

  • 把每個小塊展平成一個48維向量。
  • 通過一個線性層,把48維映射到一個固定維度(比如32維),得到嵌入向量。
  • 加上位置編碼,告訴模型每個塊的位置。

現在,這張圖片變成了一個64x32的矩陣,就像一個有64個“單詞”的序列。

3.3 自注意力計算

假設貓咪的耳朵在第10個塊,眼睛在第20個塊。Transformer會:

  1. 為每個塊生成查詢、鍵和值向量。
  2. 計算第10個塊的查詢和第20個塊的鍵之間的相似度,發現它們關系密切。
  3. 根據相似度加權組合值向量,生成一個新的表示。

經過多層自注意力,模型學會關聯貓的特征。

3.4 分類輸出

在最后一層,ViT取一個特殊的“分類標記”(CLS Token),通過全連接層輸出10個類別的概率(CIFAR-10有10類),比如“貓”的概率是0.8,“狗”是0.1。

3.5 代碼實現

下面我們提供兩種代碼實現方式,幫助你直觀感受ViT的運作。代碼基于PyTorch,使用CIFAR-10數據集。

實現方式1:從頭實現一個簡化的ViT

這個實現簡化了ViT的核心組件,適合理解原理。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 超參數
patch_size = 4  # 切分圖像為4x4的小塊
embed_dim = 32  # 每個小塊的嵌入維度
num_heads = 4   # 注意力頭的數量
num_classes = 10  # CIFAR-10有10個類別
num_patches = (32 // patch_size) ** 2  # 64個小塊 (32x32圖像)# 數據加載
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)# 簡化的ViT模型
class SimpleViT(nn.Module):def __init__(self):super(SimpleViT, self).__init__()# 將圖像塊映射到嵌入空間self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, embed_dim)# 位置編碼self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))# CLS Tokenself.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))# Transformer編碼器self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads), num_layers=2)# 分類頭self.fc = nn.Linear(embed_dim, num_classes)def forward(self, x):b, c, h, w = x.shape  # [batch_size, 3, 32, 32]# 切分成小塊并展平x = x.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5).contiguous()  # [b, 8, 8, 3, 4, 4]x = x.view(b, num_patches, -1)  # [b, 64, 48]# 映射到嵌入空間x = self.patch_to_embedding(x)  # [b, 64, 32]# 添加CLS Tokencls_tokens = self.cls_token.expand(b, -1, -1)  # [b, 1, 32]x = torch.cat((cls_tokens, x), dim=1)  # [b, 65, 32]# 加上位置編碼x = x + self.pos_embedding# 通過Transformerx = self.transformer(x)  # [b, 65, 32]# 取CLS Token的輸出進行分類x = self.fc(x[:, 0])  # [b, 10]return x# 訓練模型
model = SimpleViT()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):  # 訓練5個epochfor images, labels in trainloader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

代碼解釋:

  • 數據加載:從CIFAR-10加載32x32的圖像,歸一化處理。
  • 圖像切分:將32x32圖像切成64個4x4的小塊,展平后映射到32維嵌入。
  • CLS Token:添加一個特殊標記,用于最終分類。
  • Transformer:使用PyTorch內置的Transformer編碼器,包含2層,每層有4個注意力頭。
  • 訓練:簡單訓練5個epoch,優化分類損失。
實現方式2:使用預訓練ViT模型(Hugging Face)

這個實現利用Hugging Face的預訓練ViT模型,適合快速上手。

import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 數據加載
transform = transforms.Compose([transforms.Resize((224, 224)),  # ViT需要224x224輸入transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)# 加載預訓練ViT模型和特征提取器
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)  # 修改分類頭為10類# 訓練設置
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 訓練模型
model.train()
for epoch in range(3):  # 訓練3個epochfor images, labels in trainloader:inputs = feature_extractor(images=[img.permute(1, 2, 0).numpy() for img in images], return_tensors="pt")inputs = {k: v for k, v in inputs.items()}  # 轉換為模型輸入格式optimizer.zero_grad()outputs = model(**inputs).logits  # 獲取分類輸出loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

代碼解釋:

  • 數據預處理:將CIFAR-10圖像調整到224x224(ViT預訓練模型的要求)。
  • 預訓練模型:加載Google的vit-base-patch16-224,替換分類頭為10類。
  • 特征提取器:自動處理圖像輸入,切分并嵌入。
  • 訓練:微調模型,適應CIFAR-10任務。

注意:運行第二種方式需要安裝transformers庫(pip install transformers)。

第四部分:新手常見問題解答

4.1 Transformer和CNN有什么不同?

CNN像一個放大鏡,逐步掃描圖像的局部特征;而Transformer像一個全景相機,一次性捕捉全局關系。兩者各有千秋,ViT證明了Transformer也能在圖像任務中大放異彩。

4.2 我需要多強的編程基礎才能用Transformer?

好消息是,你不需要從頭寫Transformer!開源工具(如PyTorch和Hugging Face)提供了預訓練模型。你只需要學會加載模型、準備數據和微調,就能上手。

4.3 ViT適合所有圖像任務嗎?

不完全是。ViT在大數據集(如ImageNet)上表現很好,但在小數據集或需要精細局部特征的任務上,CNN可能更合適。

第五部分

Transformer通過注意力機制和全局視野,為圖像分類帶來了新思路。Vision Transformer(ViT)展示了它如何將圖像切分成塊,像處理句子一樣處理圖片,最終實現分類。對于新手來說,理解它的關鍵在于:

  1. 圖像如何變成序列。
  2. 自注意力如何捕捉關系。
  3. 分類如何通過簡單輸出實現。

通過上面的代碼示例,你可以看到:

  • 從頭實現ViT幫助理解原理。
  • 使用預訓練模型能快速應用到實際任務。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/73265.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/73265.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/73265.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Java --- 根據身份證號計算年齡

介紹 根據身份證號計算年齡 Java代碼 /*** 根據身份證號計算年齡* param birthDateStr* return*/public static int calculateAge(String birthDateStr) {try {birthDateStrbirthDateStr.substring(6,68);// 定義日期格式SimpleDateFormat sdf new SimpleDateFormat("…

零成本搭建Calibre個人數字圖書館支持EPUB MOBI格式遠程直讀

文章目錄 前言1.網絡書庫軟件下載安裝2.網絡書庫服務器設置3.內網穿透工具設置4.公網使用kindle訪問內網私人書庫 前言 嘿,各位書蟲們!今天要給大家安利一個超級炫酷的技能——如何在本地Windows電腦上搭建自己的私人云端書庫。亞馬遜服務停了&#xff…

【Linux 指北】常用 Linux 指令匯總

第一章、常用基本指令 # 注意: # #表示管理員 # $表示普通用戶 [rootlocalhost Practice]# 說明此處表示管理員01. ls 指令 語法: ls [選項][目錄或文件] 功能:對于目錄,該命令列出該目錄下的所有子目錄與文件。對于文件&#xf…

跟蹤napi_gro_receive_entry時IP頭信息缺失的分析

問題描述 在使用eBPF程序跟蹤napi_gro_receive_entry內核跟蹤點時,發現獲取到的IP頭部字段(如saddr、daddr、protocol)為空值。 代碼如下: /* 自定義結構體來映射 napi_gro_receive_entry tracepoint 的 format */ struct napi…

Android子線程更新View的方法原理

對于所有的Android開發者來說,“View的更新必須在UI線程中進行”是一項最基本常識。 如果不在UI線程中更新View,系統會拋出CalledFromWrongThreadException異常。那么有沒有什么辦法可以不在UI線程中更新View?答案當然是有的! 一…

【Manus資料合集】激活碼內測渠道+《Manus Al:Agent應用的ChatGPT時刻》(附資源)

DeepSeek 之后,又一個AI沸騰,沖擊的不僅僅是通用大模型。 ——全球首款通用AI Agent的破圈啟示錄 2025年3月6日凌晨,全球AI圈被一款名為Manus的產品徹底點燃。由Monica團隊(隸屬中國夜鶯科技)推出的“全球首款通用AI…

Python----計算機視覺處理(opencv:像素,RGB顏色,圖像的存儲,opencv安裝,代碼展示)

一、計算機眼中的圖像 像素 像素是圖像的基本單元,每個像素存儲著圖像的顏色、亮度和其他特征。一系列像素組合到一起就形成 了完整的圖像,在計算機中,圖像以像素的形式存在并采用二進制格式進行存儲。根據圖像的顏色不 同,每個像…

SQLiteStudio:一款免費跨平臺的SQLite管理工具

SQLiteStudio 是一款專門用于管理和操作 SQLite 數據庫的免費工具。它提供直觀的圖形化界面,簡化了數據庫的創建、編輯、查詢和維護,適合數據庫開發者和數據分析師使用。 功能特性 SQLiteStudio 提供的主要功能包括: 免費開源,可…

【軟考網工-實踐篇】DHCP 動態主機配置協議

一、DHCP簡介 DHCP,Dynamic Host Configuration Protocol,動態主機配置協議。 位置:DHCP常見運行于路由器上,作為DHCP服務器功能:用于自動分配IP地址及其他網絡參數給網絡中的設備作用:簡化網絡管理&…

【Linux學習筆記】Linux用戶和文件權限的深度剖析

【Linux學習筆記】Linux用戶和文件權限的深度剖析 🔥個人主頁:大白的編程日記 🔥專欄:Linux學習筆記 前言 文章目錄 【Linux學習筆記】Linux用戶和文件權限的深度剖析前言一. Linux權限管理1.1 文件訪問者的分類(人)…

Centos離線安裝openssl-devel

文章目錄 Centos離線安裝openssl-devel1. openssl-devel是什么?2. openssl-devel下載地址3. openssl-devel安裝4. 安裝結果驗證 Centos離線安裝openssl-devel 1. openssl-devel是什么? openssl-devel 是 Linux 系統中與 OpenSSL 加密庫相關的開發包&…

深度學習篇---Opencv中Haar級聯分類器的自定義

文章目錄 1. 準備工作1.1安裝 OpenCV1.2準備數據集1.2.1正樣本1.2.2負樣本 2. 數據準備2.1 正樣本的準備2.1.1步驟2.1.2生成正樣本描述文件2.1.3示例命令2.1.4正樣本描述文件格式 2.2 負樣本的準備2.2.1步驟2.2.2負樣本描述文件格式 3. 訓練分類器3.1命令格式3.2參數說明 4. 訓…

Smart Time Plus smarttimeplus-MySQLConnection SQL注入漏洞(CVE-2024-53544)

免責聲明 本文所描述的漏洞及其復現步驟僅供網絡安全研究與教育目的使用。任何人不得將本文提供的信息用于非法目的或未經授權的系統測試。作者不對任何由于使用本文信息而導致的直接或間接損害承擔責任。如涉及侵權,請及時與我們聯系,我們將盡快處理并刪除相關內容。 0x01…

58.Harmonyos NEXT 圖片預覽組件架構設計與實現原理

溫馨提示:本篇博客的詳細代碼已發布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下載運行哦! Harmonyos NEXT 圖片預覽組件架構設計與實現原理 文章目錄 Harmonyos NEXT 圖片預覽組件架構設計與實現原理效果預覽一、組件架構概述1. 核心組件層…

虛擬機下ubuntu進不了圖形界面

6.844618] piix4_smbus 0000:07.3: SMBus Host ContrFoller not enabled! 7.859836] sd 2:0:0:0:0: [sda] Assuming drive cache: wirite through /dev/sda1: clean, 200424/1966080 files, 4053235/7864064 blocks ubuntu啟動時,卡在上面輸出位置 當前遇到的原因…

Appium高級操作--從源碼角度解析--模擬復雜手勢操作

書接上回,Android自動化--Appium基本操作-CSDN博客文章瀏覽閱讀600次,點贊10次,收藏5次。書接上回,上一篇文章已經介紹了appium在Android端的元素定位方法和識別工具Inspector,本次要介紹使用如何利用Appium對找到的元…

SpringBoot學生宿舍管理系統的設計與開發

項目概述 幽絡源分享的《SpringBoot學生宿舍管理系統的設計與開發》是一款專為校園宿舍管理設計的智能化系統,基于SpringBoot框架開發,功能全面,操作便捷。該系統涵蓋管理員、宿管員和學生三大角色,分別提供宿舍管理、學生信息管…

愛普生溫補晶振 TG5032CFN高精度穩定時鐘的典范

在科技日新月異的當下,眾多領域對時鐘信號的穩定性與精準度提出了極為嚴苛的要求。愛普生溫補晶振TG5032CFN是一款高穩定性溫度補償晶體振蕩器(TCXO)。該器件通過內置溫度補償電路,有效抑制環境溫度變化對頻率穩定性的影響&#x…

【原創】在高性能服務器上,使用受限用戶運行Nginx,充當反向代理服務器[未完待續]

起因 在公共高性能服務器上運行OllamaDeepSeek,如果按照默認配置啟動Ollama程序,則自己在遠程無法連接你啟動的Ollama服務。 如果修改配置,則會遇到你的Ollama被他人完全控制的安全風險。 不過,我們可以使用一個方向代理&#…

Bash和Zsh的主要差異是?

Bash(GNU Bourne-Again Shell) 和 Zsh(Z Shell) 都是功能強大的Unix/Linux Shell,廣泛用于交互式使用和腳本編寫。 盡管它們有很多相似之處,但在功能、語法、配置選項等方面也存在一些顯著的區別。 是Bas…