引言
如果你對人工智能(AI)或深度學習(Deep Learning)感興趣,可能聽說過“Transformer”這個詞。它最初在自然語言處理(NLP)領域大放異彩,比如在翻譯、聊天機器人和文本生成中表現出色。但你知道嗎?Transformer不僅能處理文字,還能用來分類圖像!這聽起來是不是有點神奇?別擔心,這篇博客將帶你從零開始,了解Transformer的基本概念、它如何被應用到圖像分類,以及通過一個簡單的例子讓你直觀理解它的運作原理。無論你是AI新手還是好奇的技術愛好者,這篇文章都會盡量用通俗的語言為你解鎖Transformer的奧秘。
第一部分:Transformer是什么?
Transformer是一種深度學習模型,最早由Vaswani等人在2017年的論文《Attention is All You Need》中提出。它的核心思想是“注意力機制”(Attention Mechanism),這是一種讓模型學會“關注”輸入中重要部分的能力。傳統的模型,比如卷積神經網絡(CNN)和循環神經網絡(RNN),在處理圖像或序列數據時有局限性,而Transformer通過注意力機制突破了這些限制。
1.1 為什么叫“Transformer”?
“Transformer”這個名字聽起來很酷,但它其實反映了模型的功能:它能將輸入數據“轉換”(Transform)成更有意義的表示形式。比如,把一句話翻譯成另一種語言,或者把一張圖片“翻譯”成一個分類標簽(比如“貓”或“狗”)。它的核心在于通過計算輸入數據之間的關系,生成更有用的輸出。
1.2 Transformer的基本結構
Transformer由兩個主要部分組成:編碼器(Encoder)和解碼器(Decoder)。不過,在圖像分類任務中,我們通常只用到編碼器部分。讓我們簡單看看它的組成:
- 輸入嵌入(Input Embedding):把輸入數據(比如單詞或圖像塊)轉換成數字向量。
- 注意力機制(Attention):讓模型關注輸入中最重要的部分。
- 前饋神經網絡(Feed-Forward Network):對數據進一步處理。
- 層歸一化和殘差連接(Layer Normalization & Residual Connection):幫助模型穩定訓練,避免“梯度消失”等問題。
這些組件堆疊在一起,形成多層結構,每一層都讓模型對數據的理解更深一層。
1.3 注意力機制:Transformer的“超能力”
注意力機制是Transformer的核心。想象你在讀一本書,當你看到“貓”這個詞時,你會自動想到整句話的上下文,比如“貓在睡覺”還是“貓在跑”。注意力機制讓模型也能做到這一點:它會計算輸入中每個部分對其他部分的“重要性”,然后根據這些關系調整輸出。
具體來說,Transformer使用的是“自注意力”(Self-Attention)。它會為輸入的每個部分(比如圖像的一個小塊)生成三個向量:
- 查詢(Query):我想知道什么?
- 鍵(Key):我有哪些信息?
- 值(Value):這些信息有多重要?
通過計算查詢和鍵之間的相似度,模型決定每個值的權重,然后把它們加權組合起來。這種方式讓Transformer能捕捉全局關系,而不是像CNN那樣只關注局部區域。
第二部分:從NLP到圖像分類:Vision Transformer (ViT)
Transformer最初是為NLP設計的,那它是怎么“跨界”到圖像分類的呢?這要歸功于2020年提出的Vision Transformer(簡稱ViT)。讓我們看看它是如何工作的。
2.1 圖像怎么變成Transformer的輸入?
圖像和文字完全不同,對吧?圖像是一堆像素,而文字是一串單詞。要讓Transformer處理圖像,第一步就是把圖像“翻譯”成它能理解的形式。ViT的做法是:
- 切分圖像:把一張圖片(比如224x224像素)切成固定大小的小塊(比如16x16像素),就像把一張大拼圖拆成小碎片。
- 展平并嵌入:把每個小塊展平成一個向量(就像把拼圖碎片攤平),然后通過一個線性層把它們變成嵌入向量(Embedding)。
- 加上位置信息:因為Transformer不像CNN有固定的空間感知能力,我們需要手動告訴它每個小塊在圖像中的位置。這通過“位置編碼”(Positional Encoding)實現。
經過這些步驟,一張圖像就變成了一個序列(Sequence),就像NLP中的一句話,只不過這里的“單詞”是圖像塊。
2.2 Transformer處理圖像的過程
一旦圖像被轉換成序列,Transformer的編碼器就開始工作:
- 自注意力:計算每個圖像塊和其他圖像塊之間的關系。比如,在一張貓的圖片中,耳朵和眼睛的圖像塊可能會被關聯起來。
- 多層堆疊:通過多層編碼器,模型逐漸提取更高層次的特征。
分類頭:在最后一層,添加一個簡單的分類層(比如全連接層),輸出圖像的類別(比如“貓”或“狗”)。
2.3 ViT的優勢和挑戰
相比傳統的CNN,ViT有幾個優點:
- 全局視野:它能一次性看到整張圖像的關系,而不像CNN只關注局部。
- 靈活性:同一個模型可以輕松處理不同大小的輸入。
但它也有挑戰:
- 計算量大:自注意力機制需要大量計算,尤其當圖像塊很多時。
- 數據需求高:ViT需要大量標注數據才能訓練得好。
第三部分:一個簡單的例子:用ViT分類貓和狗
為了讓新手更容易理解,我們通過一個具體的例子來說明Transformer如何進行圖像分類。假設我們要訓練一個模型,區分CIFAR-10數據集中的“貓”和“狗”圖片(CIFAR-10是PyTorch內置的一個小型圖像數據集,包含10類32x32像素的圖像)。下面我們逐步拆解過程,并新增代碼實現。
3.1 數據準備
CIFAR-10中的每張圖片是32x32像素,RGB格式。我們將它切成4x4的小塊(為了簡化示例),總共有64個塊(32 ÷ 4 = 8,8x8 = 64)。每個小塊有48個數值(4x4x3,因為RGB有3個通道)。
3.2 嵌入過程
- 把每個小塊展平成一個48維向量。
- 通過一個線性層,把48維映射到一個固定維度(比如32維),得到嵌入向量。
- 加上位置編碼,告訴模型每個塊的位置。
現在,這張圖片變成了一個64x32的矩陣,就像一個有64個“單詞”的序列。
3.3 自注意力計算
假設貓咪的耳朵在第10個塊,眼睛在第20個塊。Transformer會:
- 為每個塊生成查詢、鍵和值向量。
- 計算第10個塊的查詢和第20個塊的鍵之間的相似度,發現它們關系密切。
- 根據相似度加權組合值向量,生成一個新的表示。
經過多層自注意力,模型學會關聯貓的特征。
3.4 分類輸出
在最后一層,ViT取一個特殊的“分類標記”(CLS Token),通過全連接層輸出10個類別的概率(CIFAR-10有10類),比如“貓”的概率是0.8,“狗”是0.1。
3.5 代碼實現
下面我們提供兩種代碼實現方式,幫助你直觀感受ViT的運作。代碼基于PyTorch,使用CIFAR-10數據集。
實現方式1:從頭實現一個簡化的ViT
這個實現簡化了ViT的核心組件,適合理解原理。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 超參數
patch_size = 4 # 切分圖像為4x4的小塊
embed_dim = 32 # 每個小塊的嵌入維度
num_heads = 4 # 注意力頭的數量
num_classes = 10 # CIFAR-10有10個類別
num_patches = (32 // patch_size) ** 2 # 64個小塊 (32x32圖像)# 數據加載
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)# 簡化的ViT模型
class SimpleViT(nn.Module):def __init__(self):super(SimpleViT, self).__init__()# 將圖像塊映射到嵌入空間self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, embed_dim)# 位置編碼self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))# CLS Tokenself.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))# Transformer編碼器self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads), num_layers=2)# 分類頭self.fc = nn.Linear(embed_dim, num_classes)def forward(self, x):b, c, h, w = x.shape # [batch_size, 3, 32, 32]# 切分成小塊并展平x = x.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5).contiguous() # [b, 8, 8, 3, 4, 4]x = x.view(b, num_patches, -1) # [b, 64, 48]# 映射到嵌入空間x = self.patch_to_embedding(x) # [b, 64, 32]# 添加CLS Tokencls_tokens = self.cls_token.expand(b, -1, -1) # [b, 1, 32]x = torch.cat((cls_tokens, x), dim=1) # [b, 65, 32]# 加上位置編碼x = x + self.pos_embedding# 通過Transformerx = self.transformer(x) # [b, 65, 32]# 取CLS Token的輸出進行分類x = self.fc(x[:, 0]) # [b, 10]return x# 訓練模型
model = SimpleViT()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(5): # 訓練5個epochfor images, labels in trainloader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')
代碼解釋:
- 數據加載:從CIFAR-10加載32x32的圖像,歸一化處理。
- 圖像切分:將32x32圖像切成64個4x4的小塊,展平后映射到32維嵌入。
- CLS Token:添加一個特殊標記,用于最終分類。
- Transformer:使用PyTorch內置的Transformer編碼器,包含2層,每層有4個注意力頭。
- 訓練:簡單訓練5個epoch,優化分類損失。
實現方式2:使用預訓練ViT模型(Hugging Face)
這個實現利用Hugging Face的預訓練ViT模型,適合快速上手。
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 數據加載
transform = transforms.Compose([transforms.Resize((224, 224)), # ViT需要224x224輸入transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)# 加載預訓練ViT模型和特征提取器
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10) # 修改分類頭為10類# 訓練設置
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 訓練模型
model.train()
for epoch in range(3): # 訓練3個epochfor images, labels in trainloader:inputs = feature_extractor(images=[img.permute(1, 2, 0).numpy() for img in images], return_tensors="pt")inputs = {k: v for k, v in inputs.items()} # 轉換為模型輸入格式optimizer.zero_grad()outputs = model(**inputs).logits # 獲取分類輸出loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')
代碼解釋:
- 數據預處理:將CIFAR-10圖像調整到224x224(ViT預訓練模型的要求)。
- 預訓練模型:加載Google的vit-base-patch16-224,替換分類頭為10類。
- 特征提取器:自動處理圖像輸入,切分并嵌入。
- 訓練:微調模型,適應CIFAR-10任務。
注意:運行第二種方式需要安裝transformers庫(pip install transformers)。
第四部分:新手常見問題解答
4.1 Transformer和CNN有什么不同?
CNN像一個放大鏡,逐步掃描圖像的局部特征;而Transformer像一個全景相機,一次性捕捉全局關系。兩者各有千秋,ViT證明了Transformer也能在圖像任務中大放異彩。
4.2 我需要多強的編程基礎才能用Transformer?
好消息是,你不需要從頭寫Transformer!開源工具(如PyTorch和Hugging Face)提供了預訓練模型。你只需要學會加載模型、準備數據和微調,就能上手。
4.3 ViT適合所有圖像任務嗎?
不完全是。ViT在大數據集(如ImageNet)上表現很好,但在小數據集或需要精細局部特征的任務上,CNN可能更合適。
第五部分
Transformer通過注意力機制和全局視野,為圖像分類帶來了新思路。Vision Transformer(ViT)展示了它如何將圖像切分成塊,像處理句子一樣處理圖片,最終實現分類。對于新手來說,理解它的關鍵在于:
- 圖像如何變成序列。
- 自注意力如何捕捉關系。
- 分類如何通過簡單輸出實現。
通過上面的代碼示例,你可以看到:
- 從頭實現ViT幫助理解原理。
- 使用預訓練模型能快速應用到實際任務。