""" Colorize model wrapper replicating the behaviour of the `fffiloni/text-guided-image-colorization` Space. """ from __future__ import annotations import logging import os from typing import Tuple import torch from PIL import Image from diffusers import ( AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline, UNet2DConditionModel, ) from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import BlipForConditionalGeneration, BlipProcessor from app.config import settings logger = logging.getLogger(__name__) def _ensure_cache_dir() -> str: cache_dir = os.environ.get("HF_HOME") or "/tmp/hf_cache" try: os.makedirs(cache_dir, exist_ok=True) except Exception as exc: # pragma: no cover logger.warning("Could not create cache directory %s: %s", cache_dir, exc) os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir os.environ["HF_HUB_CACHE"] = cache_dir return cache_dir def _apply_lab_merge(original_luminance: Image.Image, color_map: Image.Image) -> Image.Image: base_lab = original_luminance.convert("LAB") color_lab = color_map.convert("LAB") l_channel, _, _ = base_lab.split() _, a_channel, b_channel = color_lab.split() merged = Image.merge("LAB", (l_channel, a_channel, b_channel)) return merged.convert("RGB") def _clean_caption(prompt: str) -> str: remove_terms = [ "black and white", "black & white", "monochrome", "bw photo", "historical", "restored", "low contrast", "desaturated", "overcast", ] cleaned = prompt for term in remove_terms: cleaned = cleaned.replace(term, "") return cleaned.strip(" ,") class ColorizeModel: """Colorization model that runs the SDXL + ControlNet pipeline locally.""" def __init__(self, model_id: str | None = None) -> None: self.cache_dir = _ensure_cache_dir() self.hf_token = ( os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HUGGINGFACE_API_TOKEN") ) if not self.hf_token: logger.warning("HF token not provided – attempting to download public models only.") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 os.environ.setdefault("OMP_NUM_THREADS", "1") self.controlnet_id = model_id or settings.MODEL_ID self.base_model_id = settings.BASE_MODEL_ID self.lightning_repo = settings.LIGHTNING_REPO self.lightning_weights = settings.LIGHTNING_WEIGHTS self.caption_model_id = settings.CAPTION_MODEL_ID self.num_inference_steps = settings.NUM_INFERENCE_STEPS self.guidance_scale = settings.GUIDANCE_SCALE self.controlnet_scale = settings.CONTROLNET_SCALE self.positive_prompt = settings.POSITIVE_PROMPT self.negative_prompt = settings.NEGATIVE_PROMPT self.caption_prefix = settings.CAPTION_PREFIX self.seed = settings.COLORIZE_SEED self._load_caption_model() self._load_pipeline() def _load_caption_model(self) -> None: logger.info("Loading BLIP captioning model: %s", self.caption_model_id) self.caption_processor = BlipProcessor.from_pretrained( self.caption_model_id, cache_dir=self.cache_dir, token=self.hf_token, ) self.caption_model = BlipForConditionalGeneration.from_pretrained( self.caption_model_id, cache_dir=self.cache_dir, token=self.hf_token, torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32, ).to(self.device) def _load_pipeline(self) -> None: logger.info("Loading ControlNet model: %s", self.controlnet_id) controlnet = ControlNetModel.from_pretrained( self.controlnet_id, torch_dtype=self.dtype, cache_dir=self.cache_dir, token=self.hf_token, ) logger.info("Loading SDXL base model components: %s", self.base_model_id) vae = AutoencoderKL.from_pretrained( self.base_model_id, subfolder="vae", torch_dtype=self.dtype, cache_dir=self.cache_dir, token=self.hf_token, ) unet = UNet2DConditionModel.from_config( self.base_model_id, subfolder="unet", cache_dir=self.cache_dir, token=self.hf_token, ) lightning_path = hf_hub_download( repo_id=self.lightning_repo, filename=self.lightning_weights, cache_dir=self.cache_dir, token=self.hf_token, ) unet.load_state_dict(load_file(lightning_path)) self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained( self.base_model_id, vae=vae, unet=unet, controlnet=controlnet, torch_dtype=self.dtype, cache_dir=self.cache_dir, token=self.hf_token, safety_checker=None, requires_safety_checker=False, ) self.pipe.set_progress_bar_config(disable=True) self.pipe.to(self.device, dtype=self.dtype) if self.device.type == "cuda" and hasattr(self.pipe, "enable_xformers_memory_efficient_attention"): try: self.pipe.enable_xformers_memory_efficient_attention() except Exception as exc: # pragma: no cover logger.warning("Could not enable xFormers optimizations: %s", exc) logger.info("Colorization pipeline ready.") def caption_image(self, image: Image.Image) -> str: inputs = self.caption_processor( image, self.caption_prefix, return_tensors="pt", ).to(self.device) if self.device.type != "cuda": inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} with torch.inference_mode(): caption_ids = self.caption_model.generate(**inputs) caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True) return _clean_caption(caption) def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: original_size = image.size control_image = image.convert("L").convert("RGB").resize((512, 512), Image.Resampling.LANCZOS) caption = self.caption_image(image) prompt_components = [self.positive_prompt, caption] prompt = ", ".join([p for p in prompt_components if p]) steps = num_inference_steps or self.num_inference_steps generator = torch.Generator(device=self.device).manual_seed(self.seed) logger.info("Running ControlNet pipeline with prompt: %s", prompt) result = self.pipe( prompt=prompt, negative_prompt=self.negative_prompt or None, image=control_image, control_image=control_image, num_inference_steps=steps, guidance_scale=self.guidance_scale, controlnet_conditioning_scale=self.controlnet_scale, generator=generator, ) generated = result.images[0] colorized = _apply_lab_merge(control_image, generated) if colorized.size != original_size: colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) return colorized, caption