LogicGoInfotechSpaces commited on
Commit
ae9bbd0
·
1 Parent(s): 807fb92

Refactor to use Hugging Face Inference API with fal-ai provider - Replace local model loading with InferenceClient API - Remove heavy SDXL/ControlNet/BLIP model dependencies - Use FLUX.1-Kontext-dev model via API - Keep FastAPI and Firebase authentication - Significantly reduce memory usage (no local models)

Browse files
Files changed (2) hide show
  1. app/config.py +3 -1
  2. app/main_sdxl.py +69 -171
app/config.py CHANGED
@@ -44,8 +44,10 @@ class Settings(BaseSettings):
44
  "FASTAI_OUTPUT_CAPTION",
45
  "Colorized using GAN-Colorization-Model"
46
  )
47
- INFERENCE_PROVIDER: str = os.getenv("INFERENCE_PROVIDER", "hf-inference")
 
48
  INFERENCE_TIMEOUT: int = int(os.getenv("INFERENCE_TIMEOUT", "180"))
 
49
 
50
  # Storage settings
51
  UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "uploads")
 
44
  "FASTAI_OUTPUT_CAPTION",
45
  "Colorized using GAN-Colorization-Model"
46
  )
47
+ INFERENCE_PROVIDER: str = os.getenv("INFERENCE_PROVIDER", "fal-ai")
48
+ INFERENCE_MODEL: str = os.getenv("INFERENCE_MODEL", "black-forest-labs/FLUX.1-Kontext-dev")
49
  INFERENCE_TIMEOUT: int = int(os.getenv("INFERENCE_TIMEOUT", "180"))
50
+ HF_TOKEN: str = os.getenv("HF_TOKEN", "")
51
 
52
  # Storage settings
53
  UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "uploads")
app/main_sdxl.py CHANGED
@@ -1,17 +1,8 @@
1
  """
2
- FastAPI application for Text-Guided Image Colorization using SDXL + ControlNet
3
- Based on fffiloni/text-guided-image-colorization
4
  """
5
  import os
6
- # Set environment variables BEFORE any imports
7
- os.environ["OMP_NUM_THREADS"] = "1"
8
- os.environ["HF_HOME"] = "/tmp/hf_cache"
9
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
- os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
11
- os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
12
- os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
13
- os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
14
-
15
  import io
16
  import uuid
17
  import logging
@@ -25,23 +16,11 @@ from fastapi.staticfiles import StaticFiles
25
  import firebase_admin
26
  from firebase_admin import credentials, app_check, auth as firebase_auth
27
  from PIL import Image
28
- import torch
29
  import uvicorn
30
  import gradio as gr
31
 
32
- # SDXL + ControlNet imports
33
- from accelerate import Accelerator
34
- from diffusers import (
35
- AutoencoderKL,
36
- StableDiffusionXLControlNetPipeline,
37
- ControlNetModel,
38
- UNet2DConditionModel,
39
- )
40
- from transformers import (
41
- BlipProcessor, BlipForConditionalGeneration,
42
- )
43
- from safetensors.torch import load_file
44
- from huggingface_hub import hf_hub_download, snapshot_download
45
 
46
  from app.config import settings
47
 
@@ -102,12 +81,8 @@ RESULT_DIR = Path("/tmp/colorize_results")
102
  app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
103
  app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
104
 
105
- # Global model variables
106
- pipe = None
107
- caption_model = None
108
- processor = None
109
- device = None
110
- weight_dtype = None
111
  model_load_error: Optional[str] = None
112
 
113
  # ========== Utility Functions ==========
@@ -177,110 +152,29 @@ def remove_unlikely_words(prompt: str) -> str:
177
 
178
  @app.on_event("startup")
179
  async def startup_event():
180
- """Load SDXL + ControlNet models on startup"""
181
- global pipe, caption_model, processor, device, weight_dtype, model_load_error
182
 
183
  try:
184
- logger.info("🔄 Loading SDXL + ControlNet colorization models...")
185
-
186
- # Use writable directory for model downloads
187
- controlnet_dir = "/tmp/sdxl_light_caption_output"
188
- try:
189
- os.makedirs(controlnet_dir, exist_ok=True)
190
- # Test write permissions
191
- test_file = os.path.join(controlnet_dir, ".test_write")
192
- with open(test_file, "w") as f:
193
- f.write("test")
194
- os.remove(test_file)
195
- logger.info(f"Using directory: {controlnet_dir}")
196
- except PermissionError as e:
197
- logger.error(f"Permission denied for directory {controlnet_dir}: {e}")
198
- raise
199
- except Exception as e:
200
- logger.error(f"Failed to create directory {controlnet_dir}: {e}")
201
- raise
202
-
203
- # Download controlnet model snapshot
204
- controlnet_path = os.path.join(controlnet_dir, "checkpoint-30000", "controlnet")
205
- if os.path.exists(controlnet_path):
206
- logger.info(f"ControlNet model already exists at {controlnet_path}")
207
- else:
208
- try:
209
- logger.info("Downloading ControlNet model...")
210
- snapshot_download(
211
- repo_id='nickpai/sdxl_light_caption_output',
212
- local_dir=controlnet_dir
213
- )
214
- logger.info("ControlNet model downloaded successfully")
215
- except Exception as e:
216
- logger.error(f"Could not download controlnet snapshot: {e}")
217
- if not os.path.exists(controlnet_path):
218
- raise
219
-
220
- # Device and precision setup
221
- accelerator = Accelerator(mixed_precision="fp16")
222
- weight_dtype = torch.float16 if accelerator.mixed_precision == "fp16" else torch.float32
223
- device = accelerator.device
224
-
225
- logger.info(f"Using device: {device}, dtype: {weight_dtype}")
226
-
227
- # Pretrained paths
228
- base_model_path = settings.BASE_MODEL_ID
229
- safetensors_ckpt = settings.LIGHTNING_WEIGHTS
230
- # controlnet_path already defined above
231
 
232
- # Load diffusion components
233
- logger.info("Loading VAE...")
234
- vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae")
235
- # Enable VAE slicing for memory efficiency
236
- vae.enable_slicing()
237
- vae.enable_tiling()
238
 
239
- logger.info("Loading UNet...")
240
- unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet")
241
- unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt)))
242
- # Enable attention slicing for memory efficiency
243
- unet.set_attention_slice("max")
244
-
245
- logger.info("Loading ControlNet...")
246
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype)
247
- # Enable attention slicing for ControlNet
248
- controlnet.set_attention_slice("max")
249
-
250
- logger.info("Creating pipeline...")
251
- pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
252
- base_model_path, vae=vae, unet=unet, controlnet=controlnet, torch_dtype=weight_dtype
253
  )
254
- pipe.safety_checker = None
255
-
256
- # Enable sequential CPU offloading to reduce memory usage
257
- logger.info("Enabling CPU offloading for memory efficiency...")
258
- pipe.enable_sequential_cpu_offload()
259
- # Alternative: use model CPU offload (moves entire model to CPU when not in use)
260
- # pipe.enable_model_cpu_offload()
261
 
262
- logger.info("Memory optimizations enabled")
263
-
264
- # Load BLIP captioning model (use base to save memory)
265
- logger.info("Loading BLIP captioning model (using base model for memory efficiency)...")
266
- caption_model_name = "blip-image-captioning-base"
267
- try:
268
- processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}")
269
- caption_model = BlipForConditionalGeneration.from_pretrained(
270
- f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype
271
- )
272
- # Keep BLIP on CPU and move to device only during inference
273
- caption_model.eval()
274
- except Exception as e:
275
- logger.error(f"Failed to load BLIP model: {e}")
276
- raise
277
-
278
- logger.info("✅ All models loaded successfully!")
279
  model_load_error = None
280
 
281
  except Exception as e:
282
  error_msg = str(e)
283
- logger.error(f"❌ Failed to load models: {error_msg}")
284
  model_load_error = error_msg
285
  # Don't raise - allow health check to work
286
 
@@ -288,11 +182,9 @@ async def startup_event():
288
  @app.on_event("shutdown")
289
  async def shutdown_event():
290
  """Cleanup on shutdown"""
291
- global pipe, caption_model
292
- if pipe:
293
- del pipe
294
- if caption_model:
295
- del caption_model
296
  logger.info("Application shutdown")
297
 
298
 
@@ -356,9 +248,9 @@ async def health_check():
356
  """Health check endpoint"""
357
  response = {
358
  "status": "healthy",
359
- "model_loaded": pipe is not None and caption_model is not None,
360
- "model_type": "sdxl_controlnet",
361
- "device": str(device) if device else None
362
  }
363
  if model_load_error:
364
  response["model_error"] = model_load_error
@@ -373,7 +265,7 @@ def colorize_image_sdxl(
373
  num_inference_steps: int = 8
374
  ) -> Tuple[Image.Image, str]:
