DreamDiffusion代碼學習及復現

論文解讀在這里

File path | Description
```/pretrains
┣ 📂 models
┃   ┗ 📜 config.yaml
┃   ┗ 📜 v1-5-pruned.ckpt┣ 📂 generation  
┃   ┗ 📜 checkpoint_best.pth ┣ 📂 eeg_pretain
┃   ┗ 📜 checkpoint.pth  (pre-trained EEG encoder)/datasets
┣ 📂 imageNet_images (subset of Imagenet)┗  📜 block_splits_by_image_all.pth
┗  📜 block_splits_by_image_single.pth 
┗  📜 eeg_5_95_std.pth  /code
┣ 📂 sc_mbm
┃   ┗ 📜 mae_for_eeg.py
┃   ┗ 📜 trainer.py
┃   ┗ 📜 utils.py┣ 📂 dc_ldm
┃   ┗ 📜 ldm_for_eeg.py
┃   ┗ 📜 utils.py
┃   ┣ 📂 models
┃   ┃   ┗ (adopted from LDM)
┃   ┣ 📂 modules
┃   ┃   ┗ (adopted from LDM)┗  📜 stageA1_eeg_pretrain.py   (main script for EEG pre-training)
┗  📜 eeg_ldm.py    (main script for fine-tuning stable diffusion)
┗  📜 gen_eval_eeg.py               (main script for generating images)┗  📜 dataset.py                (functions for loading datasets)
┗  📜 eval_metrics.py           (functions for evaluation metrics)
┗  📜 config.py                 (configurations for the main scripts)```

目錄

dataset.py

gen_eval_eeg.py

stageA1_eeg_pretrain.py

eeg_ldm.py

gen_eval_eeg.py


dataset.py

一、基礎工具函數模塊

"沿時間軸進行環形填充"是一種信號處理技術,當數據長度不足時,用數據的起始部分循環填充到末尾(類似"循環播放")

  • 對比其他填充方式

    • 零填充(Zero-pad):[1,2,3] -> [1,2,3,0,0]

    • 環形填充:[1,2,3] -> [1,2,3,1,2]

  • 參數解讀

    • ((0,0), (0, pad_size)):表示只在第二個維度(時間軸)右側填充

    • 'wrap':指定環形填充模式

  • 輸入

    • x.shape = (128, 500)(128個EEG通道,500個時間點)

    • patch_size = 16(每個時間塊包含16個時間點)

  • 計算需要填充的長度

    • 當前時間點:500

    • 需要達到?N × patch_size?的最小長度

    • ceil(500 / 16) = 32?塊 →?32×16=512

    • 需填充:512 - 500 = 12?個時間點

  • 填充操作:從每個通道的起始位置取前12個時間點,拼接到末尾

為什么選擇環形填充?

填充方式優點缺點適用場景
環形填充保持信號周期性
避免邊界突變
可能引入周期性假象EEG/ECG等準周期信號
零填充實現簡單引入高頻噪聲通用場景
鏡像填充平滑邊界計算復雜圖像處理

對于EEG信號:

  • 具有準周期性(alpha/beta波等)

  • 避免零填充導致的頻譜泄漏(spectral leakage)

  • 更適合后續的塊處理(patch劃分)

Z-score標準化(又稱標準差標準化)是一種常見的數據標準化方法,其核心是通過線性變換將原始數據轉換為均值為0、標準差為1的分布。

對于一組數據?x,其標準化值?z的計算公式為:z=(x?μ)/σ

  • μ:數據的均值(平均值)

  • σ:數據的標準差(反映數據離散程度)

二、時間序列處理模塊

?時間窗口

  • 定義:將連續的EEG信號按固定時長分段處理

  • 目的

    • 降低計算復雜度

    • 捕捉局部時域特征

    • 匹配后續處理(如傅里葉變換、模型輸入長度)

  • 8 / 0.75 ≈ 10.67,0.75秒/幀:該數據集的時間分辨率(每幀持續時間)

三、數據增強模塊

四、核心數據集類

