LogicGoInfotechSpaces's picture
Update app/main.py
b475327 verified
raw
history blame
3.62 kB
import io
import os
import uuid
import torch
import torch.nn as nn
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from PIL import Image
import torchvision.transforms as T
import gradio as gr
import uvicorn
# ==========================================================
# πŸ”§ PATHS
# ==========================================================
MODEL_PATH = "generator.pt"
UPLOAD_DIR = "/tmp/uploads"
RESULT_DIR = "/tmp/results"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# ==========================================================
# 🧩 Define Generator Architecture (from repo style)
# ==========================================================
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# ==========================================================
# πŸš€ Load Model
# ==========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = Generator().to(device)
# Load weights
state_dict = torch.load(MODEL_PATH, map_location=device)
generator.load_state_dict(state_dict)
generator.eval()
print("βœ… Model loaded successfully!")
# ==========================================================
# 🎨 Colorization Function
# ==========================================================
def colorize_image(image: Image.Image):
transform = T.Compose([
T.Resize((256, 256)),
T.Grayscale(num_output_channels=1),
T.ToTensor()
])
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = generator(img_tensor)
output = (output.squeeze(0).permute(1, 2, 0).cpu().numpy() + 1) / 2.0 # Scale 0-1
output_img = Image.fromarray((output * 255).astype("uint8"))
return output_img
# ==========================================================
# 🌐 FASTAPI APP
# ==========================================================
app = FastAPI(title="GAN Image Colorization API")
@app.post("/colorize")
async def colorize_endpoint(file: UploadFile = File(...)):
img_bytes = await file.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
colorized = colorize_image(image)
output_filename = f"{uuid.uuid4()}.png"
output_path = os.path.join(RESULT_DIR, output_filename)
colorized.save(output_path)
return FileResponse(output_path, media_type="image/png")
# ==========================================================
# πŸ’  GRADIO UI
# ==========================================================
def gradio_ui(image):
return colorize_image(image)
iface = gr.Interface(
fn=gradio_ui,
inputs=gr.Image(type="pil", label="Upload B&W Image"),
outputs=gr.Image(type="pil", label="Colorized Image"),
title="🎨 GAN Image Colorization",
description="Upload a black-and-white photo to get it colorized using a GAN model."
)
gradio_app = gr.mount_gradio_app(app, iface, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)