375
  """
376
- Colorize a grayscale or low-color image using SDXL + ControlNet.
377
 
378
  Args:
379
  image: PIL Image to colorize
@@ -385,46 +277,52 @@ def colorize_image_sdxl(
385
  Returns:
386
  Tuple of (colorized PIL Image, caption string)
387
  """
388
- if pipe is None or caption_model is None:
389
- raise RuntimeError("Models not loaded")
390
 
391
- torch.manual_seed(seed)
392
  original_size = image.size
393
- control_image = image.convert("L").convert("RGB").resize((512, 512))
 
394
 
395
- # Image captioning - keep BLIP on CPU to save memory
396
- input_text = settings.CAPTION_PREFIX
397
- # Use CPU for BLIP to save GPU memory
398
- blip_device = torch.device("cpu")
399
- inputs = processor(control_image, input_text, return_tensors="pt").to(blip_device)
400
- with torch.no_grad():
401
- caption_ids = caption_model.generate(**inputs, max_length=50, num_beams=3)
402
- caption = processor.decode(caption_ids[0], skip_special_tokens=True)
403
- caption = remove_unlikely_words(caption)
404
 
405
- # Construct final prompt
406
- if positive_prompt:
407
- final_prompt = f"{positive_prompt}, {caption}"
 
 
408
  else:
409
- final_prompt = caption
410
 
411
- # Inference with memory-efficient settings
412
- with torch.no_grad():
413
- result = pipe(
 
 
 
414
  prompt=final_prompt,
415
- negative_prompt=negative_prompt or settings.NEGATIVE_PROMPT,
416
- num_inference_steps=num_inference_steps,
417
- generator=torch.manual_seed(seed),
418
- image=control_image,
419
- guidance_scale=7.5, # Lower guidance scale uses less memory
420
  )
421
-
422
- # Clear cache after inference
423
- if torch.cuda.is_available():
424
- torch.cuda.empty_cache()
425
-
426
- colorized = apply_color(control_image, result.images[0]).resize(original_size)
427
- return colorized, caption
 
 
 
 
 
 
 
 
 
428
 
429
 
430
  @app.post("/colorize")
