Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import time | |
| import warnings | |
| import argparse | |
| import yaml | |
| import torch | |
| import math | |
| import logging | |
| import transformers | |
| import diffusers | |
| from pathlib import Path | |
| from transformers import Qwen2Model, Qwen2TokenizerFast | |
| from accelerate import Accelerator, InitProcessGroupKwargs | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| from accelerate.logging import get_logger | |
| from diffusers.models import AutoencoderKL | |
| from diffusers.optimization import get_scheduler | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.training_utils import EMAModel | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import AutoTokenizer, AutoModel, AutoProcessor | |
| from train_dataset import build_dataloader | |
| from longcat_image.models import LongCatImageTransformer2DModel | |
| from longcat_image.utils import LogBuffer | |
| from longcat_image.utils import pack_latents, unpack_latents, calculate_shift, prepare_pos_ids | |
| warnings.filterwarnings("ignore") # ignore warning | |
| current_file_path = Path(__file__).resolve() | |
| sys.path.insert(0, str(current_file_path.parent.parent)) | |
| logger = get_logger(__name__) | |
| def train(global_step=0): | |
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
| # Train! | |
| total_batch_size = args.train_batch_size * \ | |
| accelerator.num_processes * args.gradient_accumulation_steps | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataloader.dataset)}") | |
| logger.info(f" Num batches each epoch = {len(train_dataloader)}") | |
| logger.info(f" Num Epochs = {args.num_train_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {args.max_train_steps}") | |
| last_tic = time.time() | |
| # Now you train the model | |
| for epoch in range(first_epoch, args.num_train_epochs + 1): | |
| data_time_start = time.time() | |
| data_time_all = 0 | |
| for step, batch in enumerate(train_dataloader): | |
| image = batch['images'] | |
| ref_image = batch['ref_images'] | |
| data_time_all += time.time() - data_time_start | |
| with torch.no_grad(): | |
| latents = vae.encode(image.to(weight_dtype).to(accelerator.device)).latent_dist.sample() | |
| latents = latents.to(dtype=(weight_dtype)) | |
| latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor | |
| ref_latents = vae.encode(ref_image.to(weight_dtype).to(accelerator.device)).latent_dist.sample() | |
| ref_latents = ref_latents.to(dtype=(weight_dtype)) | |
| ref_latents = (ref_latents - vae.config.shift_factor) * vae.config.scaling_factor | |
| text_input_ids = batch['input_ids'].to(accelerator.device) | |
| text_attention_mask = batch['attention_mask'].to(accelerator.device) | |
| pixel_values = batch['pixel_values'].to(accelerator.device) | |
| image_grid_thw = batch['image_grid_thw'].to(accelerator.device) | |
| with torch.no_grad(): | |
| text_output = text_encoder( | |
| input_ids=text_input_ids, | |
| attention_mask=text_attention_mask, | |
| pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| output_hidden_states=True | |
| ) | |
| prompt_embeds = text_output.hidden_states[-1].clone().detach() | |
| prompt_embeds = prompt_embeds.to(weight_dtype) | |
| prompt_embeds = prompt_embeds[:,args.prompt_template_encode_start_idx: -args.prompt_template_encode_end_idx ,:] | |
| # Sample a random timestep for each image | |
| grad_norm = None | |
| with accelerator.accumulate(transformer): | |
| # Predict the noise residual | |
| optimizer.zero_grad() | |
| # logit-normal | |
| sigmas = torch.sigmoid(torch.randn((latents.shape[0],), device=accelerator.device, dtype=latents.dtype)) | |
| if args.use_dynamic_shifting: | |
| sigmas = noise_scheduler.time_shift(mu, 1.0, sigmas) | |
| timesteps = sigmas * 1000.0 | |
| sigmas = sigmas.view(-1, 1, 1, 1) | |
| noise = torch.randn_like(latents) | |
| noisy_latents = (1 - sigmas) * latents + sigmas * noise | |
| noisy_latents = noisy_latents.to(weight_dtype) | |
| packed_noisy_latents = pack_latents( | |
| noisy_latents, | |
| batch_size=latents.shape[0], | |
| num_channels_latents=latents.shape[1], | |
| height=latents.shape[2], | |
| width=latents.shape[3], | |
| ) | |
| packed_ref_latents = pack_latents( | |
| ref_latents, | |
| batch_size=ref_latents.shape[0], | |
| num_channels_latents=ref_latents.shape[1], | |
| height=ref_latents.shape[2], | |
| width=ref_latents.shape[3], | |
| ) | |
| guidance = None | |
| img_ids = prepare_pos_ids(modality_id=1, | |
| type='image', | |
| start=(prompt_embeds.shape[1], prompt_embeds.shape[1]), | |
| height=latents.shape[2]//2, | |
| width=latents.shape[3]//2).to(accelerator.device, dtype=torch.float64) | |
| img_ids_ref = prepare_pos_ids(modality_id=2, | |
| type='image', | |
| start=(prompt_embeds.shape[1], prompt_embeds.shape[1]), | |
| height=ref_latents.shape[2]//2, | |
| width=ref_latents.shape[3]//2).to(accelerator.device, dtype=torch.float64) | |
| timesteps = ( | |
| torch.tensor(timesteps) | |
| .expand(noisy_latents.shape[0]) | |
| .to(device=accelerator.device) | |
| / 1000 | |
| ) | |
| text_ids = prepare_pos_ids(modality_id=0, | |
| type='text', | |
| start=(0, 0), | |
| num_token=prompt_embeds.shape[1]).to(accelerator.device, torch.float64) | |
| img_ids = torch.cat([img_ids, img_ids_ref], dim=0) | |
| latent_model_input = torch.cat([packed_noisy_latents, packed_ref_latents], dim=1) | |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): | |
| model_pred = transformer(latent_model_input, prompt_embeds, timesteps, | |
| img_ids, text_ids, guidance, return_dict=False)[0] | |
| model_pred = model_pred[:, :packed_noisy_latents.size(1)] | |
| model_pred = unpack_latents( | |
| model_pred, | |
| height=latents.shape[2] * 8, | |
| width=latents.shape[3] * 8, | |
| vae_scale_factor=16, | |
| ) | |
| target = noise - latents | |
| loss = torch.mean( | |
| ((model_pred.float() - target.float()) ** 2).reshape( | |
| target.shape[0], -1 | |
| ), | |
| 1, | |
| ).mean() | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| grad_norm = transformer.get_global_grad_norm() | |
| optimizer.step() | |
| if not accelerator.optimizer_step_was_skipped: | |
| lr_scheduler.step() | |
| if accelerator.sync_gradients and args.use_ema: | |
| model_ema.step(transformer.parameters()) | |
| lr = lr_scheduler.get_last_lr()[0] | |
| if accelerator.sync_gradients: | |
| bsz, ic, ih, iw = image.shape | |
| logs = {"loss": accelerator.gather(loss).mean().item(), 'aspect_ratio': (ih*1.0 / iw)} | |
| if grad_norm is not None: | |
| logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) | |
| log_buffer.update(logs) | |
| if (step + 1) % args.log_interval == 0 or (step + 1) == 1: | |
| t = (time.time() - last_tic) / args.log_interval | |
| t_d = data_time_all / args.log_interval | |
| log_buffer.average() | |
| info = f"Step={step+1}, Epoch={epoch}, global_step={global_step}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:(ch:{latents.shape[1]},h:{latents.shape[2]},w:{latents.shape[3]}), " | |
| info += ', '.join([f"{k}:{v:.4f}" for k,v in log_buffer.output.items()]) | |
| logger.info(info) | |
| last_tic = time.time() | |
| log_buffer.clear() | |
| data_time_all = 0 | |
| logs.update(lr=lr) | |
| accelerator.log(logs, step=global_step) | |
| global_step += 1 | |
| data_time_start = time.time() | |
| if global_step != 0 and global_step % args.save_model_steps == 0: | |
| save_path = os.path.join(args.work_dir, f'checkpoints-{global_step}') | |
| if args.use_ema: | |
| model_ema.store(transformer.parameters()) | |
| model_ema.copy_to(transformer.parameters()) | |
| accelerator.save_state(save_path) | |
| if args.use_ema: | |
| model_ema.restore(transformer.parameters()) | |
| logger.info(f"Saved state to {save_path} (global_step: {global_step})") | |
| accelerator.wait_for_everyone() | |
| if global_step >= args.max_train_steps: | |
| break | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Process some integers.") | |
| parser.add_argument("--config", type=str, default='', help="config") | |
| parser.add_argument( | |
| "--report_to", | |
| type=str, | |
| default="tensorboard", | |
| help=( | |
| 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | |
| ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--allow_tf32", | |
| action="store_true", | |
| help=( | |
| "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" | |
| " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
| args = parse_args() | |
| if args.config != '' and os.path.exists(args.config): | |
| config = yaml.safe_load(open(args.config, 'r')) | |
| else: | |
| config = yaml.safe_load(open(f'{cur_dir}/train_config.yaml', 'r')) | |
| args_dict = vars(args) | |
| args_dict.update(config) | |
| args = argparse.Namespace(**args_dict) | |
| os.umask(0o000) | |
| os.makedirs(args.work_dir, exist_ok=True) | |
| log_dir = args.work_dir + f'/logs' | |
| os.makedirs(log_dir, exist_ok=True) | |
| accelerator_project_config = ProjectConfiguration(project_dir=args.work_dir, logging_dir=log_dir) | |
| with open(f'{log_dir}/train.yaml', 'w') as f: | |
| yaml.dump(args_dict, f) | |
| accelerator = Accelerator( | |
| mixed_precision=args.mixed_precision, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| log_with=args.report_to, | |
| project_config= accelerator_project_config, | |
| ) | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| if accelerator.is_local_main_process: | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| logger.info(f'using weight_dtype {weight_dtype}!!!') | |
| if args.diffusion_pretrain_weight: | |
| transformer = LongCatImageTransformer2DModel.from_pretrained(args.diffusion_pretrain_weight, ignore_mismatched_sizes=False) | |
| logger.info(f'successful load model weight {args.diffusion_pretrain_weight}!!!') | |
| else: | |
| transformer = LongCatImageTransformer2DModel.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "transformer"), ignore_mismatched_sizes=False) | |
| logger.info(f'successful load model weight {args.pretrained_model_name_or_path+"/transformer"}!!!') | |
| transformer = transformer.train() | |
| total_trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad) | |
| logger.info(f">>>>>> total_trainable_params: {total_trainable_params}") | |
| if args.use_ema: | |
| model_ema = EMAModel(transformer.parameters(), decay=args.ema_rate) | |
| else: | |
| model_ema = None | |
| vae_dtype = torch.float32 | |
| vae = AutoencoderKL.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype).cuda().eval() | |
| text_encoder = AutoModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder" , torch_dtype=weight_dtype, trust_remote_code=True).cuda().eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True) | |
| text_processor = AutoProcessor.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="tokenizer" , torch_dtype=weight_dtype, trust_remote_code=True) | |
| logger.info("all models loaded successfully") | |
| # build models | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="scheduler") | |
| latent_size = int(args.resolution) // 8 | |
| mu = calculate_shift( | |
| (latent_size//2)**2, | |
| noise_scheduler.config.base_image_seq_len, | |
| noise_scheduler.config.max_image_seq_len, | |
| noise_scheduler.config.base_shift, | |
| noise_scheduler.config.max_shift, | |
| ) | |
| # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
| def save_model_hook(models, weights, output_dir): | |
| if accelerator.is_main_process: | |
| for i, model in enumerate(models): | |
| model.save_pretrained(os.path.join(output_dir, "transformer")) | |
| if len(weights) != 0: | |
| weights.pop() | |
| def load_model_hook(models, input_dir): | |
| while len(models) > 0: | |
| # pop models so that they are not loaded again | |
| model = models.pop() | |
| # load diffusers style into model | |
| load_model = LongCatImageTransformer2DModel.from_pretrained( | |
| input_dir, subfolder="transformer") | |
| model.register_to_config(**load_model.config) | |
| model.load_state_dict(load_model.state_dict()) | |
| del load_model | |
| accelerator.register_save_state_pre_hook(save_model_hook) | |
| accelerator.register_load_state_pre_hook(load_model_hook) | |
| if args.gradient_checkpointing: | |
| transformer.enable_gradient_checkpointing() | |
| # Enable TF32 for faster training on Ampere GPUs, | |
| # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
| if args.allow_tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | |
| if args.use_8bit_adam: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| raise ImportError( | |
| "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | |
| ) | |
| optimizer_class = bnb.optim.AdamW8bit | |
| else: | |
| optimizer_class = torch.optim.AdamW | |
| params_to_optimize = transformer.parameters() | |
| optimizer = optimizer_class( | |
| params_to_optimize, | |
| lr=args.learning_rate, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| weight_decay=args.adam_weight_decay, | |
| eps=args.adam_epsilon, | |
| ) | |
| transformer.to(accelerator.device, dtype=weight_dtype) | |
| if args.use_ema: | |
| model_ema.to(accelerator.device, dtype=weight_dtype) | |
| train_dataloader = build_dataloader(args, args.data_txt_root, tokenizer, text_processor,args.resolution,) | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
| lr_scheduler = get_scheduler( | |
| args.lr_scheduler, | |
| optimizer=optimizer, | |
| num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, | |
| num_training_steps=args.max_train_steps * accelerator.num_processes, | |
| num_cycles=args.lr_num_cycles, | |
| power=args.lr_power, | |
| ) | |
| global_step = 0 | |
| first_epoch = 0 | |
| # Potentially load in the weights and states from a previous save | |
| if args.resume_from_checkpoint: | |
| if args.resume_from_checkpoint != "latest": | |
| path = os.path.basename(args.resume_from_checkpoint) | |
| else: | |
| # Get the most recent checkpoint | |
| dirs = os.listdir(args.work_dir) | |
| dirs = [d for d in dirs if d.startswith("checkpoint")] | |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | |
| path = dirs[-1] if len(dirs) > 0 else None | |
| if path is None: | |
| logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") | |
| args.resume_from_checkpoint = None | |
| initial_global_step = 0 | |
| else: | |
| logger.info(f"Resuming from checkpoint {path}") | |
| accelerator.load_state(os.path.join(args.work_dir, path)) | |
| global_step = int(path.split("-")[1]) | |
| initial_global_step = global_step | |
| first_epoch = global_step // num_update_steps_per_epoch | |
| timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) | |
| if accelerator.is_main_process: | |
| tracker_config = dict(vars(args)) | |
| try: | |
| accelerator.init_trackers('sft', tracker_config) | |
| except Exception as e: | |
| logger.warning(f'get error in save config, {e}') | |
| accelerator.init_trackers(f"sft_{timestamp}") | |
| transformer, optimizer, _, _ = accelerator.prepare( | |
| transformer, optimizer, train_dataloader, lr_scheduler) | |
| log_buffer = LogBuffer() | |
| train(global_step=global_step) | |