1.Dataset for Stable Diffusion
筆記來源:
1.Flickr8k數據集處理
2.處理Flickr8k數據集
3.Github:pytorch-stable-diffusion
4.Flickr 8k Dataset
5.dataset_flickr8k.json
1.1 Dataset
采用Flicker8k數據集,該數據集有兩個文件,第一個文件為Flicker8k_Dataset (全部為圖片),第二個文件為Flickr8k.token.txt (含兩列image_id和caption),其中一個image_id對應5個caption (sentence)
![]() | ![]() |
![]() | ![]() |
1.2 Dataset description file
數據集文本描述文件:dataset_flickr8k.json
文件格式如下:
{“images”: [ {“sentids”: [ ],“imgid”: 0,“sentences”:[{“tokens”:[ ]}, {“tokens”:[ ], “raw”: “…”, “imgid”:0, “sentid”:0}, …, “split”: “train”, “filename”: …jpg}, {“sentids”…} ], “dataset”: “flickr8k”}
參數 | 解釋 |
---|---|
“sentids”:[0,1,2,3,4] | caption 的 id 范圍(一個image對應5個caption,所以sentids從0到4) |
“imgid”:0 | image 的 id(從0到7999共8000張image) |
“sentences”:[ ] | 包含一張照片的5個caption |
“tokens”:[ ] | 每個caption分割為單個word |
“raw”: " " | 每個token連接起來的caption |
“imgid”: 0 | 與caption相匹配的image的id |
“sentid”: 0 | imag0對應的具體的caption的id |
“split”:" " | 將該image和對應caption劃分到訓練集or驗證集or測試集 |
“filename”:“…jpg” | image具體名稱 |
dataset_flickr8k.json
1.3 Process Datasets
下面代碼引用自:Flickr8k數據集處理(僅作學習使用)
import json
import os
import random
from collections import Counter, defaultdict
from matplotlib import pyplot as plt
from PIL import Image
from argparse import Namespace
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transformsdef create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):"""Parameters:dataset: Name of the datasetcaptions_per_image: Number of captions per imagemin_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)max_len: Maximum number of words in a caption. Captions longer than this will be truncated.Output:A vocabulary file: vocab.jsonThree dataset files: train_data.json, val_data.json, test_data.json"""# Paths for reading data and saving processed data# Path to the dataset JSON fileflickr_json_path = ".../sd/data/dataset_flickr8k.json"# Folder containing imagesimage_folder = ".../sd/data/Flicker8k_Dataset"# Folder to save processed results# The % operator is used to format the string by replacing %s with the value of the dataset variable.# For example, if dataset is "flickr8k", the resulting output_folder will be# /home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion/sd/data/flickr8k.output_folder = ".../sd/data/%s" % dataset# Ensure output directory existsos.makedirs(output_folder, exist_ok=True)print(f"Output folder: {output_folder}")# Read the dataset JSON filewith open(file=flickr_json_path, mode="r") as j:data = json.load(fp=j)# Initialize containers for image paths, captions, and vocabulary# Dictionary to store image pathsimage_paths = defaultdict(list)# Dictionary to store image captionsimage_captions = defaultdict(list)# Count the number of elements, then count and return a dictionary# key:element value:the number of elements.vocab = Counter()# read from file dataset_flickr8k.jsonfor img in data["images"]: # Iterate over each image in the datasetsplit = img["split"] # Determine the split (train, val, or test) for the imagecaptions = []for c in img["sentences"]: # Iterate over each caption for the image# Update word frequency count, excluding test set dataif split != "test": # Only update vocabulary for train/val splits# c['tokens'] is a list, The number of occurrences of each word in the list is increased by onevocab.update(c['tokens']) # Update vocabulary with words in the caption# Only consider captions that are within the maximum lengthif len(c["tokens"]) <= max_len:captions.append(c["tokens"]) # Add the caption to the list if it meets the length requirementif len(captions) == 0: # Skip images with no valid captionscontinue# Construct the full image path/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion# image_folder + image_name# ./Flicker8k_Dataset/img['filename']path = os.path.join(image_folder, img['filename'])# Save the full image path and its captions in the respective dictionariesimage_paths[split].append(path)image_captions[split].append(captions)'''After the above steps, we have:- vocab(a dict) keys:words、values: counts of all words- image_paths: (a dict) keys "train", "val", and "test"; values: lists of absolute image paths- image_captions: (a dict) keys: "train", "val", and "test"; values: lists of captions'''/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion........
我們通過dataset_flickr8k.json文件把數據集轉化為三個詞典
dict | key | value |
---|---|---|
vacab | word | frequency of words in all captions |
image_path | “train”、“val”、“test” | lists of absolute image path |
image_captions | “train”、“val”、“test” | lists of captions |
我們通過Debug打印其中的內容
print(vocab)
print(image_paths["train"][1])
print(image_captions["train"][1])
def create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):"""Parameters:dataset: Name of the datasetcaptions_per_image: Number of captions per imagemin_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)max_len: Maximum number of words in a caption. Captions longer than this will be truncated.Output:A vocabulary file: vocab.jsonThree dataset files: train_data.json, val_data.json, test_data.json"""........# Create the vocabulary, adding placeholders for special tokens# Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end>words = [w for w in vocab.keys() if vocab[w] > min_word_count] # Filter words by minimum countvocab = {k: v + 1 for v, k in enumerate(words)} # Create the vocabulary with indices# Add special tokens to the vocabularyvocab['<pad>'] = 0vocab['<unk>'] = len(vocab)vocab['<start>'] = len(vocab)vocab['<end>'] = len(vocab)# Save the vocabulary to a filewith open(os.path.join(output_folder, 'vocab.json'), "w") as fw:json.dump(vocab, fw)# Process each dataset split (train, val, test)# Iterate over each split: split = "train" 、 split = "val" 和 split = "test"for split in image_paths:# List of image paths for the splitimgpaths = image_paths[split] # type(imgpaths)=list# List of captions for the splitimcaps = image_captions[split] # type(imcaps)=list# store result that converting words of caption to their respective indices in the vocabularyenc_captions = []for i, path in enumerate(imgpaths):# Check if the image can be openedimg = Image.open(path)# Ensure each image has the required number of captionsif len(imcaps[i]) < captions_per_image:filled_num = captions_per_image - len(imcaps[i])# Repeat captions if neededcaptions = imcaps[i] + [random.choice(imcaps[i]) for _ in range(0, filled_num)]else:# Randomly sample captions if there are more than neededcaptions = random.sample(imcaps[i], k=captions_per_image)assert len(captions) == captions_per_imagefor j, c in enumerate(captions):# Encode each caption by converting words to their respective indices in the vocabularyenc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]enc_captions.append(enc_c)assert len(imgpaths) * captions_per_image == len(enc_captions)data = {"IMAGES": imgpaths,"CAPTIONS": enc_captions}# Save the processed dataset for the current split (train,val,test)with open(os.path.join(output_folder, split + "_data.json"), 'w') as fw:json.dump(data, fw)create_dataset()
經過create_dataset函數,我們得到如下圖的文件
四個文件的詳細內容見下表
![]() | ![]() |
![]() | ![]() |
![]() | ![]() |
![]() | ![]() |
生成vocab.json的關鍵代碼
首先統計所有caption中word出現至少大于5次的word,而后給這些word依次賦予一個下標
# Create the vocabulary, adding placeholders for special tokens# Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end># Create a list of words from the vocabulary that have a frequency higher than 'min_word_count'# min_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)words = [w for w in vocab.keys() if vocab[w] > min_word_count] # Filter words by minimum count# assign an index to each word, starting from 1 (indices start from 0, so add 1)vocab = {k: v + 1 for v, k in enumerate(words)} # Create the vocabulary with indices
最終生成vocab.json
生成 [“split”]_data.json 的關鍵
讀入文件dataset_flickr8k.json,并創建兩個字典,第一個字典放置每張image的絕對路徑,第二個字典放置描述image的caption,根據vocab將token換為下標保存,根據文件dataset_flickr8k.json中不同的split,這image的絕對路徑和相應caption保存在不同文件中(train_data.json、test_data.json、val_data.json)
dataset_flickr8k.json
train_data.json
從vocab中獲取token的下標得到CAPTION的編碼
for j, c in enumerate(captions):# Encode each caption by converting words to their respective indices in the vocabularyenc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]enc_captions.append(enc_c)
嘗試使用上面生成的測試集文件test_data.json和vocab.json輸出某張image以及對應的caption
下面代碼引用自:Flickr8k數據集處理(僅作學習使用)
'''
test
1.Iterates over the 5 captions for 下面代碼引用自:[Flickr8k數據集處理](https://blog.csdn.net/weixin_48981284/article/details/134676813)(僅作學習使用)the 250th image.
2.Retrieves the word indices for each caption.
3.Converts the word indices to words using vocab_idx2word.
4.Joins the words to form complete sentences.
5.Prints each caption.
'''
import json
from PIL import Image
from matplotlib import pyplot as plt
# Load the vocabulary from the JSON file
with open('.../sd/data/flickr8k/vocab.json', 'r') as f:vocab = json.load(f) # Load the vocabulary from the JSON file into a dictionary
# Create a dictionary to map indices to words
vocab_idx2word = {idx: word for word, idx in vocab.items()}
# Load the test data from the JSON file
with open('.../sd/data/flickr8k/test_data.json', 'r') as f:data = json.load(f) # Load the test data from the JSON file into a dictionary
# Open and display the 250th image in the test set
# Open the image at index 250 in the 'IMAGES' list
content_img = Image.open(data['IMAGES'][250])
plt.figure(figsize=(6, 6))
plt.subplot(1,1,1)
plt.imshow(content_img)
plt.title('Image')
plt.axis('off')
plt.show()
# Print the lengths of the data, image list, and caption list
# Print the number of keys in the dataset dictionary (should be 2: 'IMAGES' and 'CAPTIONS')
print(len(data))
print(len(data['IMAGES'])) # Print the number of images in the 'IMAGES' list
print(len(data["CAPTIONS"])) # Print the number of captions in the 'CAPTIONS' list
# Display the captions for the 300th image
# Iterate over the 5 captions associated with the 300th image
for i in range(5):# Get the word indices for the i-th caption of the 300th imageword_indices = data['CAPTIONS'][250 * 5 + i]# Convert indices to words and join them to form a captionprint(''.join([vocab_idx2word[idx] for idx in word_indices]))
data 的 key 有兩個 IMAGES 和 CAPTIONS
測試集image有1000張,每張對應5個caption,共5000個caption
第250張圖片的5個caption如下圖
1.4 Dataloader
下面代碼引用自:Flickr8k數據集處理(僅作學習使用)
import json
import os
import random
from collections import Counter, defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils import data
import torchvision.transforms as transformsclass ImageTextDataset(Dataset):"""Pytorch Dataset class to generate data batches using torch DataLoader"""def __init__(self, dataset_path, vocab_path, split, captions_per_image=5, max_len=30, transform=None):"""Parameters:dataset_path: Path to the JSON file containing the datasetvocab_path: Path to the JSON file containing the vocabularysplit: The dataset split, which can be "train", "val", or "test"captions_per_image: Number of captions per imagemax_len: Maximum number of words per captiontransform: Image transformation methods"""self.split = split# Validate that the split is one of the allowed valuesassert self.split in {"train", "val", "test"}# Store captions per imageself.cpi = captions_per_image# Store maximum caption lengthself.max_len = max_len# Load the datasetwith open(dataset_path, "r") as f:self.data = json.load(f)# Load the vocabularywith open(vocab_path, "r") as f:self.vocab = json.load(f)# Store the image transformation methodsself.transform = transform# Number of captions in the dataset# Calculate the size of the datasetself.dataset_size = len(self.data["CAPTIONS"])def __getitem__(self, i):"""Retrieve the i-th sample from the dataset"""# Get [i // self.cpi]-th image corresponding to the i-th sample (each image has multiple captions)img = Image.open(self.data['IMAGES'][i // self.cpi]).convert("RGB")# Apply image transformation if providedif self.transform is not None:# Apply the transformation to the imageimg = self.transform(img)# Get the length of the captioncaplen = len(self.data["CAPTIONS"][i])# Pad the caption if its length is less than max_lenpad_caps = [self.vocab['<pad>']] * (self.max_len + 2 - caplen)# Convert the caption to a tensor and pad itcaption = torch.LongTensor(self.data["CAPTIONS"][i] + pad_caps)return img, caption, caplen # Return the image, caption, and caption lengthdef __len__(self):return self.dataset_size # Number of samples in the datasetdef make_train_val(data_dir, vocab_path, batch_size, workers=4):"""Create DataLoader objects for training, validation, and testing sets.Parameters:data_dir: Directory where the dataset JSON files are locatedvocab_path: Path to the vocabulary JSON filebatch_size: Number of samples per batchworkers: Number of subprocesses to use for data loading (default is 4)Returns:train_loader: DataLoader for the training setval_loader: DataLoader for the validation settest_loader: DataLoader for the test set"""# Define transformation for training settrain_tx = transforms.Compose([transforms.Resize(256), # Resize images to 256x256transforms.ToTensor(), # Convert image to PyTorch tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet mean and std])val_tx = transforms.Compose([transforms.Resize(256),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# Create dataset objects for training, validation, and test setstrain_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "train_data.json"), vocab_path=vocab_path,split="train", transform=train_tx)vaild_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "val_data.json"), vocab_path=vocab_path,split="val", transform=val_tx)test_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "test_data.json"), vocab_path=vocab_path,split="test", transform=val_tx)# Create DataLoader for training set with data shufflingtrain_loder = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffer=True,num_workers=workers, pin_memory=True)# Create DataLoader for validation set without data shufflingval_loder = data.DataLoader(dataset=vaild_set, batch_size=batch_size, shuffer=False,num_workers=workers, pin_memory=True, drop_last=False)# Create DataLoader for test set without data shufflingtest_loder = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffer=False,num_workers=workers, pin_memory=True, drop_last=False)return train_loder, val_loder, test_loder
創建好train_loader后,接下來我們就可以著手開始訓練SD了!