簡介:通過特征向量從文本生成圖像
回顧:多模態 Transformer
在使用Python從零實現一個端到端多模態 Transformer大模型中,我們調整了字符級 Transformer 以處理圖像(通過 ResNet 特征)和文本提示,用于視覺描述等任務。我們將狀態保存到 saved_models/multimodal_model.pt
中,包括文本處理組件、分詞器和視覺特征投影層。
文本到圖像的挑戰(簡化版)
直接從文本生成逼真圖像是一項復雜的任務,通常涉及生成對抗網絡(GANs)或擴散模型等專門架構。為了保持在 Transformer 框架內,我們將處理一個簡化版本:
- 目標:給定文本提示(例如“一個藍色方塊”),生成表示描述圖像的圖像特征向量。
- 為什么使用特征向量? 使用基本 Transformer 自回歸生成原始像素既困難又計算昂貴。生成固定大小的特征向量(如 ResNet 中的特征向量)是一個更易管理的中間步驟。
- 圖像重建:生成特征向量后,我們將使用簡單的最近鄰方法在已知的訓練圖像上可視化結果。模型預測特征向量,我們找到訓練圖像(紅色方塊、藍色方塊、綠色圓圈)中特征向量最相似的圖像并顯示它。
我們的方法:用于特征預測的 Transformer
- 加載組件:我們從
multimodal_model.pt
中加載文本 Transformer 組件(嵌入、位置編碼、注意力/FFN 塊、最終層歸一化)和分詞器。我們還加載凍結的 ResNet-18 特征提取器以在訓練期間獲取目標特征。 - 調整架構:輸入現在僅為文本。我們將用一個新的線性層替換最終輸出層(之前預測文本標記),該線性層將 Transformer 的最終隱藏狀態映射到 ResNet 圖像特征的維度(例如 512)。
- 訓練數據:我們需要(文本提示,目標圖像)對。我們將使用之前創建的簡單圖像的描述性提示。
- 訓練過程:模型讀取提示,通過 Transformer 塊處理它,并使用新的輸出層預測圖像特征向量。損失(MSE)將此預測向量與目標圖像的實際特征向量進行比較。
- 推理(生成):輸入文本提示,從模型中獲取預測的特征向量,找到最接近的已知圖像特征,并顯示相應的圖像。
內聯實現風格
我們繼續以理論為基礎的極其詳細、逐步的內聯實現,避免使用函數和類。
步驟 0:設置 - 庫、加載、數據、特征提取
目標:準備環境,從之前的多模態模型中加載相關組件,定義文本-圖像對數據,并提取目標圖像特征。
步驟 0.1:導入庫
理論:導入必要的庫。我們需要 torch
、torchvision
、PIL
、math
、os
、numpy
。我們稍后還需要 scipy.spatial.distance
或 torch.nn.functional.cosine_similarity
來在生成過程中找到最接近的特征向量。
# 導入必要的庫
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import math
import os
import numpy as np
# 用于后續尋找最近向量
from scipy.spatial import distance as scipy_distance # 為了可重復性
torch.manual_seed(123)
np.random.seed(123)print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print("Libraries imported.")# --- 設備配置 ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
步驟 0.2:從多模態模型加載相關狀態
理論:從 multimodal_model.pt
加載狀態字典。我們需要配置參數、分詞器、文本 Transformer 組件(嵌入層、注意力塊等)以及 ResNet 特征提取器(保持凍結狀態)。
# --- 加載保存的多模態模型狀態 ---
print("\nStep 0.2: Loading state from multi-modal model...")
model_load_path = 'saved_models/multimodal_model.pt'
if not os.path.exists(model_load_path):raise FileNotFoundError(f"Error: Model file not found at {model_load_path}. Please ensure 'multimodal.ipynb' was run and saved the model.")loaded_state_dict = torch.load(model_load_path, map_location=device)
print(f"Loaded state dictionary from '{model_load_path}'.")# --- Extract Config and Tokenizer ---
config = loaded_state_dict['config']
vocab_size = config['vocab_size'] # Includes special tokens
d_model = config['d_model']
n_heads = config['n_heads']
n_layers = config['n_layers']
d_ff = config['d_ff']
# 使用加載配置中的block_size,后續可能根據文本提示長度需求調整
block_size = config['block_size']
vision_feature_dim = config['vision_feature_dim'] # ResNet特征維度(例如512)
d_k = d_model // n_headschar_to_int = loaded_state_dict['tokenizer']['char_to_int']
int_to_char = loaded_state_dict['tokenizer']['int_to_char']
pad_token_id = char_to_int.get('<PAD>', -1) # Get PAD token ID
if pad_token_id == -1:print("Warning: PAD token not found in loaded tokenizer!")print("Extracted model configuration and tokenizer:")
print(f" vocab_size: {vocab_size}")
print(f" d_model: {d_model}")
print(f" n_layers: {n_layers}")
print(f" n_heads: {n_heads}")
print(f" d_ff: {d_ff}")
print(f" block_size: {block_size}")
print(f" vision_feature_dim: {vision_feature_dim}")
print(f" PAD token ID: {pad_token_id}")# --- Load Positional Encoding ---
positional_encoding = loaded_state_dict['positional_encoding'].to(device)
# 驗證塊大小一致性
if positional_encoding.shape[1] != block_size:print(f"Warning: Loaded PE size ({positional_encoding.shape[1]}) doesn't match loaded block_size ({block_size}). Using loaded PE size.")# block_size = positional_encoding.shape[1] # 選項1: 使用PE的大小# 選項2: 如有必要重新計算PE(像之前一樣),但先嘗試切片/填充
print(f"Loaded positional encoding with shape: {positional_encoding.shape}")# --- 加載文本Transformer組件(僅加載權重,后續創建結構) ---
loaded_embedding_dict = loaded_state_dict['token_embedding_table']
loaded_ln1_dicts = loaded_state_dict['layer_norms_1']
loaded_qkv_dicts = loaded_state_dict['mha_qkv_linears']
loaded_mha_out_dicts = loaded_state_dict['mha_output_linears']
loaded_ln2_dicts = loaded_state_dict['layer_norms_2']
loaded_ffn1_dicts = loaded_state_dict['ffn_linear_1']
loaded_ffn2_dicts = loaded_state_dict['ffn_linear_2']
loaded_final_ln_dict = loaded_state_dict['final_layer_norm']
print("Stored state dicts for text transformer components.")# --- 加載視覺特征提取器(ResNet) ---
# 重新加載之前使用的ResNet-18模型并保持凍結狀態
print("Loading pre-trained vision model (ResNet-18) for target feature extraction...")
vision_model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
vision_model.fc = nn.Identity() # 移除分類器
vision_model = vision_model.to(device)
vision_model.eval() # 保持評估模式
# 凍結ResNet參數 - 非常重要
for param in vision_model.parameters():param.requires_grad = False
print(f"Loaded and froze ResNet-18 feature extractor on device: {device}")# --- Define Image Transformations (same as before) ---
image_transforms = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("Defined image transformations.")
步驟 0.3:定義示例文本-圖像數據
理論:創建(描述性文本提示,目標圖像路徑)對。使用之前 notebook 生成的相同簡單圖像。
print("\nStep 0.3: Defining sample text-to-image data...")# --- Image Paths (Assuming they exist from multimodal.ipynb) ---
sample_data_dir = "sample_multimodal_data"
image_paths = {"red_square": os.path.join(sample_data_dir, "red_square.png"),"blue_square": os.path.join(sample_data_dir, "blue_square.png"),"green_circle": os.path.join(sample_data_dir, "green_circle.png")
}
# Verify paths exist
for key, path in image_paths.items():if not os.path.exists(path):print(f"Warning: Image file not found at {path}. Attempting to recreate.")if key == "red_square":img_ = Image.new('RGB', (64, 64), color = 'red')img_.save(path)elif key == "blue_square":img_ = Image.new('RGB', (64, 64), color = 'blue')img_.save(path)elif key == "green_circle":img_ = Image.new('RGB', (64, 64), color = 'white')draw = ImageDraw.Draw(img_)draw.ellipse((4, 4, 60, 60), fill='green', outline='green')img_.save(path)else:print(f"Error: Cannot recreate unknown image key '{key}'.")# --- Define Text Prompt -> Image Path Pairs ---
text_to_image_data = [{"prompt": "a red square", "image_path": image_paths["red_square"]},{"prompt": "the square is red", "image_path": image_paths["red_square"]},{"prompt": "show a blue square", "image_path": image_paths["blue_square"]},{"prompt": "blue shape, square", "image_path": image_paths["blue_square"]},{"prompt": "a green circle", "image_path": image_paths["green_circle"]},{"prompt": "the circle, it is green", "image_path": image_paths["green_circle"]},# Add maybe one more variation{"prompt": "make a square that is red", "image_path": image_paths["red_square"]}
]num_samples = len(text_to_image_data)
print(f"Defined {num_samples} sample text-to-image data points.")
# print(f"Sample 0: {text_to_image_data[0]}")
步驟 0.4:提取目標圖像特征
理論:使用凍結的 ResNet 預計算數據集中所有唯一圖像的目標特征向量。存儲這些特征向量,建立圖像路徑到特征張量的映射,并另存為(path, feature)列表供生成階段最近鄰查找使用。
print("\nStep 0.4: Extracting target image features...")
target_image_features = {} # Dict: {image_path: feature_tensor}
known_features_list = [] # List: [(path, feature_tensor)] for generation lookup# --- Loop Through Unique Image Paths in this dataset ---
unique_image_paths_in_data = sorted(list(set(d["image_path"] for d in text_to_image_data)))
print(f"Found {len(unique_image_paths_in_data)} unique target images to process.")for img_path in unique_image_paths_in_data:# Avoid re-extracting if already done (e.g., if loading from multimodal)# if img_path in extracted_image_features: # Check previous notebook's dict# feature_vector_squeezed = extracted_image_features[img_path]# print(f" Using pre-extracted features for '{os.path.basename(img_path)}'")# else: # --- Load Image --- try:img = Image.open(img_path).convert('RGB')except FileNotFoundError:print(f"Error: Image file not found at {img_path}. Skipping.")continue# --- Apply Transformations ---img_tensor = image_transforms(img).unsqueeze(0).to(device) # (1, 3, 224, 224)# --- Extract Features (using frozen vision_model) ---with torch.no_grad():feature_vector = vision_model(img_tensor) # (1, vision_feature_dim)feature_vector_squeezed = feature_vector.squeeze(0) # (vision_feature_dim,)print(f" Extracted features for '{os.path.basename(img_path)}', shape: {feature_vector_squeezed.shape}")# --- Store Features ---target_image_features[img_path] = feature_vector_squeezedknown_features_list.append((img_path, feature_vector_squeezed))if not target_image_features:raise ValueError("No target image features were extracted. Cannot proceed.")print("Finished extracting and storing target image features.")
print(f"Stored {len(known_features_list)} known (path, feature) pairs for generation lookup.")
步驟 0.5:定義訓練超參數
理論:為此特定任務調整超參數。可能需要不同的學習率或訓練輪數。此處的 block_size
主要與預期的最大提示長度相關。
print("\nStep 0.5: Defining training hyperparameters for text-to-image...")# Use block_size from loaded config, ensure it's adequate for prompts
max_prompt_len = max(len(d["prompt"]) for d in text_to_image_data)
if block_size < max_prompt_len + 1: # +1 for potential special tokens like <EOS> if used in promptprint(f"Warning: Loaded block_size ({block_size}) might be small for max prompt length ({max_prompt_len}). Consider increasing block_size if issues arise.")# Adjust block_size if needed, and recompute PE / causal mask# block_size = max_prompt_len + 5 # Example adjustment# print(f"Adjusted block_size to {block_size}")# Need to recompute PE and causal_mask if block_size changes# Recreate causal mask just in case block_size was adjusted, or use loaded size
causal_mask = torch.tril(torch.ones(block_size, block_size, device=device)).view(1, 1, block_size, block_size)
print(f"Using block_size: {block_size}")learning_rate = 1e-4 # Potentially lower LR for fine-tuning
batch_size = 4 # Keep small due to limited data
epochs = 5000 # Number of training iterations
eval_interval = 500print(f" Training Params: LR={learning_rate}, BatchSize={batch_size}, Epochs={epochs}")
步驟 1:模型適配與初始化
目標:使用加載的權重重建文本 Transformer 組件,并初始化新的輸出投影層。
步驟 1.1:初始化文本 Transformer 組件
理論:創建嵌入層、LayerNorms 和 Transformer 塊的線性層實例,然后從步驟 0.2 存儲的字典中加載預訓練權重。
print("\nStep 1.1: Initializing Text Transformer components and loading weights...")# --- Token Embedding Table ---
token_embedding_table = nn.Embedding(vocab_size, d_model).to(device)
token_embedding_table.load_state_dict(loaded_embedding_dict)
print(f" Loaded Token Embedding Table, shape: {token_embedding_table.weight.shape}")# --- Transformer Blocks Components ---
layer_norms_1 = []
mha_qkv_linears = []
mha_output_linears = []
layer_norms_2 = []
ffn_linear_1 = []
ffn_linear_2 = []for i in range(n_layers):# LayerNorm 1ln1 = nn.LayerNorm(d_model).to(device)ln1.load_state_dict(loaded_ln1_dicts[i])layer_norms_1.append(ln1)# MHA QKV Linear (Check bias presence)qkv_dict = loaded_qkv_dicts[i]has_bias = 'bias' in qkv_dictqkv_linear = nn.Linear(d_model, 3 * d_model, bias=has_bias).to(device)qkv_linear.load_state_dict(qkv_dict)mha_qkv_linears.append(qkv_linear)# MHA Output Linear (Check bias presence)mha_out_dict = loaded_mha_out_dicts[i]has_bias = 'bias' in mha_out_dictoutput_linear_mha = nn.Linear(d_model, d_model, bias=has_bias).to(device)output_linear_mha.load_state_dict(mha_out_dict)mha_output_linears.append(output_linear_mha)# LayerNorm 2ln2 = nn.LayerNorm(d_model).to(device)ln2.load_state_dict(loaded_ln2_dicts[i])layer_norms_2.append(ln2)# FFN Linear 1 (Check bias presence)ffn1_dict = loaded_ffn1_dicts[i]has_bias = 'bias' in ffn1_dictlin1 = nn.Linear(d_model, d_ff, bias=has_bias).to(device)lin1.load_state_dict(ffn1_dict)ffn_linear_1.append(lin1)# FFN Linear 2 (Check bias presence)ffn2_dict = loaded_ffn2_dicts[i]has_bias = 'bias' in ffn2_dictlin2 = nn.Linear(d_ff, d_model, bias=has_bias).to(device)lin2.load_state_dict(ffn2_dict)ffn_linear_2.append(lin2)print(f" Loaded components for {n_layers} Transformer Layers.")# --- Final LayerNorm ---
final_layer_norm = nn.LayerNorm(d_model).to(device)
final_layer_norm.load_state_dict(loaded_final_ln_dict)
print(" Loaded Final LayerNorm.")print("Finished initializing and loading weights for text transformer components.")
步驟 1.2:初始化新的輸出投影層
理論:創建一個新的線性層,將 Transformer 的最終隱藏狀態(d_model
)映射到圖像特征向量的維度(vision_feature_dim
)。該層的權重將隨機初始化并進行訓練。
print("\nStep 1.2: Initializing new output projection layer (Text -> Image Feature)...")text_to_image_feature_layer = nn.Linear(d_model, vision_feature_dim).to(device)print(f" Initialized Text-to-Image-Feature Output Layer: {d_model} -> {vision_feature_dim}. Device: {device}")
步驟 2:文本到圖像訓練的數據準備
目標:對文本提示進行分詞和填充,并將其與對應的目標圖像特征向量配對。
步驟 2.1:分詞和填充提示
理論:使用 char_to_int
將文本提示轉換為 token ID 序列。使用 pad_token_id
將每個序列填充到 block_size
長度。同時創建對應的注意力掩碼(1 表示真實 token,0 表示填充部分)。
print("\nStep 2.1: Tokenizing and padding text prompts...")prepared_prompts = []
target_features_ordered = [] # Store target features in the same orderfor sample in text_to_image_data:prompt = sample["prompt"]image_path = sample["image_path"]# --- Tokenize Prompt --- prompt_ids_no_pad = [char_to_int[ch] for ch in prompt]# --- Padding --- current_len = len(prompt_ids_no_pad)pad_len = block_size - current_lenif pad_len < 0:print(f"Warning: Prompt length ({current_len}) exceeds block_size ({block_size}). Truncating prompt.")prompt_ids = prompt_ids_no_pad[:block_size]pad_len = 0current_len = block_sizeelse:prompt_ids = prompt_ids_no_pad + ([pad_token_id] * pad_len)# --- Create Attention Mask --- attention_mask = ([1] * current_len) + ([0] * pad_len)# --- Store Prompt Data --- prepared_prompts.append({"input_ids": torch.tensor(prompt_ids, dtype=torch.long),"attention_mask": torch.tensor(attention_mask, dtype=torch.long)})# --- Store Corresponding Target Feature --- if image_path in target_image_features:target_features_ordered.append(target_image_features[image_path])else:print(f"Error: Target feature not found for {image_path}. Data mismatch?")# Handle error - maybe skip this sample or raise exceptiontarget_features_ordered.append(torch.zeros(vision_feature_dim, device=device)) # Placeholder# --- Stack into Tensors ---
all_prompt_input_ids = torch.stack([p['input_ids'] for p in prepared_prompts])
all_prompt_attention_masks = torch.stack([p['attention_mask'] for p in prepared_prompts])
all_target_features = torch.stack(target_features_ordered)num_sequences_available = all_prompt_input_ids.shape[0]
print(f"Created {num_sequences_available} padded prompt sequences and gathered target features.")
print(f" Prompt Input IDs shape: {all_prompt_input_ids.shape}") # (num_samples, block_size)
print(f" Prompt Attention Mask shape: {all_prompt_attention_masks.shape}") # (num_samples, block_size)
print(f" Target Features shape: {all_target_features.shape}") # (num_samples, vision_feature_dim)
步驟 2.2:批處理策略(隨機采樣)
理論:設置訓練期間的隨機批次采樣。我們將選擇隨機索引并獲取相應的提示 ID、掩碼和目標圖像特征。
print("\nStep 2.2: Preparing for batching text-to-image data...")# Check batch size feasibility
if num_sequences_available < batch_size:print(f"Warning: Number of sequences ({num_sequences_available}) is less than batch size ({batch_size}). Adjusting batch size.")batch_size = num_sequences_availableprint(f"Data ready for training. Will sample batches of size {batch_size} randomly.")
# In the training loop, we will use random indices to get:
# xb_prompt_ids = all_prompt_input_ids[indices]
# batch_prompt_masks = all_prompt_attention_masks[indices]
# yb_target_features = all_target_features[indices]
步驟 3:文本到圖像訓練循環(內聯)
目標:訓練模型將文本提示映射到圖像特征向量。
步驟 3.1:定義優化器和損失函數
理論:收集可訓練參數(Transformer 組件+新的輸出層)。定義優化器(AdamW)。定義損失函數 - 均方誤差(MSE)適用于比較預測和目標特征向量。
print("\nStep 3.1: Defining Optimizer and Loss Function for text-to-image...")# --- Gather Trainable Parameters ---
# Includes Transformer components and the new output layer
all_trainable_parameters_t2i = list(token_embedding_table.parameters())
for i in range(n_layers):all_trainable_parameters_t2i.extend(list(layer_norms_1[i].parameters()))all_trainable_parameters_t2i.extend(list(mha_qkv_linears[i].parameters()))all_trainable_parameters_t2i.extend(list(mha_output_linears[i].parameters()))all_trainable_parameters_t2i.extend(list(layer_norms_2[i].parameters()))all_trainable_parameters_t2i.extend(list(ffn_linear_1[i].parameters()))all_trainable_parameters_t2i.extend(list(ffn_linear_2[i].parameters()))
all_trainable_parameters_t2i.extend(list(final_layer_norm.parameters()))
all_trainable_parameters_t2i.extend(list(text_to_image_feature_layer.parameters())) # Add the new layer# --- Define Optimizer ---
optimizer = optim.AdamW(all_trainable_parameters_t2i, lr=learning_rate)
print(f" Optimizer defined: AdamW with lr={learning_rate}")
print(f" Managing {len(all_trainable_parameters_t2i)} parameter groups/tensors.")# --- Define Loss Function ---
# Theory: MSE loss compares the predicted feature vector to the target feature vector.
criterion = nn.MSELoss()
print(f" Loss function defined: {type(criterion).__name__}")
步驟 3.2:訓練循環
理論:迭代 epochs
次。每一步:
- 選擇批次(提示 ID、掩碼、目標特征)。
- 執行前向傳播:嵌入提示,添加位置編碼,通過 Transformer 塊傳遞,應用最終層歸一化,使用
text_to_image_feature_layer
投影到圖像特征維度。 - 計算預測和目標特征向量之間的 MSE 損失。
- 反向傳播并更新權重。
print("\nStep 3.2: 開始文本到圖像訓練循環...")t2i_losses = []# --- Set Trainable Layers to Training Mode ---
token_embedding_table.train()
for i in range(n_layers):layer_norms_1[i].train()mha_qkv_linears[i].train()mha_output_linears[i].train()layer_norms_2[i].train()ffn_linear_1[i].train()ffn_linear_2[i].train()
final_layer_norm.train()
text_to_image_feature_layer.train() # New layer also needs training mode
# vision_model remains in eval() mode# --- Training Loop ---
for epoch in range(epochs):# --- 1. 批次選擇 --- indices = torch.randint(0, num_sequences_available, (batch_size,))xb_prompt_ids = all_prompt_input_ids[indices].to(device) # (B, T)batch_prompt_masks = all_prompt_attention_masks[indices].to(device) # (B, T)yb_target_features = all_target_features[indices].to(device) # (B, vision_feature_dim)# --- 2. 前向傳播 --- B, T = xb_prompt_ids.shapeC = d_model# --- 嵌入 + 位置編碼 ---token_embed = token_embedding_table(xb_prompt_ids) # (B, T, C)pos_enc_slice = positional_encoding[:, :T, :] # (1, T, C)x = token_embed + pos_enc_slice # (B, T, C)# --- Transformer塊 --- # 創建注意力掩碼(因果掩碼+提示的填充掩碼)padding_mask_expanded = batch_prompt_masks.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)combined_attn_mask = causal_mask[:,:,:T,:T] * padding_mask_expanded # (B, 1, T, T)for i in range(n_layers):x_input_block = x# Pre-LN MHAx_ln1 = layer_norms_1[i](x_input_block)qkv = mha_qkv_linears[i](x_ln1)qkv = qkv.view(B, T, n_heads, 3 * d_k).permute(0, 2, 1, 3)q, k, v = qkv.chunk(3, dim=-1)attn_scores = (q @ k.transpose(-2, -1)) * (d_k ** -0.5)attn_scores_masked = attn_scores.masked_fill(combined_attn_mask == 0, float('-inf'))attention_weights = F.softmax(attn_scores_masked, dim=-1)attention_weights = torch.nan_to_num(attention_weights)attn_output = attention_weights @ vattn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(B, T, C)mha_result = mha_output_linears[i](attn_output)x = x_input_block + mha_result # Residual 1# Pre-LN FFNx_input_ffn = xx_ln2 = layer_norms_2[i](x_input_ffn)ffn_hidden = ffn_linear_1[i](x_ln2)ffn_activated = F.relu(ffn_hidden)ffn_output = ffn_linear_2[i](ffn_activated)x = x_input_ffn + ffn_output # Residual 2# --- Final LayerNorm --- final_norm_output = final_layer_norm(x) # (B, T, C)# --- 選擇用于預測的隱藏狀態 --- # 理論:我們需要每個序列一個向量來預測圖像特征向量。# 我們可以取*最后一個非填充*標記的隱藏狀態。# 找到批次中每個序列的最后一個非填充標記的索引。# batch_prompt_masks的形狀為(B, T),值為1表示非填充,0表示填充。# `torch.sum(mask, 1) - 1`給出了最后一個'1'的索引。last_token_indices = torch.sum(batch_prompt_masks, 1) - 1 # Shape: (B,)# 確保索引在邊界內(處理全部是填充的情況,雖然不太可能)last_token_indices = torch.clamp(last_token_indices, min=0)# 收集對應于這些最后標記的隱藏狀態。# 我們需要索引final_norm_output[batch_index, token_index, :]batch_indices = torch.arange(B, device=device)last_token_hidden_states = final_norm_output[batch_indices, last_token_indices, :] # (B, C)# --- 投影到圖像特征維度 --- # 理論:使用新的輸出層預測圖像特征向量。predicted_image_features = text_to_image_feature_layer(last_token_hidden_states) # (B, vision_feature_dim)# --- 3. 計算損失 --- # 理論:計算預測特征和目標特征之間的MSE。loss = criterion(predicted_image_features, yb_target_features)# --- 4. Zero Gradients --- optimizer.zero_grad()# --- 5. Backward Pass --- if not torch.isnan(loss) and not torch.isinf(loss):loss.backward()# Optional: Gradient Clipping# torch.nn.utils.clip_grad_norm_(all_trainable_parameters_t2i, max_norm=1.0)# --- 6. Update Parameters --- optimizer.step()else:print(f"Warning: Invalid loss detected (NaN or Inf) at epoch {epoch+1}. Skipping optimizer step.")loss = None# --- Logging --- if loss is not None:current_loss = loss.item()t2i_losses.append(current_loss)if epoch % eval_interval == 0 or epoch == epochs - 1:print(f" Epoch {epoch+1}/{epochs}, MSE Loss: {current_loss:.6f}")elif epoch % eval_interval == 0 or epoch == epochs - 1:print(f" Epoch {epoch+1}/{epochs}, Loss: Invalid (NaN/Inf)")print("--- Text-to-Image Training Loop Completed ---\n")# Optional: Plot losses
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 3))
plt.plot(t2i_losses)
plt.title("Text-to-Image Training Loss (MSE)")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.grid(True)
plt.show()
步驟 4:文本到圖像生成(內聯實現)
目標:使用訓練好的模型從文本提示生成圖像特征向量,并找到最接近的已知圖像。
步驟 4.1:準備輸入提示
理論:定義新的文本提示,進行分詞處理,并填充到 block_size
長度。
print("\nStep 4.1: Preparing input prompt for generation...")# --- Input Prompt ---
generation_prompt_text = "a blue square shape"
print(f"Input Prompt: '{generation_prompt_text}'")# --- Tokenize and Pad ---
gen_prompt_ids_no_pad = [char_to_int.get(ch, pad_token_id) for ch in generation_prompt_text] # Use get with default for safety
gen_current_len = len(gen_prompt_ids_no_pad)
gen_pad_len = block_size - gen_current_lenif gen_pad_len < 0:print(f"Warning: Generation prompt length ({gen_current_len}) exceeds block_size ({block_size}). Truncating.")gen_prompt_ids = gen_prompt_ids_no_pad[:block_size]gen_pad_len = 0gen_current_len = block_size
else:gen_prompt_ids = gen_prompt_ids_no_pad + ([pad_token_id] * gen_pad_len)# --- Create Attention Mask ---
gen_attention_mask = ([1] * gen_current_len) + ([0] * gen_pad_len)# --- Convert to Tensor ---
xb_gen_prompt_ids = torch.tensor([gen_prompt_ids], dtype=torch.long, device=device) # Add batch dim B=1
batch_gen_prompt_masks = torch.tensor([gen_attention_mask], dtype=torch.long, device=device) # Add batch dim B=1print(f"Prepared prompt tensor shape: {xb_gen_prompt_ids.shape}")
print(f"Prepared mask tensor shape: {batch_gen_prompt_masks.shape}")
步驟 4.2:生成圖像特征向量
理論:使用訓練好的模型(在評估模式下)對輸入提示進行前向傳播,獲取預測的圖像特征向量。
print("\nStep 4.2: Generating image feature vector...")# --- Set Model to Evaluation Mode ---
token_embedding_table.eval()
for i in range(n_layers):layer_norms_1[i].eval()mha_qkv_linears[i].eval()mha_output_linears[i].eval()layer_norms_2[i].eval()ffn_linear_1[i].eval()ffn_linear_2[i].eval()
final_layer_norm.eval()
text_to_image_feature_layer.eval()# --- Forward Pass ---
with torch.no_grad():B_gen, T_gen = xb_gen_prompt_ids.shapeC_gen = d_model# Embeddings + PEtoken_embed_gen = token_embedding_table(xb_gen_prompt_ids)pos_enc_slice_gen = positional_encoding[:, :T_gen, :]x_gen = token_embed_gen + pos_enc_slice_gen# Transformer Blockspadding_mask_expanded_gen = batch_gen_prompt_masks.unsqueeze(1).unsqueeze(2)combined_attn_mask_gen = causal_mask[:,:,:T_gen,:T_gen] * padding_mask_expanded_genfor i in range(n_layers):x_input_block_gen = x_gen# Pre-LN MHAx_ln1_gen = layer_norms_1[i](x_input_block_gen)qkv_gen = mha_qkv_linears[i](x_ln1_gen)qkv_gen = qkv_gen.view(B_gen, T_gen, n_heads, 3 * d_k).permute(0, 2, 1, 3)q_gen, k_gen, v_gen = qkv_gen.chunk(3, dim=-1)attn_scores_gen = (q_gen @ k_gen.transpose(-2, -1)) * (d_k ** -0.5)attn_scores_masked_gen = attn_scores_gen.masked_fill(combined_attn_mask_gen == 0, float('-inf'))attention_weights_gen = F.softmax(attn_scores_masked_gen, dim=-1)attention_weights_gen = torch.nan_to_num(attention_weights_gen)attn_output_gen = attention_weights_gen @ v_genattn_output_gen = attn_output_gen.permute(0, 2, 1, 3).contiguous().view(B_gen, T_gen, C_gen)mha_result_gen = mha_output_linears[i](attn_output_gen)x_gen = x_input_block_gen + mha_result_gen # Residual 1# Pre-LN FFNx_input_ffn_gen = x_genx_ln2_gen = layer_norms_2[i](x_input_ffn_gen)ffn_hidden_gen = ffn_linear_1[i](x_ln2_gen)ffn_activated_gen = F.relu(ffn_hidden_gen)ffn_output_gen = ffn_linear_2[i](ffn_activated_gen)x_gen = x_input_ffn_gen + ffn_output_gen # Residual 2# Final LayerNormfinal_norm_output_gen = final_layer_norm(x_gen)# Select Hidden State (use last non-padding token's state)last_token_indices_gen = torch.sum(batch_gen_prompt_masks, 1) - 1last_token_indices_gen = torch.clamp(last_token_indices_gen, min=0)batch_indices_gen = torch.arange(B_gen, device=device)last_token_hidden_states_gen = final_norm_output_gen[batch_indices_gen, last_token_indices_gen, :]# Project to Image Feature Dimensionpredicted_feature_vector = text_to_image_feature_layer(last_token_hidden_states_gen)print(f"Generated predicted feature vector with shape: {predicted_feature_vector.shape}") # Should be (1, vision_feature_dim)
步驟 4.3:查找最接近的已知圖像(簡化重建)
理論:將預測的特征向量與我們已知訓練圖像的預計算特征向量(known_features_list
)進行比較。找到特征向量與預測向量距離最小(如歐幾里得距離或余弦距離)的已知圖像,并顯示該圖像。
print("\n步驟4.3:尋找最接近的已知圖像...")# --- 計算距離 ---
# 理論:計算預測向量與每個已知向量之間的距離。
# 我們使用余弦距離:1 - 余弦相似度。距離越小越好。
predicted_vec = predicted_feature_vector.squeeze(0).cpu().numpy() # 移至CPU并轉換為numpy供scipy使用min_distance = float('inf')
closest_image_path = Nonefor known_path, known_vec_tensor in known_features_list:known_vec = known_vec_tensor.cpu().numpy()# 計算余弦距離# dist = scipy_distance.cosine(predicted_vec, known_vec)# 或者計算歐幾里得距離(L2范數)dist = scipy_distance.euclidean(predicted_vec, known_vec)# print(f" 到 {os.path.basename(known_path)} 的距離: {dist:.4f}") # 可選:打印距離if dist < min_distance:min_distance = distclosest_image_path = known_path# --- 顯示結果 ---
if closest_image_path:print(f"Closest match found: '{os.path.basename(closest_image_path)}' with distance {min_distance:.4f}")# 使用PIL或matplotlib顯示圖像try:matched_img = Image.open(closest_image_path)print("Displaying the closest matching image:")# 在notebook環境中,簡單地顯示對象通常就可以工作:# matched_img # 或者使用matplotlib:import matplotlib.pyplot as pltplt.imshow(matched_img)plt.title(f"Generated for: '{generation_prompt_text}'\nClosest Match: {os.path.basename(closest_image_path)}")plt.axis('off')plt.show()except FileNotFoundError:print(f"Error: Could not load the matched image file at {closest_image_path}")except Exception as e:print(f"Error displaying image: {e}")
else:print("Could not determine the closest image.")
步驟 5:保存模型狀態(可選)
要保存我們的文本到圖像生成模型,您需要創建一個包含所有模型組件和配置的字典,然后使用 torch.save()
。具體方法如下:
import os# 如果目錄不存在則創建
save_dir = 'saved_models'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'text_to_image_model.pt')# 創建包含所有相關組件和配置的字典
text_to_image_state_dict = {# 配置'config': {'vocab_size': vocab_size,'d_model': d_model,'n_heads': n_heads,'n_layers': n_layers,'d_ff': d_ff,'block_size': block_size,'vision_feature_dim': vision_feature_dim},# 分詞器'tokenizer': {'char_to_int': char_to_int,'int_to_char': int_to_char},# 模型權重(可訓練部分)'token_embedding_table': token_embedding_table.state_dict(),'positional_encoding': positional_encoding, # 未訓練,但重建時需要'layer_norms_1': [ln.state_dict() for ln in layer_norms_1],'mha_qkv_linears': [l.state_dict() for l in mha_qkv_linears],'mha_output_linears': [l.state_dict() for l in mha_output_linears],'layer_norms_2': [ln.state_dict() for ln in layer_norms_2],'ffn_linear_1': [l.state_dict() for l in ffn_linear_1],'ffn_linear_2': [l.state_dict() for l in ffn_linear_2],'final_layer_norm': final_layer_norm.state_dict(),'text_to_image_feature_layer': text_to_image_feature_layer.state_dict() # 新的輸出層# 注意:我們不在這里保存凍結的vision_model權重,# 因為我們假設它在使用時會從torchvision單獨加載。
}# 保存到文件
torch.save(text_to_image_state_dict, save_path)
print(f"文本到圖像模型已保存到 {save_path}")
加載保存的文本到圖像模型
要加載模型,您需要逆轉保存過程:
# 加載保存的文本到圖像模型# 加載保存的狀態字典
load_path = 'saved_models/text_to_image_model.pt'
if os.path.exists(load_path):loaded_t2i_state = torch.load(load_path, map_location=device)print(f"從'{load_path}'加載狀態字典。")# 提取配置和分詞器config = loaded_t2i_state['config']vocab_size = config['vocab_size']d_model = config['d_model']n_heads = config['n_heads']n_layers = config['n_layers']d_ff = config['d_ff']block_size = config['block_size']vision_feature_dim = config['vision_feature_dim']d_k = d_model // n_headschar_to_int = loaded_t2i_state['tokenizer']['char_to_int']int_to_char = loaded_t2i_state['tokenizer']['int_to_char']# 重新創建因果掩碼causal_mask = torch.tril(torch.ones(block_size, block_size, device=device)).view(1, 1, block_size, block_size)# 重建并加載模型組件token_embedding_table = nn.Embedding(vocab_size, d_model).to(device)token_embedding_table.load_state_dict(loaded_t2i_state['token_embedding_table'])positional_encoding = loaded_t2i_state['positional_encoding'].to(device)layer_norms_1 = []mha_qkv_linears = []mha_output_linears = []layer_norms_2 = []ffn_linear_1 = []ffn_linear_2 = []for i in range(n_layers):# Load Layer Norm 1ln1 = nn.LayerNorm(d_model).to(device)ln1.load_state_dict(loaded_t2i_state['layer_norms_1'][i])layer_norms_1.append(ln1)# Load MHA QKV Linearqkv_dict = loaded_t2i_state['mha_qkv_linears'][i]has_bias = 'bias' in qkv_dictqkv = nn.Linear(d_model, 3 * d_model, bias=has_bias).to(device)qkv.load_state_dict(qkv_dict)mha_qkv_linears.append(qkv)# Load MHA Output Linearmha_out_dict = loaded_t2i_state['mha_output_linears'][i]has_bias = 'bias' in mha_out_dictmha_out = nn.Linear(d_model, d_model, bias=has_bias).to(device)mha_out.load_state_dict(mha_out_dict)mha_output_linears.append(mha_out)# Load Layer Norm 2ln2 = nn.LayerNorm(d_model).to(device)ln2.load_state_dict(loaded_t2i_state['layer_norms_2'][i])layer_norms_2.append(ln2)# Load FFN Linear 1ffn1_dict = loaded_t2i_state['ffn_linear_1'][i]has_bias = 'bias' in ffn1_dictff1 = nn.Linear(d_model, d_ff, bias=has_bias).to(device)ff1.load_state_dict(ffn1_dict)ffn_linear_1.append(ff1)# Load FFN Linear 2ffn2_dict = loaded_t2i_state['ffn_linear_2'][i]has_bias = 'bias' in ffn2_dictff2 = nn.Linear(d_ff, d_model, bias=has_bias).to(device)ff2.load_state_dict(ffn2_dict)ffn_linear_2.append(ff2)# Load Final LayerNormfinal_layer_norm = nn.LayerNorm(d_model).to(device)final_layer_norm.load_state_dict(loaded_t2i_state['final_layer_norm'])# Load Text-to-Image Feature Layert2i_out_dict = loaded_t2i_state['text_to_image_feature_layer']has_bias = 'bias' in t2i_out_dicttext_to_image_feature_layer = nn.Linear(d_model, vision_feature_dim, bias=has_bias).to(device)text_to_image_feature_layer.load_state_dict(t2i_out_dict)print("文本到圖像模型組件加載成功。")else:print(f"在{load_path}未找到模型文件。無法加載模型。")
第六步:總結
關鍵步驟包括:
- 加載:重用之前多模態模型中的文本 Transformer 組件和 ResNet 特征提取器。
- 數據:定義與目標圖像配對的文本提示,并使用凍結的 ResNet 預提取目標圖像特征向量。
- 架構調整:將 Transformer 的最終輸出層替換為新的線性層,將隱藏狀態投影到圖像特征維度。
- 訓練:訓練 Transformer 組件和新輸出層,以最小化給定文本提示的預測圖像特征向量與目標圖像特征向量之間的均方誤差(MSE)。
- 生成與簡化重建:使用訓練好的模型從新的文本提示中預測特征向量,然后找到實際特征向量與預測向量最接近的訓練圖像(使用歐幾里得距離)并顯示該圖像。
該方法展示了在保持內聯實現約束的同時,使用 Transformer 進行跨模態生成(文本到視覺表示)的概念。它避免了直接像素生成(GANs、擴散模型)的顯著復雜性,但通過在已知圖像的特征空間中進行最近鄰搜索,提供了一個有形的基礎視覺輸出。現實中的文本到圖像模型要復雜得多,通常結合不同的架構并在海量數據集上進行訓練。