多模態分類案例實現

以下是基于飛槳平臺實現的多模態分類詳細案例,結合圖像和文本信息進行分類任務。案例包含數據處理、模型構建、訓練和評估的完整流程,并提供詳細注釋:

一、多模態分類案例實現

import os
import json
import numpy as np
from PIL import Image
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset, DataLoader
from paddle.vision import models
import paddlenlp as ppnlp
from paddlenlp.transformers import ErnieTokenizer, ErnieModel# 設置隨機種子,確保結果可復現
paddle.seed(42)
np.random.seed(42)# ---------------------- 1. 數據集定義 ----------------------
class MultiModalDataset(Dataset):"""多模態圖像-文本分類數據集"""def __init__(self, data_path, image_dir, tokenizer, max_seq_len=128, mode='train'):"""data_path: 標注文件路徑image_dir: 圖像文件夾路徑tokenizer: 文本tokenizermax_seq_len: 文本最大長度mode: 模式,train/val/test"""super().__init__()self.image_dir = image_dirself.tokenizer = tokenizerself.max_seq_len = max_seq_lenself.mode = mode# 加載數據集with open(data_path, 'r', encoding='utf-8') as f:self.data = json.load(f)# 定義類別到ID的映射(根據數據集調整)self.label2id = {'科技': 0, '娛樂': 1, '體育': 2, '財經': 3, '教育': 4}self.id2label = {v: k for k, v in self.label2id.items()}def __len__(self):return len(self.data)def __getitem__(self, idx):# 獲取單條數據item = self.data[idx]image_path = os.path.join(self.image_dir, item['image'])text = item['text']label = self.label2id[item['label']]# 處理圖像image = Image.open(image_path).convert('RGB')image = self._preprocess_image(image)# 處理文本encoded_inputs = self.tokenizer(text=text,max_seq_len=self.max_seq_len,pad_to_max_seq_len=True,return_attention_mask=True,return_token_type_ids=True)# 轉換為Tensorinput_ids = paddle.to_tensor(encoded_inputs['input_ids'], dtype='int64')attention_mask = paddle.to_tensor(encoded_inputs['attention_mask'], dtype='int64')token_type_ids = paddle.to_tensor(encoded_inputs['token_type_ids'], dtype='int64')label = paddle.to_tensor(label, dtype='int64')return {'image': image,'input_ids': input_ids,'attention_mask': attention_mask,'token_type_ids': token_type_ids,'label': label}def _preprocess_image(self, image):"""圖像預處理:縮放、歸一化、轉Tensor"""# 調整圖像大小為224x224image = image.resize((224, 224), Image.BICUBIC)# 轉換為numpy數組image = np.array(image).astype('float32')# 歸一化image = image / 255.0# 標準化(ImageNet均值和標準差)image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])# 調整通道順序 (HWC -> CHW)image = np.transpose(image, (2, 0, 1))return paddle.to_tensor(image, dtype='float32')# ---------------------- 2. 多模態分類模型 ----------------------
class MultiModalClassifier(nn.Layer):"""基于圖像和文本的多模態分類模型"""def __init__(self, num_classes, text_encoder='ernie-1.0', pretrained=True):super().__init__()# 圖像編碼器(使用預訓練ResNet50)self.image_encoder = models.resnet50(pretrained=pretrained)# 移除最后的全連接層self.image_encoder.fc = nn.Identity()# 添加投影層,將圖像特征映射到共同空間self.image_proj = nn.Linear(2048, 512)# 文本編碼器(使用預訓練ERNIE)self.text_encoder = ErnieModel.from_pretrained(text_encoder)# 添加投影層,將文本特征映射到共同空間self.text_proj = nn.Linear(768, 512)# 特征融合層self.fusion = nn.Sequential(nn.Linear(1024, 512),  # 拼接圖像和文本特征 (512+512)nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(),nn.Dropout(0.5))# 分類器self.classifier = nn.Linear(256, num_classes)def forward(self, image, input_ids, attention_mask, token_type_ids=None):# 提取圖像特征image_features = self.image_encoder(image)  # [batch_size, 2048]image_features = self.image_proj(image_features)  # [batch_size, 512]# 提取文本特征text_outputs = self.text_encoder(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)# 獲取[CLS] token的表示text_features = text_outputs[1]  # [batch_size, 768]text_features = self.text_proj(text_features)  # [batch_size, 512]# 特征融合fused_features = paddle.concat([image_features, text_features], axis=1)  # [batch_size, 1024]fused_features = self.fusion(fused_features)  # [batch_size, 256]# 分類預測logits = self.classifier(fused_features)  # [batch_size, num_classes]return logits# ---------------------- 3. 模型訓練 ----------------------
def train_model(model, train_loader, val_loader, optimizer, criterion, epochs, save_dir):"""訓練多模態分類模型"""best_acc = 0.0for epoch in range(epochs):# 訓練模式model.train()train_loss = 0.0correct = 0total = 0for batch in train_loader:# 獲取數據image = batch['image']input_ids = batch['input_ids']attention_mask = batch['attention_mask']token_type_ids = batch['token_type_ids']label = batch['label']# 前向傳播logits = model(image, input_ids, attention_mask, token_type_ids)loss = criterion(logits, label)# 反向傳播loss.backward()optimizer.step()optimizer.clear_grad()# 統計訓練指標train_loss += loss.numpy()[0]total += label.shape[0]pred = paddle.argmax(logits, axis=1)correct += (pred == label).sum().numpy()[0]# 計算訓練準確率train_acc = correct / totalprint(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}')# 驗證val_acc = evaluate_model(model, val_loader)print(f'Epoch [{epoch+1}/{epochs}], Val Acc: {val_acc:.4f}')# 保存最佳模型if val_acc > best_acc:best_acc = val_accpaddle.save(model.state_dict(), os.path.join(save_dir, 'best_model.pdparams'))print(f'Model saved at acc: {best_acc:.4f}')# ---------------------- 4. 模型評估 ----------------------
def evaluate_model(model, data_loader):"""評估模型性能"""model.eval()correct = 0total = 0with paddle.no_grad():for batch in data_loader:# 獲取數據image = batch['image']input_ids = batch['input_ids']attention_mask = batch['attention_mask']token_type_ids = batch['token_type_ids']label = batch['label']# 模型預測logits = model(image, input_ids, attention_mask, token_type_ids)pred = paddle.argmax(logits, axis=1)# 統計準確率total += label.shape[0]correct += (pred == label).sum().numpy()[0]return correct / total# ---------------------- 5. 主函數 ----------------------
def main():# 配置參數config = {'train_data_path': 'data/train.json',  # 訓練數據路徑'val_data_path': 'data/val.json',      # 驗證數據路徑'image_dir': 'data/images',            # 圖像文件夾路徑'save_dir': 'checkpoints',             # 模型保存路徑'num_classes': 5,                      # 分類類別數'batch_size': 16,                      # 批次大小'epochs': 10,                          # 訓練輪數'learning_rate': 1e-4,                 # 學習率'max_seq_len': 128                     # 文本最大長度}# 創建保存目錄os.makedirs(config['save_dir'], exist_ok=True)# 初始化tokenizertokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')# 創建數據集train_dataset = MultiModalDataset(config['train_data_path'], config['image_dir'], tokenizer, config['max_seq_len'],mode='train')val_dataset = MultiModalDataset(config['val_data_path'], config['image_dir'], tokenizer, config['max_seq_len'],mode='val')# 創建數據加載器train_loader = DataLoader(train_dataset,batch_size=config['batch_size'],shuffle=True,num_workers=4)val_loader = DataLoader(val_dataset,batch_size=config['batch_size'],shuffle=False,num_workers=4)# 初始化模型model = MultiModalClassifier(config['num_classes'])# 定義損失函數和優化器criterion = nn.CrossEntropyLoss()optimizer = paddle.optimizer.AdamW(learning_rate=config['learning_rate'],parameters=model.parameters())# 訓練模型train_model(model, train_loader, val_loader, optimizer, criterion, config['epochs'], config['save_dir'])# 加載最佳模型并評估model.set_state_dict(paddle.load(os.path.join(config['save_dir'], 'best_model.pdparams')))test_acc = evaluate_model(model, val_loader)print(f'Final Test Accuracy: {test_acc:.4f}')if __name__ == '__main__':main()