1. 預訓練數據集

2. 完整EEG-Image數據集
class EEGDataset(Dataset):def __init__(self, eeg_signals_path):loaded = torch.load(eeg_signals_path)  # 加載預處理數據self.data = [{'eeg': tensor,       # EEG信號 [通道, 時間]'label': int,        # 類別標簽 'image': 'n01440764' # ImageNet ID}, ...]def __getitem__(self, i):# EEG處理eeg = data[i]['eeg'].t()     # 轉置為[時間, 通道]eeg = eeg[20:460]            # 選擇有效時間窗口eeg = interp1d(...)          # 插值到512點# 圖像處理image_path = 'n01440764/n01440764_10026.JPEG'image = Image.open(path)image = processor(image)     # CLIP預處理

五、數據劃分模塊

class Splitter:def __init__(self, dataset, split_path):loaded = torch.load(split_path)self.split_idx = loaded['splits'][0]['train']  # 取第一個劃分方案# 過濾條件:# 1. EEG長度在450-600之間# 2. 被試匹配(當subject!=0時)

六、圖像處理模塊

class random_crop:def __call__(self, img):if 概率p: 執行隨機裁剪else: 返回原圖def normalize2(img):return img * 2.0 - 1.0  # 歸一化到[-1,1]

七、重要技術細節

對齊流程:

sequenceDiagramparticipant EEG_Dataparticipant ImageNetEEG_Data->>EEGDataset: 加載樣本iEEGDataset->>EEG_Data: 讀取self.data[i]["image"]字段EEGDataset->>ImageNet: 根據ID構造路徑ImageNet-->>EEGDataset: 返回對應圖像EEGDataset->>Model: 返回{'eeg':eeg, 'image':image}

gen_eval_eeg.py

基于MAE (Masked Autoencoder) 的EEG信號預訓練框架,主要包含以下核心模塊:

  1. 環境配置與工具函數

  2. 數據加載與預處理

  3. 模型定義與訓練流程

  4. 可視化與日志記錄

  5. 分布式訓練支持

1. 核心模塊解析

2. 關鍵實現細節

4. 可視化模塊

代碼流程圖

graph TDA[初始化配置] --> B[加載數據集]B --> C[構建MAE模型]C --> D[初始化優化器]D --> E[訓練循環]E --> F{達到保存點?}F -- 是 --> G[保存模型+可視化]F -- 否 --> EG --> H[完成訓練]

stageA1_eeg_pretrain.py

Pre-training on EEG data

用于大量訓練的數據集從MOABB上下載,還沒學會,,,,

eeg_ldm.py

Finetune the Stable Diffusion with Pre-trained EEG Encoder

實現了一個基于Latent Diffusion Model (LDM) 的EEG信號到圖像生成的完整流程:


一、代碼整體架構

本代碼是DreamDiffusion項目的第二階段(Stage B),主要包含以下核心模塊:

  1. 配置管理(Config_Generative_Model)

  2. 數據加載與預處理(create_EEG_dataset)

  3. 生成模型定義(eLDM)

  4. 訓練流程控制(main函數)

  5. 圖像生成與評估(generate_images)

  6. 實驗日志記錄(wandb集成)

二、核心組件詳解

1. 配置管理
class Config_Generative_Model:def __init__(self):# 項目參數self.seed = 2022self.root_path = '.'self.eeg_signals_path = 'datasets/eeg_5_95_std.pth'# 模型參數self.pretrain_mbm_path = 'pretrains/generation/checkpoint.pth'self.pretrain_gm_path = 'pretrains/stable-diffusion-v1-5'# 訓練參數self.batch_size = 25self.lr = 5.3e-5self.num_epoch = 500
2. 數據加載
  1. 加載EEG信號和對應的ImageNet圖像路徑

  2. 應用兩種圖像變換:

    • 訓練集:隨機裁剪+歸一化(img_transform_train

    • 測試集:僅歸一化(img_transform_test

  3. 返回包含EEG-圖像對的數據集

3. 生成模型(eLDM)
  • 雙條件機制:同時接受EEG特征和CLIP文本特征

  • 基于Latent Diffusion架構

  • 支持從檢查點恢復訓練

5. 圖像生成與評估
def generate_images(generative_model, dataset, num_samples, ddim_steps):grid, samples = generative_model.generate(dataset, num_samples, ddim_steps)# 保存圖像網格Image.fromarray(grid).save('samples.png')# 計算評估指標metrics = get_eval_metric(samples)return metrics

評估指標

  • 像素級:MSE, PCC, SSIM

  • 語義級:Top-1分類準確率

三、關鍵技術細節

1. 條件擴散模型
graph LRA[EEG信號] --> B[EEG編碼器]C[CLIP文本編碼] --> D[LDM UNet]B --> DD --> E[圖像生成]
2. 雙階段訓練策略
  1. 階段A:預訓練EEG編碼器(MAE架構)

  2. 階段B:微調擴散模型(本代碼)

3. 圖像變換流水線
img_transform_train = transforms.Compose([normalize,                     # 歸一化到[-1,1]transforms.Resize(512),        # 調整大小random_crop(448, p=0.5),       # 隨機裁剪(數據增強)transforms.Resize(512),        # 再次調整channel_last                   # 通道順序轉換
])

gen_eval_eeg.py

Generating Images with Trained Checkpoints

實現了EEG信號到圖像生成的評估流程:

一、代碼整體架構

這段代碼是DreamDiffusion項目的評估部分,主要功能是加載預訓練好的生成模型,對EEG信號進行圖像生成并保存結果。核心模塊包括:

  1. 配置加載:從檢查點恢復實驗配置

  2. 數據準備:加載EEG測試數據集

  3. 模型初始化:構建條件擴散模型(eLDM)

  4. 圖像生成:使用訓練好的模型生成圖像

  5. 結果保存:存儲生成的圖像網格


二、核心組件詳解

圖像變換流程
img_transform_test = transforms.Compose([normalize,                  # 歸一化到[-1,1]transforms.Resize((512,512)), # 調整尺寸channel_last                # 通道順序轉換 (C,H,W)->(H,W,C)
])
  • 數據規格

    • 輸入EEG形狀:(num_samples, 128通道, 512時間點)

    • 輸出圖像尺寸:512×512

3. 模型初始化
generative_model = eLDM(pretrain_mbm_metafile,   # EEG編碼器配置num_voxels,              # 輸入維度=EEG特征長度device=device,           # 計算設備pretrain_root=config.pretrain_gm_path,  # SD權重路徑ddim_steps=config.ddim_steps  # 擴散步數(默認250)
)
generative_model.model.load_state_dict(sd['model_state_dict'])  # 加載訓練權重

模型架構特點

  • 雙條件機制:EEG特征 + CLIP文本特征

  • 基于Latent Diffusion架構

  • 使用DDIM采樣方法

4. 圖像生成
# 生成訓練集樣本(10個實例)
grid, _ = generative_model.generate(dataset_train, num_samples=config.num_samples,ddim_steps=config.ddim_steps,HW=config.HW,  # 圖像尺寸limit=10
)# 生成測試集樣本
grid, samples = generative_model.generate(dataset_test,num_samples=config.num_samples,ddim_steps=config.ddim_steps,state=sd['state']  # 隨機狀態恢復
)

生成參數

參數含義典型值
num_samples每樣本生成數量5
ddim_steps擴散采樣步數250
HW圖像高寬[512,512]
limit最大生成樣本數10

三、關鍵技術細節

1.?條件生成流程
sequenceDiagramparticipant EEGparticipant Modelparticipant ImageEEG->>Model: 輸入EEG信號(128ch×512t)Model->>Model: 通過EEG編碼器提取特征Model->>Model: 擴散模型條件生成Model->>Image: 輸出512×512圖像

這個生成代碼很有問題啊,一直報錯,類似這樣,很多人都出現了,但目前無法解決,,,,

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/75508.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/75508.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/75508.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

用Python實現TCP代理

依舊是Python黑帽子這本書 先附上代碼,我在原書代碼上加了注釋,更好理解 import sys import socket import threading#生成可打印字符映射 HEX_FILTER.join([(len(repr(chr(i)))3) and chr(i) or . for i in range(256)])#接收bytes或string類型的輸入…

Pyinstaller 打包flask_socketio為exe程序后出現:ValueError: Invalid async_mode specified

Pyinstaller 打包flask_socketio為exe程序后出現&#xff1a;ValueError: Invalid async_mode specified 一、詳細描述問題描述 Traceback (most recent call last): File "app_3.py", line 22, in <module> File "flask_socketio\__init__.py"…

django REST framework(DRF)教程

Django DRF API Django 基本使用Django DRF序列化器Django DRF視圖Django DRF常用功能Django 基本使用 前后端分離開發模式認識RestFulAPI回顧Django開發模式Django REST Framework初探前后端分離開發模式 前后端分離前:前端頁面看到的效果都是由后端控制,即后端渲染HTML頁面…

【Linux】Orin NX + Ubuntu22.04配置國內源

1、獲取源 清華源 arm 系統的源,可以在如下地址獲取到 https://mirror.tuna.tsinghua.edu.cn/help/ubuntu-ports/ 選擇HTTPS,否則可能報錯: 明文簽署文件不可用,結果為‘NOSPLIT’(您的網絡需要認證嗎?)查看Orin NX系統版本 選擇jammy的源 2、更新源 1)備份原配…

【含文檔+PPT+源碼】基于微信小程序的社交攝影約拍平臺的設計與實現

項目介紹 本課程演示的是一款基于微信小程序的社交攝影約拍平臺的設計與實現&#xff0c;主要針對計算機相關專業的正在做畢設的學生與需要項目實戰練習的 Java 學習者。 1.包含&#xff1a;項目源碼、項目文檔、數據庫腳本、軟件工具等所有資料 2.帶你從零開始部署運行本套系…

JDBC常用的接口

一、什么是JDBC JDBC是Java語言連接數據庫的接口規范。 二、JDBC的體系 1、Java官方提供一個操作數據庫的抽象接口 抽象接口有很多的接口和抽象類。 例如&#xff1a;Driver、Connection、Statement。 2、各個數據庫廠商提供各自的Java實現類 需要各自實現具體的細節。 例如&am…

容器適配器-stack棧

C標準庫不只是包含了順序容器&#xff0c;還包含一些為滿足特殊需求而設計的容器&#xff0c;它們提供簡單的接口。 這些容器可被歸類為容器適配器(container adapter)&#xff0c;它們是改造別的標準順序容器&#xff0c;使之滿足特殊需求的新容器。 適配器:也稱配置器,把一…

[250403] HuggingFace 新增檢查模型與電腦兼容性的功能 | Firefox 發布137.0 支持標簽組

目錄 Hugging Face 讓尋找兼容的 AI 模型變得更容易Firefox 137 版本更新摘要 Hugging Face 讓尋找兼容的 AI 模型變得更容易 Hugging Face 是一個流行的在線平臺&#xff0c;用于訪問開源人工智能 (AI) 工具和模型。該平臺推出了一項有用的新功能&#xff0c;允許個人輕松檢查…

.NET 創建MCP使用大模型對話二:調用遠程MCP服務

在上一篇文章.NET 創建MCP使用大模型對話-CSDN博客中&#xff0c;我們簡述了如何使用mcp client使用StdIo模式調用本地mcp server。本次實例將會展示如何使用mcp client模式調用遠程mcp server。 一&#xff1a;創建mcp server 我們創建一個天氣服務。 新建WebApi項目&#x…

Redis 中 Set(例如標簽) 和 ZSet(例如排行榜) 的詳細對比,涵蓋定義、特性、命令、適用場景及總結表格

以下是 Redis 中 Set 和 ZSet 的詳細對比&#xff0c;涵蓋定義、特性、命令、適用場景及總結表格&#xff1a; 1. 核心定義 數據類型SetZSet&#xff08;Sorted Set&#xff09;定義無序的、唯一的字符串集合&#xff0c;元素不重復。有序的、唯一的字符串集合&#xff0c;每個…

解決Spring參數解析異常:Name for argument of type XXX not specified

前言 在開發 Spring Boot 應用時&#xff0c;我們常遇到類似 java.lang.IllegalArgumentException: Name for argument not specified 的報錯。這類問題通常與方法參數名稱的解析機制相關&#xff0c;尤其在使用 RequestParam、PathVariable 等注解時更為常見。 一、問題現象與…

剛剛,OpenAI開源PaperBench,重塑頂級AI Agent評測

今天凌晨1點&#xff0c;OpenAI開源了一個全新的AI Agent評測基準——PaperBench。 這個基準主要考核智能體的搜索、整合、執行等能力&#xff0c;需要對2024年國際機器學習大會上頂尖論文的復現&#xff0c;包括對論文內容的理解、代碼編寫以及實驗執行等方面的能力。 根據O…

Golang封裝Consul 服務發現庫

以下是一個經過生產驗證的 Consul 服務發現封裝庫,支持注冊/注銷、健康檢查、智能發現等核心功能,可直接集成到項目中: package consulimport ("context""fmt""log""math/rand""net""os""sync"&quo…

自適應信號處理任務(過濾,預測,重建,分類)

自適應濾波 # signals creation: u, v, d N = 5000 n = 10 u = np.sin(np.arange(0, N/10., N/50000

PyTorch深度學習框架 的基礎知識

目錄 1.pyTorch檢查是否安裝成功 2.PyTorch的張量tensor 基礎創建方式&#xff08;三種&#xff09; 2.2用列表創建tensor 2.2使用元組創建 tensor 2.3使用ndarray創建創建 tensor 2.4 快速創建tensor的常用方法 3.pyTorch中的張量tensor的常用屬性 4. tensor中的基礎數據…

MySQL學習集--DDL

DDL 數據庫操作 查詢所有數據庫 SHOW DATABASES;查詢當前數據庫 SELECT DATABASE();創建 CREATE DATABASE[IF NOT EXISTS]數據庫名[DEFAULT CHARSET 字符集][COLLATE 排序規則];刪除 DROR DATABASE[IF EXISTS]數據庫名;使用 USE 數據庫名;表操作 創建表格 CREATE TABL…

Vue 3 中按照某個字段將數組分成多個數組

方法一&#xff1a;使用 reduce 方法 const originalArray [{ id: 1, category: A, name: Item 1 },{ id: 2, category: B, name: Item 2 },{ id: 3, category: A, name: Item 3 },{ id: 4, category: C, name: Item 4 },{ id: 5, category: B, name: Item 5 }, ];const grou…

LeetCode刷題 -- 48. 旋轉圖像

題目 算法題解&#xff1a;順時針旋轉矩陣&#xff08;90度&#xff09; 1. 算法描述 給定一個 n n 的二維矩陣&#xff0c;請將矩陣順時針旋轉 90 度。 例如&#xff1a; 輸入&#xff1a; [[1,2,3],[4,5,6],[7,8,9] ]輸出&#xff1a; [[7,4,1],[8,5,2],[9,6,3] ]2. 思…

Vulkan進階系列1 - Vulkan應用程序結構(完整代碼)

一: 概述 在前面的20多篇文章中,我們了解了Vulkan的基礎知識,和相關API的使用,接下來我們要從零開始寫一套完整Vulkan應用程序,在這個過程中加深對Vulkan中的各種概念的理解。 Vulkan 應用程序一般遵循 初始化 -> 運行循環 -> 資源清理 的結構,本實例也基本遵循了…

VTK的兩種顯示刷新方式

在類中先聲明vtk的顯示對象 vtkRenderer out_render; vtkVertexGlyphFilter glyphFilter; vtkPolyDataMapper mapper; // 新建制圖器 vtkActor actor; // 新建角色 然后在init中先初始化一下&#xff1a; out_rend…