File size: 5,969 Bytes
60c56d7
f79a7fe
60c56d7
8f6f449
 
 
f79a7fe
60c56d7
7471c96
8f6f449
 
f79a7fe
60c56d7
 
8f6f449
 
2ae242d
60c56d7
 
 
8f6f449
 
 
 
f79a7fe
8f6f449
f79a7fe
 
8f6f449
 
 
 
 
 
f79a7fe
8f6f449
 
 
 
 
 
 
f79a7fe
8f6f449
 
 
 
 
f79a7fe
 
 
 
 
8f6f449
 
f79a7fe
8f6f449
 
 
 
60c56d7
f79a7fe
8f6f449
 
 
 
f79a7fe
d58eb50
f79a7fe
 
 
 
 
 
 
 
 
 
 
7471c96
f79a7fe
8f6f449
7471c96
 
8f6f449
 
 
f79a7fe
 
8f6f449
 
f79a7fe
 
8f6f449
 
 
 
f79a7fe
bdef219
 
 
 
f79a7fe
8f6f449
 
bdef219
f79a7fe
8f6f449
 
 
 
 
 
 
 
 
 
 
 
 
 
f79a7fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c56d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Colorize model wrapper that forwards requests to the Hugging Face Inference API.
"""

from __future__ import annotations

import io
import logging
import os
from typing import Tuple

import requests
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor

from app.config import settings

logger = logging.getLogger(__name__)


def _ensure_cache_dir() -> str:
    """Ensure we have a writable Hugging Face cache directory."""
    data_dir = os.getenv("DATA_DIR")
    candidates = []
    if data_dir:
        candidates.append(os.path.join(data_dir, "hf_cache"))
    candidates.extend(
        [
            os.path.join("/tmp", "hf_cache"),
            os.path.join(os.path.expanduser("~"), ".cache", "huggingface"),
        ]
    )

    for path in candidates:
        try:
            os.makedirs(path, exist_ok=True)
            logger.info("Using HF cache directory: %s", path)
            os.environ["HF_HOME"] = path
            os.environ["HUGGINGFACE_HUB_CACHE"] = path
            os.environ["TRANSFORMERS_CACHE"] = path
            return path
        except Exception as exc:
            logger.warning("Failed to create cache dir %s: %s", path, exc)

    raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.")


def _clean_caption(prompt: str) -> str:
    replacements = [
        "black and white", "black & white", "monochrome", "monochromatic",
        "bw photo", "blurry", "grainy", "historical", "restored", "circa",
        "taken in", "overcast", "desaturated", "low contrast",
    ]
    cleaned = prompt
    for word in replacements:
        cleaned = cleaned.replace(word, "")
    return cleaned.strip(" ,")


class ColorizeModel:
    """Colorization model that leverages the HF Inference API."""

    CAPTION_MODEL = "Salesforce/blip-image-captioning-large"

    def __init__(self, model_id: str | None = None) -> None:
        self.model_id = model_id or settings.MODEL_ID
        self.api_url = f"https://router.huggingface.co/hf-inference/models/{self.model_id}"

        self.api_token = (
            os.getenv("HUGGINGFACE_API_TOKEN")
            or os.getenv("HUGGINGFACE_HUB_TOKEN")
            or os.getenv("HF_TOKEN")
        )
        if not self.api_token:
            raise RuntimeError(
                "HUGGINGFACE_API_TOKEN (or HUGGINGFACE_HUB_TOKEN / HF_TOKEN) is not set. "
                "Please provide an access token with Inference API permissions."
            )

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
        os.environ.setdefault("OMP_NUM_THREADS", "1")

        self.cache_dir = _ensure_cache_dir()
        self.positive_prompt = settings.POSITIVE_PROMPT
        self.negative_prompt = settings.NEGATIVE_PROMPT
        self.num_inference_steps = settings.NUM_INFERENCE_STEPS
        self.guidance_scale = settings.GUIDANCE_SCALE
        self.caption_prefix = settings.CAPTION_PREFIX
        self.seed = settings.COLORIZE_SEED
        self.timeout = settings.INFERENCE_TIMEOUT
        self.provider = settings.INFERENCE_PROVIDER

        self._load_caption_model()

    def _load_caption_model(self) -> None:
        logger.info("Loading BLIP captioning model for prompt generation...")
        self.caption_processor = BlipProcessor.from_pretrained(
            self.CAPTION_MODEL,
            cache_dir=self.cache_dir
        )
        self.caption_model = BlipForConditionalGeneration.from_pretrained(
            self.CAPTION_MODEL,
            torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
            cache_dir=self.cache_dir
        ).to(self.device)

    def caption_image(self, image: Image.Image) -> str:
        inputs = self.caption_processor(
            image,
            self.caption_prefix,
            return_tensors="pt",
        ).to(self.device)

        if self.device.type != "cuda":
            inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

        with torch.inference_mode():
            caption_ids = self.caption_model.generate(**inputs)
        caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
        return _clean_caption(caption)

    def _build_payload(self, prompt: str) -> dict:
        payload = {
            "inputs": prompt,
            "parameters": {
                "num_inference_steps": self.num_inference_steps,
                "guidance_scale": self.guidance_scale,
                "negative_prompt": self.negative_prompt,
                "seed": self.seed,
            },
        }
        if self.provider:
            payload["provider"] = {"name": self.provider}
        return payload

    def colorize(self, image: Image.Image, _num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
        caption = self.caption_image(image)
        prompt_parts = [self.positive_prompt, caption]
        prompt = ", ".join([p for p in prompt_parts if p])

        headers = {
            "Authorization": f"Bearer {self.api_token}",
            "Content-Type": "application/json",
        }
        payload = self._build_payload(prompt)

        logger.info("Calling HF Inference API for prompt: %s", prompt)
        response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout)

        if response.status_code != 200:
            try:
                data = response.json()
            except ValueError:
                data = response.text
            logger.error("Inference API error (%s): %s", response.status_code, data)
            raise RuntimeError(f"Inference API error ({response.status_code}): {data}")

        colorized = Image.open(io.BytesIO(response.content)).convert("RGB")
        colorized = colorized.resize(image.size, Image.Resampling.LANCZOS)
        return colorized, caption