""" Colorize model wrapper that forwards requests to the Hugging Face Inference API. """ from __future__ import annotations import io import logging import os from typing import Tuple import requests import torch from PIL import Image from transformers import BlipForConditionalGeneration, BlipProcessor from app.config import settings logger = logging.getLogger(__name__) def _ensure_cache_dir() -> str: """Ensure we have a writable Hugging Face cache directory.""" data_dir = os.getenv("DATA_DIR") candidates = [] if data_dir: candidates.append(os.path.join(data_dir, "hf_cache")) candidates.extend( [ os.path.join("/tmp", "hf_cache"), os.path.join(os.path.expanduser("~"), ".cache", "huggingface"), ] ) for path in candidates: try: os.makedirs(path, exist_ok=True) logger.info("Using HF cache directory: %s", path) os.environ["HF_HOME"] = path os.environ["HUGGINGFACE_HUB_CACHE"] = path os.environ["TRANSFORMERS_CACHE"] = path return path except Exception as exc: logger.warning("Failed to create cache dir %s: %s", path, exc) raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.") def _clean_caption(prompt: str) -> str: replacements = [ "black and white", "black & white", "monochrome", "monochromatic", "bw photo", "blurry", "grainy", "historical", "restored", "circa", "taken in", "overcast", "desaturated", "low contrast", ] cleaned = prompt for word in replacements: cleaned = cleaned.replace(word, "") return cleaned.strip(" ,") class ColorizeModel: """Colorization model that leverages the HF Inference API.""" CAPTION_MODEL = "Salesforce/blip-image-captioning-large" def __init__(self, model_id: str | None = None) -> None: self.model_id = model_id or settings.MODEL_ID self.api_url = f"https://router.huggingface.co/hf-inference/models/{self.model_id}" self.api_token = ( os.getenv("HUGGINGFACE_API_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN") ) if not self.api_token: raise RuntimeError( "HUGGINGFACE_API_TOKEN (or HUGGINGFACE_HUB_TOKEN / HF_TOKEN) is not set. " "Please provide an access token with Inference API permissions." ) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 os.environ.setdefault("OMP_NUM_THREADS", "1") self.cache_dir = _ensure_cache_dir() self.positive_prompt = settings.POSITIVE_PROMPT self.negative_prompt = settings.NEGATIVE_PROMPT self.num_inference_steps = settings.NUM_INFERENCE_STEPS self.guidance_scale = settings.GUIDANCE_SCALE self.caption_prefix = settings.CAPTION_PREFIX self.seed = settings.COLORIZE_SEED self.timeout = settings.INFERENCE_TIMEOUT self.provider = settings.INFERENCE_PROVIDER self._load_caption_model() def _load_caption_model(self) -> None: logger.info("Loading BLIP captioning model for prompt generation...") self.caption_processor = BlipProcessor.from_pretrained( self.CAPTION_MODEL, cache_dir=self.cache_dir ) self.caption_model = BlipForConditionalGeneration.from_pretrained( self.CAPTION_MODEL, torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32, cache_dir=self.cache_dir ).to(self.device) def caption_image(self, image: Image.Image) -> str: inputs = self.caption_processor( image, self.caption_prefix, return_tensors="pt", ).to(self.device) if self.device.type != "cuda": inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} with torch.inference_mode(): caption_ids = self.caption_model.generate(**inputs) caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True) return _clean_caption(caption) def _build_payload(self, prompt: str) -> dict: payload = { "inputs": prompt, "parameters": { "num_inference_steps": self.num_inference_steps, "guidance_scale": self.guidance_scale, "negative_prompt": self.negative_prompt, "seed": self.seed, }, } if self.provider: payload["provider"] = {"name": self.provider} return payload def colorize(self, image: Image.Image, _num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: caption = self.caption_image(image) prompt_parts = [self.positive_prompt, caption] prompt = ", ".join([p for p in prompt_parts if p]) headers = { "Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json", } payload = self._build_payload(prompt) logger.info("Calling HF Inference API for prompt: %s", prompt) response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout) if response.status_code != 200: try: data = response.json() except ValueError: data = response.text logger.error("Inference API error (%s): %s", response.status_code, data) raise RuntimeError(f"Inference API error ({response.status_code}): {data}") colorized = Image.open(io.BytesIO(response.content)).convert("RGB") colorized = colorized.resize(image.size, Image.Resampling.LANCZOS) return colorized, caption