|
|
import io |
|
|
import os |
|
|
import uuid |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from fastapi import FastAPI, UploadFile, File |
|
|
from fastapi.responses import FileResponse |
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
import gradio as gr |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "generator.pt" |
|
|
UPLOAD_DIR = "/tmp/uploads" |
|
|
RESULT_DIR = "/tmp/results" |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(RESULT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
|
def __init__(self): |
|
|
super(Generator, self).__init__() |
|
|
self.main = nn.Sequential( |
|
|
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), |
|
|
nn.ReLU(True), |
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(True), |
|
|
|
|
|
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(True), |
|
|
|
|
|
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.main(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
generator = Generator().to(device) |
|
|
|
|
|
|
|
|
state_dict = torch.load(MODEL_PATH, map_location=device) |
|
|
generator.load_state_dict(state_dict) |
|
|
generator.eval() |
|
|
|
|
|
print("β
Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def colorize_image(image: Image.Image): |
|
|
transform = T.Compose([ |
|
|
T.Resize((256, 256)), |
|
|
T.Grayscale(num_output_channels=1), |
|
|
T.ToTensor() |
|
|
]) |
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
output = generator(img_tensor) |
|
|
output = (output.squeeze(0).permute(1, 2, 0).cpu().numpy() + 1) / 2.0 |
|
|
|
|
|
output_img = Image.fromarray((output * 255).astype("uint8")) |
|
|
return output_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="GAN Image Colorization API") |
|
|
|
|
|
@app.post("/colorize") |
|
|
async def colorize_endpoint(file: UploadFile = File(...)): |
|
|
img_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
colorized = colorize_image(image) |
|
|
|
|
|
output_filename = f"{uuid.uuid4()}.png" |
|
|
output_path = os.path.join(RESULT_DIR, output_filename) |
|
|
colorized.save(output_path) |
|
|
|
|
|
return FileResponse(output_path, media_type="image/png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_ui(image): |
|
|
return colorize_image(image) |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=gradio_ui, |
|
|
inputs=gr.Image(type="pil", label="Upload B&W Image"), |
|
|
outputs=gr.Image(type="pil", label="Colorized Image"), |
|
|
title="π¨ GAN Image Colorization", |
|
|
description="Upload a black-and-white photo to get it colorized using a GAN model." |
|
|
) |
|
|
|
|
|
gradio_app = gr.mount_gradio_app(app, iface, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|