Dataset for Stable Diffusion

1.Dataset for Stable Diffusion

筆記來源:
1.Flickr8k數據集處理
2.處理Flickr8k數據集
3.Github:pytorch-stable-diffusion
4.Flickr 8k Dataset
5.dataset_flickr8k.json

1.1 Dataset

采用Flicker8k數據集,該數據集有兩個文件,第一個文件為Flicker8k_Dataset (全部為圖片),第二個文件為Flickr8k.token.txt (含兩列image_id和caption),其中一個image_id對應5個caption (sentence)

1.2 Dataset description file

數據集文本描述文件:dataset_flickr8k.json
文件格式如下:
{“images”: [ {“sentids”: [ ],“imgid”: 0,“sentences”:[{“tokens”:[ ]}, {“tokens”:[ ], “raw”: “…”, “imgid”:0, “sentid”:0}, …, “split”: “train”, “filename”: …jpg}, {“sentids”…} ], “dataset”: “flickr8k”}

參數解釋
“sentids”:[0,1,2,3,4]caption 的 id 范圍(一個image對應5個caption,所以sentids從0到4)
“imgid”:0image 的 id(從0到7999共8000張image)
“sentences”:[ ]包含一張照片的5個caption
“tokens”:[ ]每個caption分割為單個word
“raw”: " "每個token連接起來的caption
“imgid”: 0與caption相匹配的image的id
“sentid”: 0imag0對應的具體的caption的id
“split”:" "將該image和對應caption劃分到訓練集or驗證集or測試集
“filename”:“…jpg”image具體名稱

dataset_flickr8k.json

1.3 Process Datasets

下面代碼引用自:Flickr8k數據集處理(僅作學習使用)

import json
import os
import random
from collections import Counter, defaultdict
from matplotlib import pyplot as plt
from PIL import Image
from argparse import Namespace
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transformsdef create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):"""Parameters:dataset: Name of the datasetcaptions_per_image: Number of captions per imagemin_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)max_len: Maximum number of words in a caption. Captions longer than this will be truncated.Output:A vocabulary file: vocab.jsonThree dataset files: train_data.json, val_data.json, test_data.json"""# Paths for reading data and saving processed data# Path to the dataset JSON fileflickr_json_path = ".../sd/data/dataset_flickr8k.json"# Folder containing imagesimage_folder = ".../sd/data/Flicker8k_Dataset"# Folder to save processed results# The % operator is used to format the string by replacing %s with the value of the dataset variable.# For example, if dataset is "flickr8k", the resulting output_folder will be# /home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion/sd/data/flickr8k.output_folder = ".../sd/data/%s" % dataset# Ensure output directory existsos.makedirs(output_folder, exist_ok=True)print(f"Output folder: {output_folder}")# Read the dataset JSON filewith open(file=flickr_json_path, mode="r") as j:data = json.load(fp=j)# Initialize containers for image paths, captions, and vocabulary# Dictionary to store image pathsimage_paths = defaultdict(list)# Dictionary to store image captionsimage_captions = defaultdict(list)# Count the number of elements, then count and return a dictionary# key:element value:the number of elements.vocab = Counter()# read from file dataset_flickr8k.jsonfor img in data["images"]:  # Iterate over each image in the datasetsplit = img["split"]  # Determine the split (train, val, or test) for the imagecaptions = []for c in img["sentences"]:  # Iterate over each caption for the image# Update word frequency count, excluding test set dataif split != "test":  # Only update vocabulary for train/val splits# c['tokens'] is a list, The number of occurrences of each word in the list is increased by onevocab.update(c['tokens'])  # Update vocabulary with words in the caption# Only consider captions that are within the maximum lengthif len(c["tokens"]) <= max_len:captions.append(c["tokens"])  # Add the caption to the list if it meets the length requirementif len(captions) == 0:  # Skip images with no valid captionscontinue# Construct the full image path/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion# image_folder + image_name# ./Flicker8k_Dataset/img['filename']path = os.path.join(image_folder, img['filename'])# Save the full image path and its captions in the respective dictionariesimage_paths[split].append(path)image_captions[split].append(captions)'''After the above steps, we have:- vocab(a dict) keys:words、values: counts of all words- image_paths: (a dict) keys "train", "val", and "test"; values: lists of absolute image paths- image_captions: (a dict) keys: "train", "val", and "test"; values: lists of captions'''/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion........

我們通過dataset_flickr8k.json文件把數據集轉化為三個詞典

dictkeyvalue
vacabwordfrequency of words in all captions
image_path“train”、“val”、“test”lists of absolute image path
image_captions“train”、“val”、“test”lists of captions

