|
|
""" |
|
|
FastAPI application for FastAI GAN Image Colorization |
|
|
with Firebase Authentication and Gradio UI |
|
|
""" |
|
|
import os |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
os.environ["HF_HOME"] = "/tmp/hf_cache" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache" |
|
|
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config" |
|
|
|
|
|
import io |
|
|
import uuid |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, app_check, auth as firebase_auth |
|
|
from PIL import Image |
|
|
import torch |
|
|
import uvicorn |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
|
|
|
from fastai.vision.all import * |
|
|
from huggingface_hub import from_pretrained_fastai |
|
|
|
|
|
from app.config import settings |
|
|
from app.pytorch_colorizer import PyTorchColorizer |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="FastAI Image Colorizer API", |
|
|
description="Image colorization using FastAI GAN model with Firebase authentication", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json") |
|
|
if os.path.exists(firebase_cred_path): |
|
|
try: |
|
|
cred = credentials.Certificate(firebase_cred_path) |
|
|
firebase_admin.initialize_app(cred) |
|
|
logger.info("Firebase Admin SDK initialized") |
|
|
except Exception as e: |
|
|
logger.warning("Failed to initialize Firebase: %s", str(e)) |
|
|
try: |
|
|
firebase_admin.initialize_app() |
|
|
except: |
|
|
pass |
|
|
else: |
|
|
logger.warning("Firebase credentials file not found. App Check will be disabled.") |
|
|
try: |
|
|
firebase_admin.initialize_app() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
UPLOAD_DIR = Path("/tmp/colorize_uploads") |
|
|
RESULT_DIR = Path("/tmp/colorize_results") |
|
|
|
|
|
|
|
|
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") |
|
|
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") |
|
|
|
|
|
|
|
|
learn = None |
|
|
pytorch_colorizer = None |
|
|
model_load_error: Optional[str] = None |
|
|
model_type: str = "none" |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load FastAI or PyTorch model on startup""" |
|
|
global learn, pytorch_colorizer, model_load_error, model_type |
|
|
model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("🔄 Attempting to load FastAI GAN Colorization Model: %s", model_id) |
|
|
learn = from_pretrained_fastai(model_id) |
|
|
logger.info("✅ FastAI model loaded successfully!") |
|
|
model_type = "fastai" |
|
|
model_load_error = None |
|
|
return |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.warning("⚠️ FastAI model loading failed: %s. Trying PyTorch fallback...", error_msg) |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("🔄 Attempting to load PyTorch GAN Colorization Model: %s", model_id) |
|
|
pytorch_colorizer = PyTorchColorizer(model_id=model_id, model_filename="generator.pt") |
|
|
logger.info("✅ PyTorch model loaded successfully!") |
|
|
model_type = "pytorch" |
|
|
model_load_error = None |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.error("❌ Failed to load both FastAI and PyTorch models: %s", error_msg) |
|
|
model_load_error = error_msg |
|
|
model_type = "none" |
|
|
|
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup on shutdown""" |
|
|
global learn, pytorch_colorizer |
|
|
if learn: |
|
|
del learn |
|
|
if pytorch_colorizer: |
|
|
del pytorch_colorizer |
|
|
logger.info("Application shutdown") |
|
|
|
|
|
def _extract_bearer_token(authorization_header: str | None) -> str | None: |
|
|
if not authorization_header: |
|
|
return None |
|
|
parts = authorization_header.split(" ", 1) |
|
|
if len(parts) == 2 and parts[0].lower() == "bearer": |
|
|
return parts[1].strip() |
|
|
return None |
|
|
|
|
|
async def verify_request(request: Request): |
|
|
""" |
|
|
Verify Firebase authentication |
|
|
Accept either: |
|
|
- Firebase Auth id_token via Authorization: Bearer <id_token> |
|
|
- Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true) |
|
|
""" |
|
|
|
|
|
if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": |
|
|
return True |
|
|
|
|
|
|
|
|
bearer = _extract_bearer_token(request.headers.get("Authorization")) |
|
|
if bearer: |
|
|
try: |
|
|
decoded = firebase_auth.verify_id_token(bearer) |
|
|
request.state.user = decoded |
|
|
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("Auth token verification failed: %s", str(e)) |
|
|
|
|
|
|
|
|
if settings.ENABLE_APP_CHECK: |
|
|
app_check_token = request.headers.get("X-Firebase-AppCheck") |
|
|
if not app_check_token: |
|
|
raise HTTPException(status_code=401, detail="Missing App Check token") |
|
|
try: |
|
|
app_check_claims = app_check.verify_token(app_check_token) |
|
|
logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("App Check token verification failed: %s", str(e)) |
|
|
raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
@app.get("/api") |
|
|
async def api_info(): |
|
|
"""API info endpoint""" |
|
|
return { |
|
|
"app": "FastAI Image Colorizer API", |
|
|
"version": "1.0.0", |
|
|
"health": "/health", |
|
|
"colorize": "/colorize", |
|
|
"gradio": "/" |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
model_loaded = (learn is not None) or (pytorch_colorizer is not None) |
|
|
response = { |
|
|
"status": "healthy", |
|
|
"model_loaded": model_loaded, |
|
|
"model_type": model_type, |
|
|
"model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model"), |
|
|
"using_fallback": not model_loaded |
|
|
} |
|
|
if model_load_error: |
|
|
response["model_error"] = model_load_error |
|
|
response["message"] = "Model failed to load. Using fallback colorization method." |
|
|
elif not model_loaded: |
|
|
response["message"] = "No model loaded. Using fallback colorization method." |
|
|
else: |
|
|
response["message"] = f"Model loaded successfully ({model_type})" |
|
|
return response |
|
|
|
|
|
def simple_colorize_fallback(image: Image.Image) -> Image.Image: |
|
|
""" |
|
|
Enhanced fallback colorization using LAB color space with better color hints |
|
|
This provides basic colorization when the model doesn't load |
|
|
Note: This is a simple heuristic-based approach and won't match trained models |
|
|
""" |
|
|
|
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
img_array = np.array(image) |
|
|
original_shape = img_array.shape |
|
|
|
|
|
|
|
|
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB) |
|
|
|
|
|
|
|
|
l, a, b = cv2.split(lab) |
|
|
|
|
|
|
|
|
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) |
|
|
l_enhanced = clahe.apply(l) |
|
|
|
|
|
|
|
|
|
|
|
l_normalized = l.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
brightness_mask = np.clip((l_normalized - 0.3) * 2, 0, 1) |
|
|
|
|
|
|
|
|
a_hint = np.clip(a.astype(np.float32) + brightness_mask * 8 + (1 - brightness_mask) * 2, 0, 255).astype(np.uint8) |
|
|
b_hint = np.clip(b.astype(np.float32) + brightness_mask * 12 + (1 - brightness_mask) * 3, 0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
lab_colored = cv2.merge([l_enhanced, a_hint, b_hint]) |
|
|
colored_rgb = cv2.cvtColor(lab_colored, cv2.COLOR_LAB2RGB) |
|
|
|
|
|
|
|
|
hsv = cv2.cvtColor(colored_rgb, cv2.COLOR_RGB2HSV) |
|
|
hsv[:, :, 1] = np.clip(hsv[:, :, 1].astype(np.float32) * 1.2, 0, 255).astype(np.uint8) |
|
|
colored_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) |
|
|
|
|
|
return Image.fromarray(colored_rgb) |
|
|
|
|
|
|
|
|
def colorize_pil(image: Image.Image) -> Image.Image: |
|
|
"""Run model prediction and return colorized image""" |
|
|
|
|
|
if learn is not None: |
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
pred = learn.predict(image) |
|
|
|
|
|
if isinstance(pred, (list, tuple)): |
|
|
colorized = pred[0] if len(pred) > 0 else image |
|
|
else: |
|
|
colorized = pred |
|
|
|
|
|
|
|
|
if not isinstance(colorized, Image.Image): |
|
|
if isinstance(colorized, torch.Tensor): |
|
|
|
|
|
if colorized.dim() == 4: |
|
|
colorized = colorized[0] |
|
|
if colorized.dim() == 3: |
|
|
colorized = colorized.permute(1, 2, 0).cpu() |
|
|
if colorized.dtype in (torch.float32, torch.float16): |
|
|
colorized = torch.clamp(colorized, 0, 1) |
|
|
colorized = (colorized * 255).byte() |
|
|
colorized = Image.fromarray(colorized.numpy(), 'RGB') |
|
|
else: |
|
|
raise ValueError(f"Unexpected tensor shape: {colorized.shape}") |
|
|
else: |
|
|
raise ValueError(f"Unexpected prediction type: {type(colorized)}") |
|
|
|
|
|
if colorized.mode != "RGB": |
|
|
colorized = colorized.convert("RGB") |
|
|
|
|
|
return colorized |
|
|
|
|
|
|
|
|
elif pytorch_colorizer is not None: |
|
|
return pytorch_colorizer.colorize(image) |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info("No model loaded, using enhanced colorization fallback (LAB color space method)") |
|
|
return simple_colorize_fallback(image) |
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize_api( |
|
|
file: UploadFile = File(...), |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
""" |
|
|
Upload a black & white image -> returns colorized image. |
|
|
Requires Firebase authentication unless DISABLE_AUTH=true |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not file.content_type or not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
try: |
|
|
img_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
|
|
|
logger.info("Colorizing image...") |
|
|
colorized = colorize_pil(image) |
|
|
|
|
|
output_filename = f"{uuid.uuid4()}.png" |
|
|
output_path = RESULT_DIR / output_filename |
|
|
colorized.save(output_path, "PNG") |
|
|
|
|
|
logger.info("Colorized image saved: %s", output_filename) |
|
|
|
|
|
|
|
|
return FileResponse( |
|
|
output_path, |
|
|
media_type="image/png", |
|
|
filename=f"colorized_{output_filename}" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error("Error colorizing image: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_colorize(image): |
|
|
"""Gradio colorization function""" |
|
|
if image is None: |
|
|
return None |
|
|
try: |
|
|
|
|
|
return colorize_pil(image) |
|
|
except Exception as e: |
|
|
logger.error("Gradio colorization error: %s", str(e)) |
|
|
return None |
|
|
|
|
|
title = "🎨 Image Colorizer" |
|
|
description = "Upload a black & white photo to generate a colorized version. Uses AI model when available, otherwise uses enhanced colorization fallback." |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=gradio_colorize, |
|
|
inputs=gr.Image(type="pil", label="Upload B&W Image"), |
|
|
outputs=gr.Image(type="pil", label="Colorized Image"), |
|
|
title=title, |
|
|
description=description, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, iface, path="/") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", "7860")) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|
|
|
|