""" Colorize model wrapper replicating the behaviour of the `fffiloni/text-guided-image-colorization` Space. """ from __future__ import annotations import logging import os from typing import Tuple import torch from PIL import Image from diffusers import ( AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline, UNet2DConditionModel, ) from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import load_file from transformers import BlipForConditionalGeneration, BlipProcessor from app.config import settings logger = logging.getLogger(__name__) def _ensure_cache_dir() -> str: """Ensure we have a writable Hugging Face cache directory.""" data_dir = os.getenv("DATA_DIR") candidate_dirs = [] if data_dir: candidate_dirs.append(os.path.join(data_dir, "hf_cache")) candidate_dirs.extend( [ os.path.join("/tmp", "hf_cache"), os.path.join(os.path.expanduser("~"), ".cache", "huggingface"), ] ) for path in candidate_dirs: try: os.makedirs(path, exist_ok=True) logger.info("Using HF cache directory: %s", path) os.environ["HF_HOME"] = path os.environ["HUGGINGFACE_HUB_CACHE"] = path os.environ["TRANSFORMERS_CACHE"] = path return path except Exception as exc: # pragma: no cover - best effort logger.warning("Failed to create cache dir %s: %s", path, exc) raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.") def _apply_color(luminance_image: Image.Image, color_map: Image.Image) -> Image.Image: """Merge the L channel of the grayscale control image with AB channels from generated image.""" image_lab = luminance_image.convert("LAB") color_map_lab = color_map.convert("LAB") l_channel, _, _ = image_lab.split() _, a_channel, b_channel = color_map_lab.split() merged = Image.merge("LAB", (l_channel, a_channel, b_channel)) return merged.convert("RGB") def _remove_unlikely_words(prompt: str) -> str: """Clean up BLIP captions to avoid misleading descriptors.""" unlikely_words = [] decades = [f"{i}s" for i in range(1900, 2000)] years = [f"{i}" for i in range(1900, 2000)] years_with_word = [f"year {i}" for i in range(1900, 2000)] circa_years = [f"circa {i}" for i in range(1900, 2000)] expanded = [ [f"{d[0]} {d[1]} {d[2]} {d[3]} s" for d in decades], [f"{d[0]} {d[1]} {d[2]} {d[3]}" for d in decades], [f"year {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades], [f"circa {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades], ] manual_terms = [ "black and white,", "black and white", "black & white,", "black & white", "circa", "monochrome,", "monochrome", "bw", "bw,", "b&w", "b&w,", "grainy", "grainy photo", "grainy photograph", "grainy footage", "black-and-white", "black - and - white", "black on white", "historical photo", "historic photo", "restored", "desaturated", "low contrast", "blurry", "overcast", "taken in", "photo taken in", ", photo", ", photo", ", photo", ", photograph", ] for seq in expanded: unlikely_words.extend(seq) unlikely_words.extend(decades + years + years_with_word + circa_years + manual_terms) cleaned = prompt for word in unlikely_words: cleaned = cleaned.replace(word, "") return cleaned.strip(" ,") class ColorizeModel: """Colorization model wrapper.""" CONTROLNET_REPO = "nickpai/sdxl_light_caption_output" CONTROLNET_SUBDIR = os.path.join("checkpoint-30000", "controlnet") BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" LIGHTNING_REPO = "ByteDance/SDXL-Lightning" LIGHTNING_WEIGHTS = "sdxl_lightning_8step_unet.safetensors" CAPTION_MODEL = "Salesforce/blip-image-captioning-large" def __init__(self, model_id: str | None = None) -> None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info("Using device: %s", self.device) self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 os.environ.setdefault("OMP_NUM_THREADS", "1") self.hf_token = ( os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None ) self.cache_dir = _ensure_cache_dir() self.num_inference_steps = settings.NUM_INFERENCE_STEPS self.guidance_scale = settings.GUIDANCE_SCALE self.controlnet_scale = settings.CONTROLNET_SCALE self.positive_prompt = settings.POSITIVE_PROMPT self.negative_prompt = settings.NEGATIVE_PROMPT self.caption_prefix = settings.CAPTION_PREFIX self.seed = settings.COLORIZE_SEED self.model_id = model_id or settings.MODEL_ID self._load_pipeline() self._load_caption_model() self.last_caption: str | None = None # --------------------------------------------------------------------- # # Initialisation helpers # --------------------------------------------------------------------- # def _download_controlnet(self) -> str: logger.info("Downloading ControlNet snapshot: %s", self.CONTROLNET_REPO) local_dir = os.path.join(self.cache_dir, "sdxl_light_caption_output") path = snapshot_download( repo_id=self.CONTROLNET_REPO, local_dir=local_dir, local_dir_use_symlinks=False, token=self.hf_token, ) controlnet_path = os.path.join(path, self.CONTROLNET_SUBDIR) if not os.path.isdir(controlnet_path): raise RuntimeError(f"ControlNet weights not found at {controlnet_path}") return controlnet_path def _load_pipeline(self) -> None: controlnet_path = self._download_controlnet() base_kwargs = {"use_auth_token": self.hf_token} if self.hf_token else {} logger.info("Loading SDXL components...") vae = AutoencoderKL.from_pretrained(self.BASE_MODEL, subfolder="vae", torch_dtype=self.dtype, token=self.hf_token) unet = UNet2DConditionModel.from_config( self.BASE_MODEL, subfolder="unet", token=self.hf_token if self.hf_token else None, ) lightning_path = hf_hub_download( repo_id=self.LIGHTNING_REPO, filename=self.LIGHTNING_WEIGHTS, token=self.hf_token if self.hf_token else None, ) unet.load_state_dict(load_file(lightning_path)) controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=self.dtype) try: self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained( self.BASE_MODEL, vae=vae, unet=unet, controlnet=controlnet, torch_dtype=self.dtype, safety_checker=None, requires_safety_checker=False, token=self.hf_token if self.hf_token else None, ) except Exception as exc: logger.error("Failed to load base SDXL model: %s", exc) logger.error( "Ensure the account associated with HUGGINGFACE_HUB_TOKEN has accepted " "the license for %s and that the token has access.", self.BASE_MODEL ) raise self.pipe.set_progress_bar_config(disable=True) if self.device.type == "cuda": self.pipe.to(self.device, dtype=self.dtype) if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"): try: self.pipe.enable_xformers_memory_efficient_attention() except Exception as exc: # pragma: no cover logger.warning("Could not enable xformers attention: %s", exc) else: self.pipe.to(self.device, dtype=self.dtype) logger.info("Colorization pipeline ready.") def _load_caption_model(self) -> None: logger.info("Loading BLIP captioning model...") processor = BlipProcessor.from_pretrained(self.CAPTION_MODEL, token=self.hf_token) model = BlipForConditionalGeneration.from_pretrained( self.CAPTION_MODEL, torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32, token=self.hf_token, ) self.caption_processor = processor self.caption_model = model.to(self.device) # --------------------------------------------------------------------- # # Public API # --------------------------------------------------------------------- # def caption_image(self, image: Image.Image) -> str: """Generate a cleaned caption for the image.""" inputs = self.caption_processor( image, self.caption_prefix, return_tensors="pt", ).to(self.device) # BLIP on CPU expects float32 inputs if self.device.type != "cuda": inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} with torch.inference_mode(): caption_ids = self.caption_model.generate(**inputs) caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True) cleaned_caption = _remove_unlikely_words(caption) return cleaned_caption or caption def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: """Colorize a grayscale image.""" try: original_size = image.size control_image = image.convert("L").convert("RGB").resize( (512, 512), Image.Resampling.LANCZOS ) caption = self.caption_image(image) self.last_caption = caption prompt_parts = [caption] if self.positive_prompt: prompt_parts.insert(0, self.positive_prompt) final_prompt = ", ".join([part for part in prompt_parts if part]) negative_prompt = self.negative_prompt or None steps = num_inference_steps or self.num_inference_steps generator = torch.Generator(device=self.device).manual_seed(self.seed) logger.info("Running SDXL pipeline with prompt: %s", final_prompt) result = self.pipe( prompt=final_prompt, negative_prompt=negative_prompt, image=control_image, num_inference_steps=steps, guidance_scale=self.guidance_scale, controlnet_conditioning_scale=self.controlnet_scale, generator=generator, ) generated_image = result.images[0] colorized = _apply_color(control_image, generated_image) if colorized.size != original_size: colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) return colorized, caption except Exception as exc: logger.exception("Error during colorization: %s", exc) raise