我們通過Debug打印其中的內容

print(vocab)
print(image_paths["train"][1])
print(image_captions["train"][1])

def create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):"""Parameters:dataset: Name of the datasetcaptions_per_image: Number of captions per imagemin_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)max_len: Maximum number of words in a caption. Captions longer than this will be truncated.Output:A vocabulary file: vocab.jsonThree dataset files: train_data.json, val_data.json, test_data.json"""........# Create the vocabulary, adding placeholders for special tokens# Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end>words = [w for w in vocab.keys() if vocab[w] > min_word_count]  # Filter words by minimum countvocab = {k: v + 1 for v, k in enumerate(words)}  # Create the vocabulary with indices# Add special tokens to the vocabularyvocab['<pad>'] = 0vocab['<unk>'] = len(vocab)vocab['<start>'] = len(vocab)vocab['<end>'] = len(vocab)# Save the vocabulary to a filewith open(os.path.join(output_folder, 'vocab.json'), "w") as fw:json.dump(vocab, fw)# Process each dataset split (train, val, test)# Iterate over each split: split = "train" 、 split = "val" 和 split = "test"for split in image_paths:# List of image paths for the splitimgpaths = image_paths[split]  # type(imgpaths)=list# List of captions for the splitimcaps = image_captions[split]  # type(imcaps)=list# store result that converting words of caption to their respective indices in the vocabularyenc_captions = []for i, path in enumerate(imgpaths):# Check if the image can be openedimg = Image.open(path)# Ensure each image has the required number of captionsif len(imcaps[i]) < captions_per_image:filled_num = captions_per_image - len(imcaps[i])# Repeat captions if neededcaptions = imcaps[i] + [random.choice(imcaps[i]) for _ in range(0, filled_num)]else:# Randomly sample captions if there are more than neededcaptions = random.sample(imcaps[i], k=captions_per_image)assert len(captions) == captions_per_imagefor j, c in enumerate(captions):# Encode each caption by converting words to their respective indices in the vocabularyenc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]enc_captions.append(enc_c)assert len(imgpaths) * captions_per_image == len(enc_captions)data = {"IMAGES": imgpaths,"CAPTIONS": enc_captions}# Save the processed dataset for the current split (train,val,test)with open(os.path.join(output_folder, split + "_data.json"), 'w') as fw:json.dump(data, fw)create_dataset()

經過create_dataset函數,我們得到如下圖的文件

四個文件的詳細內容見下表

train_data.json中的第一個key:IMAGES
train_data.json中的第二個key:CAPTIONS
test_data.json中的第一個key:IMAGES
test_data.json中的第二個key:CAPTIONS
val_data.json中的第一個key:IMAGES
val_data.json中的第二個key:CAPTIONS
vocab.json開始部分
vocab.json結尾部分

生成vocab.json的關鍵代碼
首先統計所有caption中word出現至少大于5次的word,而后給這些word依次賦予一個下標

# Create the vocabulary, adding placeholders for special tokens# Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end># Create a list of words from the vocabulary that have a frequency higher than 'min_word_count'# min_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)words = [w for w in vocab.keys() if vocab[w] > min_word_count]  # Filter words by minimum count# assign an index to each word, starting from 1 (indices start from 0, so add 1)vocab = {k: v + 1 for v, k in enumerate(words)}  # Create the vocabulary with indices

最終生成vocab.json

生成 [“split”]_data.json 的關鍵
讀入文件dataset_flickr8k.json,并創建兩個字典,第一個字典放置每張image的絕對路徑,第二個字典放置描述image的caption,根據vocab將token換為下標保存,根據文件dataset_flickr8k.json中不同的split,這image的絕對路徑和相應caption保存在不同文件中(train_data.json、test_data.json、val_data.json)

dataset_flickr8k.json

train_data.json


從vocab中獲取token的下標得到CAPTION的編碼

for j, c in enumerate(captions):# Encode each caption by converting words to their respective indices in the vocabularyenc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]enc_captions.append(enc_c)


嘗試使用上面生成的測試集文件test_data.json和vocab.json輸出某張image以及對應的caption
下面代碼引用自:Flickr8k數據集處理(僅作學習使用)

