File size: 7,680 Bytes
60c56d7
8d0a1ae
 
60c56d7
8f6f449
 
 
60c56d7
7471c96
8f6f449
 
60c56d7
 
8d0a1ae
 
 
 
 
 
 
 
8f6f449
 
2ae242d
60c56d7
 
 
8f6f449
 
8d0a1ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f6f449
 
f79a7fe
8d0a1ae
 
 
8f6f449
 
8d0a1ae
 
8f6f449
 
 
60c56d7
8d0a1ae
8f6f449
 
8d0a1ae
 
 
f79a7fe
8d0a1ae
f79a7fe
8d0a1ae
 
7471c96
f79a7fe
8f6f449
7471c96
 
8d0a1ae
 
 
 
 
 
f79a7fe
 
8d0a1ae
 
 
8f6f449
 
 
 
8d0a1ae
8f6f449
 
8d0a1ae
bdef219
8d0a1ae
 
 
bdef219
f79a7fe
8d0a1ae
 
 
8f6f449
f79a7fe
8f6f449
8d0a1ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f6f449
 
 
 
 
 
 
 
 
 
 
 
 
f79a7fe
 
8d0a1ae
 
 
f79a7fe
8d0a1ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79a7fe
8d0a1ae
 
 
 
f79a7fe
 
60c56d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""
Colorize model wrapper replicating the behaviour of the
`fffiloni/text-guided-image-colorization` Space.
"""

from __future__ import annotations

import logging
import os
from typing import Tuple

import torch
from PIL import Image
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    StableDiffusionXLControlNetPipeline,
    UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import BlipForConditionalGeneration, BlipProcessor

from app.config import settings

logger = logging.getLogger(__name__)


def _ensure_cache_dir() -> str:
    cache_dir = os.environ.get("HF_HOME") or "/tmp/hf_cache"
    try:
        os.makedirs(cache_dir, exist_ok=True)
    except Exception as exc:  # pragma: no cover
        logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
    os.environ["HF_HOME"] = cache_dir
    os.environ["TRANSFORMERS_CACHE"] = cache_dir
    os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
    os.environ["HF_HUB_CACHE"] = cache_dir
    return cache_dir


def _apply_lab_merge(original_luminance: Image.Image, color_map: Image.Image) -> Image.Image:
    base_lab = original_luminance.convert("LAB")
    color_lab = color_map.convert("LAB")
    l_channel, _, _ = base_lab.split()
    _, a_channel, b_channel = color_lab.split()
    merged = Image.merge("LAB", (l_channel, a_channel, b_channel))
    return merged.convert("RGB")


def _clean_caption(prompt: str) -> str:
    remove_terms = [
        "black and white", "black & white", "monochrome", "bw photo",
        "historical", "restored", "low contrast", "desaturated", "overcast",
    ]
    cleaned = prompt
    for term in remove_terms:
        cleaned = cleaned.replace(term, "")
    return cleaned.strip(" ,")


class ColorizeModel:
    """Colorization model that runs the SDXL + ControlNet pipeline locally."""

    def __init__(self, model_id: str | None = None) -> None:
        self.cache_dir = _ensure_cache_dir()
        self.hf_token = (
            os.getenv("HF_TOKEN")
            or os.getenv("HUGGINGFACE_HUB_TOKEN")
            or os.getenv("HUGGINGFACE_API_TOKEN")
        )
        if not self.hf_token:
            logger.warning("HF token not provided – attempting to download public models only.")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
        os.environ.setdefault("OMP_NUM_THREADS", "1")

        self.controlnet_id = model_id or settings.MODEL_ID
        self.base_model_id = settings.BASE_MODEL_ID
        self.lightning_repo = settings.LIGHTNING_REPO
        self.lightning_weights = settings.LIGHTNING_WEIGHTS
        self.caption_model_id = settings.CAPTION_MODEL_ID

        self.num_inference_steps = settings.NUM_INFERENCE_STEPS
        self.guidance_scale = settings.GUIDANCE_SCALE
        self.controlnet_scale = settings.CONTROLNET_SCALE
        self.positive_prompt = settings.POSITIVE_PROMPT
        self.negative_prompt = settings.NEGATIVE_PROMPT
        self.caption_prefix = settings.CAPTION_PREFIX
        self.seed = settings.COLORIZE_SEED

        self._load_caption_model()
        self._load_pipeline()

    def _load_caption_model(self) -> None:
        logger.info("Loading BLIP captioning model: %s", self.caption_model_id)
        self.caption_processor = BlipProcessor.from_pretrained(
            self.caption_model_id,
            cache_dir=self.cache_dir,
            token=self.hf_token,
        )
        self.caption_model = BlipForConditionalGeneration.from_pretrained(
            self.caption_model_id,
            cache_dir=self.cache_dir,
            token=self.hf_token,
            torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
        ).to(self.device)

    def _load_pipeline(self) -> None:
        logger.info("Loading ControlNet model: %s", self.controlnet_id)
        controlnet = ControlNetModel.from_pretrained(
            self.controlnet_id,
            torch_dtype=self.dtype,
            cache_dir=self.cache_dir,
            token=self.hf_token,
        )

        logger.info("Loading SDXL base model components: %s", self.base_model_id)
        vae = AutoencoderKL.from_pretrained(
            self.base_model_id,
            subfolder="vae",
            torch_dtype=self.dtype,
            cache_dir=self.cache_dir,
            token=self.hf_token,
        )
        unet = UNet2DConditionModel.from_config(
            self.base_model_id,
            subfolder="unet",
            cache_dir=self.cache_dir,
            token=self.hf_token,
        )
        lightning_path = hf_hub_download(
            repo_id=self.lightning_repo,
            filename=self.lightning_weights,
            cache_dir=self.cache_dir,
            token=self.hf_token,
        )
        unet.load_state_dict(load_file(lightning_path))

        self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
            self.base_model_id,
            vae=vae,
            unet=unet,
            controlnet=controlnet,
            torch_dtype=self.dtype,
            cache_dir=self.cache_dir,
            token=self.hf_token,
            safety_checker=None,
            requires_safety_checker=False,
        )
        self.pipe.set_progress_bar_config(disable=True)
        self.pipe.to(self.device, dtype=self.dtype)
        if self.device.type == "cuda" and hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
            try:
                self.pipe.enable_xformers_memory_efficient_attention()
            except Exception as exc:  # pragma: no cover
                logger.warning("Could not enable xFormers optimizations: %s", exc)

        logger.info("Colorization pipeline ready.")

    def caption_image(self, image: Image.Image) -> str:
        inputs = self.caption_processor(
            image,
            self.caption_prefix,
            return_tensors="pt",
        ).to(self.device)

        if self.device.type != "cuda":
            inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

        with torch.inference_mode():
            caption_ids = self.caption_model.generate(**inputs)
        caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
        return _clean_caption(caption)

    def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
        original_size = image.size
        control_image = image.convert("L").convert("RGB").resize((512, 512), Image.Resampling.LANCZOS)

        caption = self.caption_image(image)
        prompt_components = [self.positive_prompt, caption]
        prompt = ", ".join([p for p in prompt_components if p])
        steps = num_inference_steps or self.num_inference_steps
        generator = torch.Generator(device=self.device).manual_seed(self.seed)

        logger.info("Running ControlNet pipeline with prompt: %s", prompt)
        result = self.pipe(
            prompt=prompt,
            negative_prompt=self.negative_prompt or None,
            image=control_image,
            control_image=control_image,
            num_inference_steps=steps,
            guidance_scale=self.guidance_scale,
            controlnet_conditioning_scale=self.controlnet_scale,
            generator=generator,
        )

        generated = result.images[0]
        colorized = _apply_lab_merge(control_image, generated)
        if colorized.size != original_size:
            colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)

        return colorized, caption