File size: 1,213 Bytes
b475327
e91dfad
 
60c56d7
e91dfad
60c56d7
b475327
60c56d7
e91dfad
 
 
 
 
60c56d7
e91dfad
 
 
 
 
 
 
 
b475327
 
e91dfad
 
 
 
 
 
b475327
 
e91dfad
 
 
 
 
e4599d1
e91dfad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from fastapi import FastAPI, UploadFile, File
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
import io

app = FastAPI(title="GAN Image Colorization API")

# --------------------------
# Model Loading Section
# --------------------------
MODEL_REPO = "Hammad712/GAN-Colorization-Model"
MODEL_FILENAME = "generator.pt"

try:
    print("🔄 Downloading model from Hugging Face...")
    model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
    state_dict = torch.load(model_path, map_location="cpu")
    print("✅ Model loaded successfully from:", model_path)
except Exception as e:
    print("❌ Failed to load model:", e)
    state_dict = None


# --------------------------
# Example Endpoint
# --------------------------
@app.get("/")
def read_root():
    return {"message": "GAN Colorization API is running!"}


@app.post("/colorize")
async def colorize_image(file: UploadFile = File(...)):
    # Load the uploaded image
    image_bytes = await file.read()
    img = Image.open(io.BytesIO(image_bytes)).convert("RGB")

    # For now, just return confirmation (replace with inference logic later)
    return {"status": "success", "filename": file.filename}