| """ | |
| Colorize model wrapper that forwards requests to the Hugging Face Inference API. | |
| """ | |
| from __future__ import annotations | |
| import io | |
| import logging | |
| import os | |
| from typing import Tuple | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from transformers import BlipForConditionalGeneration, BlipProcessor | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| def _ensure_cache_dir() -> str: | |
| """Ensure we have a writable Hugging Face cache directory.""" | |
| data_dir = os.getenv("DATA_DIR") | |
| candidates = [] | |
| if data_dir: | |
| candidates.append(os.path.join(data_dir, "hf_cache")) | |
| candidates.extend( | |
| [ | |
| os.path.join("/tmp", "hf_cache"), | |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface"), | |
| ] | |
| ) | |
| for path in candidates: | |
| try: | |
| os.makedirs(path, exist_ok=True) | |
| logger.info("Using HF cache directory: %s", path) | |
| os.environ["HF_HOME"] = path | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = path | |
| os.environ["TRANSFORMERS_CACHE"] = path | |
| return path | |
| except Exception as exc: | |
| logger.warning("Failed to create cache dir %s: %s", path, exc) | |
| raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.") | |
| def _clean_caption(prompt: str) -> str: | |
| replacements = [ | |
| "black and white", "black & white", "monochrome", "monochromatic", | |
| "bw photo", "blurry", "grainy", "historical", "restored", "circa", | |
| "taken in", "overcast", "desaturated", "low contrast", | |
| ] | |
| cleaned = prompt | |
| for word in replacements: | |
| cleaned = cleaned.replace(word, "") | |
| return cleaned.strip(" ,") | |
| class ColorizeModel: | |
| """Colorization model that leverages the HF Inference API.""" | |
| CAPTION_MODEL = "Salesforce/blip-image-captioning-large" | |
| def __init__(self, model_id: str | None = None) -> None: | |
| self.model_id = model_id or settings.MODEL_ID | |
| self.api_url = f"https://router.huggingface.co/hf-inference/models/{self.model_id}" | |
| self.api_token = ( | |
| os.getenv("HUGGINGFACE_API_TOKEN") | |
| or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| or os.getenv("HF_TOKEN") | |
| ) | |
| if not self.api_token: | |
| raise RuntimeError( | |
| "HUGGINGFACE_API_TOKEN (or HUGGINGFACE_HUB_TOKEN / HF_TOKEN) is not set. " | |
| "Please provide an access token with Inference API permissions." | |
| ) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32 | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| self.cache_dir = _ensure_cache_dir() | |
| self.positive_prompt = settings.POSITIVE_PROMPT | |
| self.negative_prompt = settings.NEGATIVE_PROMPT | |
| self.num_inference_steps = settings.NUM_INFERENCE_STEPS | |
| self.guidance_scale = settings.GUIDANCE_SCALE | |
| self.caption_prefix = settings.CAPTION_PREFIX | |
| self.seed = settings.COLORIZE_SEED | |
| self.timeout = settings.INFERENCE_TIMEOUT | |
| self.provider = settings.INFERENCE_PROVIDER | |
| self._load_caption_model() | |
| def _load_caption_model(self) -> None: | |
| logger.info("Loading BLIP captioning model for prompt generation...") | |
| self.caption_processor = BlipProcessor.from_pretrained( | |
| self.CAPTION_MODEL, | |
| cache_dir=self.cache_dir | |
| ) | |
| self.caption_model = BlipForConditionalGeneration.from_pretrained( | |
| self.CAPTION_MODEL, | |
| torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32, | |
| cache_dir=self.cache_dir | |
| ).to(self.device) | |
| def caption_image(self, image: Image.Image) -> str: | |
| inputs = self.caption_processor( | |
| image, | |
| self.caption_prefix, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| if self.device.type != "cuda": | |
| inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| caption_ids = self.caption_model.generate(**inputs) | |
| caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True) | |
| return _clean_caption(caption) | |
| def _build_payload(self, prompt: str) -> dict: | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "num_inference_steps": self.num_inference_steps, | |
| "guidance_scale": self.guidance_scale, | |
| "negative_prompt": self.negative_prompt, | |
| "seed": self.seed, | |
| }, | |
| } | |
| if self.provider: | |
| payload["provider"] = {"name": self.provider} | |
| return payload | |
| def colorize(self, image: Image.Image, _num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: | |
| caption = self.caption_image(image) | |
| prompt_parts = [self.positive_prompt, caption] | |
| prompt = ", ".join([p for p in prompt_parts if p]) | |
| headers = { | |
| "Authorization": f"Bearer {self.api_token}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = self._build_payload(prompt) | |
| logger.info("Calling HF Inference API for prompt: %s", prompt) | |
| response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout) | |
| if response.status_code != 200: | |
| try: | |
| data = response.json() | |
| except ValueError: | |
| data = response.text | |
| logger.error("Inference API error (%s): %s", response.status_code, data) | |
| raise RuntimeError(f"Inference API error ({response.status_code}): {data}") | |
| colorized = Image.open(io.BytesIO(response.content)).convert("RGB") | |
| colorized = colorized.resize(image.size, Image.Resampling.LANCZOS) | |
| return colorized, caption | |