from typing import Any, Callable, Dict, List, Optional, Union import os import random import traceback import math import json import numpy as np import torch import torch.distributed as dist import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from transformers import AutoTokenizer from PIL import Image from tqdm import tqdm from longcat_image.dataset import MULTI_RESOLUTION_MAP from longcat_image.utils import encode_prompt from longcat_image.dataset import MultiResolutionDistributedSampler Image.MAX_IMAGE_PIXELS = 2000000000 MAX_RETRY_NUMS = 100 class DpoPairDataSet(torch.utils.data.Dataset): def __init__(self, cfg: dict, txt_root: str, tokenizer: AutoTokenizer, resolution: tuple = (1024, 1024)): super(DpoPairDataSet, self).__init__() self.resolution = resolution self.text_tokenizer_max_length = cfg.text_tokenizer_max_length self.null_text_ratio = cfg.null_text_ratio self.aspect_ratio_type = cfg.aspect_ratio_type self.aspect_ratio = MULTI_RESOLUTION_MAP[self.aspect_ratio_type] self.tokenizer = tokenizer self.total_datas = [] self.data_resolution_infos = [] with open(txt_root, 'r') as f: lines = f.readlines() for line in tqdm(lines): data = json.loads(line.strip()) try: height, widht = int(data['height']), int(data['width']) self.data_resolution_infos.append((height, widht)) self.total_datas.append(data) except Exception as e: print(f'get error {e}, data {data}.') continue self.data_nums = len(self.total_datas) def transform_img(self, image, original_size, target_size): img_h, img_w = original_size target_height, target_width = target_size original_aspect = img_h / img_w # height/width crop_aspect = target_height / target_width if original_aspect >= crop_aspect: resize_width = target_width resize_height = math.ceil(img_h * (target_width/img_w)) else: resize_width = math.ceil(img_w * (target_height/img_h)) resize_height = target_height image = T.Compose([ T.Resize((resize_height, resize_width),interpolation=InterpolationMode.BICUBIC), # Image.LANCZOS T.CenterCrop((target_height, target_width)), T.ToTensor(), T.Normalize([.5], [.5]), ])(image) return image def __getitem__(self, index_tuple): index, target_size = index_tuple for _ in range(MAX_RETRY_NUMS): try: item = self.total_datas[index] img_path_win = item["img_path_win"] img_path_lose = item["img_path_lose"] prompt = item['prompt'] if random.random() < self.null_text_ratio: prompt = '' raw_image_win = Image.open(img_path_win).convert('RGB') raw_image_lose = Image.open(img_path_lose).convert('RGB') assert raw_image_win is not None and raw_image_lose is not None img_w, img_h = raw_image_win.size raw_image_win = self.transform_img(raw_image_win, original_size=( img_h, img_w), target_size= target_size ) raw_image_lose = self.transform_img(raw_image_lose, original_size=( img_h, img_w), target_size= target_size ) input_ids,attention_mask = encode_prompt(prompt, self.tokenizer, self.text_tokenizer_max_length) return {"image_win": raw_image_win, "image_lose": raw_image_lose, "prompt": prompt, 'input_ids': input_ids, 'attention_mask': attention_mask} except Exception as e: traceback.print_exc() print(f"failed read data {e}!!!") index = random.randint(0, self.data_nums-1) def __len__(self): return self.data_nums def collate_fn(self, batchs): images_win = torch.stack([example["image_win"] for example in batchs]) images_lose = torch.stack([example["image_lose"] for example in batchs]) input_ids = torch.stack([example["input_ids"] for example in batchs]) attention_mask = torch.stack([example["attention_mask"] for example in batchs]) prompts = [example['prompt'] for example in batchs] batch_dict = { "images_win": images_win, "images_lose": images_lose, "input_ids": input_ids, "attention_mask": attention_mask, "prompts": prompts, } return batch_dict def build_dataloader(cfg: dict, csv_root: str, tokenizer: AutoTokenizer, resolution: tuple = (1024, 1024)): dataset = DpoPairDataSet(cfg, csv_root, tokenizer, resolution) sampler = MultiResolutionDistributedSampler(batch_size=cfg.train_batch_size, dataset=dataset, data_resolution_infos=dataset.data_resolution_infos, bucket_info=dataset.aspect_ratio, epoch=0, num_replicas=None, rank=None ) train_loader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, batch_size=cfg.train_batch_size, num_workers=cfg.dataloader_num_workers, sampler=sampler, shuffle=None, ) return train_loader if __name__ == '__main__': import sys import argparse from torchvision.transforms.functional import to_pil_image txt_root = 'xxxx' cfg = argparse.Namespace( txt_root=txt_root, text_tokenizer_max_length=256, resolution=1024, text_encoder_path="xxx", center_crop=True, dataloader_num_workers=0, null_text_ratio=0.1, train_batch_size=16, seed=0, aspect_ratio_type='mar_1024', revision=None) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(cfg.text_encoder_path, trust_remote_code=True) data_loader = build_dataloader(cfg, cfg.csv_root, tokenizer, cfg.resolution) _oroot = f'./debug_data_example_show' os.makedirs(_oroot, exist_ok=True) cnt = 0 for epoch in range(1): print(f"Start, epoch {epoch}!!!") for i_batch, batch in enumerate(data_loader): print(batch['attention_mask'].shape) print(batch['images_win'].shape,'-',batch['images_lose'].shape,) if cnt > 100: break cnt += 1