LogicGoInfotechSpaces commited on
Commit
f79a7fe
·
1 Parent(s): dfc30a3

Switch colorization to HF Inference API

Browse files
Files changed (2) hide show
  1. app/colorize_model.py +78 -200
  2. app/config.py +5 -3
app/colorize_model.py CHANGED
@@ -1,24 +1,17 @@
1
  """
2
- Colorize model wrapper replicating the behaviour of the
3
- `fffiloni/text-guided-image-colorization` Space.
4
  """
5
 
6
  from __future__ import annotations
7
 
 
8
  import logging
9
  import os
10
  from typing import Tuple
11
 
 
12
  import torch
13
  from PIL import Image
14
- from diffusers import (
15
- AutoencoderKL,
16
- ControlNetModel,
17
- StableDiffusionXLControlNetPipeline,
18
- UNet2DConditionModel,
19
- )
20
- from huggingface_hub import hf_hub_download, snapshot_download
21
- from safetensors.torch import load_file
22
  from transformers import BlipForConditionalGeneration, BlipProcessor
23
 
24
  from app.config import settings
@@ -29,17 +22,17 @@ logger = logging.getLogger(__name__)
29
  def _ensure_cache_dir() -> str:
30
  """Ensure we have a writable Hugging Face cache directory."""
31
  data_dir = os.getenv("DATA_DIR")
32
- candidate_dirs = []
33
  if data_dir:
34
- candidate_dirs.append(os.path.join(data_dir, "hf_cache"))
35
- candidate_dirs.extend(
36
  [
37
  os.path.join("/tmp", "hf_cache"),
38
  os.path.join(os.path.expanduser("~"), ".cache", "huggingface"),
39
  ]
40
  )
41
 
42
- for path in candidate_dirs:
43
  try:
44
  os.makedirs(path, exist_ok=True)
45
  logger.info("Using HF cache directory: %s", path)
@@ -47,235 +40,120 @@ def _ensure_cache_dir() -> str:
47
  os.environ["HUGGINGFACE_HUB_CACHE"] = path
48
  os.environ["TRANSFORMERS_CACHE"] = path
49
  return path
50
- except Exception as exc: # pragma: no cover - best effort
51
  logger.warning("Failed to create cache dir %s: %s", path, exc)
52
 
53
  raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.")
54
 
55
 
56
- def _apply_color(luminance_image: Image.Image, color_map: Image.Image) -> Image.Image:
57
- """Merge the L channel of the grayscale control image with AB channels from generated image."""
58
- image_lab = luminance_image.convert("LAB")
59
- color_map_lab = color_map.convert("LAB")
60
- l_channel, _, _ = image_lab.split()
61
- _, a_channel, b_channel = color_map_lab.split()
62
- merged = Image.merge("LAB", (l_channel, a_channel, b_channel))
63
- return merged.convert("RGB")
64
-
65
-
66
- def _remove_unlikely_words(prompt: str) -> str:
67
- """Clean up BLIP captions to avoid misleading descriptors."""
68
- unlikely_words = []
69
-
70
- decades = [f"{i}s" for i in range(1900, 2000)]
71
- years = [f"{i}" for i in range(1900, 2000)]
72
- years_with_word = [f"year {i}" for i in range(1900, 2000)]
73
- circa_years = [f"circa {i}" for i in range(1900, 2000)]
74
-
75
- expanded = [
76
- [f"{d[0]} {d[1]} {d[2]} {d[3]} s" for d in decades],
77
- [f"{d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
78
- [f"year {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
79
- [f"circa {d[0]} {d[1]} {d[2]} {d[3]}" for d in decades],
80
- ]
81
-
82
- manual_terms = [
83
- "black and white,", "black and white", "black & white,", "black & white",
84
- "circa", "monochrome,", "monochrome", "bw", "bw,", "b&w", "b&w,",
85
- "grainy", "grainy photo", "grainy photograph", "grainy footage",
86
- "black-and-white", "black - and - white", "black on white",
87
- "historical photo", "historic photo", "restored", "desaturated",
88
- "low contrast", "blurry", "overcast", "taken in", "photo taken in",
89
- ", photo", ", photo", ", photo", ", photograph",
90
  ]
91
-
92
- for seq in expanded:
93
- unlikely_words.extend(seq)
94
- unlikely_words.extend(decades + years + years_with_word + circa_years + manual_terms)
95
-
96
  cleaned = prompt
97
- for word in unlikely_words:
98
  cleaned = cleaned.replace(word, "")
99
  return cleaned.strip(" ,")
100
 
101
 
102
  class ColorizeModel:
103
- """Colorization model wrapper."""
104
 
105
- CONTROLNET_REPO = "nickpai/sdxl_light_caption_output"
106
- CONTROLNET_SUBDIR = os.path.join("checkpoint-30000", "controlnet")
107
- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
108
- LIGHTNING_REPO = "ByteDance/SDXL-Lightning"
109
- LIGHTNING_WEIGHTS = "sdxl_lightning_8step_unet.safetensors"
110
  CAPTION_MODEL = "Salesforce/blip-image-captioning-large"
111
 
112
  def __init__(self, model_id: str | None = None) -> None:
113
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
- logger.info("Using device: %s", self.device)
 
 
 
 
 
 
 
 
 
 
 
115
 
 
116
  self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
117
  os.environ.setdefault("OMP_NUM_THREADS", "1")
118
 
119
- self.hf_token = (
120
- os.getenv("HF_TOKEN")
121
- or os.getenv("HUGGINGFACE_HUB_TOKEN")
122
- or None
123
- )
124
  self.cache_dir = _ensure_cache_dir()
125
-
126
- self.num_inference_steps = settings.NUM_INFERENCE_STEPS
127
- self.guidance_scale = settings.GUIDANCE_SCALE
128
- self.controlnet_scale = settings.CONTROLNET_SCALE
129
  self.positive_prompt = settings.POSITIVE_PROMPT
130
  self.negative_prompt = settings.NEGATIVE_PROMPT
 
 
131
  self.caption_prefix = settings.CAPTION_PREFIX
132
  self.seed = settings.COLORIZE_SEED
 
 
133
 
134
- self.model_id = model_id or settings.MODEL_ID
135
-
136
- self._load_pipeline()
137
  self._load_caption_model()
138
- self.last_caption: str | None = None
139
-
140
- # --------------------------------------------------------------------- #
141
- # Initialisation helpers
142
- # --------------------------------------------------------------------- #
143
- def _download_controlnet(self) -> str:
144
- logger.info("Downloading ControlNet snapshot: %s", self.CONTROLNET_REPO)
145
- local_dir = os.path.join(self.cache_dir, "sdxl_light_caption_output")
146
- path = snapshot_download(
147
- repo_id=self.CONTROLNET_REPO,
148
- local_dir=local_dir,
149
- local_dir_use_symlinks=False,
150
- token=self.hf_token,
151
- )
152
- controlnet_path = os.path.join(path, self.CONTROLNET_SUBDIR)
153
- if not os.path.isdir(controlnet_path):
154
- raise RuntimeError(f"ControlNet weights not found at {controlnet_path}")
155
- return controlnet_path
156
-
157
- def _load_pipeline(self) -> None:
158
- controlnet_path = self._download_controlnet()
159
- base_kwargs = {"use_auth_token": self.hf_token} if self.hf_token else {}
160
-
161
- logger.info("Loading SDXL components...")
162
- vae = AutoencoderKL.from_pretrained(self.BASE_MODEL, subfolder="vae", torch_dtype=self.dtype, token=self.hf_token)
163
- unet = UNet2DConditionModel.from_config(
164
- self.BASE_MODEL,
165
- subfolder="unet",
166
- token=self.hf_token if self.hf_token else None,
167
- )
168
- lightning_path = hf_hub_download(
169
- repo_id=self.LIGHTNING_REPO,
170
- filename=self.LIGHTNING_WEIGHTS,
171
- token=self.hf_token if self.hf_token else None,
172
- )
173
- unet.load_state_dict(load_file(lightning_path))
174
-
175
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=self.dtype)
176
-
177
- try:
178
- self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
179
- self.BASE_MODEL,
180
- vae=vae,
181
- unet=unet,
182
- controlnet=controlnet,
183
- torch_dtype=self.dtype,
184
- safety_checker=None,
185
- requires_safety_checker=False,
186
- token=self.hf_token if self.hf_token else None,
187
- )
188
- except Exception as exc:
189
- logger.error("Failed to load base SDXL model: %s", exc)
190
- logger.error(
191
- "Ensure the account associated with HUGGINGFACE_HUB_TOKEN has accepted "
192
- "the license for %s and that the token has access.", self.BASE_MODEL
193
- )
194
- raise
195
- self.pipe.set_progress_bar_config(disable=True)
196
-
197
- if self.device.type == "cuda":
198
- self.pipe.to(self.device, dtype=self.dtype)
199
- if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
200
- try:
201
- self.pipe.enable_xformers_memory_efficient_attention()
202
- except Exception as exc: # pragma: no cover
203
- logger.warning("Could not enable xformers attention: %s", exc)
204
- else:
205
- self.pipe.to(self.device, dtype=self.dtype)
206
-
207
- logger.info("Colorization pipeline ready.")
208
 
