LogicGoInfotechSpaces commited on
Commit
e91dfad
·
verified ·
1 Parent(s): 311af3e

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +29 -101
app/main.py CHANGED
@@ -1,112 +1,40 @@
1
- import io
2
- import os
3
- import uuid
4
- import torch
5
- import torch.nn as nn
6
  from fastapi import FastAPI, UploadFile, File
7
- from fastapi.responses import FileResponse
 
8
  from PIL import Image
9
- import torchvision.transforms as T
10
- import gradio as gr
11
- import uvicorn
12
-
13
- # ==========================================================
14
- # 🔧 PATHS
15
- # ==========================================================
16
- MODEL_PATH = "generator.pt"
17
- UPLOAD_DIR = "/tmp/uploads"
18
- RESULT_DIR = "/tmp/results"
19
- os.makedirs(UPLOAD_DIR, exist_ok=True)
20
- os.makedirs(RESULT_DIR, exist_ok=True)
21
-
22
- # ==========================================================
23
- # 🧩 Define Generator Architecture (from repo style)
24
- # ==========================================================
25
- class Generator(nn.Module):
26
- def __init__(self):
27
- super(Generator, self).__init__()
28
- self.main = nn.Sequential(
29
- nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
30
- nn.ReLU(True),
31
-
32
- nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
33
- nn.BatchNorm2d(128),
34
- nn.ReLU(True),
35
-
36
- nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
37
- nn.BatchNorm2d(64),
38
- nn.ReLU(True),
39
-
40
- nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
41
- nn.Tanh()
42
- )
43
-
44
- def forward(self, x):
45
- return self.main(x)
46
-
47
- # ==========================================================
48
- # 🚀 Load Model
49
- # ==========================================================
50
- device = "cuda" if torch.cuda.is_available() else "cpu"
51
- generator = Generator().to(device)
52
-
53
- # Load weights
54
- state_dict = torch.load(MODEL_PATH, map_location=device)
55
- generator.load_state_dict(state_dict)
56
- generator.eval()
57
-
58
- print("✅ Model loaded successfully!")
59
-
60
- # ==========================================================
61
- # 🎨 Colorization Function
62
- # ==========================================================
63
- def colorize_image(image: Image.Image):
64
- transform = T.Compose([
65
- T.Resize((256, 256)),
66
- T.Grayscale(num_output_channels=1),
67
- T.ToTensor()
68
- ])
69
-
70
- img_tensor = transform(image).unsqueeze(0).to(device)
71
- with torch.no_grad():
72
- output = generator(img_tensor)
73
- output = (output.squeeze(0).permute(1, 2, 0).cpu().numpy() + 1) / 2.0 # Scale 0-1
74
-
75
- output_img = Image.fromarray((output * 255).astype("uint8"))
76
- return output_img
77
 
78
- # ==========================================================
79
- # 🌐 FASTAPI APP
80
- # ==========================================================
81
  app = FastAPI(title="GAN Image Colorization API")
82
 
83
- @app.post("/colorize")
84
- async def colorize_endpoint(file: UploadFile = File(...)):
85
- img_bytes = await file.read()
86
- image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
87
- colorized = colorize_image(image)
88
 
89
- output_filename = f"{uuid.uuid4()}.png"
90
- output_path = os.path.join(RESULT_DIR, output_filename)
91
- colorized.save(output_path)
 
 
 
 
 
92
 
93
- return FileResponse(output_path, media_type="image/png")
94
 
95
- # ==========================================================
96
- # 💠 GRADIO UI
97
- # ==========================================================
98
- def gradio_ui(image):
99
- return colorize_image(image)
 
100
 
101
- iface = gr.Interface(
102
- fn=gradio_ui,
103
- inputs=gr.Image(type="pil", label="Upload B&W Image"),
104
- outputs=gr.Image(type="pil", label="Colorized Image"),
105
- title="🎨 GAN Image Colorization",
106
- description="Upload a black-and-white photo to get it colorized using a GAN model."
107
- )
108
 
109
- gradio_app = gr.mount_gradio_app(app, iface, path="/")
 
 
 
 
110
 
111
- if __name__ == "__main__":
112
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
+ from huggingface_hub import hf_hub_download
3
+ import torch
4
  from PIL import Image
5
+ import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
7
  app = FastAPI(title="GAN Image Colorization API")
8
 
9
+ # --------------------------
10
+ # Model Loading Section
11
+ # --------------------------
12
+ MODEL_REPO = "Hammad712/GAN-Colorization-Model"
13
+ MODEL_FILENAME = "generator.pt"
14
 
15
+ try:
16
+ print("🔄 Downloading model from Hugging Face...")
17
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
18
+ state_dict = torch.load(model_path, map_location="cpu")
19
+ print("✅ Model loaded successfully from:", model_path)
20
+ except Exception as e:
21
+ print("❌ Failed to load model:", e)
22
+ state_dict = None
23
 
 
24
 
25
+ # --------------------------
26
+ # Example Endpoint
27
+ # --------------------------
28
+ @app.get("/")
29
+ def read_root():
30
+ return {"message": "GAN Colorization API is running!"}
31
 
 
 
 
 
 
 
 
32
 
33
+ @app.post("/colorize")
34
+ async def colorize_image(file: UploadFile = File(...)):
35
+ # Load the uploaded image
36
+ image_bytes = await file.read()
37
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
38
 
39
+ # For now, just return confirmation (replace with inference logic later)
40
+ return {"status": "success", "filename": file.filename}