'''
test
1.Iterates over the 5 captions for 下面代碼引用自:[Flickr8k數據集處理](https://blog.csdn.net/weixin_48981284/article/details/134676813)(僅作學習使用)the 250th image.
2.Retrieves the word indices for each caption.
3.Converts the word indices to words using vocab_idx2word.
4.Joins the words to form complete sentences.
5.Prints each caption.
'''
import json
from PIL import Image
from matplotlib import pyplot as plt
# Load the vocabulary from the JSON file
with open('.../sd/data/flickr8k/vocab.json', 'r') as f:vocab = json.load(f)  # Load the vocabulary from the JSON file into a dictionary
# Create a dictionary to map indices to words
vocab_idx2word = {idx: word for word, idx in vocab.items()}
# Load the test data from the JSON file
with open('.../sd/data/flickr8k/test_data.json', 'r') as f:data = json.load(f)  # Load the test data from the JSON file into a dictionary
# Open and display the 250th image in the test set
# Open the image at index 250 in the 'IMAGES' list
content_img = Image.open(data['IMAGES'][250])
plt.figure(figsize=(6, 6))
plt.subplot(1,1,1)
plt.imshow(content_img)
plt.title('Image')
plt.axis('off')
plt.show()
# Print the lengths of the data, image list, and caption list
# Print the number of keys in the dataset dictionary (should be 2: 'IMAGES' and 'CAPTIONS')
print(len(data))
print(len(data['IMAGES']))  # Print the number of images in the 'IMAGES' list
print(len(data["CAPTIONS"]))  # Print the number of captions in the 'CAPTIONS' list
# Display the captions for the 300th image
# Iterate over the 5 captions associated with the 300th image
for i in range(5):# Get the word indices for the i-th caption of the 300th imageword_indices = data['CAPTIONS'][250 * 5 + i]# Convert indices to words and join them to form a captionprint(''.join([vocab_idx2word[idx] for idx in word_indices]))

data 的 key 有兩個 IMAGES 和 CAPTIONS
測試集image有1000張,每張對應5個caption,共5000個caption
第250張圖片的5個caption如下圖

1.4 Dataloader

下面代碼引用自:Flickr8k數據集處理(僅作學習使用)

import json
import os
import random
from collections import Counter, defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils import data
import torchvision.transforms as transformsclass ImageTextDataset(Dataset):"""Pytorch Dataset class to generate data batches using torch DataLoader"""def __init__(self, dataset_path, vocab_path, split, captions_per_image=5, max_len=30, transform=None):"""Parameters:dataset_path: Path to the JSON file containing the datasetvocab_path: Path to the JSON file containing the vocabularysplit: The dataset split, which can be "train", "val", or "test"captions_per_image: Number of captions per imagemax_len: Maximum number of words per captiontransform: Image transformation methods"""self.split = split# Validate that the split is one of the allowed valuesassert self.split in {"train", "val", "test"}# Store captions per imageself.cpi = captions_per_image# Store maximum caption lengthself.max_len = max_len# Load the datasetwith open(dataset_path, "r") as f:self.data = json.load(f)# Load the vocabularywith open(vocab_path, "r") as f:self.vocab = json.load(f)# Store the image transformation methodsself.transform = transform# Number of captions in the dataset# Calculate the size of the datasetself.dataset_size = len(self.data["CAPTIONS"])def __getitem__(self, i):"""Retrieve the i-th sample from the dataset"""# Get [i // self.cpi]-th image corresponding to the i-th sample (each image has multiple captions)img = Image.open(self.data['IMAGES'][i // self.cpi]).convert("RGB")# Apply image transformation if providedif self.transform is not None:# Apply the transformation to the imageimg = self.transform(img)# Get the length of the captioncaplen = len(self.data["CAPTIONS"][i])# Pad the caption if its length is less than max_lenpad_caps = [self.vocab['<pad>']] * (self.max_len + 2 - caplen)# Convert the caption to a tensor and pad itcaption = torch.LongTensor(self.data["CAPTIONS"][i] + pad_caps)return img, caption, caplen  # Return the image, caption, and caption lengthdef __len__(self):return self.dataset_size  # Number of samples in the datasetdef make_train_val(data_dir, vocab_path, batch_size, workers=4):"""Create DataLoader objects for training, validation, and testing sets.Parameters:data_dir: Directory where the dataset JSON files are locatedvocab_path: Path to the vocabulary JSON filebatch_size: Number of samples per batchworkers: Number of subprocesses to use for data loading (default is 4)Returns:train_loader: DataLoader for the training setval_loader: DataLoader for the validation settest_loader: DataLoader for the test set"""# Define transformation for training settrain_tx = transforms.Compose([transforms.Resize(256),  # Resize images to 256x256transforms.ToTensor(),  # Convert image to PyTorch tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet mean and std])val_tx = transforms.Compose([transforms.Resize(256),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# Create dataset objects for training, validation, and test setstrain_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "train_data.json"), vocab_path=vocab_path,split="train", transform=train_tx)vaild_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "val_data.json"), vocab_path=vocab_path,split="val", transform=val_tx)test_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "test_data.json"), vocab_path=vocab_path,split="test", transform=val_tx)# Create DataLoader for training set with data shufflingtrain_loder = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffer=True,num_workers=workers, pin_memory=True)# Create DataLoader for validation set without data shufflingval_loder = data.DataLoader(dataset=vaild_set, batch_size=batch_size, shuffer=False,num_workers=workers, pin_memory=True, drop_last=False)# Create DataLoader for test set without data shufflingtest_loder = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffer=False,num_workers=workers, pin_memory=True, drop_last=False)return train_loder, val_loder, test_loder