209
  def _load_caption_model(self) -> None:
210
- logger.info("Loading BLIP captioning model...")
211
- processor = BlipProcessor.from_pretrained(self.CAPTION_MODEL, token=self.hf_token)
212
- model = BlipForConditionalGeneration.from_pretrained(
213
  self.CAPTION_MODEL,
214
  torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
215
- token=self.hf_token,
216
- )
217
- self.caption_processor = processor
218
- self.caption_model = model.to(self.device)
219
 
220
- # --------------------------------------------------------------------- #
221
- # Public API
222
- # --------------------------------------------------------------------- #
223
  def caption_image(self, image: Image.Image) -> str:
224
- """Generate a cleaned caption for the image."""
225
  inputs = self.caption_processor(
226
  image,
227
  self.caption_prefix,
228
  return_tensors="pt",
229
  ).to(self.device)
230
 
231
- # BLIP on CPU expects float32 inputs
232
  if self.device.type != "cuda":
233
  inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
234
 
235
  with torch.inference_mode():
236
  caption_ids = self.caption_model.generate(**inputs)
237
  caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
238
- cleaned_caption = _remove_unlikely_words(caption)
239
- return cleaned_caption or caption
240
-
241
- def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
242
- """Colorize a grayscale image."""
243
- try:
244
- original_size = image.size
245
- control_image = image.convert("L").convert("RGB").resize(
246
- (512, 512), Image.Resampling.LANCZOS
247
- )
248
-
249
- caption = self.caption_image(image)
250
- self.last_caption = caption
251
-
252
- prompt_parts = [caption]
253
- if self.positive_prompt:
254
- prompt_parts.insert(0, self.positive_prompt)
255
- final_prompt = ", ".join([part for part in prompt_parts if part])
256
-
257
- negative_prompt = self.negative_prompt or None
258
- steps = num_inference_steps or self.num_inference_steps
259
- generator = torch.Generator(device=self.device).manual_seed(self.seed)
260
-
261
- logger.info("Running SDXL pipeline with prompt: %s", final_prompt)
262
- result = self.pipe(
263
- prompt=final_prompt,
264
- negative_prompt=negative_prompt,
265
- image=control_image,
266
- num_inference_steps=steps,
267
- guidance_scale=self.guidance_scale,
268
- controlnet_conditioning_scale=self.controlnet_scale,
269
- generator=generator,
270
- )
271
-
272
- generated_image = result.images[0]
273
- colorized = _apply_color(control_image, generated_image)
274
- if colorized.size != original_size:
275
- colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
276
-
277
- return colorized, caption
278
- except Exception as exc:
279
- logger.exception("Error during colorization: %s", exc)
280
- raise
281
 
 
1
  """
2
+ Colorize model wrapper that forwards requests to the Hugging Face Inference API.
 
3
  """
4
 
5
  from __future__ import annotations
6
 
7
+ import io
8
  import logging
9
  import os
10
  from typing import Tuple
11
 
12
+ import requests
13
  import torch
14
  from PIL import Image
 
 
 
 
 
 
 
 
15
  from transformers import BlipForConditionalGeneration, BlipProcessor
