|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Header |
|
|
from fastapi.responses import FileResponse |
|
|
from huggingface_hub import hf_hub_download |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import os, uuid, io, json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="UNet Image Colorization API") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, app_check |
|
|
|
|
|
firebase_json = os.getenv("FIREBASE_CREDENTIALS") |
|
|
|
|
|
if firebase_json: |
|
|
print("🔥 Loading Firebase credentials from ENV...") |
|
|
firebase_dict = json.loads(firebase_json) |
|
|
cred = credentials.Certificate(firebase_dict) |
|
|
firebase_admin.initialize_app(cred) |
|
|
else: |
|
|
print("⚠️ No Firebase credentials found. Firebase disabled.") |
|
|
|
|
|
except Exception as e: |
|
|
print("❌ Firebase initialization failed:", e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UPLOAD_DIR = "uploads" |
|
|
RESULTS_DIR = "results" |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UNet(nn.Module): |
|
|
def __init__(self): |
|
|
super(UNet, self).__init__() |
|
|
|
|
|
def CBR(in_c, out_c): |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(in_c, out_c, 3, padding=1), |
|
|
nn.BatchNorm2d(out_c), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
self.enc1 = CBR(1, 64) |
|
|
self.enc2 = CBR(64, 128) |
|
|
self.enc3 = CBR(128, 256) |
|
|
self.enc4 = CBR(256, 512) |
|
|
|
|
|
self.pool = nn.MaxPool2d(2) |
|
|
|
|
|
self.middle = CBR(512, 512) |
|
|
|
|
|
self.up4 = nn.ConvTranspose2d(512, 256, 2, stride=2) |
|
|
self.dec4 = CBR(512, 256) |
|
|
|
|
|
self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2) |
|
|
self.dec3 = CBR(256, 128) |
|
|
|
|
|
self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2) |
|
|
self.dec2 = CBR(128, 64) |
|
|
|
|
|
self.out_layer = nn.Conv2d(64, 2, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
c1 = self.enc1(x) |
|
|
p1 = self.pool(c1) |
|
|
|
|
|
c2 = self.enc2(p1) |
|
|
p2 = self.pool(c2) |
|
|
|
|
|
c3 = self.enc3(p2) |
|
|
p3 = self.pool(c3) |
|
|
|
|
|
c4 = self.enc4(p3) |
|
|
p4 = self.pool(c4) |
|
|
|
|
|
mid = self.middle(p4) |
|
|
|
|
|
u4 = self.up4(mid) |
|
|
u4 = torch.cat([u4, c4], dim=1) |
|
|
d4 = self.dec4(u4) |
|
|
|
|
|
u3 = self.up3(d4) |
|
|
u3 = torch.cat([u3, c3], dim=1) |
|
|
d3 = self.dec3(u3) |
|
|
|
|
|
u2 = self.up2(d3) |
|
|
u2 = torch.cat([u2, c2], dim=1) |
|
|
d2 = self.dec2(u2) |
|
|
|
|
|
out = self.out_layer(d2) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO = "Hammad712/GAN-Colorization-Model" |
|
|
MODEL_FILENAME = "generator.pt" |
|
|
|
|
|
print("⬇️ Downloading model...") |
|
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
|
|
|
|
|
print("📦 Loading weights into UNet model...") |
|
|
model = UNet() |
|
|
state_dict = torch.load(model_path, map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
def colorize_image(img: Image.Image): |
|
|
|
|
|
img = img.convert("L") |
|
|
img_np = np.array(img) |
|
|
|
|
|
|
|
|
L = img_np.astype("float32") / 255.0 |
|
|
L_tensor = torch.tensor(L).unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
ab = model(L_tensor).squeeze(0).numpy() |
|
|
|
|
|
ab = np.transpose(ab, (1, 2, 0)) |
|
|
|
|
|
|
|
|
ab = cv2.resize(ab, (img_np.shape[1], img_np.shape[0])) |
|
|
|
|
|
|
|
|
LAB = np.zeros((img_np.shape[0], img_np.shape[1], 3), dtype=np.float32) |
|
|
LAB[..., 0] = L * 100 |
|
|
LAB[..., 1:] = ab * 128 |
|
|
|
|
|
|
|
|
rgb = cv2.cvtColor(LAB.astype("float32"), cv2.COLOR_LAB2RGB) |
|
|
rgb = np.clip(rgb, 0, 1) |
|
|
|
|
|
rgb_img = Image.fromarray((rgb * 255).astype("uint8")) |
|
|
return rgb_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verify_app_check_token(token: str): |
|
|
if not token or len(token) < 20: |
|
|
raise HTTPException(status_code=401, detail="Invalid Firebase App Check token") |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
return {"status": "healthy", "unet_loaded": True} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_image(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)): |
|
|
|
|
|
verify_app_check_token(x_firebase_appcheck) |
|
|
|
|
|
image_id = f"{uuid.uuid4()}.jpg" |
|
|
path = os.path.join(UPLOAD_DIR, image_id) |
|
|
|
|
|
with open(path, "wb") as f: |
|
|
f.write(await file.read()) |
|
|
|
|
|
base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"image_id": image_id[:-4], |
|
|
"url": f"{base}/uploads/{image_id}" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)): |
|
|
|
|
|
verify_app_check_token(x_firebase_appcheck) |
|
|
|
|
|
img = Image.open(io.BytesIO(await file.read())) |
|
|
output_img = colorize_image(img) |
|
|
|
|
|
result_id = f"{uuid.uuid4()}.jpg" |
|
|
path = os.path.join(RESULTS_DIR, result_id) |
|
|
output_img.save(path) |
|
|
|
|
|
base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"result_id": result_id[:-4], |
|
|
"url": f"{base}/results/{result_id}" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/results/{filename}") |
|
|
def get_result(filename: str): |
|
|
path = os.path.join(RESULTS_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
raise HTTPException(status_code=404) |
|
|
return FileResponse(path) |
|
|
|
|
|
@app.get("/uploads/{filename}") |
|
|
def get_upload(filename: str): |
|
|
path = os.path.join(UPLOAD_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
raise HTTPException(status_code=404) |
|
|
return FileResponse(path) |
|
|
|