Refactor to use Hugging Face Inference API with fal-ai provider - Replace local model loading with InferenceClient API - Remove heavy SDXL/ControlNet/BLIP model dependencies - Use FLUX.1-Kontext-dev model via API - Keep FastAPI and Firebase authentication - Significantly reduce memory usage (no local models)
ae9bbd0
| """ | |
| FastAPI application for Text-Guided Image Colorization using Hugging Face Inference API | |
| Uses fal-ai provider for memory-efficient inference | |
| """ | |
| import os | |
| import io | |
| import uuid | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| import firebase_admin | |
| from firebase_admin import credentials, app_check, auth as firebase_auth | |
| from PIL import Image | |
| import uvicorn | |
| import gradio as gr | |
| # Hugging Face Inference API | |
| from huggingface_hub import InferenceClient | |
| from app.config import settings | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Create writable directories | |
| Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True) | |
| Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True) | |
| Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True) | |
| Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Text-Guided Image Colorization API", | |
| description="Image colorization using SDXL + ControlNet with automatic captioning", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize Firebase Admin SDK | |
| firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json") | |
| if os.path.exists(firebase_cred_path): | |
| try: | |
| cred = credentials.Certificate(firebase_cred_path) | |
| firebase_admin.initialize_app(cred) | |
| logger.info("Firebase Admin SDK initialized") | |
| except Exception as e: | |
| logger.warning("Failed to initialize Firebase: %s", str(e)) | |
| try: | |
| firebase_admin.initialize_app() | |
| except: | |
| pass | |
| else: | |
| logger.warning("Firebase credentials file not found. App Check will be disabled.") | |
| try: | |
| firebase_admin.initialize_app() | |
| except: | |
| pass | |
| # Storage directories | |
| UPLOAD_DIR = Path("/tmp/colorize_uploads") | |
| RESULT_DIR = Path("/tmp/colorize_results") | |
| # Mount static files | |
| app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") | |
| app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") | |
| # Global Inference API client | |
| inference_client = None | |
| model_load_error: Optional[str] = None | |
| # ========== Utility Functions ========== | |
| def apply_color(image: Image.Image, color_map: Image.Image) -> Image.Image: | |
| """Apply color from color_map to image using LAB color space.""" | |
| # Convert to LAB color space | |
| image_lab = image.convert('LAB') | |
| color_map_lab = color_map.convert('LAB') | |
| # Extract and merge LAB channels | |
| l, _, _ = image_lab.split() | |
| _, a_map, b_map = color_map_lab.split() | |
| merged_lab = Image.merge('LAB', (l, a_map, b_map)) | |
| return merged_lab.convert('RGB') | |
| def remove_unlikely_words(prompt: str) -> str: | |
| """Removes predefined unlikely phrases from prompt text.""" | |
| unlikely_words = [] | |
| a1 = [f'{i}s' for i in range(1900, 2000)] | |
| a2 = [f'{i}' for i in range(1900, 2000)] | |
| a3 = [f'year {i}' for i in range(1900, 2000)] | |
| a4 = [f'circa {i}' for i in range(1900, 2000)] | |
| b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1] | |
| b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] | |
| b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] | |
| b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] | |
| manual = [ | |
| "black and white,", "black and white", "black & white,", "black & white", "circa", | |
| "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", | |
| "black - and - white photography,", "monochrome bw,", "black white,", "black an white,", | |
| "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", | |
| "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", | |
| "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", | |
| "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", | |
| "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", | |
| "black-and-white photo,", "black-and-white photo", "black - and - white photography", | |
| "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", | |
| "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", | |
| "black - and - white photograph,", "black - and - white photograph", "black on white,", | |
| "black on white", "black-and-white", "historical image,", "historical picture,", | |
| "historical photo,", "historical photograph,", "archival photo,", "taken in the early", | |
| "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", | |
| "historical photo", "historical setting,", | |
| "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", | |
| "taken in", "shot on leica", "shot on leica sl2", "sl2", | |
| "taken with a leica camera", "leica sl2", "leica", "setting", | |
| "overcast day", "overcast weather", "slight overcast", "overcast", | |
| "picture taken in", "photo taken in", | |
| ", photo", ", photo", ", photo", ", photo", ", photograph", | |
| ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", | |
| ] | |
| unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual) | |
| for word in unlikely_words: | |
| prompt = prompt.replace(word, "") | |
| return prompt | |
| # ========== Model Loading ========== | |
| async def startup_event(): | |
| """Initialize Hugging Face Inference API client""" | |
| global inference_client, model_load_error | |
| try: | |
| logger.info("🔄 Initializing Hugging Face Inference API client...") | |
| # Get HF token from environment or settings | |
| hf_token = os.getenv("HF_TOKEN") or settings.HF_TOKEN | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN environment variable is required for Inference API") | |
| # Initialize InferenceClient with fal-ai provider | |
| inference_client = InferenceClient( | |
| provider="fal-ai", | |
| api_key=hf_token, | |
| ) | |
| logger.info("✅ Inference API client initialized successfully!") | |
| model_load_error = None | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"❌ Failed to initialize Inference API client: {error_msg}") | |
| model_load_error = error_msg | |
| # Don't raise - allow health check to work | |
| async def shutdown_event(): | |
| """Cleanup on shutdown""" | |
| global inference_client | |
| if inference_client: | |
| inference_client = None | |
| logger.info("Application shutdown") | |
| # ========== Authentication ========== | |
| def _extract_bearer_token(authorization_header: str | None) -> str | None: | |
| if not authorization_header: | |
| return None | |
| parts = authorization_header.split(" ", 1) | |
| if len(parts) == 2 and parts[0].lower() == "bearer": | |
| return parts[1].strip() | |
| return None | |
| async def verify_request(request: Request): | |
| """Verify Firebase authentication""" | |
| if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": | |
| return True | |
| bearer = _extract_bearer_token(request.headers.get("Authorization")) | |
| if bearer: | |
| try: | |
| decoded = firebase_auth.verify_id_token(bearer) | |
| request.state.user = decoded | |
| logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) | |
| return True | |
| except Exception as e: | |
| logger.warning("Auth token verification failed: %s", str(e)) | |
| if settings.ENABLE_APP_CHECK: | |
| app_check_token = request.headers.get("X-Firebase-AppCheck") | |
| if not app_check_token: | |
| raise HTTPException(status_code=401, detail="Missing App Check token") | |
| try: | |
| app_check_claims = app_check.verify_token(app_check_token) | |
| logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) | |
| return True | |
| except Exception as e: | |
| logger.warning("App Check token verification failed: %s", str(e)) | |
| raise HTTPException(status_code=401, detail="Invalid App Check token") | |
| return True | |
| # ========== API Endpoints ========== | |
| async def api_info(): | |
| """API info endpoint""" | |
| return { | |
| "app": "Text-Guided Image Colorization API", | |
| "version": "1.0.0", | |
| "health": "/health", | |
| "colorize": "/colorize", | |
| "gradio": "/" | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| response = { | |
| "status": "healthy", | |
| "model_loaded": inference_client is not None, | |
| "model_type": "hf_inference_api", | |
| "provider": "fal-ai" | |
| } | |
| if model_load_error: | |
| response["model_error"] = model_load_error | |
| return response | |
| def colorize_image_sdxl( | |
| image: Image.Image, | |
| positive_prompt: Optional[str] = None, | |
| negative_prompt: Optional[str] = None, | |
| seed: int = 123, | |
| num_inference_steps: int = 8 | |
| ) -> Tuple[Image.Image, str]: | |
| """ | |
| Colorize a grayscale or low-color image using Hugging Face Inference API. | |
| Args: | |
| image: PIL Image to colorize | |
| positive_prompt: Additional descriptive text to enhance the caption | |
| negative_prompt: Words or phrases to avoid during generation | |
| seed: Random seed for reproducible generation | |
| num_inference_steps: Number of inference steps | |
| Returns: | |
| Tuple of (colorized PIL Image, caption string) | |
| """ | |
| if inference_client is None: | |
| raise RuntimeError("Inference API client not initialized") | |
| original_size = image.size | |
| # Resize to 512x512 for inference (FLUX models work well at this size) | |
| control_image = image.convert("RGB").resize((512, 512)) | |
| # Convert image to bytes for API | |
| img_bytes = io.BytesIO() | |
| control_image.save(img_bytes, format="PNG") | |
| img_bytes.seek(0) | |
| input_image = img_bytes.read() | |
| # Construct prompt | |
| base_prompt = positive_prompt or "colorize this image with vibrant natural colors, high quality" | |
| if negative_prompt: | |
| # Note: Some models may not support negative_prompt directly | |
| final_prompt = f"{base_prompt}. Avoid: {negative_prompt}" | |
| else: | |
| final_prompt = base_prompt | |
| # Use Inference API for image-to-image generation | |
| model_name = settings.INFERENCE_MODEL | |
| logger.info(f"Calling Inference API with model {model_name}, prompt: {final_prompt}") | |
| try: | |
| result_image = inference_client.image_to_image( | |
| input_image, | |
| prompt=final_prompt, | |
| model=model_name, | |
| ) | |
| # Resize back to original size | |
| if isinstance(result_image, Image.Image): | |
| colorized = result_image.resize(original_size) | |
| else: | |
| # If it's bytes, convert to PIL Image | |
| colorized = Image.open(io.BytesIO(result_image)).resize(original_size) | |
| # Generate a simple caption from the prompt | |
| caption = final_prompt[:100] # Truncate for display | |
| return colorized, caption | |
| except Exception as e: | |
| logger.error(f"Inference API error: {e}") | |
| raise RuntimeError(f"Failed to colorize image: {str(e)}") | |
| async def colorize_api( | |
| file: UploadFile = File(...), | |
| positive_prompt: Optional[str] = None, | |
| negative_prompt: Optional[str] = None, | |
| seed: int = 123, | |
| num_inference_steps: int = 8, | |
| verified: bool = Depends(verify_request) | |
| ): | |
| """ | |
| Upload a grayscale image -> returns colorized image. | |
| Uses SDXL + ControlNet with automatic captioning. | |
| """ | |
| if inference_client is None: | |
| raise HTTPException(status_code=503, detail="Inference API client not initialized") | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| img_bytes = await file.read() | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| logger.info("Colorizing image with SDXL + ControlNet...") | |
| colorized, caption = colorize_image_sdxl( | |
| image, | |
| positive_prompt=positive_prompt, | |
| negative_prompt=negative_prompt, | |
| seed=seed, | |
| num_inference_steps=num_inference_steps | |
| ) | |
| output_filename = f"{uuid.uuid4()}.png" | |
| output_path = RESULT_DIR / output_filename | |
| colorized.save(output_path, "PNG") | |
| logger.info("Colorized image saved: %s", output_filename) | |
| return JSONResponse({ | |
| "success": True, | |
| "result_id": output_filename.replace(".png", ""), | |
| "caption": caption, | |
| "download_url": f"/results/{output_filename}", | |
| "api_download": f"/download/{output_filename.replace('.png', '')}" | |
| }) | |
| except Exception as e: | |
| logger.error("Error colorizing image: %s", str(e)) | |
| raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}") | |
| def download_result(file_id: str, verified: bool = Depends(verify_request)): | |
| """Download colorized image by file ID""" | |
| filename = f"{file_id}.png" | |
| path = RESULT_DIR / filename | |
| if not path.exists(): | |
| raise HTTPException(status_code=404, detail="Result not found") | |
| return FileResponse(path, media_type="image/png") | |
| def get_result(filename: str): | |
| """Public endpoint to access colorized images""" | |
| path = RESULT_DIR / filename | |
| if not path.exists(): | |
| raise HTTPException(status_code=404, detail="Result not found") | |
| return FileResponse(path, media_type="image/png") | |
| # ========== Gradio Interface (Optional) ========== | |
| def gradio_colorize(image, positive_prompt=None, negative_prompt=None, seed=123): | |
| """Gradio colorization function""" | |
| if image is None: | |
| return None, "" | |
| try: | |
| if inference_client is None: | |
| return None, "Inference API client not initialized" | |
| colorized, caption = colorize_image_sdxl( | |
| image, | |
| positive_prompt=positive_prompt, | |
| negative_prompt=negative_prompt, | |
| seed=seed | |
| ) | |
| return colorized, caption | |
| except Exception as e: | |
| logger.error("Gradio colorization error: %s", str(e)) | |
| return None, str(e) | |
| title = "🎨 Text-Guided Image Colorization" | |
| description = "Upload a grayscale image and generate a color version using Hugging Face Inference API (fal-ai provider)." | |
| iface = gr.Interface( | |
| fn=gradio_colorize, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"), | |
| gr.Textbox(label="Negative Prompt", value=settings.NEGATIVE_PROMPT), | |
| gr.Slider(0, 1000, 123, label="Seed") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Colorized Image"), | |
| gr.Textbox(label="Caption", show_copy_button=True) | |
| ], | |
| title=title, | |
| description=description, | |
| ) | |
| # Mount Gradio app at root | |
| app = gr.mount_gradio_app(app, iface, path="/") | |
| # ========== Run Server ========== | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |