PyTorch中的Flatten

在 PyTorch 中,Flatten?操作是將多維張量轉換為一維向量的重要操作,常用于卷積神經網絡(CNN)的全連接層之前。以下是 PyTorch 中實現 Flatten 的各種方法及其應用場景。

一、基本 Flatten 方法

1. 使用?torch.flatten()?函數

import torch# 創建一個4D張量 (batch_size, channels, height, width)
x = torch.randn(32, 3, 28, 28)  # 32張28x28的RGB圖像# 展平整個張量
flattened = torch.flatten(x)  # 輸出形狀: [75264] (32*3*28*28)# 從指定維度開始展平
flattened = torch.flatten(x, start_dim=1)  # 輸出形狀: [32, 2352] (保持batch維度)

2. 使用?nn.Flatten?層

import torch.nn as nnflatten = nn.Flatten()  # 默認從第1維開始展平(保持batch維度)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 輸出形狀: [32, 2352]

?可以指定開始和結束維度:

flatten = nn.Flatten(start_dim=1, end_dim=2)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 輸出形狀: [32, 84, 28] (合并了第1和2維)

二、不同場景下的 Flatten 應用

1. CNN 中的典型用法

class CNN(nn.Module):def __init__(self):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 16, 3),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16, 32, 3),nn.ReLU(),nn.MaxPool2d(2))self.flatten = nn.Flatten()self.fc = nn.Linear(32 * 5 * 5, 10)  # 計算展平后的尺寸def forward(self, x):x = self.conv_layers(x)x = self.flatten(x)  # 形狀從 [B, 32, 5, 5] 變為 [B, 800]x = self.fc(x)return x

?2. 手動計算展平后的尺寸

# 計算卷積層輸出尺寸的輔助函數
def conv_output_size(input_size, kernel_size, stride=1, padding=0):return (input_size - kernel_size + 2 * padding) // stride + 1# 計算經過多層卷積和池化后的尺寸
h, w = 28, 28  # 輸入尺寸
h = conv_output_size(h, 3)  # conv1: 26
w = conv_output_size(w, 3)  # conv1: 26
h = conv_output_size(h, 2, 2)  # pool1: 13
w = conv_output_size(w, 2, 2)  # pool1: 13
h = conv_output_size(h, 3)  # conv2: 11
w = conv_output_size(w, 3)  # conv2: 11
h = conv_output_size(h, 2, 2)  # pool2: 5
w = conv_output_size(w, 2, 2)  # pool2: 5
print(f"展平后的特征數: {32 * h * w}")  # 32 * 5 * 5 = 800

三、高級用法

1. 部分展平

# 只展平圖像空間維度,保留通道維度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(start_dim=2)  # 形狀: [32, 3, 784]

?2. 自定義 Flatten 層

class ChannelLastFlatten(nn.Module):"""將通道維度移到最后的展平層"""def forward(self, x):# 輸入形狀: [B, C, H, W]x = x.permute(0, 2, 3, 1)  # [B, H, W, C]return x.reshape(x.size(0), -1)  # [B, H*W*C]

3. 展平特定維度

# 展平批量維度和通道維度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(end_dim=1)  # 形狀: [96, 28, 28] (32*3=96)

四、注意事項

  1. 維度計算:確保展平后的尺寸與全連接層的輸入尺寸匹配

  2. 批量維度:通常保留第0維(batch維度)不被展平

  3. 內存連續性view()需要連續內存,必要時先調用contiguous()

  4. 替代方法x.view(x.size(0), -1)flatten(start_dim=1)的常見替代寫法

五、性能比較

方法優點缺點
torch.flatten()官方推薦,可讀性好
nn.Flatten()可作為網絡層使用需要實例化對象
x.view()最簡潔需要手動計算尺寸
x.reshape()自動處理內存連續性性能略低于view

六、示例代碼

import torch
import torch.nn as nn# 定義一個包含Flatten的完整模型
class ImageClassifier(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.flatten = nn.Flatten()self.classifier = nn.Sequential(nn.Linear(256 * 4 * 4, 1024),  # 假設輸入圖像是32x32nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 10))def forward(self, x):x = self.features(x)x = self.flatten(x)x = self.classifier(x)return x# 使用示例
model = ImageClassifier()
input_tensor = torch.randn(16, 3, 32, 32)  # batch=16, 3通道, 32x32圖像
output = model(input_tensor)
print(output.shape)  # 輸出形狀: [16, 10]

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/77125.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/77125.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/77125.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spring Boot + MyBatis + Maven論壇內容管理系統源碼

項目描述 xxxForum是一個基于Spring Boot MyBatis Maven開發的一個論壇內容管理系統,主要實現了的功能有: 前臺頁面展示數據、廣告展示內容模塊:發帖、評論、帖子分類、分頁、回帖統計、訪問統計、表單驗證用戶模塊:權限、資料…

探索AI編程規范化的利器:Awesome Cursor Rules

在AI輔助編程逐漸成為開發者標配的今天,如何讓AI生成的代碼既符合項目規范又保持高質量,成為開發者面臨的新挑戰。GitHub倉庫**awesome-cursorrules**正是為解決這一問題而生的開源項目,它通過系統化的規則模板庫,重新定義了AI編程的規范邊界。本文將深入解析這一工具的核心…

AnimateCC基礎教學:json數據結構的測試

一.核心代碼: const user1String {"name": "張三", "age": 30, "gender": "男"}; const user1Obj JSON.parse(user1String); console.log("測試1:", user1Obj.name, user1Obj.age, user1Obj.gender);/*const u…

阿里云域名證書自動更新acme.sh

因為阿里云的免費證書只有三個月的有效期,每次更換都比較繁瑣,所以找到了 acme.sh,還有一種 certbot 我沒有去了解,就直接使用了 acme.sh 來更新證書,acme.sh 的主要特點就是: 支持多種 DNS 服務商自動化續…

PDF 中提取數學公式

? 方法一:使用 doc2x extract_formula_imgs Pix2Text 一鍵運行腳本(自動提取識別) 👉 適合你如果用 Python 的話,只需要運行一段腳本即可: ? 🔁 一步搞定腳本(僅需安裝一次&…

SQL并行產生進程數量問題

有一些數據庫性能問題可能是因為同時啟動的并行進程過多造成的,特別常見于RAC節點重啟,很多時候是因為瞬間啟動了幾百個并行進程,導致OS各項指標“彪高”,后臺進程失去響應。最近遇到的一個,是因為SQL語句中寫了/* par…

【Vue-組件】學習筆記

目錄 <<回到導覽組件1.項目1.1.Vue Cli1.2.項目目錄1.3.運行流程1.4.組件的組成1.5.注意事項 2.組件2.1.組件注冊2.2.scoped樣式沖突2.3.data是一個函數2.4.props詳解2.5.data和prop的區別 3.組件通信3.1.父子通信3.1.1.父傳子&#xff08;props&#xff09;3.1.2.子傳父…

【Kafka基礎】單機安裝與配置指南,從零搭建環境

學習Kafka&#xff0c;掌握Kafka的單機部署是理解其分布式特性的第一步。本文將手把手帶你完成Kafka單機環境的安裝、配置及基礎驗證&#xff0c;涵蓋常見問題排查技巧。 1 環境準備 1.1 系統要求 操作系統&#xff1a;CentOS 7.9依賴組件&#xff1a;JDK 8&#xff08;Kafka …

OpenCV 圖形API(21)逐像素操作

操作系統&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 在OpenCV的G-API模塊中&#xff0c;逐像素操作指的是對圖像中的每個像素單獨進行處理的操作。這些操作可以通過G-API的計算圖&#xff08;Graph …

CubeMX配置STM32VET6實現網口通信(無操作系統版-附源碼)

下面是使用CubeMX配置STM32F407VET6,實現以太網通訊(PHY芯片為LAN8720)的具體步驟總結: 一、硬件連接方式: 硬件原理圖: 使用外部晶振為PHY芯片提供時鐘。 STM32F407VET6 與 LAN8720 采用 RMII 模式連接。STM32F407VET6引腳功能(RMII)LAN8720引腳PA1ETH_REF_CLKREF_CL…

Android Compose 中獲取和使用 Context 的完整指南

在 Android Jetpack Compose 中&#xff0c;雖然大多數 UI 組件不再需要直接使用 Context&#xff0c;但有時你仍然需要訪問它來執行一些 Android 平臺特定的操作。以下是幾種在 Compose 中獲取和使用 Context 的方法&#xff1a; 1. 使用 LocalContext 這是 Compose 中最常用…

在VMware下Hadoop分布式集群環境的配置--基于Yarn模式的一個Master節點、兩個Slaver(Worker)節點的配置

你遇到的大部分ubuntu中配置hadoop的問題這里都有解決方法&#xff01;&#xff01;&#xff01;&#xff08;近10000字&#xff09; 概要 在Docker虛擬容器環境下&#xff0c;進行Hadoop-3.2.2分布式集群環境的配置與安裝&#xff0c;完成基于Yarn模式的一個Master節點、兩個…

PID燈控算法

根據代碼分析&#xff0c;以下是針對PID算法和光敏傳感器系統的優化建議&#xff0c;分為算法優化、代碼結構優化和系統級優化三部分&#xff1a; 一、PID算法優化 1. 增量式PID 輸出平滑 // 修改PID計算函數 uint16_t PID_calculation_fun(void) {if(PID_Str_Val.Tdata >…

文件映射mmap與管道文件

在用戶態申請內存&#xff0c;內存內容和磁盤內容建立一一映射 讀寫內存等價于讀寫磁盤 支持隨機訪問 簡單來說&#xff0c;把磁盤里的數據與內存的用戶態建立一一映射關系&#xff0c;讓讀寫內存等價于讀寫磁盤&#xff0c;支持隨機訪問。 管道文件&#xff1a;進程間通信機…

在 Java 中調用 ChatGPT API 并實現流式接收(Server-Sent Events, SSE)

文章目錄 簡介OkHttp 流式獲取 GPT 響應通過 SSE 流式推送前端后端代碼消息實體接口接口實現數據推送給前端 前端代碼創建 sseClient.jsvue3代碼 優化后端代碼 簡介 用過 ChatGPT 的伙伴應該想過自己通過調用ChatGPT官網提供的接口來實現一個自己的問答機器人&#xff0c;但是…

硬盤分區格式之GPT(GUID Partition Table)筆記250407

硬盤分區格式之GPT&#xff08;GUID Partition Table&#xff09;筆記250407 GPT&#xff08;GUID Partition Table&#xff09;硬盤分區格式詳解 GPT&#xff08;GUID Partition Table&#xff09;是替代傳統 MBR 的現代分區方案&#xff0c;專為 UEFI&#xff08;統一可擴展固…

Vite環境下解決跨域問題

在 Vite 開發環境中&#xff0c;可以通過配置代理來解決跨域問題。以下是具體步驟&#xff1a; 在項目根目錄下找到 vite.config.js 文件&#xff1a;如果沒有&#xff0c;則需要創建一個。配置代理&#xff1a;在 vite.config.js 文件中&#xff0c;使用 server.proxy 選項來…

交換機與ARP

交換機與 ARP&#xff08;Address Resolution Protocol&#xff0c;地址解析協議&#xff09; 的關系主要體現在 局域網&#xff08;LAN&#xff09;內設備通信的地址解析與數據幀轉發 過程中。以下是二者的核心關聯&#xff1a; 1. 基本角色 交換機&#xff1a;工作在 數據鏈…

【Spring】小白速通AOP-日志記錄Demo

這篇文章我將通過一個最常用的AOP場景-方法調用日志記錄&#xff0c;帶你徹底理解AOP的使用。例子使用Spring BootSpring AOP實現。 如果對你有幫助可以點個贊和關注。謝謝大家的支持&#xff01;&#xff01; 一、Demo實操步驟&#xff1a; 1.首先添加Maven依賴 <!-- Sp…

git功能點管理

需求&#xff1a; 功能模塊1 已經完成&#xff0c;已經提交并推送到遠程&#xff0c;準備交給測試。功能模塊2 已經完成&#xff0c;但不提交給測試&#xff0c;繼續開發。功能模塊3 正在開發中。 管理流程&#xff1a; 創建并開發功能模塊1&#xff1a; git checkout main…