二、數據集格式說明

數據集采用JSON格式,每條數據包含圖像路徑、文本描述和類別標簽:

[{"image": "image_001.jpg","text": "這款新手機的相機功能非常出色,拍照效果堪比專業相機","label": "科技"},{"image": "image_002.jpg","text": "這支足球隊在本賽季表現出色,有望奪得冠軍","label": "體育"},...
]

三、模型架構解析

  1. 圖像編碼器:使用預訓練的ResNet50提取圖像特征,最后通過全連接層投影到512維空間。
  2. 文本編碼器:使用預訓練的ERNIE模型提取文本特征,取[CLS]標記表示,再通過全連接層投影到512維空間。
  3. 特征融合:將圖像和文本特征拼接后,通過多層感知機進行融合和降維。
  4. 分類器:基于融合特征進行多分類預測。

四、訓練和評估流程

  1. 數據加載:使用自定義數據集類加載圖像和文本數據,并進行預處理。
  2. 模型訓練:采用交叉熵損失函數和AdamW優化器,訓練10個epoch。
  3. 模型評估:在驗證集上計算分類準確率,并保存性能最佳的模型。

五、擴展建議

  1. 特征融合改進:嘗試更復雜的融合方法,如注意力機制、雙線性池化等。
  2. 數據增強:對圖像進行隨機裁剪、翻轉等增強,對文本進行同義詞替換、插入等操作。
  3. 模型調優:調整學習率、批次大小、dropout率等超參數。
  4. 多模態權重平衡:為圖像和文本分支設計可學習的權重,自適應調整各模態的重要性。

