import os import sys import time import warnings import argparse import yaml import torch import math import logging import transformers import diffusers import torch.nn.functional as F from pathlib import Path from transformers import Qwen2Model, Qwen2TokenizerFast from accelerate import Accelerator, InitProcessGroupKwargs from accelerate.utils import ProjectConfiguration, set_seed from accelerate.logging import get_logger from diffusers.models import AutoencoderKL from diffusers.optimization import get_scheduler from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.training_utils import EMAModel from diffusers.utils.import_utils import is_xformers_available from transformers import AutoTokenizer, AutoModel from train_dataset import build_dataloader from longcat_image.models import LongCatImageTransformer2DModel from longcat_image.utils import LogBuffer from longcat_image.utils import pack_latents, unpack_latents, calculate_shift, prepare_pos_ids warnings.filterwarnings("ignore") # ignore warning current_file_path = Path(__file__).resolve() sys.path.insert(0, str(current_file_path.parent.parent)) logger = get_logger(__name__) def train(global_step=0): # Train! total_batch_size = args.train_batch_size * \ accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataloader.dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") last_tic = time.time() # Now you train the model for epoch in range(first_epoch, args.num_train_epochs + 1): data_time_start = time.time() data_time_all = 0 for step, batch in enumerate(train_dataloader): images_win = batch['images_win'] images_lose = batch['images_lose'] half_batch_size = images_win.shape[0] image = torch.concat([images_win, images_lose], dim=0) data_time_all += time.time() - data_time_start with torch.no_grad(): latents = vae.encode(image.to(weight_dtype).to(accelerator.device)).latent_dist.sample() latents = latents.to(dtype=(weight_dtype)) latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor text_input_ids = batch['input_ids'].to(accelerator.device) text_attention_mask = batch['attention_mask'].to(accelerator.device) with torch.no_grad(): text_output = text_encoder( input_ids=text_input_ids, attention_mask=text_attention_mask, output_hidden_states=True ) prompt_embeds = text_output.hidden_states[-1].clone().detach() prompt_embeds = prompt_embeds.to(weight_dtype) prompt_embeds = torch.concat([prompt_embeds, prompt_embeds], dim=0) # [batch_size, 256, 4096] prompt_embeds = prompt_embeds[:,args.prompt_template_encode_start_idx: -args.prompt_template_encode_end_idx ,:] # Sample a random timestep for each image grad_norm = None with accelerator.accumulate(transformer): # Predict the noise residual optimizer.zero_grad() # logit-normal sigmas = torch.sigmoid(torch.randn((half_batch_size,), device=accelerator.device, dtype=latents.dtype)) if args.use_dynamic_shifting: sigmas = noise_scheduler.time_shift(mu, 1.0, sigmas) sigmas = torch.concat( [sigmas,sigmas], dim=0 ) timesteps = sigmas * 1000.0 sigmas = sigmas.view(-1, 1, 1, 1) noise = torch.randn_like(latents) noise = noise.chunk(2)[0].repeat(2, 1, 1, 1) noisy_latents = (1 - sigmas) * latents + sigmas * noise noisy_latents = noisy_latents.to(weight_dtype) packed_noisy_latents = pack_latents( noisy_latents, batch_size=latents.shape[0], num_channels_latents=latents.shape[1], height=latents.shape[2], width=latents.shape[3], ) guidance = None img_ids = prepare_pos_ids(modality_id=1, type='image', start=(prompt_embeds.shape[1], prompt_embeds.shape[1]), height=latents.shape[2]//2, width=latents.shape[3]//2).to(accelerator.device, dtype=torch.float64) timesteps = ( torch.tensor(timesteps) .expand(noisy_latents.shape[0]) .to(device=accelerator.device) / 1000 ) text_ids = prepare_pos_ids(modality_id=0, type='text', start=(0, 0), num_token=prompt_embeds.shape[1]).to(accelerator.device, torch.float64) with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): model_pred = transformer(packed_noisy_latents, prompt_embeds, timesteps, img_ids, text_ids, guidance, return_dict=False)[0] model_pred = unpack_latents( model_pred, height=latents.shape[2] * 8, width=latents.shape[3] * 8, vae_scale_factor=16, ) target = noise - latents model_losses = torch.mean( ((model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) model_loss_win, model_loss_lose = model_losses.chunk(2) model_diff = model_loss_win - model_loss_lose with torch.no_grad(): ref_pred = ref_model(packed_noisy_latents, prompt_embeds, timesteps, img_ids, text_ids, guidance, return_dict=False)[0] ref_pred = unpack_latents( ref_pred, height=latents.shape[2] * 8, width=latents.shape[3] * 8, vae_scale_factor=16, ) ref_losses = torch.mean( ((ref_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) ref_loss_win, ref_loss_lose = ref_losses.chunk(2) ref_diff = (ref_loss_win - ref_loss_lose) inside_term = -args.beta_dpo * (model_diff - ref_diff) # implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) loss = - F.logsigmoid(inside_term).mean() accelerator.backward(loss) if accelerator.sync_gradients: grad_norm = transformer.get_global_grad_norm() optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() if accelerator.sync_gradients and args.use_ema: model_ema.step(transformer.parameters()) lr = lr_scheduler.get_last_lr()[0] if accelerator.sync_gradients: global_step += 1 bsz, ic, ih, iw = image.shape logs = {"loss": accelerator.gather(loss).mean().item(), 'aspect_ratio': (ih*1.0 / iw)} if grad_norm is not None: logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) log_buffer.update(logs) if (step + 1) % args.log_interval == 0 or (step + 1) == 1: t = (time.time() - last_tic) / args.log_interval t_d = data_time_all / args.log_interval log_buffer.average() info = f"Step={step+1}, Epoch={epoch}, global_step={global_step}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:(ch:{latents.shape[1]},h:{latents.shape[2]},w:{latents.shape[3]}), " info += ', '.join([f"{k}:{v:.4f}" for k,v in log_buffer.output.items()]) logger.info(info) last_tic = time.time() log_buffer.clear() data_time_all = 0 logs.update(lr=lr) accelerator.log(logs, step=global_step) data_time_start = time.time() if global_step != 0 and global_step % args.save_model_steps == 0: save_path = os.path.join( args.work_dir, f'checkpoints-{global_step}') if args.use_ema: model_ema.store(transformer.parameters()) model_ema.copy_to(transformer.parameters()) accelerator.save_state(save_path) if args.use_ema: model_ema.restore(transformer.parameters()) logger.info(f"Saved state to {save_path} (global_step: {global_step})") accelerator.wait_for_everyone() def parse_args(): parser = argparse.ArgumentParser(description="Process some integers.") parser.add_argument("--config", type=str, default='', help="config") parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) args = parser.parse_args() return args if __name__ == '__main__': cur_dir = os.path.dirname(os.path.abspath(__file__)) args = parse_args() if args.config != '' and os.path.exists(args.config): config = yaml.safe_load(open(args.config, 'r')) else: config = yaml.safe_load(open(f'{cur_dir}/train_config.yaml', 'r')) args_dict = vars(args) args_dict.update(config) args = argparse.Namespace(**args_dict) os.umask(0o000) os.makedirs(args.work_dir, exist_ok=True) log_dir = args.work_dir + f'/logs' os.makedirs(log_dir, exist_ok=True) accelerator_project_config = ProjectConfiguration( project_dir=args.work_dir, logging_dir=log_dir) with open(f'{log_dir}/train.yaml', 'w') as f: yaml.dump(args_dict, f) accelerator = Accelerator( mixed_precision=args.mixed_precision, gradient_accumulation_steps=args.gradient_accumulation_steps, log_with=args.report_to, project_config= accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() if args.seed is not None: set_seed(args.seed) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 logger.info(f'using weight_dtype {weight_dtype}!!!') if args.diffusion_pretrain_weight: transformer = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False) ref_model = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False) logger.info(f'successful load model weight for BaseModel and RefModel, {args.diffusion_pretrain_weight}!!!') else: transformer = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False) ref_model = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False) logger.info(f'successful load model weight for BaseModel and RefModel, {args.pretrained_model_name_or_path+"/transformer"}!!!') transformer = transformer.train() ref_model.requires_grad_(False) ref_model.to(accelerator.device, weight_dtype) total_trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad) logger.info(f">>>>>> total_trainable_params: {total_trainable_params}") if args.use_ema: model_ema = EMAModel(transformer.parameters(), decay=args.ema_rate) else: model_ema = None vae_dtype = torch.float32 vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype).cuda().eval() text_encoder = AutoModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder" , torch_dtype=weight_dtype, trust_remote_code=True).cuda().eval() tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True) logger.info("all models loaded successfully") # build models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler") latent_size = int(args.resolution) // 8 mu = calculate_shift( (latent_size//2)**2, noise_scheduler.config.base_image_seq_len, noise_scheduler.config.max_image_seq_len, noise_scheduler.config.base_shift, noise_scheduler.config.max_shift, ) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "transformer")) if len(weights) != 0: weights.pop() def load_model_hook(models, input_dir): while len(models) > 0: # pop models so that they are not loaded again model = models.pop() # load diffusers style into model load_model = MeiGenImageTransformer2DModel.from_pretrained( input_dir, subfolder="transformer") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW params_to_optimize = transformer.parameters() optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) transformer.to(accelerator.device, dtype=weight_dtype) if args.use_ema: model_ema.to(accelerator.device, dtype=weight_dtype) train_dataloader = build_dataloader(args, args.data_txt_root, tokenizer, args.resolution,) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.work_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") args.resume_from_checkpoint = None initial_global_step = 0 else: logger.info(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.work_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) if accelerator.is_main_process: tracker_config = dict(vars(args)) try: accelerator.init_trackers('dpo', tracker_config) except Exception as e: logger.warning(f'get error in save config, {e}') accelerator.init_trackers(f"dpo_{timestamp}") transformer, optimizer, _, _ = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler) log_buffer = LogBuffer() train(global_step=global_step)