16
 
17
  from app.config import settings
 
22
  def _ensure_cache_dir() -> str:
23
  """Ensure we have a writable Hugging Face cache directory."""
24
  data_dir = os.getenv("DATA_DIR")
25
+ candidates = []
26
  if data_dir:
27
+ candidates.append(os.path.join(data_dir, "hf_cache"))
28
+ candidates.extend(
29
  [
30
  os.path.join("/tmp", "hf_cache"),
31
  os.path.join(os.path.expanduser("~"), ".cache", "huggingface"),
32
  ]
33
  )
34
 
35
+ for path in candidates:
36
  try:
37
  os.makedirs(path, exist_ok=True)
38
  logger.info("Using HF cache directory: %s", path)
 
40
  os.environ["HUGGINGFACE_HUB_CACHE"] = path
41
  os.environ["TRANSFORMERS_CACHE"] = path
42
  return path
43
+ except Exception as exc:
44
  logger.warning("Failed to create cache dir %s: %s", path, exc)
45
 
46
  raise RuntimeError("Unable to create a writable cache directory for Hugging Face downloads.")
47
 
48
 
49
+ def _clean_caption(prompt: str) -> str:
50
+ replacements = [
51
+ "black and white", "black & white", "monochrome", "monochromatic",
52
+ "bw photo", "blurry", "grainy", "historical", "restored", "circa",
53
+ "taken in", "overcast", "desaturated", "low contrast",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ]
 
 
 
 
 
55
  cleaned = prompt
56
+ for word in replacements:
57
  cleaned = cleaned.replace(word, "")
58
  return cleaned.strip(" ,")
59
 
60
 
61
  class ColorizeModel:
62
+ """Colorization model that leverages the HF Inference API."""
63
 
 
 
 
 
 
64
  CAPTION_MODEL = "Salesforce/blip-image-captioning-large"
65
 
66
  def __init__(self, model_id: str | None = None) -> None:
67
+ self.model_id = model_id or settings.MODEL_ID
68
+ self.api_url = f"https://api-inference.huggingface.co/models/{self.model_id}"
69
+
70
+ self.api_token = (
71
+ os.getenv("HUGGINGFACE_API_TOKEN")
72
+ or os.getenv("HUGGINGFACE_HUB_TOKEN")
73
+ or os.getenv("HF_TOKEN")
74
+ )
75
+ if not self.api_token:
76
+ raise RuntimeError(
77
+ "HUGGINGFACE_API_TOKEN (or HUGGINGFACE_HUB_TOKEN / HF_TOKEN) is not set. "
78
+ "Please provide an access token with Inference API permissions."
79
+ )
80
 
81
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
83
  os.environ.setdefault("OMP_NUM_THREADS", "1")
84
 
 
 
 
 
 
85
  self.cache_dir = _ensure_cache_dir()
 
 
 
 
86
  self.positive_prompt = settings.POSITIVE_PROMPT
87
  self.negative_prompt = settings.NEGATIVE_PROMPT
88
+ self.num_inference_steps = settings.NUM_INFERENCE_STEPS
89
+ self.guidance_scale = settings.GUIDANCE_SCALE
90
  self.caption_prefix = settings.CAPTION_PREFIX
91
  self.seed = settings.COLORIZE_SEED
92
+ self.timeout = settings.INFERENCE_TIMEOUT
93
+ self.provider = settings.INFERENCE_PROVIDER
94
 
 
 
 
95
  self._load_caption_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def _load_caption_model(self) -> None:
98
+ logger.info("Loading BLIP captioning model for prompt generation...")
99
+ self.caption_processor = BlipProcessor.from_pretrained(self.CAPTION_MODEL)
100
+ self.caption_model = BlipForConditionalGeneration.from_pretrained(
101
  self.CAPTION_MODEL,
102
  torch_dtype=self.dtype if self.device.type == "cuda" else torch.float32,
103
+ ).to(self.device)
 
 
 
104
 
 
 
 
105
  def caption_image(self, image: Image.Image) -> str:
 
106
  inputs = self.caption_processor(
107
  image,
108
  self.caption_prefix,
109
  return_tensors="pt",
110
  ).to(self.device)
