from fastapi import FastAPI, File, UploadFile, HTTPException, Header from fastapi.responses import FileResponse from huggingface_hub import hf_hub_download from firebase_admin import credentials, initialize_app, app_check import uuid import os from PIL import Image import torch import io from torchvision import transforms app = FastAPI(title="Text-Guided Image Colorization API") # ------------------------------------------------- # 🔐 Firebase App Check Initialization # ------------------------------------------------- cred = credentials.Certificate("firebase-key.json") # Your service account key initialize_app(cred) UPLOAD_DIR = "uploads" RESULTS_DIR = "results" os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULTS_DIR, exist_ok=True) # ------------------------------------------------- # 🧠 Load ColorizeNet Model # ------------------------------------------------- MODEL_REPO = "Hammad712/GAN-Colorization-Model" MODEL_FILENAME = "generator.pt" model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) state_dict = torch.load(model_path, map_location="cpu") # (Example model structure – replace with your actual ColorizeNet) # from your_model import ColorizeNet # model = ColorizeNet() # model.load_state_dict(state_dict) # model.eval() # Dummy colorization function def colorize_image(img: Image.Image): transform = transforms.ToTensor() tensor = transform(img.convert("L")).unsqueeze(0) tensor = tensor.repeat(1, 3, 1, 1) output_img = transforms.ToPILImage()(tensor.squeeze()) return output_img # ------------------------------------------------- # 🩺 1. Health Check # ------------------------------------------------- @app.get("/health") def health_check(): return {"status": "healthy", "model_loaded": True} # ------------------------------------------------- # ✅ Firebase App Check Token Validation # ------------------------------------------------- def verify_app_check_token(token: str): # In production, verify token with Firebase REST API or Admin SDK. if not token or len(token) < 20: raise HTTPException(status_code=401, detail="Missing or invalid Firebase App Check token") return True # ------------------------------------------------- # 📤 2. Upload Image # ------------------------------------------------- @app.post("/upload") async def upload_image( file: UploadFile = File(...), x_firebase_appcheck: str = Header(None) ): verify_app_check_token(x_firebase_appcheck) if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Invalid file type") image_id = str(uuid.uuid4()) filename = f"{image_id}.jpg" path = os.path.join(UPLOAD_DIR, filename) with open(path, "wb") as f: f.write(await file.read()) return { "success": True, "image_id": image_id, "image_url": f"https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/uploads/{filename}", "filename": filename } # ------------------------------------------------- # 🎨 3. Colorize Image # ------------------------------------------------- @app.post("/colorize") async def colorize( file: UploadFile = File(...), x_firebase_appcheck: str = Header(None) ): verify_app_check_token(x_firebase_appcheck) if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Invalid file type") img = Image.open(io.BytesIO(await file.read())) output_img = colorize_image(img) result_id = str(uuid.uuid4()) filename = f"{result_id}.jpg" path = os.path.join(RESULTS_DIR, filename) output_img.save(path) base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" return { "success": True, "result_id": result_id, "download_url": f"{base_url}/results/{filename}", "api_download_url": f"{base_url}/download/{result_id}", "filename": filename } # ------------------------------------------------- # ⬇️ 4. Download Processed Image # ------------------------------------------------- @app.get("/download/{file_id}") def download_result(file_id: str, x_firebase_appcheck: str = Header(None)): verify_app_check_token(x_firebase_appcheck) path = os.path.join(RESULTS_DIR, f"{file_id}.jpg") if not os.path.exists(path): raise HTTPException(status_code=404, detail="File not found") return FileResponse(path, media_type="image/jpeg") # ------------------------------------------------- # 🌈 5. Get Result (Public URL) # ------------------------------------------------- @app.get("/results/{filename}") def get_result(filename: str): path = os.path.join(RESULTS_DIR, filename) if not os.path.exists(path): raise HTTPException(status_code=404, detail="File not found") return FileResponse(path, media_type="image/jpeg") # ------------------------------------------------- # 🖼️ 6. Get Uploaded Image (Public URL) # ------------------------------------------------- @app.get("/uploads/{filename}") def get_upload(filename: str): path = os.path.join(UPLOAD_DIR, filename) if not os.path.exists(path): raise HTTPException(status_code=404, detail="File not found") return FileResponse(path, media_type="image/jpeg")