LogicGoInfotechSpaces's picture
Align pipeline with text-guided colorization Space
8f6f449
raw
history blame
10.7 kB
"""
Colorize model wrapper replicating the behaviour of the
`fffiloni/text-guided-image-colorization` Space.
"""
from __future__ import annotations
import logging
import os
from typing import Tuple
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import BlipForConditionalGeneration, BlipProcessor
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
"""Ensure we have a writable Hugging Face cache directory."""
data_dir = os.getenv("DATA_DIR")
candidate_dirs = []
if data_dir:
candidate_dirs.append(os.path.join(data_dir, "hf_cache"))
candidate_dirs.extend(
[
os.path.join("/tmp", "hf_cache"),
os.path.join(os.path.expanduser("~"), ".cache", "huggingface"),
]
)
for path in candidate_dirs:
try:
os.makedirs(path, exist_ok=True)
logger.info("Using HF cache directory: %s", path)
os.environ["HF_HOME"] = path
os.environ["HUGGINGFACE_HUB_CACHE"] = path
os.environ["TRANSFORMERS_CACHE"] = path
return path
except Exception as exc: # pragma: no cover - best effort
logger.warning("Failed to create cache dir %s: %s", path, exc)
raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.")
def _apply_color(luminance_image: Image.Image, color_map: Image.Image) -> Image.Image:
"""Merge the L channel of the grayscale control image with AB channels from generated image."""
image_lab = luminance_image.convert("LAB")
color_map_lab = color_map.convert("LAB")
l_channel, _, _ = image_lab.split()
_, a_channel, b_channel = color_map_lab.split()
merged = Image.merge("LAB", (l_channel, a_channel, b_channel))
return merged.convert("RGB")
def _remove_unlikely_words(prompt: str) -> str:
"""Clean up BLIP captions to avoid misleading descriptors."""
unlikely_words = []
decades = [f"{i}s" for i in range(1900, 2000)]
years = [f"{i}" for i in range(1900, 2000)]
years_with_word = [f"year {i}" for i in range(1900, 2000)]
circa_years = [f"circa {i}" for i in range(1900, 2000)]
expanded = [
[f"{d[0]} {d[1]} {d[2]} {d[3]} s" for d in decades],
[f"{d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
[f"year {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
[f"circa {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
]
manual_terms = [
"black and white,", "black and white", "black & white,", "black & white",
"circa", "monochrome,", "monochrome", "bw", "bw,", "b&w", "b&w,",
"grainy", "grainy photo", "grainy photograph", "grainy footage",
"black-and-white", "black - and - white", "black on white",
"historical photo", "historic photo", "restored", "desaturated",
"low contrast", "blurry", "overcast", "taken in", "photo taken in",
", photo", ", photo", ", photo", ", photograph",
]
for seq in expanded:
unlikely_words.extend(seq)
unlikely_words.extend(decades + years + years_with_word + circa_years + manual_terms)
cleaned = prompt
for word in unlikely_words:
cleaned = cleaned.replace(word, "")
return cleaned.strip(" ,")
class ColorizeModel:
"""Colorization model wrapper."""
CONTROLNET_REPO = "nickpai/sdxl_light_caption_output"
CONTROLNET_SUBDIR = os.path.join("checkpoint-30000", "controlnet")
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
LIGHTNING_REPO = "ByteDance/SDXL-Lightning"
LIGHTNING_WEIGHTS = "sdxl_lightning_8step_unet.safetensors"
CAPTION_MODEL = "Salesforce/blip-image-captioning-large"
def __init__(self, model_id: str | None = None) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Using device: %s", self.device)
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
os.environ.setdefault("OMP_NUM_THREADS", "1")
self.hf_token = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACE_HUB_TOKEN")
or None
)
self.cache_dir = _ensure_cache_dir()
self.num_inference_steps = settings.NUM_INFERENCE_STEPS
self.guidance_scale = settings.GUIDANCE_SCALE
self.controlnet_scale = settings.CONTROLNET_SCALE
self.positive_prompt = settings.POSITIVE_PROMPT
self.negative_prompt = settings.NEGATIVE_PROMPT
self.caption_prefix = settings.CAPTION_PREFIX
self.seed = settings.COLORIZE_SEED
self.model_id = model_id or settings.MODEL_ID
self._load_pipeline()
self._load_caption_model()
self.last_caption: str | None = None
# --------------------------------------------------------------------- #
# Initialisation helpers
# --------------------------------------------------------------------- #
def _download_controlnet(self) -> str:
logger.info("Downloading ControlNet snapshot: %s", self.CONTROLNET_REPO)
local_dir = os.path.join(self.cache_dir, "sdxl_light_caption_output")
path = snapshot_download(
repo_id=self.CONTROLNET_REPO,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=self.hf_token,
)
controlnet_path = os.path.join(path, self.CONTROLNET_SUBDIR)
if not os.path.isdir(controlnet_path):
raise RuntimeError(f"ControlNet weights not found at {controlnet_path}")
return controlnet_path
def _load_pipeline(self) -> None:
controlnet_path = self._download_controlnet()
logger.info("Loading SDXL components...")
vae = AutoencoderKL.from_pretrained(
self.BASE_MODEL,
subfolder="vae",
torch_dtype=self.dtype,
token=self.hf_token,
)
unet = UNet2DConditionModel.from_config(
self.BASE_MODEL,
subfolder="unet",
token=self.hf_token,
)
lightning_path = hf_hub_download(
repo_id=self.LIGHTNING_REPO,
filename=self.LIGHTNING_WEIGHTS,
token=self.hf_token,
)
unet.load_state_dict(load_file(lightning_path))
controlnet = ControlNetModel.from_pretrained(
controlnet_path,
torch_dtype=self.dtype,
)
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
self.BASE_MODEL,
vae=vae,
unet=unet,
controlnet=controlnet,
torch_dtype=self.dtype,
safety_checker=None,
requires_safety_checker=False,
token=self.hf_token,
)
self.pipe.set_progress_bar_config(disable=True)
if self.device.type == "cuda":
self.pipe.to(self.device, dtype=self.dtype)
if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
try:
self.pipe.enable_xformers_memory_efficient_attention()
except Exception as exc: # pragma: no cover
logger.warning("Could not enable xformers attention: %s", exc)
else:
self.pipe.to(self.device, dtype=self.dtype)
logger.info("Colorization pipeline ready.")
def _load_caption_model(self) -> None:
logger.info("Loading BLIP captioning model...")
processor = BlipProcessor.from_pretrained(self.CAPTION_MODEL, token=self.hf_token)
model = BlipForConditionalGeneration.from_pretrained(
self.CAPTION_MODEL,
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
token=self.hf_token,
)
self.caption_processor = processor
self.caption_model = model.to(self.device)
# --------------------------------------------------------------------- #
# Public API
# --------------------------------------------------------------------- #
def caption_image(self, image: Image.Image) -> str:
"""Generate a cleaned caption for the image."""
inputs = self.caption_processor(
image,
self.caption_prefix,
return_tensors="pt",
).to(self.device)
# BLIP on CPU expects float32 inputs
if self.device.type != "cuda":
inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with torch.inference_mode():
caption_ids = self.caption_model.generate(**inputs)
caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
cleaned_caption = _remove_unlikely_words(caption)
return cleaned_caption or caption
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
"""Colorize a grayscale image."""
try:
original_size = image.size
control_image = image.convert("L").convert("RGB").resize(
(512, 512), Image.Resampling.LANCZOS
)
caption = self.caption_image(image)
self.last_caption = caption
prompt_parts = [caption]
if self.positive_prompt:
prompt_parts.insert(0, self.positive_prompt)
final_prompt = ", ".join([part for part in prompt_parts if part])
negative_prompt = self.negative_prompt or None
steps = num_inference_steps or self.num_inference_steps
generator = torch.Generator(device=self.device).manual_seed(self.seed)
logger.info("Running SDXL pipeline with prompt: %s", final_prompt)
result = self.pipe(
prompt=final_prompt,
negative_prompt=negative_prompt,
image=control_image,
num_inference_steps=steps,
guidance_scale=self.guidance_scale,
controlnet_conditioning_scale=self.controlnet_scale,
generator=generator,
)
generated_image = result.images[0]
colorized = _apply_color(control_image, generated_image)
if colorized.size != original_size:
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
return colorized, caption
except Exception as exc:
logger.exception("Error during colorization: %s", exc)
raise