|
|
""" |
|
|
Colorize model wrapper replicating the behaviour of the |
|
|
`fffiloni/text-guided-image-colorization` Space. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
ControlNetModel, |
|
|
StableDiffusionXLControlNetPipeline, |
|
|
UNet2DConditionModel, |
|
|
) |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
from safetensors.torch import load_file |
|
|
from transformers import BlipForConditionalGeneration, BlipProcessor |
|
|
|
|
|
from app.config import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _ensure_cache_dir() -> str: |
|
|
"""Ensure we have a writable Hugging Face cache directory.""" |
|
|
data_dir = os.getenv("DATA_DIR") |
|
|
candidate_dirs = [] |
|
|
if data_dir: |
|
|
candidate_dirs.append(os.path.join(data_dir, "hf_cache")) |
|
|
candidate_dirs.extend( |
|
|
[ |
|
|
os.path.join("/tmp", "hf_cache"), |
|
|
os.path.join(os.path.expanduser("~"), ".cache", "huggingface"), |
|
|
] |
|
|
) |
|
|
|
|
|
for path in candidate_dirs: |
|
|
try: |
|
|
os.makedirs(path, exist_ok=True) |
|
|
logger.info("Using HF cache directory: %s", path) |
|
|
os.environ["HF_HOME"] = path |
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = path |
|
|
os.environ["TRANSFORMERS_CACHE"] = path |
|
|
return path |
|
|
except Exception as exc: |
|
|
logger.warning("Failed to create cache dir %s: %s", path, exc) |
|
|
|
|
|
raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.") |
|
|
|
|
|
|
|
|
def _apply_color(luminance_image: Image.Image, color_map: Image.Image) -> Image.Image: |
|
|
"""Merge the L channel of the grayscale control image with AB channels from generated image.""" |
|
|
image_lab = luminance_image.convert("LAB") |
|
|
color_map_lab = color_map.convert("LAB") |
|
|
l_channel, _, _ = image_lab.split() |
|
|
_, a_channel, b_channel = color_map_lab.split() |
|
|
merged = Image.merge("LAB", (l_channel, a_channel, b_channel)) |
|
|
return merged.convert("RGB") |
|
|
|
|
|
|
|
|
def _remove_unlikely_words(prompt: str) -> str: |
|
|
"""Clean up BLIP captions to avoid misleading descriptors.""" |
|
|
unlikely_words = [] |
|
|
|
|
|
decades = [f"{i}s" for i in range(1900, 2000)] |
|
|
years = [f"{i}" for i in range(1900, 2000)] |
|
|
years_with_word = [f"year {i}" for i in range(1900, 2000)] |
|
|
circa_years = [f"circa {i}" for i in range(1900, 2000)] |
|
|
|
|
|
expanded = [ |
|
|
[f"{d[0]} {d[1]} {d[2]} {d[3]} s" for d in decades], |
|
|
[f"{d[0]} {d[1]} {d[2]} {d[3]}" for d in decades], |
|
|
[f"year {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades], |
|
|
[f"circa {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades], |
|
|
] |
|
|
|
|
|
manual_terms = [ |
|
|
"black and white,", "black and white", "black & white,", "black & white", |
|
|
"circa", "monochrome,", "monochrome", "bw", "bw,", "b&w", "b&w,", |
|
|
"grainy", "grainy photo", "grainy photograph", "grainy footage", |
|
|
"black-and-white", "black - and - white", "black on white", |
|
|
"historical photo", "historic photo", "restored", "desaturated", |
|
|
"low contrast", "blurry", "overcast", "taken in", "photo taken in", |
|
|
", photo", ", photo", ", photo", ", photograph", |
|
|
] |
|
|
|
|
|
for seq in expanded: |
|
|
unlikely_words.extend(seq) |
|
|
unlikely_words.extend(decades + years + years_with_word + circa_years + manual_terms) |
|
|
|
|
|
cleaned = prompt |
|
|
for word in unlikely_words: |
|
|
cleaned = cleaned.replace(word, "") |
|
|
return cleaned.strip(" ,") |
|
|
|
|
|
|
|
|
class ColorizeModel: |
|
|
"""Colorization model wrapper.""" |
|
|
|
|
|
CONTROLNET_REPO = "nickpai/sdxl_light_caption_output" |
|
|
CONTROLNET_SUBDIR = os.path.join("checkpoint-30000", "controlnet") |
|
|
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
LIGHTNING_REPO = "ByteDance/SDXL-Lightning" |
|
|
LIGHTNING_WEIGHTS = "sdxl_lightning_8step_unet.safetensors" |
|
|
CAPTION_MODEL = "Salesforce/blip-image-captioning-large" |
|
|
|
|
|
def __init__(self, model_id: str | None = None) -> None: |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info("Using device: %s", self.device) |
|
|
|
|
|
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 |
|
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
|
|
|
|
self.hf_token = ( |
|
|
os.getenv("HF_TOKEN") |
|
|
or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
or None |
|
|
) |
|
|
self.cache_dir = _ensure_cache_dir() |
|
|
|
|
|
self.num_inference_steps = settings.NUM_INFERENCE_STEPS |
|
|
self.guidance_scale = settings.GUIDANCE_SCALE |
|
|
self.controlnet_scale = settings.CONTROLNET_SCALE |
|
|
self.positive_prompt = settings.POSITIVE_PROMPT |
|
|
self.negative_prompt = settings.NEGATIVE_PROMPT |
|
|
self.caption_prefix = settings.CAPTION_PREFIX |
|
|
self.seed = settings.COLORIZE_SEED |
|
|
|
|
|
self.model_id = model_id or settings.MODEL_ID |
|
|
|
|
|
self._load_pipeline() |
|
|
self._load_caption_model() |
|
|
self.last_caption: str | None = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download_controlnet(self) -> str: |
|
|
logger.info("Downloading ControlNet snapshot: %s", self.CONTROLNET_REPO) |
|
|
local_dir = os.path.join(self.cache_dir, "sdxl_light_caption_output") |
|
|
path = snapshot_download( |
|
|
repo_id=self.CONTROLNET_REPO, |
|
|
local_dir=local_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
token=self.hf_token, |
|
|
) |
|
|
controlnet_path = os.path.join(path, self.CONTROLNET_SUBDIR) |
|
|
if not os.path.isdir(controlnet_path): |
|
|
raise RuntimeError(f"ControlNet weights not found at {controlnet_path}") |
|
|
return controlnet_path |
|
|
|
|
|
def _load_pipeline(self) -> None: |
|
|
controlnet_path = self._download_controlnet() |
|
|
|
|
|
logger.info("Loading SDXL components...") |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
self.BASE_MODEL, |
|
|
subfolder="vae", |
|
|
torch_dtype=self.dtype, |
|
|
token=self.hf_token, |
|
|
) |
|
|
unet = UNet2DConditionModel.from_config( |
|
|
self.BASE_MODEL, |
|
|
subfolder="unet", |
|
|
token=self.hf_token, |
|
|
) |
|
|
lightning_path = hf_hub_download( |
|
|
repo_id=self.LIGHTNING_REPO, |
|
|
filename=self.LIGHTNING_WEIGHTS, |
|
|
token=self.hf_token, |
|
|
) |
|
|
unet.load_state_dict(load_file(lightning_path)) |
|
|
|
|
|
controlnet = ControlNetModel.from_pretrained( |
|
|
controlnet_path, |
|
|
torch_dtype=self.dtype, |
|
|
) |
|
|
|
|
|
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
|
self.BASE_MODEL, |
|
|
vae=vae, |
|
|
unet=unet, |
|
|
controlnet=controlnet, |
|
|
torch_dtype=self.dtype, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False, |
|
|
token=self.hf_token, |
|
|
) |
|
|
self.pipe.set_progress_bar_config(disable=True) |
|
|
|
|
|
if self.device.type == "cuda": |
|
|
self.pipe.to(self.device, dtype=self.dtype) |
|
|
if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"): |
|
|
try: |
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
except Exception as exc: |
|
|
logger.warning("Could not enable xformers attention: %s", exc) |
|
|
else: |
|
|
self.pipe.to(self.device, dtype=self.dtype) |
|
|
|
|
|
logger.info("Colorization pipeline ready.") |
|
|
|
|
|
def _load_caption_model(self) -> None: |
|
|
logger.info("Loading BLIP captioning model...") |
|
|
processor = BlipProcessor.from_pretrained(self.CAPTION_MODEL, token=self.hf_token) |
|
|
model = BlipForConditionalGeneration.from_pretrained( |
|
|
self.CAPTION_MODEL, |
|
|
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32, |
|
|
token=self.hf_token, |
|
|
) |
|
|
self.caption_processor = processor |
|
|
self.caption_model = model.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def caption_image(self, image: Image.Image) -> str: |
|
|
"""Generate a cleaned caption for the image.""" |
|
|
inputs = self.caption_processor( |
|
|
image, |
|
|
self.caption_prefix, |
|
|
return_tensors="pt", |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
if self.device.type != "cuda": |
|
|
inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
|
caption_ids = self.caption_model.generate(**inputs) |
|
|
caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True) |
|
|
cleaned_caption = _remove_unlikely_words(caption) |
|
|
return cleaned_caption or caption |
|
|
|
|
|
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: |
|
|
"""Colorize a grayscale image.""" |
|
|
try: |
|
|
original_size = image.size |
|
|
control_image = image.convert("L").convert("RGB").resize( |
|
|
(512, 512), Image.Resampling.LANCZOS |
|
|
) |
|
|
|
|
|
caption = self.caption_image(image) |
|
|
self.last_caption = caption |
|
|
|
|
|
prompt_parts = [caption] |
|
|
if self.positive_prompt: |
|
|
prompt_parts.insert(0, self.positive_prompt) |
|
|
final_prompt = ", ".join([part for part in prompt_parts if part]) |
|
|
|
|
|
negative_prompt = self.negative_prompt or None |
|
|
steps = num_inference_steps or self.num_inference_steps |
|
|
generator = torch.Generator(device=self.device).manual_seed(self.seed) |
|
|
|
|
|
logger.info("Running SDXL pipeline with prompt: %s", final_prompt) |
|
|
result = self.pipe( |
|
|
prompt=final_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
image=control_image, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=self.guidance_scale, |
|
|
controlnet_conditioning_scale=self.controlnet_scale, |
|
|
generator=generator, |
|
|
) |
|
|
|
|
|
generated_image = result.images[0] |
|
|
colorized = _apply_color(control_image, generated_image) |
|
|
if colorized.size != original_size: |
|
|
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
return colorized, caption |
|
|
except Exception as exc: |
|
|
logger.exception("Error during colorization: %s", exc) |
|
|
raise |
|
|
|
|
|
|