git倉庫:https://github.com/FoundationVision/LlamaGen
數據集準備
如果用ImageFolder讀取,則最好和ImageNet一致。
data_path/class_1/image_001.jpgimage_002.jpg...class_2/image_003.jpgimage_004.jpg......class_n/image_005.jpgimage_006.jpg...
則
def build_imagenet(args, transform):return ImageFolder(args.data_path, transform=transform)
如果是train,val,test,最好整理成
data_path/train/class_1/image_001.jpgimage_002.jpg...class_2/image_003.jpgimage_004.jpg......val/class_1/image_005.jpgimage_006.jpg...class_2/image_007.jpgimage_008.jpg......test/class_1/image_009.jpgimage_010.jpg...class_2/image_011.jpgimage_012.jpg......
讀取:
train_dataset = datasets.ImageFolder(root=args.data_path + '/train', transform=transform)# 加載驗證集
val_dataset = datasets.ImageFolder(root=args.data_path + '/val', transform=transform)# 加載測試集
test_dataset = datasets.ImageFolder(root=args.data_path + '/test', transform=transform)
數據集預處理
NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=3 torchrun \
--nnodes=1 --nproc_per_node=1 --node_rank=0 \
--master_addr=localhost \
autoregressive/train/extract_codes_c2i.py \
--vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \
--data-path 你的數據集 \
--code-path VQGAN處理的數據集放在哪 \--ten-crop \--crop-range 1.1 \--image-size 256
這里改成自己數據集的長度
ten-crop是作者定義的一種數據增強,每一個圖片生成10個crop。最好修改一下這里的代碼,訓練的時候僅僅取一個。
注釋掉這個self.flip
訓練
NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=4,5 torchrun \
--nnodes=1 --nproc_per_node=2 --node_rank=0 \
--master_addr=localhost \
--master_port=8902 \
./autoregressive/train/train_c2i.py \
--cloud-save-path xxx \
--code-path 之前放VQGAN處理后數據集的地方 \
--image-size 256 \
--gpt-model GPT-B
生成
修改類別,權重
parser.add_argument("--num-classes", type=int, default=xxx)
label定義:
我的生成結果(數據集用了TinyImageNet的8個類)
300step
1500step