在 PyTorch 中,Flatten
?操作是將多維張量轉換為一維向量的重要操作,常用于卷積神經網絡(CNN)的全連接層之前。以下是 PyTorch 中實現 Flatten 的各種方法及其應用場景。
一、基本 Flatten 方法
1. 使用?torch.flatten()
?函數
import torch# 創建一個4D張量 (batch_size, channels, height, width)
x = torch.randn(32, 3, 28, 28) # 32張28x28的RGB圖像# 展平整個張量
flattened = torch.flatten(x) # 輸出形狀: [75264] (32*3*28*28)# 從指定維度開始展平
flattened = torch.flatten(x, start_dim=1) # 輸出形狀: [32, 2352] (保持batch維度)
2. 使用?nn.Flatten
?層
import torch.nn as nnflatten = nn.Flatten() # 默認從第1維開始展平(保持batch維度)
x = torch.randn(32, 3, 28, 28)
output = flatten(x) # 輸出形狀: [32, 2352]
?可以指定開始和結束維度:
flatten = nn.Flatten(start_dim=1, end_dim=2)
x = torch.randn(32, 3, 28, 28)
output = flatten(x) # 輸出形狀: [32, 84, 28] (合并了第1和2維)
二、不同場景下的 Flatten 應用
1. CNN 中的典型用法
class CNN(nn.Module):def __init__(self):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 16, 3),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16, 32, 3),nn.ReLU(),nn.MaxPool2d(2))self.flatten = nn.Flatten()self.fc = nn.Linear(32 * 5 * 5, 10) # 計算展平后的尺寸def forward(self, x):x = self.conv_layers(x)x = self.flatten(x) # 形狀從 [B, 32, 5, 5] 變為 [B, 800]x = self.fc(x)return x
?2. 手動計算展平后的尺寸
# 計算卷積層輸出尺寸的輔助函數
def conv_output_size(input_size, kernel_size, stride=1, padding=0):return (input_size - kernel_size + 2 * padding) // stride + 1# 計算經過多層卷積和池化后的尺寸
h, w = 28, 28 # 輸入尺寸
h = conv_output_size(h, 3) # conv1: 26
w = conv_output_size(w, 3) # conv1: 26
h = conv_output_size(h, 2, 2) # pool1: 13
w = conv_output_size(w, 2, 2) # pool1: 13
h = conv_output_size(h, 3) # conv2: 11
w = conv_output_size(w, 3) # conv2: 11
h = conv_output_size(h, 2, 2) # pool2: 5
w = conv_output_size(w, 2, 2) # pool2: 5
print(f"展平后的特征數: {32 * h * w}") # 32 * 5 * 5 = 800
三、高級用法
1. 部分展平
# 只展平圖像空間維度,保留通道維度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(start_dim=2) # 形狀: [32, 3, 784]
?2. 自定義 Flatten 層
class ChannelLastFlatten(nn.Module):"""將通道維度移到最后的展平層"""def forward(self, x):# 輸入形狀: [B, C, H, W]x = x.permute(0, 2, 3, 1) # [B, H, W, C]return x.reshape(x.size(0), -1) # [B, H*W*C]
3. 展平特定維度
# 展平批量維度和通道維度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(end_dim=1) # 形狀: [96, 28, 28] (32*3=96)
四、注意事項
-
維度計算:確保展平后的尺寸與全連接層的輸入尺寸匹配
-
批量維度:通常保留第0維(batch維度)不被展平
-
內存連續性:
view()
需要連續內存,必要時先調用contiguous()
-
替代方法:
x.view(x.size(0), -1)
是flatten(start_dim=1)
的常見替代寫法
五、性能比較
方法 | 優點 | 缺點 |
---|---|---|
torch.flatten() | 官方推薦,可讀性好 | 無 |
nn.Flatten() | 可作為網絡層使用 | 需要實例化對象 |
x.view() | 最簡潔 | 需要手動計算尺寸 |
x.reshape() | 自動處理內存連續性 | 性能略低于view |
六、示例代碼
import torch
import torch.nn as nn# 定義一個包含Flatten的完整模型
class ImageClassifier(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.flatten = nn.Flatten()self.classifier = nn.Sequential(nn.Linear(256 * 4 * 4, 1024), # 假設輸入圖像是32x32nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 10))def forward(self, x):x = self.features(x)x = self.flatten(x)x = self.classifier(x)return x# 使用示例
model = ImageClassifier()
input_tensor = torch.randn(16, 3, 32, 32) # batch=16, 3通道, 32x32圖像
output = model(input_tensor)
print(output.shape) # 輸出形狀: [16, 10]