Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| import os | |
| import torch | |
| from torch import Tensor | |
| import torch.nn.functional as F | |
| from diffusers.pipelines import FluxPipeline | |
| from diffusers.utils import logging | |
| from diffusers.loaders import TextualInversionLoaderMixin | |
| from diffusers.pipelines.flux.pipeline_flux import FluxLoraLoaderMixin | |
| from diffusers.models.transformers.transformer_flux import ( | |
| USE_PEFT_BACKEND, | |
| scale_lora_layers, | |
| unscale_lora_layers, | |
| logger, | |
| ) | |
| from torchvision.transforms import ToPILImage | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| # from optimum.quanto import ( | |
| # freeze, quantize, QTensor, qfloat8, qint8, qint4, qint2, | |
| # ) | |
| import re | |
| import safetensors | |
| from src.adapters.mod_adapters import CLIPModAdapter | |
| from peft import LoraConfig, set_peft_model_state_dict | |
| from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPVisionModel | |
| def encode_vae_images(pipeline: FluxPipeline, images: Tensor): | |
| images = pipeline.image_processor.preprocess(images) | |
| images = images.to(pipeline.device).to(pipeline.dtype) | |
| images = pipeline.vae.encode(images).latent_dist.sample() | |
| images = ( | |
| images - pipeline.vae.config.shift_factor | |
| ) * pipeline.vae.config.scaling_factor | |
| images_tokens = pipeline._pack_latents(images, *images.shape) | |
| images_ids = pipeline._prepare_latent_image_ids( | |
| images.shape[0], | |
| images.shape[2], | |
| images.shape[3], | |
| pipeline.device, | |
| pipeline.dtype, | |
| ) | |
| if images_tokens.shape[1] != images_ids.shape[0]: | |
| images_ids = pipeline._prepare_latent_image_ids( | |
| images.shape[0], | |
| images.shape[2] // 2, | |
| images.shape[3] // 2, | |
| pipeline.device, | |
| pipeline.dtype, | |
| ) | |
| return images_tokens, images_ids | |
| def decode_vae_images(pipeline: FluxPipeline, latents: Tensor, height, width, output_type: Optional[str] = "pil"): | |
| latents = pipeline._unpack_latents(latents, height, width, pipeline.vae_scale_factor) | |
| latents = (latents / pipeline.vae.config.scaling_factor) + pipeline.vae.config.shift_factor | |
| image = pipeline.vae.decode(latents, return_dict=False)[0] | |
| return pipeline.image_processor.postprocess(image, output_type=output_type) | |
| def _get_clip_prompt_embeds( | |
| self, | |
| prompt: Union[str, List[str]], | |
| num_images_per_prompt: int = 1, | |
| device: Optional[torch.device] = None, | |
| ): | |
| device = device or self._execution_device | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if isinstance(self, TextualInversionLoaderMixin): | |
| prompt = self.maybe_convert_prompt(prompt, self.tokenizer) | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_max_length, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
| # Use pooled output of CLIPTextModel | |
| prompt_embeds = prompt_embeds.pooler_output | |
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) | |
| return prompt_embeds | |
| def encode_prompt_with_clip_t5( | |
| self, | |
| prompt: Union[str, List[str]], | |
| prompt_2: Union[str, List[str]], | |
| device: Optional[torch.device] = None, | |
| num_images_per_prompt: int = 1, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| max_sequence_length: int = 512, | |
| lora_scale: Optional[float] = None, | |
| ): | |
| r""" | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| prompt to be encoded | |
| prompt_2 (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is | |
| used in all text-encoders | |
| device: (`torch.device`): | |
| torch device | |
| num_images_per_prompt (`int`): | |
| number of images that should be generated per prompt | |
| prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| pooled_prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | |
| If not provided, pooled text embeddings will be generated from `prompt` input argument. | |
| lora_scale (`float`, *optional*): | |
| A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
| """ | |
| device = device or self._execution_device | |
| # set lora scale so that monkey patched LoRA | |
| # function of text encoder can correctly access it | |
| if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): | |
| self._lora_scale = lora_scale | |
| # dynamically adjust the LoRA scale | |
| if self.text_encoder is not None and USE_PEFT_BACKEND: | |
| scale_lora_layers(self.text_encoder, lora_scale) | |
| if self.text_encoder_2 is not None and USE_PEFT_BACKEND: | |
| scale_lora_layers(self.text_encoder_2, lora_scale) | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| if prompt_embeds is None: | |
| prompt_2 = prompt_2 or prompt | |
| prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 | |
| # We only use the pooled prompt output from the CLIPTextModel | |
| pooled_prompt_embeds = _get_clip_prompt_embeds( | |
| self=self, | |
| prompt=prompt, | |
| device=device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| ) | |
| if self.text_encoder_2 is not None: | |
| prompt_embeds = self._get_t5_prompt_embeds( | |
| prompt=prompt_2, | |
| num_images_per_prompt=num_images_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| ) | |
| if self.text_encoder is not None: | |
| if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(self.text_encoder, lora_scale) | |
| if self.text_encoder_2 is not None: | |
| if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(self.text_encoder_2, lora_scale) | |
| dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype | |
| if self.text_encoder_2 is not None: | |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | |
| else: | |
| text_ids = None | |
| return prompt_embeds, pooled_prompt_embeds, text_ids | |
| def prepare_text_input( | |
| pipeline: FluxPipeline, | |
| prompts, | |
| max_sequence_length=512, | |
| ): | |
| # Turn off warnings (CLIP overflow) | |
| logger.setLevel(logging.ERROR) | |
| ( | |
| t5_prompt_embeds, | |
| pooled_prompt_embeds, | |
| text_ids, | |
| ) = encode_prompt_with_clip_t5( | |
| self=pipeline, | |
| prompt=prompts, | |
| prompt_2=None, | |
| prompt_embeds=None, | |
| pooled_prompt_embeds=None, | |
| device=pipeline.device, | |
| num_images_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| lora_scale=None, | |
| ) | |
| # Turn on warnings | |
| logger.setLevel(logging.WARNING) | |
| return t5_prompt_embeds, pooled_prompt_embeds, text_ids | |
| def prepare_t5_input( | |
| pipeline: FluxPipeline, | |
| prompts, | |
| max_sequence_length=512, | |
| ): | |
| # Turn off warnings (CLIP overflow) | |
| logger.setLevel(logging.ERROR) | |
| ( | |
| t5_prompt_embeds, | |
| pooled_prompt_embeds, | |
| text_ids, | |
| ) = encode_prompt_with_clip_t5( | |
| self=pipeline, | |
| prompt=prompts, | |
| prompt_2=None, | |
| prompt_embeds=None, | |
| pooled_prompt_embeds=None, | |
| device=pipeline.device, | |
| num_images_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| lora_scale=None, | |
| ) | |
| # Turn on warnings | |
| logger.setLevel(logging.WARNING) | |
| return t5_prompt_embeds, pooled_prompt_embeds, text_ids | |
| def tokenize_t5_prompt(pipe, input_prompt, max_length, **kargs): | |
| return pipe.tokenizer_2( | |
| input_prompt, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| **kargs, | |
| ) | |
| def clear_attn_maps(transformer): | |
| for i, block in enumerate(transformer.transformer_blocks): | |
| if hasattr(block.attn, "attn_maps"): | |
| del block.attn.attn_maps | |
| del block.attn.timestep | |
| for i, block in enumerate(transformer.single_transformer_blocks): | |
| if hasattr(block.attn, "cond2latents"): | |
| del block.attn.cond2latents | |
| def gather_attn_maps(transformer, clear=False): | |
| t2i_attn_maps = {} | |
| i2t_attn_maps = {} | |
| for i, block in enumerate(transformer.transformer_blocks): | |
| name = f"block_{i}" | |
| if hasattr(block.attn, "attn_maps"): | |
| attention_maps = block.attn.attn_maps | |
| timesteps = block.attn.timestep # (B,) | |
| for (timestep, (t2i_attn_map, i2t_attn_map)) in zip(timesteps, attention_maps): | |
| timestep = str(timestep.item()) | |
| t2i_attn_maps[timestep] = t2i_attn_maps.get(timestep, dict()) | |
| t2i_attn_maps[timestep][name] = t2i_attn_maps[timestep].get(name, []) | |
| t2i_attn_maps[timestep][name].append(t2i_attn_map.cpu()) | |
| i2t_attn_maps[timestep] = i2t_attn_maps.get(timestep, dict()) | |
| i2t_attn_maps[timestep][name] = i2t_attn_maps[timestep].get(name, []) | |
| i2t_attn_maps[timestep][name].append(i2t_attn_map.cpu()) | |
| if clear: | |
| del block.attn.attn_maps | |
| for timestep in t2i_attn_maps: | |
| for name in t2i_attn_maps[timestep]: | |
| t2i_attn_maps[timestep][name] = torch.cat(t2i_attn_maps[timestep][name], dim=0) | |
| i2t_attn_maps[timestep][name] = torch.cat(i2t_attn_maps[timestep][name], dim=0) | |
| return t2i_attn_maps, i2t_attn_maps | |
| def process_token(token, startofword): | |
| if '</w>' in token: | |
| token = token.replace('</w>', '') | |
| if startofword: | |
| token = '<' + token + '>' | |
| else: | |
| token = '-' + token + '>' | |
| startofword = True | |
| elif token not in ['<|startoftext|>', '<|endoftext|>']: | |
| if startofword: | |
| token = '<' + token + '-' | |
| startofword = False | |
| else: | |
| token = '-' + token + '-' | |
| return token, startofword | |
| def save_attention_image(attn_map, tokens, batch_dir, to_pil): | |
| startofword = True | |
| for i, (token, a) in enumerate(zip(tokens, attn_map[:len(tokens)])): | |
| token, startofword = process_token(token, startofword) | |
| token = token.replace("/", "-") | |
| if token == '-<pad>-': | |
| continue | |
| a = a.to(torch.float32) | |
| a = a / a.max() * 255 / 256 | |
| to_pil(a).save(os.path.join(batch_dir, f'{i}-{token}.png')) | |
| def save_attention_maps(attn_maps, pipe, prompts, base_dir='attn_maps'): | |
| to_pil = ToPILImage() | |
| token_ids = tokenize_t5_prompt(pipe, prompts, 512).input_ids # (B, 512) | |
| token_ids = [x for x in token_ids] | |
| total_tokens = [pipe.tokenizer_2.convert_ids_to_tokens(token_id) for token_id in token_ids] | |
| os.makedirs(base_dir, exist_ok=True) | |
| total_attn_map_shape = (256, 256) | |
| total_attn_map_number = 0 | |
| # (B, 24, H, W, 512) -> (B, H, W, 512) -> (B, 512, H, W) | |
| print(attn_maps.keys()) | |
| total_attn_map = list(list(attn_maps.values())[0].values())[0].sum(1) | |
| total_attn_map = total_attn_map.permute(0, 3, 1, 2) | |
| total_attn_map = torch.zeros_like(total_attn_map) | |
| total_attn_map = F.interpolate(total_attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
| for timestep, layers in attn_maps.items(): | |
| timestep_dir = os.path.join(base_dir, f'{timestep}') | |
| os.makedirs(timestep_dir, exist_ok=True) | |
| for layer, attn_map in layers.items(): | |
| layer_dir = os.path.join(timestep_dir, f'{layer}') | |
| os.makedirs(layer_dir, exist_ok=True) | |
| attn_map = attn_map.sum(1).squeeze(1).permute(0, 3, 1, 2) | |
| resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
| total_attn_map += resized_attn_map | |
| total_attn_map_number += 1 | |
| for batch, (attn_map, tokens) in enumerate(zip(resized_attn_map, total_tokens)): | |
| save_attention_image(attn_map, tokens, layer_dir, to_pil) | |
| # for batch, (tokens, attn) in enumerate(zip(total_tokens, attn_map)): | |
| # batch_dir = os.path.join(layer_dir, f'batch-{batch}') | |
| # os.makedirs(batch_dir, exist_ok=True) | |
| # save_attention_image(attn, tokens, batch_dir, to_pil) | |
| total_attn_map /= total_attn_map_number | |
| for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)): | |
| batch_dir = os.path.join(base_dir, f'batch-{batch}') | |
| os.makedirs(batch_dir, exist_ok=True) | |
| save_attention_image(attn_map, tokens, batch_dir, to_pil) | |
| def gather_cond2latents(transformer, clear=False): | |
| c2l_attn_maps = {} | |
| # for i, block in enumerate(transformer.transformer_blocks): | |
| for i, block in enumerate(transformer.single_transformer_blocks): | |
| name = f"block_{i}" | |
| if hasattr(block.attn, "cond2latents"): | |
| attention_maps = block.attn.cond2latents | |
| timesteps = block.attn.cond_timesteps # (B,) | |
| for (timestep, c2l_attn_map) in zip(timesteps, attention_maps): | |
| timestep = str(timestep.item()) | |
| c2l_attn_maps[timestep] = c2l_attn_maps.get(timestep, dict()) | |
| c2l_attn_maps[timestep][name] = c2l_attn_maps[timestep].get(name, []) | |
| c2l_attn_maps[timestep][name].append(c2l_attn_map.cpu()) | |
| if clear: | |
| # del block.attn.attn_maps | |
| del block.attn.cond2latents | |
| del block.attn.cond_timesteps | |
| for timestep in c2l_attn_maps: | |
| for name in c2l_attn_maps[timestep]: | |
| c2l_attn_maps[timestep][name] = torch.cat(c2l_attn_maps[timestep][name], dim=0) | |
| return c2l_attn_maps | |
| def save_cond2latent_image(attn_map, batch_dir, to_pil): | |
| for i, a in enumerate(attn_map): # (N, H, W) | |
| a = a.to(torch.float32) | |
| a = a / a.max() * 255 / 256 | |
| to_pil(a).save(os.path.join(batch_dir, f'{i}.png')) | |
| def save_cond2latent(attn_maps, base_dir='attn_maps'): | |
| to_pil = ToPILImage() | |
| os.makedirs(base_dir, exist_ok=True) | |
| total_attn_map_shape = (256, 256) | |
| total_attn_map_number = 0 | |
| # (N, H, W) -> (1, N, H, W) | |
| total_attn_map = list(list(attn_maps.values())[0].values())[0].unsqueeze(0) | |
| total_attn_map = torch.zeros_like(total_attn_map) | |
| total_attn_map = F.interpolate(total_attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
| for timestep, layers in attn_maps.items(): | |
| cur_ts_attn_map = torch.zeros_like(total_attn_map) | |
| cur_ts_attn_map_number = 0 | |
| timestep_dir = os.path.join(base_dir, f'{timestep}') | |
| os.makedirs(timestep_dir, exist_ok=True) | |
| for layer, attn_map in layers.items(): | |
| # layer_dir = os.path.join(timestep_dir, f'{layer}') | |
| # os.makedirs(layer_dir, exist_ok=True) | |
| attn_map = attn_map.unsqueeze(0) # (1, N, H, W) | |
| resized_attn_map = F.interpolate(attn_map, size=total_attn_map_shape, mode='bilinear', align_corners=False) | |
| cur_ts_attn_map += resized_attn_map | |
| cur_ts_attn_map_number += 1 | |
| for batch, attn_map in enumerate(cur_ts_attn_map / cur_ts_attn_map_number): | |
| save_cond2latent_image(attn_map, timestep_dir, to_pil) | |
| total_attn_map += cur_ts_attn_map | |
| total_attn_map_number += cur_ts_attn_map_number | |
| total_attn_map /= total_attn_map_number | |
| for batch, attn_map in enumerate(total_attn_map): | |
| batch_dir = os.path.join(base_dir, f'batch-{batch}') | |
| os.makedirs(batch_dir, exist_ok=True) | |
| save_cond2latent_image(attn_map, batch_dir, to_pil) | |
| def quantization(pipe, qtype): | |
| if qtype != "None" and qtype != "": | |
| if qtype.endswith("quanto"): | |
| if qtype == "int2-quanto": | |
| quant_level = qint2 | |
| elif qtype == "int4-quanto": | |
| quant_level = qint4 | |
| elif qtype == "int8-quanto": | |
| quant_level = qint8 | |
| elif qtype == "fp8-quanto": | |
| quant_level = qfloat8 | |
| else: | |
| raise ValueError(f"Invalid quantisation level: {qtype}") | |
| extra_quanto_args = {} | |
| extra_quanto_args["exclude"] = [ | |
| "*.norm", | |
| "*.norm1", | |
| "*.norm2", | |
| "*.norm2_context", | |
| "proj_out", | |
| "x_embedder", | |
| "norm_out", | |
| "context_embedder", | |
| ] | |
| try: | |
| quantize(pipe.transformer, weights=quant_level, **extra_quanto_args) | |
| quantize(pipe.text_encoder_2, weights=quant_level, **extra_quanto_args) | |
| print("[Quantization] Start freezing") | |
| freeze(pipe.transformer) | |
| freeze(pipe.text_encoder_2) | |
| print("[Quantization] Finished") | |
| except Exception as e: | |
| if "out of memory" in str(e).lower(): | |
| print( | |
| "GPU ran out of memory during quantisation. Use --quantize_via=cpu to use the slower CPU method." | |
| ) | |
| raise e | |
| else: | |
| assert qtype == "fp8-ao" | |
| from torchao.float8 import convert_to_float8_training, Float8LinearConfig | |
| def module_filter_fn(mod: torch.nn.Module, fqn: str): | |
| # don't convert the output module | |
| if fqn == "proj_out": | |
| return False | |
| # don't convert linear modules with weight dimensions not divisible by 16 | |
| if isinstance(mod, torch.nn.Linear): | |
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: | |
| return False | |
| return True | |
| convert_to_float8_training( | |
| pipe.transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) | |
| ) | |
| class CustomFluxPipeline: | |
| def __init__( | |
| self, | |
| config, | |
| device="cuda", | |
| ckpt_root=None, | |
| ckpt_root_condition=None, | |
| torch_dtype=torch.bfloat16, | |
| ): | |
| print("[CustomFluxPipeline] Loading FLUX Pipeline") | |
| if config["model"].get("dit_quant", "None")=="int8-quanto": | |
| self.pipe = FluxPipeline.from_pretrained("diffusers/FLUX.1-dev-torchao-int8", | |
| torch_dtype=torch_dtype, | |
| use_safetensors=False).to(device) | |
| else: | |
| self.pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", | |
| torch_dtype=torch_dtype).to(device) | |
| # self.pipe.enable_sequential_cpu_offload() | |
| self.config = config | |
| self.device = device | |
| self.dtype = torch_dtype | |
| # if config["model"].get("dit_quant", "None") != "None": | |
| # quantization(self.pipe, config["model"]["dit_quant"]) | |
| self.modulation_adapters = [] | |
| self.pipe.modulation_adapters = [] | |
| try: | |
| if config["model"]["modulation"]["use_clip"]: | |
| load_clip(self, config, torch_dtype, device, None, is_training=False) | |
| except Exception as e: | |
| print(e) | |
| if config["model"]["use_dit_lora"] or config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"]: | |
| if ckpt_root_condition is None and (config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"]): | |
| ckpt_root_condition = ckpt_root | |
| load_dit_lora(self, self.pipe, config, torch_dtype, device, f"{ckpt_root}", f"{ckpt_root_condition}", is_training=False) | |
| def add_modulation_adapter(self, modulation_adapter): | |
| self.modulation_adapters.append(modulation_adapter) | |
| self.pipe.modulation_adapters.append(modulation_adapter) | |
| def clear_modulation_adapters(self): | |
| self.modulation_adapters = [] | |
| self.pipe.modulation_adapters = [] | |
| torch.cuda.empty_cache() | |
| def load_clip(self, config, torch_dtype, device, ckpt_dir=None, is_training=False): | |
| model_path = os.getenv("CLIP_MODEL_PATH", "openai/clip-vit-large-patch14") | |
| clip_model = CLIPVisionModelWithProjection.from_pretrained(model_path).to(device, dtype=torch_dtype) | |
| clip_processor = CLIPProcessor.from_pretrained(model_path) | |
| self.pipe.clip_model = clip_model | |
| self.pipe.clip_processor = clip_processor | |
| def load_dit_lora(self, pipe, config, torch_dtype, device, ckpt_dir=None, condition_ckpt_dir=None, is_training=False): | |
| if not config["model"]["use_condition_dblock_lora"] and not config["model"]["use_condition_sblock_lora"] and not config["model"]["use_dit_lora"]: | |
| print("[load_dit_lora] no dit lora, no condition lora") | |
| return [] | |
| adapter_names = ["default", "condition"] | |
| if condition_ckpt_dir is None: | |
| condition_ckpt_dir = ckpt_dir | |
| if not config["model"]["use_condition_dblock_lora"] and not config["model"]["use_condition_sblock_lora"]: | |
| print("[load_dit_lora] no condition lora") | |
| adapter_names.pop(1) | |
| elif condition_ckpt_dir is not None and os.path.exists(os.path.join(condition_ckpt_dir, "pytorch_lora_weights_condition.safetensors")): | |
| assert "condition" in adapter_names | |
| print(f"[load_dit_lora] load condition lora from {condition_ckpt_dir}") | |
| pipe.transformer.load_lora_adapter(condition_ckpt_dir, use_safetensors=True, adapter_name="condition", weight_name="pytorch_lora_weights_condition.safetensors") # TODO: check if they are trainable | |
| else: | |
| assert is_training | |
| assert "condition" in adapter_names | |
| print("[load_dit_lora] init new condition lora") | |
| pipe.transformer.add_adapter(LoraConfig(**config["model"]["condition_lora_config"]), adapter_name="condition") | |
| if not config["model"]["use_dit_lora"]: | |
| print("[load_dit_lora] no dit lora") | |
| adapter_names.pop(0) | |
| elif ckpt_dir is not None and os.path.exists(os.path.join(ckpt_dir, "pytorch_lora_weights.safetensors")): | |
| assert "default" in adapter_names | |
| print(f"[load_dit_lora] load dit lora from {ckpt_dir}") | |
| lora_file = os.path.join(ckpt_dir, "pytorch_lora_weights.safetensors") | |
| lora_state_dict = safetensors.torch.load_file(lora_file, device="cpu") | |
| single_lora_pattern = "(.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)" | |
| latent_lora_pattern = "(.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2)" | |
| use_pretrained_dit_single_lora = config["model"].get("use_pretrained_dit_single_lora", True) | |
| use_pretrained_dit_latent_lora = config["model"].get("use_pretrained_dit_latent_lora", True) | |
| if not use_pretrained_dit_single_lora or not use_pretrained_dit_latent_lora: | |
| lora_state_dict_keys = list(lora_state_dict.keys()) | |
| for layer_name in lora_state_dict_keys: | |
| if not use_pretrained_dit_single_lora: | |
| if re.search(single_lora_pattern, layer_name): | |
| del lora_state_dict[layer_name] | |
| if not use_pretrained_dit_latent_lora: | |
| if re.search(latent_lora_pattern, layer_name): | |
| del lora_state_dict[layer_name] | |
| pipe.transformer.add_adapter(LoraConfig(**config["model"]["dit_lora_config"]), adapter_name="default") | |
| set_peft_model_state_dict(pipe.transformer, lora_state_dict, adapter_name="default") | |
| else: | |
| pipe.transformer.load_lora_adapter(ckpt_dir, use_safetensors=True, adapter_name="default", weight_name="pytorch_lora_weights.safetensors") # TODO: check if they are trainable | |
| else: | |
| assert is_training | |
| assert "default" in adapter_names | |
| print("[load_dit_lora] init new dit lora") | |
| pipe.transformer.add_adapter(LoraConfig(**config["model"]["dit_lora_config"]), adapter_name="default") | |
| assert len(adapter_names) <= 2 and len(adapter_names) > 0 | |
| for name, module in pipe.transformer.named_modules(): | |
| if isinstance(module, BaseTunerLayer): | |
| module.set_adapter(adapter_names) | |
| if "default" in adapter_names: assert config["model"]["use_dit_lora"] | |
| if "condition" in adapter_names: assert config["model"]["use_condition_dblock_lora"] or config["model"]["use_condition_sblock_lora"] | |
| lora_layers = list(filter( | |
| lambda p: p[1].requires_grad, pipe.transformer.named_parameters() | |
| )) | |
| lora_layers = [l[1] for l in lora_layers] | |
| return lora_layers | |
| def load_modulation_adapter(self, config, torch_dtype, device, ckpt_dir=None, is_training=False): | |
| adapter_type = config["model"]["modulation"]["adapter_type"] | |
| if ckpt_dir is not None and os.path.exists(ckpt_dir): | |
| print(f"loading modulation adapter from {ckpt_dir}") | |
| modulation_adapter = CLIPModAdapter.from_pretrained( | |
| ckpt_dir, subfolder="modulation_adapter", strict=False, | |
| low_cpu_mem_usage=False, device_map=None, | |
| ).to(device) | |
| else: | |
| print(f"Init new modulation adapter") | |
| adapter_layers = config["model"]["modulation"]["adapter_layers"] | |
| adapter_width = config["model"]["modulation"]["adapter_width"] | |
| pblock_adapter_layers = config["model"]["modulation"]["per_block_adapter_layers"] | |
| pblock_adapter_width = config["model"]["modulation"]["per_block_adapter_width"] | |
| pblock_adapter_single_blocks = config["model"]["modulation"]["per_block_adapter_single_blocks"] | |
| use_text_mod = config["model"]["modulation"]["use_text_mod"] | |
| use_img_mod = config["model"]["modulation"]["use_img_mod"] | |
| out_dim = config["model"]["modulation"]["out_dim"] | |
| if adapter_type == "clip_adapter": | |
| modulation_adapter = CLIPModAdapter( | |
| out_dim=out_dim, | |
| width=adapter_width, | |
| pblock_width=pblock_adapter_width, | |
| layers=adapter_layers, | |
| pblock_layers=pblock_adapter_layers, | |
| heads=8, | |
| input_text_dim=4096, | |
| input_image_dim=1024, | |
| pblock_single_blocks=pblock_adapter_single_blocks, | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| if is_training: | |
| modulation_adapter.train() | |
| try: | |
| modulation_adapter.enable_gradient_checkpointing() | |
| except Exception as e: | |
| print(e) | |
| if not config["model"]["modulation"]["use_perblock_adapter"]: | |
| try: | |
| modulation_adapter.net2.requires_grad_(False) | |
| except Exception as e: | |
| print(e) | |
| else: | |
| modulation_adapter.requires_grad_(False) | |
| modulation_adapter.to(device, dtype=torch_dtype) | |
| return modulation_adapter | |
| def load_ckpt(self, ckpt_dir, is_training=False): | |
| if self.config["model"]["use_dit_lora"]: | |
| self.pipe.transformer.delete_adapters(["subject"]) | |
| lora_path = f"{ckpt_dir}/pytorch_lora_weights.safetensors" | |
| print(f"Loading DIT Lora from {lora_path}") | |
| self.pipe.load_lora_weights(lora_path, adapter_name="subject") | |