|
|
""" |
|
|
FastAPI application for FastAI GAN Image Colorization |
|
|
with Firebase Authentication and Gradio UI |
|
|
""" |
|
|
import os |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
os.environ["HF_HOME"] = "/tmp/hf_cache" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache" |
|
|
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config" |
|
|
os.environ["GRADIO_TEMP_DIR"] = "/tmp/gradio" |
|
|
|
|
|
import io |
|
|
import uuid |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, app_check, auth as firebase_auth |
|
|
from PIL import Image |
|
|
import torch |
|
|
import uvicorn |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
from fastai.vision.all import * |
|
|
from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files |
|
|
|
|
|
from app.config import settings |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True) |
|
|
Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="FastAI Image Colorizer API", |
|
|
description="Image colorization using FastAI GAN model with Firebase authentication", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json") |
|
|
if os.path.exists(firebase_cred_path): |
|
|
try: |
|
|
cred = credentials.Certificate(firebase_cred_path) |
|
|
firebase_admin.initialize_app(cred) |
|
|
logger.info("Firebase Admin SDK initialized") |
|
|
except Exception as e: |
|
|
logger.warning("Failed to initialize Firebase: %s", str(e)) |
|
|
try: |
|
|
firebase_admin.initialize_app() |
|
|
except: |
|
|
pass |
|
|
else: |
|
|
logger.warning("Firebase credentials file not found. App Check will be disabled.") |
|
|
try: |
|
|
firebase_admin.initialize_app() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
UPLOAD_DIR = Path("/tmp/colorize_uploads") |
|
|
RESULT_DIR = Path("/tmp/colorize_results") |
|
|
|
|
|
|
|
|
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") |
|
|
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") |
|
|
|
|
|
|
|
|
learn = None |
|
|
model_load_error: Optional[str] = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load FastAI model on startup""" |
|
|
global learn, model_load_error |
|
|
try: |
|
|
model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model") |
|
|
logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id) |
|
|
|
|
|
|
|
|
try: |
|
|
learn = from_pretrained_fastai(model_id) |
|
|
logger.info("✅ Model loaded successfully via from_pretrained_fastai!") |
|
|
model_load_error = None |
|
|
except Exception as e1: |
|
|
logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1)) |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
|
|
|
|
|
|
model_filenames = [] |
|
|
model_type = "fastai" |
|
|
|
|
|
try: |
|
|
repo_files = list_repo_files(repo_id=model_id, token=hf_token) |
|
|
logger.info("Repository files: %s", repo_files) |
|
|
pkl_files = [f for f in repo_files if f.endswith('.pkl')] |
|
|
pt_files = [f for f in repo_files if f.endswith('.pt')] |
|
|
|
|
|
if pkl_files: |
|
|
model_filenames = pkl_files |
|
|
logger.info("Found .pkl files in repository: %s", pkl_files) |
|
|
model_type = "fastai" |
|
|
elif pt_files: |
|
|
model_filenames = pt_files |
|
|
logger.info("Found .pt files in repository: %s", pt_files) |
|
|
model_type = "pytorch" |
|
|
else: |
|
|
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"] |
|
|
model_type = "fastai" |
|
|
except Exception as list_err: |
|
|
logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err)) |
|
|
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"] |
|
|
model_type = "fastai" |
|
|
|
|
|
|
|
|
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") |
|
|
model_path = None |
|
|
for filename in model_filenames: |
|
|
try: |
|
|
model_path = hf_hub_download( |
|
|
repo_id=model_id, |
|
|
filename=filename, |
|
|
cache_dir=cache_dir, |
|
|
token=hf_token |
|
|
) |
|
|
logger.info("Found model file: %s", filename) |
|
|
if filename.endswith('.pt'): |
|
|
model_type = "pytorch" |
|
|
elif filename.endswith('.pkl'): |
|
|
model_type = "fastai" |
|
|
break |
|
|
except Exception as dl_err: |
|
|
logger.debug("Failed to download %s: %s", filename, str(dl_err)) |
|
|
continue |
|
|
|
|
|
if model_path and os.path.exists(model_path): |
|
|
if model_type == "pytorch": |
|
|
error_msg = ( |
|
|
f"Repository '{model_id}' contains a PyTorch model (.pt file), " |
|
|
f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. " |
|
|
f"Please use a FastAI-compatible colorization model, or switch to a different model backend." |
|
|
) |
|
|
logger.error(error_msg) |
|
|
model_load_error = error_msg |
|
|
raise RuntimeError(error_msg) |
|
|
else: |
|
|
logger.info("Loading FastAI model from: %s", model_path) |
|
|
learn = load_learner(model_path) |
|
|
logger.info("✅ Model loaded successfully from %s", model_path) |
|
|
model_load_error = None |
|
|
else: |
|
|
error_msg = ( |
|
|
f"Could not find model file in repository '{model_id}'. " |
|
|
f"Tried: {', '.join(model_filenames)}. " |
|
|
f"Original error: {str(e1)}" |
|
|
) |
|
|
logger.error(error_msg) |
|
|
model_load_error = error_msg |
|
|
raise RuntimeError(error_msg) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
if not model_load_error: |
|
|
model_load_error = error_msg |
|
|
logger.error("❌ Failed to load model: %s", error_msg) |
|
|
|
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup on shutdown""" |
|
|
global learn |
|
|
if learn: |
|
|
del learn |
|
|
logger.info("Application shutdown") |
|
|
|
|
|
def _extract_bearer_token(authorization_header: str | None) -> str | None: |
|
|
if not authorization_header: |
|
|
return None |
|
|
parts = authorization_header.split(" ", 1) |
|
|
if len(parts) == 2 and parts[0].lower() == "bearer": |
|
|
return parts[1].strip() |
|
|
return None |
|
|
|
|
|
async def verify_request(request: Request): |
|
|
""" |
|
|
Verify Firebase authentication |
|
|
Accept either: |
|
|
- Firebase Auth id_token via Authorization: Bearer <id_token> |
|
|
- Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true) |
|
|
""" |
|
|
|
|
|
if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": |
|
|
return True |
|
|
|
|
|
|
|
|
bearer = _extract_bearer_token(request.headers.get("Authorization")) |
|
|
if bearer: |
|
|
try: |
|
|
decoded = firebase_auth.verify_id_token(bearer) |
|
|
request.state.user = decoded |
|
|
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("Auth token verification failed: %s", str(e)) |
|
|
|
|
|
|
|
|
if settings.ENABLE_APP_CHECK: |
|
|
app_check_token = request.headers.get("X-Firebase-AppCheck") |
|
|
if not app_check_token: |
|
|
raise HTTPException(status_code=401, detail="Missing App Check token") |
|
|
try: |
|
|
app_check_claims = app_check.verify_token(app_check_token) |
|
|
logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning("App Check token verification failed: %s", str(e)) |
|
|
raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
@app.get("/api") |
|
|
async def api_info(): |
|
|
"""API info endpoint""" |
|
|
return { |
|
|
"app": "FastAI Image Colorizer API", |
|
|
"version": "1.0.0", |
|
|
"health": "/health", |
|
|
"colorize": "/colorize", |
|
|
"gradio": "/" |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
response = { |
|
|
"status": "healthy", |
|
|
"model_loaded": learn is not None, |
|
|
"model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model") |
|
|
} |
|
|
if model_load_error: |
|
|
response["model_error"] = model_load_error |
|
|
return response |
|
|
|
|
|
def colorize_pil(image: Image.Image) -> Image.Image: |
|
|
"""Run model prediction and return colorized image""" |
|
|
if learn is None: |
|
|
raise RuntimeError("Model not loaded") |
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
pred = learn.predict(image) |
|
|
|
|
|
if isinstance(pred, (list, tuple)): |
|
|
colorized = pred[0] if len(pred) > 0 else image |
|
|
else: |
|
|
colorized = pred |
|
|
|
|
|
|
|
|
if not isinstance(colorized, Image.Image): |
|
|
if isinstance(colorized, torch.Tensor): |
|
|
|
|
|
if colorized.dim() == 4: |
|
|
colorized = colorized[0] |
|
|
if colorized.dim() == 3: |
|
|
colorized = colorized.permute(1, 2, 0).cpu() |
|
|
if colorized.dtype in (torch.float32, torch.float16): |
|
|
colorized = torch.clamp(colorized, 0, 1) |
|
|
colorized = (colorized * 255).byte() |
|
|
colorized = Image.fromarray(colorized.numpy(), 'RGB') |
|
|
else: |
|
|
raise ValueError(f"Unexpected tensor shape: {colorized.shape}") |
|
|
else: |
|
|
raise ValueError(f"Unexpected prediction type: {type(colorized)}") |
|
|
|
|
|
if colorized.mode != "RGB": |
|
|
colorized = colorized.convert("RGB") |
|
|
|
|
|
return colorized |
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize_api( |
|
|
file: UploadFile = File(...), |
|
|
verified: bool = Depends(verify_request) |
|
|
): |
|
|
""" |
|
|
Upload a black & white image -> returns colorized image. |
|
|
Requires Firebase authentication unless DISABLE_AUTH=true |
|
|
""" |
|
|
if learn is None: |
|
|
raise HTTPException(status_code=503, detail="Colorization model not loaded") |
|
|
|
|
|
if not file.content_type or not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
try: |
|
|
img_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
|
|
|
logger.info("Colorizing image...") |
|
|
colorized = colorize_pil(image) |
|
|
|
|
|
output_filename = f"{uuid.uuid4()}.png" |
|
|
output_path = RESULT_DIR / output_filename |
|
|
colorized.save(output_path, "PNG") |
|
|
|
|
|
logger.info("Colorized image saved: %s", output_filename) |
|
|
|
|
|
|
|
|
return FileResponse( |
|
|
output_path, |
|
|
media_type="image/png", |
|
|
filename=f"colorized_{output_filename}" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error("Error colorizing image: %s", str(e)) |
|
|
raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_colorize(image): |
|
|
"""Gradio colorization function""" |
|
|
if image is None: |
|
|
return None |
|
|
try: |
|
|
if learn is None: |
|
|
return None |
|
|
return colorize_pil(image) |
|
|
except Exception as e: |
|
|
logger.error("Gradio colorization error: %s", str(e)) |
|
|
return None |
|
|
|
|
|
title = "🎨 FastAI GAN Image Colorizer" |
|
|
description = "Upload a black & white photo to generate a colorized version using the FastAI GAN model." |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=gradio_colorize, |
|
|
inputs=gr.Image(type="pil", label="Upload B&W Image"), |
|
|
outputs=gr.Image(type="pil", label="Colorized Image"), |
|
|
title=title, |
|
|
description=description, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, iface, path="/") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", "7860")) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|
|
|
|