創建好train_loader后,接下來我們就可以著手開始訓練SD了!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/44813.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/44813.shtml
英文地址,請注明出處:http://en.pswp.cn/web/44813.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Node.js_mongodb用戶名和密碼操作

mongodb用戶名和密碼操作 查看用戶密碼創建管理員用戶和密碼mongodb的目標是實現快速簡單部署,所以存在很多安全問題 默認配置下沒有用戶和密碼,無需身份驗證即可登錄,不像mysql那樣需要登錄才能操作數據庫本身安全問題:升級3.0以上版本查看用戶密碼 密碼是加密存儲的,并且…

前端工程化10-webpack靜態的模塊化打包工具之各種loader處理器

9.1、案例編寫 我們創建一個component.js 通過JavaScript創建了一個元素&#xff0c;并且希望給它設置一些樣式&#xff1b; 我們自己寫的css,要把他加入到Webpack的圖結構當中&#xff0c;這樣才能被webpack檢測到進行打包&#xff0c; style.css–>div_cn.js–>main…

速盾:ddos高防ip哪里好用?

隨著互聯網的飛速發展&#xff0c;DDoS攻擊問題逐漸突出。DDoS攻擊是一種通過在網絡上創建大量請求&#xff0c;使目標網絡或服務器過載而無法正常工作的攻擊方式。為了應對DDoS攻擊&#xff0c;提高網絡的安全性和穩定性&#xff0c;使用高防IP成為了一種常見的解決辦法。 DD…

Flower花所比特幣交易及交易費用科普

在加密貨幣交易中&#xff0c;選擇一個可靠的平臺至關重要。Flower花所通過提供比特幣交易服務脫穎而出。本文將介紹在Flower花所進行比特幣交易的基礎知識及其交易費用。 什么是Flower花所&#xff1f; Flower花所是一家加密貨幣交易平臺&#xff0c;為新手和資深交易者提供…

【C++】開源:drogon-web框架配置使用

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 這篇文章主要介紹drogon-web框架配置使用。 無專精則不能成&#xff0c;無涉獵則不能通。——梁啟超 歡迎來到我的博客&#xff0c;一起學習&#xff0c;共同進步。 喜歡的朋友可以關注一下&#xff0c;…

Linux系統編程-線程同步詳解

線程同步是指多個線程協調工作&#xff0c;以便在共享資源的訪問和操作過程中保持數據一致性和正確性。在多線程環境中&#xff0c;線程是并發執行的&#xff0c;因此如果多個線程同時訪問和修改共享資源&#xff0c;可能會導致數據不一致、競態條件&#xff08;race condition…

面試題008-Java-SpringBoot

面試題008-Java-SpringBoot 目錄 面試題008-Java-SpringBoot題目自測題目答案1. Spring 和 Spring Boot有什么區別&#xff1f;2. Spring Boot 的主要優點是什么&#xff1f;3. 什么是Spring Boot Starter&#xff1f;4. 介紹一下SpringBootApplication注解&#xff1f;5. Spri…

【密碼學】消息認證

你發送給朋友一條消息&#xff08;內容&#xff1a;明天下午來我家吃飯&#xff09;&#xff0c;這一過程中你不想讓除你朋友以外的人看到消息的內容&#xff0c;這就叫做消息的機密性&#xff0c;用來保護消息機密性的方式被叫做加密機制。 現在站在朋友的視角&#xff0c;某一…

使用PyQt5實現添加工具欄、增加SwitchButton控件

前言&#xff1a;通過在網上找到的“電池電壓監控界面”&#xff0c;學習PyQt5中添加工具欄、增加SwitchButton控件&#xff0c;在滑塊控件右側增加文本顯示、設置界面背景顏色、修改文本控件字體顏色等。 1. 上位機界面效果展示 網絡上原圖如下&#xff1a; 自己使用PyQt5做…

springboot異常(一):springboot自定義全局異常處理

&#x1f337;1. 自定義一個異常類 自定義一個異常&#xff0c;有兩個變量異常代碼、異常消息&#xff0c;定義了兩個構造方法&#xff0c;一個無參構造方法&#xff0c;一個所有參數構造方法。 在構造方法中要掉用父類的構造方法&#xff0c;主要目的是在日志或控制臺打印異…

【Linux】多線程_3

文章目錄 九、多線程3. C11中的多線程4. 線程的簡單封裝 未完待續 九、多線程 3. C11中的多線程 Linux中是根據多線程庫來實現多線程的&#xff0c;C11也有自己的多線程&#xff0c;那它的多線程又是怎樣的&#xff1f;我們來使用一些C11的多線程。 Makefile&#xff1a; te…

Linux - 探索命令行

探索命令行 Linux命令行中的命令使用格式都是相同的: 命令名稱 參數1 參數2 參數3 ...參數之間用任意數量的空白字符分開. 關于命令行, 可以先閱讀一些基本常識. 然后我們介紹最常用的一些命令: ls用于列出當前目錄(即"文件夾")下的所有文件(或目錄). 目錄會用藍色…

面試經典題型:調用HashMap的put方法的具體執行流程

在調用put方法時時&#xff0c;有幾個關鍵點需要考慮&#xff1a; 哈希沖突的發生與解決&#xff1a; 哈希沖突指不同的鍵通過哈希函數計算得到相同的哈希值&#xff0c;導致它們應該存放在哈希表的同一個位置。解決沖突的常用方法包括開放尋址法和鏈表法&#xff08;或其升級形…

CSIP-FTE考試專業題

靶場下載鏈接&#xff1a; https://pan.baidu.com/s/1ce1Kk0hSYlxrUoRTnNsiKA?pwdha1x pte-2003密碼&#xff1a;admin123 centos:root admin123 解壓密碼&#xff1a; PTE考試專用 下載好后直接用vmware打開&#xff0c;有兩個靶機&#xff0c;一個是基礎題&#x…

【CTF-Crypto】數論基礎-02

【CTF-Crypto】數論基礎-02 文章目錄 【CTF-Crypto】數論基礎-021-16 二次剩余1-20 模p下-1的平方根*1-21 Legendre符號*1-22 Jacobi符號*2-1 群*2-2 群的性質2-3 阿貝爾群*2-4 子群2-11 群同態2-18 原根2-21 什么是環2-23 什么是域2-25 子環2-26 理想2-32 多項式環 1-16 二次剩…

打造智慧校園德育管理,提升學生操行基礎分

智慧校園的德育管理系統內嵌的操行基礎分功能&#xff0c;是對學生日常行為規范和道德素養進行量化評估的一個創新實踐。該功能通過將抽象的道德品質轉化為具體可量化的指標&#xff0c;如遵守紀律、尊師重道、團結協作、愛護環境及參與集體活動的積極性等&#xff0c;為每個學…

醫療器械FDA |FDA網絡安全測試具體內容

醫療器械FDA網絡安全測試的具體內容涵蓋了多個方面&#xff0c;以確保醫療器械在網絡環境中的安全性和合規性。以下是根據權威來源歸納的FDA網絡安全測試的具體內容&#xff1a; 一、技術文件審查 網絡安全計劃&#xff1a;制造商需要提交網絡安全計劃&#xff0c;詳細描述產…

Matlab【光伏預測】基于雪融優化算法SAO優化高斯過程回歸GPR實現光伏多輸入單輸出預測附代碼

% 光伏預測 - 基于SAO優化的GPR % 數據準備 % 假設有多個輸入特征 X1, X2, …, Xn 和一個目標變量 Y % 假設數據已經存儲在 X 和 Y 中&#xff0c;每個變量為矩陣&#xff0c;每行表示一個樣本&#xff0c;每列表示一個特征 % 參數設置 numFeatures size(X, 2); % 輸入特征的…

Spring Boot集成easyposter快速入門Demo

1.什么是easyposter&#xff1f; easyposter是一個簡單的,便于擴展的繪制海報工具包 使用場景 在日常工作過程中&#xff0c;通常一些C端平臺會伴隨著海報生成與分享業務。因為隨著移動互聯網的迅猛發展&#xff0c;社交分享已成為我們日常生活的重要組成部分。海報分享作為…

visual studio 2019版下載以及與UE4虛幻引擎配置(過程記錄)(官網無法下載visual studio 2019安裝包)

一、概述 由于需要使用到UE4虛幻引擎&#xff0c;我使用的版本是4.27版本的&#xff0c;其官方默認的visual studio版本是2019版本的&#xff0c;相應的版本對應關系可以通過下面的官方網站對應關系查詢。https://docs.unrealengine.com/4.27/zh-CN/ProductionPipelines/Develo…