這個案例展示了如何結合圖像和文本信息進行多模態分類,您可以根據實際需求調整模型架構和數據集。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/84348.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/84348.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/84348.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Express框架:Node.js的輕量級Web應用利器

Hi,我是布蘭妮甜 !在當今快速發展的Web開發領域,Node.js已成為構建高性能、可擴展網絡應用的重要基石。而在這片肥沃的生態系統中,Express框架猶如一座經久不衰的燈塔,指引著無數開發者高效構建Web應用的方向。本文章在為讀者提供一份全面而深入的Express框架指南。無論您…

K-Means顏色變卦和漸變色

一、理論深度提升:補充算法細節與數學基礎 1. K-Means 算法核心公式(增強專業性) 在 “原理步驟” 中加入數學表達式,說明聚類目標: K-Means 的目標是最小化簇內平方和(Within-Cluster Sum of Squares, W…

深入解析C#表達式求值:優先級、結合性與括號的魔法

—— 為什么2/6*4不等于1/12? 🔍 一、表達式求值順序為何重要? 表達式如精密儀器,子表達式求值順序直接決定結果。例如: int result 3 * 5 2;若先算乘法:(3*5)2 17 ?若先算加法:3*(52)21…

Docker 離線安裝指南

參考文章 1、確認操作系統類型及內核版本 Docker依賴于Linux內核的一些特性,不同版本的Docker對內核版本有不同要求。例如,Docker 17.06及之后的版本通常需要Linux內核3.10及以上版本,Docker17.09及更高版本對應Linux內核4.9.x及更高版本。…

Spring——Spring相關類原理與實戰

摘要 本文深入探討了 Spring 框架中 InitializingBean 接口的原理與實戰應用,該接口是 Spring 提供的一個生命周期接口,用于在 Bean 屬性注入完成后執行初始化邏輯。文章詳細介紹了接口定義、作用、典型使用場景,并與其他相關概念如 PostCon…

Angular微前端架構:Module Federation + ngx-build-plus (Webpack)

以下是一個完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 實現了主應用(Shell)與子應用(Remote)的集成。 🛠? 項目結構 angular-mf/ ├── shell-app/ # 主應用&…

ESP32 I2S音頻總線學習筆記(四): INMP441采集音頻并實時播放

簡介 前面兩期文章我們介紹了I2S的讀取和寫入,一個是通過INMP441麥克風模塊采集音頻,一個是通過PCM5102A模塊播放音頻,那如果我們將兩者結合起來,將麥克風采集到的音頻通過PCM5102A播放,是不是就可以做一個擴音器了呢…

馮諾依曼架構是什么?

馮諾依曼架構是什么? 馮諾依曼架構(Von Neumann Architecture)是現代計算機的基礎設計框架,由數學家約翰馮諾依曼(John von Neumann)及其團隊在1945年提出。其核心思想是通過統一存儲程序與數據&#xff0…

【持續更新】linux網絡編程試題

問題1 請簡要說明TCP/IP協議棧的四層結構,并分別舉出每一層出現的典型協議或應用。 答案 應用層:ping,telnet,dns 傳輸層:tcp,udp 網絡層:ip,icmp 數據鏈路層:arp,rarp 問題2 下列協議或應用分別屬于TCP/IP協議…

橢圓曲線密碼學(ECC)

一、ECC算法概述 橢圓曲線密碼學(Elliptic Curve Cryptography)是基于橢圓曲線數學理論的公鑰密碼系統,由Neal Koblitz和Victor Miller在1985年獨立提出。相比RSA,ECC在相同安全強度下密鑰更短(256位ECC ≈ 3072位RSA…

【JVM】- 內存結構

引言 JVM:Java Virtual Machine 定義:Java虛擬機,Java二進制字節碼的運行環境好處: 一次編寫,到處運行自動內存管理,垃圾回收的功能數組下標越界檢查(會拋異常,不會覆蓋到其他代碼…

React 基礎入門筆記

一、JSX語法規則 1. 定義虛擬DOM時,不要寫引號 2.標簽中混入JS表達式時要用 {} (1).JS表達式與JS語句(代碼)的區別 (2).使用案例 3.樣式的類名指定不要用class,要用className 4.內…

Linux鏈表操作全解析

Linux C語言鏈表深度解析與實戰技巧 一、鏈表基礎概念與內核鏈表優勢1.1 為什么使用鏈表?1.2 Linux 內核鏈表與用戶態鏈表的區別 二、內核鏈表結構與宏解析常用宏/函數 三、內核鏈表的優點四、用戶態鏈表示例五、雙向循環鏈表在內核中的實現優勢5.1 插入效率5.2 安全…

SQL進階之旅 Day 19:統計信息與優化器提示

【SQL進階之旅 Day 19】統計信息與優化器提示 文章簡述 在數據庫性能調優中,統計信息和優化器提示是兩個至關重要的工具。統計信息幫助數據庫優化器評估查詢成本并選擇最佳執行計劃,而優化器提示則允許開發人員對優化器的行為進行微調。本文深入探討了…

安寶特方案丨船舶智造AR+AI+作業標準化管理系統解決方案(維保)

船舶維保管理現狀:設備維保主要由維修人員負責,根據設備運行狀況和維護計劃進行定期保養和故障維修。維修人員憑借經驗判斷設備故障原因,制定維修方案。 一、痛點與需求 1 Arbigtec 人工經驗限制維修效率: 復雜設備故障的診斷和…

MFC內存泄露

1、泄露代碼示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 獲取 Ribbon Bar 指針// 創建自定義按鈕CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…

基于區塊鏈的供應鏈溯源系統:構建與實踐

前言 在當今全球化的經濟環境中,供應鏈的復雜性不斷增加,商品從原材料采購到最終交付給消費者的過程涉及多個環節和眾多參與者。如何確保供應鏈的透明度、可追溯性和安全性,成為企業和消費者關注的焦點。區塊鏈技術以其去中心化、不可篡改和透…

Web攻防-SQL注入數據格式參數類型JSONXML編碼加密符號閉合

知識點: 1、Web攻防-SQL注入-參數類型&參數格式 2、Web攻防-SQL注入-XML&JSON&BASE64等 3、Web攻防-SQL注入-數字字符搜索等符號繞過 案例說明: 在應用中,存在參數值為數字,字符時,符號的介入&#xff0c…

探秘鴻蒙 HarmonyOS NEXT:實戰用 CodeGenie 構建鴻蒙應用頁面

在開發鴻蒙應用時,你是否也曾為一個頁面的布局反復調整?是否還在為查 API、寫模板代碼而浪費大量時間?今天帶大家實戰體驗一下鴻蒙官方的 AI 編程助手——CodeGenie(代碼精靈) ,如何從 0 到 1 快速構建一個…

DBAPI如何優雅的獲取單條數據

API如何優雅的獲取單條數據 案例一 對于查詢類API,查詢的是單條數據,比如根據主鍵ID查詢用戶信息,sql如下: select id, name, age from user where id #{id}API默認返回的數據格式是多條的,如下: {&qu…