111
 
 
112
  if self.device.type != "cuda":
113
  inputs = {k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
114
 
115
  with torch.inference_mode():
116
  caption_ids = self.caption_model.generate(**inputs)
117
  caption = self.caption_processor.decode(caption_ids[0], skip_special_tokens=True)
118
+ return _clean_caption(caption)
119
+
120
+ def _build_payload(self, prompt: str) -> dict:
121
+ payload = {
122
+ "inputs": prompt,
123
+ "parameters": {
124
+ "num_inference_steps": self.num_inference_steps,
125
+ "guidance_scale": self.guidance_scale,
126
+ "negative_prompt": self.negative_prompt,
127
+ "seed": self.seed,
128
+ },
129
+ }
130
+ if self.provider:
131
+ payload["provider"] = {"name": self.provider}
132
+ return payload
133
+
134
+ def colorize(self, image: Image.Image, _num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
135
+ caption = self.caption_image(image)
136
+ prompt_parts = [self.positive_prompt, caption]
137
+ prompt = ", ".join([p for p in prompt_parts if p])
138
+
139
+ headers = {
140
+ "Authorization": f"Bearer {self.api_token}",
141
+ "Content-Type": "application/json",
142
+ }
143
+ payload = self._build_payload(prompt)
144
+
145
+ logger.info("Calling HF Inference API for prompt: %s", prompt)
146
+ response = requests.post(self.api_url, headers=headers, json=payload, timeout=self.timeout)
147
+
148
+ if response.status_code != 200:
149
+ try:
150
+ data = response.json()
151
+ except ValueError:
152
+ data = response.text
153
+ logger.error("Inference API error (%s): %s", response.status_code, data)
154
+ raise RuntimeError(f"Inference API error ({response.status_code}): {data}")
155
+
156
+ colorized = Image.open(io.BytesIO(response.content)).convert("RGB")
157
+ colorized = colorized.resize(image.size, Image.Resampling.LANCZOS)
158
+ return colorized, caption
 
 
159
 
app/config.py CHANGED
@@ -17,9 +17,9 @@ class Settings(BaseSettings):
17
  # API settings
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
- # Model settings
21
- MODEL_ID: str = os.getenv("MODEL_ID", "nickpai/sdxl_light_caption_output")
22
- NUM_INFERENCE_STEPS: int = int(os.getenv("NUM_INFERENCE_STEPS", "8"))
23
  POSITIVE_PROMPT: str = os.getenv(
24
  "POSITIVE_PROMPT",
25
  "high quality color photo, vibrant natural colors, detailed lighting"
@@ -32,6 +32,8 @@ class Settings(BaseSettings):
32
  CONTROLNET_SCALE: float = float(os.getenv("CONTROLNET_SCALE", "1.0"))
33
  CAPTION_PREFIX: str = os.getenv("CAPTION_PREFIX", "a photography of")
34
  COLORIZE_SEED: int = int(os.getenv("COLORIZE_SEED", "123"))
 
 
35
 
36
  # Storage settings
37
  UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "uploads")
 
17
  # API settings
18
  BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
19
 
20
+ # Model / inference settings
21
+ MODEL_ID: str = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-xl-base-1.0")
22
+ NUM_INFERENCE_STEPS: int = int(os.getenv("NUM_INFERENCE_STEPS", "30"))
23
  POSITIVE_PROMPT: str = os.getenv(
24
  "POSITIVE_PROMPT",
25
  "high quality color photo, vibrant natural colors, detailed lighting"
 
32
  CONTROLNET_SCALE: float = float(os.getenv("CONTROLNET_SCALE", "1.0"))
33
  CAPTION_PREFIX: str = os.getenv("CAPTION_PREFIX", "a photography of")
34
  COLORIZE_SEED: int = int(os.getenv("COLORIZE_SEED", "123"))
35
+ INFERENCE_PROVIDER: str = os.getenv("INFERENCE_PROVIDER", "hf-inference")
36
+ INFERENCE_TIMEOUT: int = int(os.getenv("INFERENCE_TIMEOUT", "180"))
37
 
38
  # Storage settings
39
  UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "uploads")