LogicGoInfotechSpaces's picture
Update app/main.py
ab9de00 verified
raw
history blame
7.03 kB
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from fastapi.responses import FileResponse
from huggingface_hub import hf_hub_download
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
import os, uuid, io, json
# ======================================================
# 🚀 FASTAPI APP
# ======================================================
app = FastAPI(title="UNet Image Colorization API")
# ======================================================
# 🔐 FIREBASE INITIALIZATION (ENV BASED)
# ======================================================
try:
import firebase_admin
from firebase_admin import credentials, app_check
firebase_json = os.getenv("FIREBASE_CREDENTIALS")
if firebase_json:
print("🔥 Loading Firebase credentials from ENV...")
firebase_dict = json.loads(firebase_json)
cred = credentials.Certificate(firebase_dict)
firebase_admin.initialize_app(cred)
else:
print("⚠️ No Firebase credentials found. Firebase disabled.")
except Exception as e:
print("❌ Firebase initialization failed:", e)
# ======================================================
# 📁 DIRECTORIES
# ======================================================
UPLOAD_DIR = "uploads"
RESULTS_DIR = "results"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
# ======================================================
# 🧠 SIMPLE UNET GENERATOR FOR COLORIZATION
# ======================================================
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
def CBR(in_c, out_c):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True)
)
self.enc1 = CBR(1, 64)
self.enc2 = CBR(64, 128)
self.enc3 = CBR(128, 256)
self.enc4 = CBR(256, 512)
self.pool = nn.MaxPool2d(2)
self.middle = CBR(512, 512)
self.up4 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec4 = CBR(512, 256)
self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec3 = CBR(256, 128)
self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec2 = CBR(128, 64)
self.out_layer = nn.Conv2d(64, 2, 1) # ab channels
def forward(self, x):
c1 = self.enc1(x)
p1 = self.pool(c1)
c2 = self.enc2(p1)
p2 = self.pool(c2)
c3 = self.enc3(p2)
p3 = self.pool(c3)
c4 = self.enc4(p3)
p4 = self.pool(c4)
mid = self.middle(p4)
u4 = self.up4(mid)
u4 = torch.cat([u4, c4], dim=1)
d4 = self.dec4(u4)
u3 = self.up3(d4)
u3 = torch.cat([u3, c3], dim=1)
d3 = self.dec3(u3)
u2 = self.up2(d3)
u2 = torch.cat([u2, c2], dim=1)
d2 = self.dec2(u2)
out = self.out_layer(d2)
return out
# ======================================================
# 🎨 LOAD MODEL WEIGHTS FROM HF
# ======================================================
MODEL_REPO = "Hammad712/GAN-Colorization-Model"
MODEL_FILENAME = "generator.pt"
print("⬇️ Downloading model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
print("📦 Loading weights into UNet model...")
model = UNet()
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model.eval()
# ======================================================
# 🎨 COLORIZE FUNCTION (LAB → RGB)
# ======================================================
import numpy as np
import cv2
def colorize_image(img: Image.Image):
img = img.convert("L") # grayscale
img_np = np.array(img)
# Normalize L channel
L = img_np.astype("float32") / 255.0
L_tensor = torch.tensor(L).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
ab = model(L_tensor).squeeze(0).numpy()
ab = np.transpose(ab, (1, 2, 0))
# Resize ab to match L
ab = cv2.resize(ab, (img_np.shape[1], img_np.shape[0]))
# Combine L + ab -> LAB image
LAB = np.zeros((img_np.shape[0], img_np.shape[1], 3), dtype=np.float32)
LAB[..., 0] = L * 100
LAB[..., 1:] = ab * 128
# Convert LAB → RGB
rgb = cv2.cvtColor(LAB.astype("float32"), cv2.COLOR_LAB2RGB)
rgb = np.clip(rgb, 0, 1)
rgb_img = Image.fromarray((rgb * 255).astype("uint8"))
return rgb_img
# ======================================================
# 🔐 FIREBASE CHECK
# ======================================================
def verify_app_check_token(token: str):
if not token or len(token) < 20:
raise HTTPException(status_code=401, detail="Invalid Firebase App Check token")
return True
# ======================================================
# 🩺 HEALTH CHECK
# ======================================================
@app.get("/health")
def health_check():
return {"status": "healthy", "unet_loaded": True}
# ======================================================
# 📤 UPLOAD
# ======================================================
@app.post("/upload")
async def upload_image(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)):
verify_app_check_token(x_firebase_appcheck)
image_id = f"{uuid.uuid4()}.jpg"
path = os.path.join(UPLOAD_DIR, image_id)
with open(path, "wb") as f:
f.write(await file.read())
base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
return {
"success": True,
"image_id": image_id[:-4],
"url": f"{base}/uploads/{image_id}"
}
# ======================================================
# 🎨 COLORIZE
# ======================================================
@app.post("/colorize")
async def colorize(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)):
verify_app_check_token(x_firebase_appcheck)
img = Image.open(io.BytesIO(await file.read()))
output_img = colorize_image(img)
result_id = f"{uuid.uuid4()}.jpg"
path = os.path.join(RESULTS_DIR, result_id)
output_img.save(path)
base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
return {
"success": True,
"result_id": result_id[:-4],
"url": f"{base}/results/{result_id}"
}
# ======================================================
# PUBLIC FILE ENDPOINTS
# ======================================================
@app.get("/results/{filename}")
def get_result(filename: str):
path = os.path.join(RESULTS_DIR, filename)
if not os.path.exists(path):
raise HTTPException(status_code=404)
return FileResponse(path)
@app.get("/uploads/{filename}")
def get_upload(filename: str):
path = os.path.join(UPLOAD_DIR, filename)
if not os.path.exists(path):
raise HTTPException(status_code=404)
return FileResponse(path)