LogicGoInfotechSpaces's picture
Point inference requests to new HF router endpoint
d58eb50
raw
history blame
5.97 kB
"""
Colorize model wrapper that forwards requests to the Hugging Face Inference API.
"""
from __future__ import annotations
import io
import logging
import os
from typing import Tuple
import requests
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
"""Ensure we have a writable Hugging Face cache directory."""
data_dir = os.getenv("DATA_DIR")
candidates = []
if data_dir:
candidates.append(os.path.join(data_dir, "hf_cache"))
candidates.extend(
[
os.path.join("/tmp", "hf_cache"),
os.path.join(os.path.expanduser("~"), ".cache", "huggingface"),
]
)
for path in candidates:
try:
os.makedirs(path, exist_ok=True)
logger.info("Using HF cache directory: %s", path)
os.environ["HF_HOME"] = path
os.environ["HUGGINGFACE_HUB_CACHE"] = path
os.environ["TRANSFORMERS_CACHE"] = path
return path
except Exception as exc:
logger.warning("Failed to create cache dir %s: %s", path, exc)
raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.")
def _clean_caption(prompt: str) -> str:
replacements = [
"black and white", "black & white", "monochrome", "monochromatic",
"bw photo", "blurry", "grainy", "historical", "restored", "circa",
"taken in", "overcast", "desaturated", "low contrast",
]
cleaned = prompt
for word in replacements:
cleaned = cleaned.replace(word, "")
return cleaned.strip(" ,")
class ColorizeModel:
"""Colorization model that leverages the HF Inference API."""
CAPTION_MODEL = "Salesforce/blip-image-captioning-large"
def __init__(self, model_id: str | None = None) -> None:
self.model_id = model_id or settings.MODEL_ID
self.api_url = f"https://router.huggingface.co/hf-inference/models/{self.model_id}"
self.api_token = (
os.getenv("HUGGINGFACE_API_TOKEN")
or os.getenv("HUGGINGFACE_HUB_TOKEN")
or os.getenv("HF_TOKEN")
)
if not self.api_token:
raise RuntimeError(
"HUGGINGFACE_API_TOKEN (or HUGGINGFACE_HUB_TOKEN / HF_TOKEN) is not set. "
"Please provide an access token with Inference API permissions."
)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
os.environ.setdefault("OMP_NUM_THREADS", "1")
self.cache_dir = _ensure_cache_dir()
self.positive_prompt = settings.POSITIVE_PROMPT
self.negative_prompt = settings.NEGATIVE_PROMPT
self.num_inference_steps = settings.NUM_INFERENCE_STEPS
self.guidance_scale = settings.GUIDANCE_SCALE
self.caption_prefix = settings.CAPTION_PREFIX
self.seed = settings.COLORIZE_SEED
self.timeout = settings.INFERENCE_TIMEOUT
self.provider = settings.INFERENCE_PROVIDER
self._load_caption_model()
def _load_caption_model(self) -> None:
logger.info("Loading BLIP captioning model for prompt generation...")
self.caption_processor = BlipProcessor.from_pretrained(
self.CAPTION_MODEL,
cache_dir=self.cache_dir
)
self.caption_model = BlipForConditionalGeneration.from_pretrained(
self.CAPTION_MODEL,
torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
cache_dir=self.cache_dir
).to(self.device)
def caption_image(self, image: Image.Image) -> str:
inputs = self.caption_processor(
image,
self.caption_prefix,
return_tensors="pt",
).to(self.device)
if self.device.type != "cuda":
inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with torch.inference_mode():
caption_ids = self.caption_model.generate(**inputs)
caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
return _clean_caption(caption)
def _build_payload(self, prompt: str) -> dict:
payload = {
"inputs": prompt,
"parameters": {
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
"negative_prompt": self.negative_prompt,
"seed": self.seed,
},
}
if self.provider:
payload["provider"] = {"name": self.provider}
return payload
def colorize(self, image: Image.Image, _num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
caption = self.caption_image(image)
prompt_parts = [self.positive_prompt, caption]
prompt = ", ".join([p for p in prompt_parts if p])
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
payload = self._build_payload(prompt)
logger.info("Calling HF Inference API for prompt: %s", prompt)
response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout)
if response.status_code != 200:
try:
data = response.json()
except ValueError:
data = response.text
logger.error("Inference API error (%s): %s", response.status_code, data)
raise RuntimeError(f"Inference API error ({response.status_code}): {data}")
colorized = Image.open(io.BytesIO(response.content)).convert("RGB")
colorized = colorized.resize(image.size, Image.Resampling.LANCZOS)
return colorized, caption