@@ -440,8 +338,8 @@ async def colorize_api(
440
  Upload a grayscale image -> returns colorized image.
441
  Uses SDXL + ControlNet with automatic captioning.
442
  """
443
- if pipe is None or caption_model is None:
444
- raise HTTPException(status_code=503, detail="Colorization models not loaded")
445
 
446
  if not file.content_type or not file.content_type.startswith("image/"):
447
  raise HTTPException(status_code=400, detail="File must be an image")
@@ -505,8 +403,8 @@ def gradio_colorize(image, positive_prompt=None, negative_prompt=None, seed=123)
505
  if image is None:
506
  return None, ""
507
  try:
508
- if pipe is None or caption_model is None:
509
- return None, "Models not loaded"
510
  colorized, caption = colorize_image_sdxl(
511
  image,
512
  positive_prompt=positive_prompt,
@@ -520,7 +418,7 @@ def gradio_colorize(image, positive_prompt=None, negative_prompt=None, seed=123)
520
 
521
 
522
  title = "🎨 Text-Guided Image Colorization"
523
- description = "Upload a grayscale image and generate a color version guided by automatic captioning using SDXL + ControlNet."
524
 
525
  iface = gr.Interface(
526
  fn=gradio_colorize,
 
1
  """
2
+ FastAPI application for Text-Guided Image Colorization using Hugging Face Inference API
3
+ Uses fal-ai provider for memory-efficient inference
4
  """
5
  import os
 
 
 
 
 
 
 
 
 
6
  import io
7
  import uuid
8
  import logging
 
16
  import firebase_admin
17
  from firebase_admin import credentials, app_check, auth as firebase_auth
18
  from PIL import Image
 
19
  import uvicorn
20
  import gradio as gr
21
 
22
+ # Hugging Face Inference API
23
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  from app.config import settings
26
 
 
81
  app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
82
  app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
83
 
84
+ # Global Inference API client
85
+ inference_client = None
 
 
 
 
86
  model_load_error: Optional[str] = None
87
 
88
  # ========== Utility Functions ==========
 
152
 
153
  @app.on_event("startup")
154
  async def startup_event():
155
+ """Initialize Hugging Face Inference API client"""
156
+ global inference_client, model_load_error
157
 
158
  try:
159
+ logger.info("🔄 Initializing Hugging Face Inference API client...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ # Get HF token from environment or settings
162
+ hf_token = os.getenv("HF_TOKEN") or settings.HF_TOKEN
163
+ if not hf_token:
164
+ raise ValueError("HF_TOKEN environment variable is required for Inference API")
 
 
165
 
166
+ # Initialize InferenceClient with fal-ai provider
167
+ inference_client = InferenceClient(
168
+ provider="fal-ai",
169
+ api_key=hf_token,
 
 
 
 
 
 
 
 
 
 
170
  )
 
 
 
 
 
 
 
171
 
172
+ logger.info(" Inference API client initialized successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  model_load_error = None
174
 
175
  except Exception as e:
176
  error_msg = str(e)
177
+ logger.error(f"❌ Failed to initialize Inference API client: {error_msg}")
178
  model_load_error = error_msg
179
  # Don't raise - allow health check to work
180
 
 
182
  @app.on_event("shutdown")
183
  async def shutdown_event():
184
  """Cleanup on shutdown"""
185
+ global inference_client
186
+ if inference_client:
187
+ inference_client = None
 
 
188
  logger.info("Application shutdown")
189
 
190
 
 
248
  """Health check endpoint"""
249
  response = {
250
  "status": "healthy",
251
+ "model_loaded": inference_client is not None,
252
+ "model_type": "hf_inference_api",
253
+ "provider": "fal-ai"
254
  }
255
  if model_load_error:
256
  response["model_error"] = model_load_error
 
265
  num_inference_steps: int = 8
266
  ) -> Tuple[Image.Image, str]:
267
  """
268
+ Colorize a grayscale or low-color image using Hugging Face Inference API.
269
 
270
  Args:
271
  image: PIL Image to colorize
 
277
  Returns:
278
  Tuple of (colorized PIL Image, caption string)
279
  """
280
+ if inference_client is None:
281
+ raise RuntimeError("Inference API client not initialized")
282
 
 
283
  original_size = image.size
284
+ # Resize to 512x512 for inference (FLUX models work well at this size)
285
+ control_image = image.convert("RGB").resize((512, 512))
286
 
287
+ # Convert image to bytes for API
288
+ img_bytes = io.BytesIO()
289
+ control_image.save(img_bytes, format="PNG")
290
+ img_bytes.seek(0)
291
+ input_image = img_bytes.read()
 
 
 
 
292
 
293
+ # Construct prompt
294
+ base_prompt = positive_prompt or "colorize this image with vibrant natural colors, high quality"
295
+ if negative_prompt:
296
+ # Note: Some models may not support negative_prompt directly
297
+ final_prompt = f"{base_prompt}. Avoid: {negative_prompt}"
298
  else:
299
+ final_prompt = base_prompt
300
 
301
+ # Use Inference API for image-to-image generation
302
+ model_name = settings.INFERENCE_MODEL
303
+ logger.info(f"Calling Inference API with model {model_name}, prompt: {final_prompt}")
304
+ try:
305
+ result_image = inference_client.image_to_image(
306
+ input_image,
307
  prompt=final_prompt,
308
+ model=model_name,
 
 
 
 
309
  )
310
+
311
+ # Resize back to original size
312
+ if isinstance(result_image, Image.Image):
313
+ colorized = result_image.resize(original_size)
314
+ else:
315
+ # If it's bytes, convert to PIL Image
316
+ colorized = Image.open(io.BytesIO(result_image)).resize(original_size)
317
+
318
+ # Generate a simple caption from the prompt
319
+ caption = final_prompt[:100] # Truncate for display
320
+
321
+ return colorized, caption
322
+
323
+ except Exception as e:
324
+ logger.error(f"Inference API error: {e}")
325
+ raise RuntimeError(f"Failed to colorize image: {str(e)}")
326
 
327
 
328
  @app.post("/colorize")
 
338
  Upload a grayscale image -> returns colorized image.
339
  Uses SDXL + ControlNet with automatic captioning.
340
  """
341
+ if inference_client is None:
342
+ raise HTTPException(status_code=503, detail="Inference API client not initialized")
343
 
344
  if not file.content_type or not file.content_type.startswith("image/"):
345
  raise HTTPException(status_code=400, detail="File must be an image")
 
403
  if image is None:
404
  return None, ""
405
  try:
406
+ if inference_client is None:
407
+ return None, "Inference API client not initialized"
408
  colorized, caption = colorize_image_sdxl(
409
  image,
410
  positive_prompt=positive_prompt,
 
418
 
419
 
420
  title = "🎨 Text-Guided Image Colorization"
421
+ description = "Upload a grayscale image and generate a color version using Hugging Face Inference API (fal-ai provider)."
422
 
423
  iface = gr.Interface(
424
  fn=gradio_colorize,