LogicGoInfotechSpaces commited on
Commit
963b208
·
verified ·
1 Parent(s): ab9de00

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +98 -156
app/main.py CHANGED
@@ -1,20 +1,22 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
  from fastapi.responses import FileResponse
3
  from huggingface_hub import hf_hub_download
4
- from torchvision import transforms
 
 
 
5
  from PIL import Image
6
  import torch
7
- import torch.nn as nn
8
- import os, uuid, io, json
9
 
10
- # ======================================================
11
- # 🚀 FASTAPI APP
12
- # ======================================================
13
- app = FastAPI(title="UNet Image Colorization API")
14
 
15
- # ======================================================
16
- # 🔐 FIREBASE INITIALIZATION (ENV BASED)
17
- # ======================================================
18
  try:
19
  import firebase_admin
20
  from firebase_admin import credentials, app_check
@@ -32,202 +34,142 @@ try:
32
  except Exception as e:
33
  print("❌ Firebase initialization failed:", e)
34
 
35
- # ======================================================
36
- # 📁 DIRECTORIES
37
- # ======================================================
38
- UPLOAD_DIR = "uploads"
39
- RESULTS_DIR = "results"
40
  os.makedirs(UPLOAD_DIR, exist_ok=True)
41
  os.makedirs(RESULTS_DIR, exist_ok=True)
42
 
43
- # ======================================================
44
- # 🧠 SIMPLE UNET GENERATOR FOR COLORIZATION
45
- # ======================================================
46
- class UNet(nn.Module):
47
- def __init__(self):
48
- super(UNet, self).__init__()
49
-
50
- def CBR(in_c, out_c):
51
- return nn.Sequential(
52
- nn.Conv2d(in_c, out_c, 3, padding=1),
53
- nn.BatchNorm2d(out_c),
54
- nn.ReLU(inplace=True)
55
- )
56
-
57
- self.enc1 = CBR(1, 64)
58
- self.enc2 = CBR(64, 128)
59
- self.enc3 = CBR(128, 256)
60
- self.enc4 = CBR(256, 512)
61
-
62
- self.pool = nn.MaxPool2d(2)
63
-
64
- self.middle = CBR(512, 512)
65
-
66
- self.up4 = nn.ConvTranspose2d(512, 256, 2, stride=2)
67
- self.dec4 = CBR(512, 256)
68
-
69
- self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
70
- self.dec3 = CBR(256, 128)
71
-
72
- self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
73
- self.dec2 = CBR(128, 64)
74
-
75
- self.out_layer = nn.Conv2d(64, 2, 1) # ab channels
76
-
77
- def forward(self, x):
78
- c1 = self.enc1(x)
79
- p1 = self.pool(c1)
80
-
81
- c2 = self.enc2(p1)
82
- p2 = self.pool(c2)
83
-
84
- c3 = self.enc3(p2)
85
- p3 = self.pool(c3)
86
-
87
- c4 = self.enc4(p3)
88
- p4 = self.pool(c4)
89
-
90
- mid = self.middle(p4)
91
-
92
- u4 = self.up4(mid)
93
- u4 = torch.cat([u4, c4], dim=1)
94
- d4 = self.dec4(u4)
95
-
96
- u3 = self.up3(d4)
97
- u3 = torch.cat([u3, c3], dim=1)
98
- d3 = self.dec3(u3)
99
-
100
- u2 = self.up2(d3)
101
- u2 = torch.cat([u2, c2], dim=1)
102
- d2 = self.dec2(u2)
103
-
104
- out = self.out_layer(d2)
105
- return out
106
-
107
-
108
- # ======================================================
109
- # 🎨 LOAD MODEL WEIGHTS FROM HF
110
- # ======================================================
111
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
112
  MODEL_FILENAME = "generator.pt"
113
 
114
  print("⬇️ Downloading model...")
115
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
116
 
117
- print("📦 Loading weights into UNet model...")
118
- model = UNet()
119
  state_dict = torch.load(model_path, map_location="cpu")
120
- model.load_state_dict(state_dict, strict=False)
121
- model.eval()
122
 
123
- # ======================================================
124
- # 🎨 COLORIZE FUNCTION (LAB → RGB)
125
- # ======================================================
126
- import numpy as np
127
- import cv2
128
 
129
  def colorize_image(img: Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- img = img.convert("L") # grayscale
132
- img_np = np.array(img)
133
-
134
- # Normalize L channel
135
- L = img_np.astype("float32") / 255.0
136
- L_tensor = torch.tensor(L).unsqueeze(0).unsqueeze(0)
137
-
138
- with torch.no_grad():
139
- ab = model(L_tensor).squeeze(0).numpy()
140
-
141
- ab = np.transpose(ab, (1, 2, 0))
142
-
143
- # Resize ab to match L
144
- ab = cv2.resize(ab, (img_np.shape[1], img_np.shape[0]))
145
-
146
- # Combine L + ab -> LAB image
147
- LAB = np.zeros((img_np.shape[0], img_np.shape[1], 3), dtype=np.float32)
148
- LAB[..., 0] = L * 100
149
- LAB[..., 1:] = ab * 128
150
-
151
- # Convert LAB → RGB
152
- rgb = cv2.cvtColor(LAB.astype("float32"), cv2.COLOR_LAB2RGB)
153
- rgb = np.clip(rgb, 0, 1)
154
-
155
- rgb_img = Image.fromarray((rgb * 255).astype("uint8"))
156
- return rgb_img
157
-
158
- # ======================================================
159
- # 🔐 FIREBASE CHECK
160
- # ======================================================
161
  def verify_app_check_token(token: str):
162
  if not token or len(token) < 20:
163
  raise HTTPException(status_code=401, detail="Invalid Firebase App Check token")
164
  return True
165
 
166
- # ======================================================
167
- # 🩺 HEALTH CHECK
168
- # ======================================================
169
- @app.get("/health")
170
- def health_check():
171
- return {"status": "healthy", "unet_loaded": True}
172
-
173
- # ======================================================
174
- # 📤 UPLOAD
175
- # ======================================================
176
  @app.post("/upload")
177
- async def upload_image(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)):
178
-
 
 
179
  verify_app_check_token(x_firebase_appcheck)
180
 
 
 
 
181
  image_id = f"{uuid.uuid4()}.jpg"
182
- path = os.path.join(UPLOAD_DIR, image_id)
183
 
184
- with open(path, "wb") as f:
185
  f.write(await file.read())
186
 
187
- base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
188
 
189
  return {
190
  "success": True,
191
- "image_id": image_id[:-4],
192
- "url": f"{base}/uploads/{image_id}"
193
  }
194
 
195
- # ======================================================
196
- # 🎨 COLORIZE
197
- # ======================================================
198
  @app.post("/colorize")
199
- async def colorize(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)):
200
-
 
 
201
  verify_app_check_token(x_firebase_appcheck)
202
 
 
 
 
203
  img = Image.open(io.BytesIO(await file.read()))
204
  output_img = colorize_image(img)
205
 
206
  result_id = f"{uuid.uuid4()}.jpg"
207
- path = os.path.join(RESULTS_DIR, result_id)
208
- output_img.save(path)
209
 
210
- base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
211
 
212
  return {
213
  "success": True,
214
- "result_id": result_id[:-4],
215
- "url": f"{base}/results/{result_id}"
 
216
  }
217
 
218
- # ======================================================
219
- # PUBLIC FILE ENDPOINTS
220
- # ======================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @app.get("/results/{filename}")
222
  def get_result(filename: str):
223
  path = os.path.join(RESULTS_DIR, filename)
224
  if not os.path.exists(path):
225
- raise HTTPException(status_code=404)
226
- return FileResponse(path)
227
 
 
 
 
228
  @app.get("/uploads/{filename}")
229
  def get_upload(filename: str):
230
  path = os.path.join(UPLOAD_DIR, filename)
231
  if not os.path.exists(path):
232
- raise HTTPException(status_code=404)
233
- return FileResponse(path)
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
  from fastapi.responses import FileResponse
3
  from huggingface_hub import hf_hub_download
4
+ import uuid
5
+ import os
6
+ import io
7
+ import json
8
  from PIL import Image
9
  import torch
10
+ from torchvision import transforms
 
11
 
12
+ # -------------------------------------------------
13
+ # 🚀 FastAPI App
14
+ # -------------------------------------------------
15
+ app = FastAPI(title="Text-Guided Image Colorization API")
16
 
17
+ # -------------------------------------------------
18
+ # 🔐 Firebase Initialization (ENV-based)
19
+ # -------------------------------------------------
20
  try:
21
  import firebase_admin
22
  from firebase_admin import credentials, app_check
 
34
  except Exception as e:
35
  print("❌ Firebase initialization failed:", e)
36
 
37
+ # -------------------------------------------------
38
+ # 📁 Directories (FIXED FOR HUGGINGFACE SPACES)
39
+ # -------------------------------------------------
40
+ UPLOAD_DIR = "/tmp/uploads"
41
+ RESULTS_DIR = "/tmp/results"
42
  os.makedirs(UPLOAD_DIR, exist_ok=True)
43
  os.makedirs(RESULTS_DIR, exist_ok=True)
44
 
45
+ # -------------------------------------------------
46
+ # 🧠 Load GAN Colorization Model
47
+ # -------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
49
  MODEL_FILENAME = "generator.pt"
50
 
51
  print("⬇️ Downloading model...")
52
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
53
 
54
+ print("📦 Loading model weights...")
 
55
  state_dict = torch.load(model_path, map_location="cpu")
 
 
56
 
57
+ # NOTE: Replace with real model architecture
58
+ # from model import ColorizeNet
59
+ # model = ColorizeNet()
60
+ # model.load_state_dict(state_dict)
61
+ # model.eval()
62
 
63
  def colorize_image(img: Image.Image):
64
+ """ Dummy colorizer (replace with real model.predict) """
65
+ transform = transforms.ToTensor()
66
+ tensor = transform(img.convert("L")).unsqueeze(0)
67
+ tensor = tensor.repeat(1, 3, 1, 1)
68
+ output_img = transforms.ToPILImage()(tensor.squeeze())
69
+ return output_img
70
+
71
+ # -------------------------------------------------
72
+ # 🩺 Health Check
73
+ # -------------------------------------------------
74
+ @app.get("/health")
75
+ def health_check():
76
+ return {"status": "healthy", "model_loaded": True}
77
 
78
+ # -------------------------------------------------
79
+ # 🔐 Firebase Token Validator
80
+ # -------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def verify_app_check_token(token: str):
82
  if not token or len(token) < 20:
83
  raise HTTPException(status_code=401, detail="Invalid Firebase App Check token")
84
  return True
85
 
86
+ # -------------------------------------------------
87
+ # 📤 Upload Image
88
+ # -------------------------------------------------
 
 
 
 
 
 
 
89
  @app.post("/upload")
90
+ async def upload_image(
91
+ file: UploadFile = File(...),
92
+ x_firebase_appcheck: str = Header(None)
93
+ ):
94
  verify_app_check_token(x_firebase_appcheck)
95
 
96
+ if not file.content_type.startswith("image/"):
97
+ raise HTTPException(status_code=400, detail="Invalid file type")
98
+
99
  image_id = f"{uuid.uuid4()}.jpg"
100
+ file_path = os.path.join(UPLOAD_DIR, image_id)
101
 
102
+ with open(file_path, "wb") as f:
103
  f.write(await file.read())
104
 
105
+ base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
106
 
107
  return {
108
  "success": True,
109
+ "image_id": image_id.replace(".jpg", ""),
110
+ "file_url": f"{base_url}/uploads/{image_id}"
111
  }
112
 
113
+ # -------------------------------------------------
114
+ # 🎨 Colorize Image
115
+ # -------------------------------------------------
116
  @app.post("/colorize")
117
+ async def colorize(
118
+ file: UploadFile = File(...),
119
+ x_firebase_appcheck: str = Header(None)
120
+ ):
121
  verify_app_check_token(x_firebase_appcheck)
122
 
123
+ if not file.content_type.startswith("image/"):
124
+ raise HTTPException(status_code=400, detail="Invalid file type")
125
+
126
  img = Image.open(io.BytesIO(await file.read()))
127
  output_img = colorize_image(img)
128
 
129
  result_id = f"{uuid.uuid4()}.jpg"
130
+ output_path = os.path.join(RESULTS_DIR, result_id)
131
+ output_img.save(output_path)
132
 
133
+ base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
134
 
135
  return {
136
  "success": True,
137
+ "result_id": result_id.replace(".jpg", ""),
138
+ "download_url": f"{base_url}/results/{result_id}",
139
+ "api_download": f"{base_url}/download/{result_id.replace('.jpg','')}"
140
  }
141
 
142
+ # -------------------------------------------------
143
+ # ⬇️ Download via API (Secure)
144
+ # -------------------------------------------------
145
+ @app.get("/download/{file_id}")
146
+ def download_result(file_id: str, x_firebase_appcheck: str = Header(None)):
147
+ verify_app_check_token(x_firebase_appcheck)
148
+
149
+ filename = f"{file_id}.jpg"
150
+ path = os.path.join(RESULTS_DIR, filename)
151
+
152
+ if not os.path.exists(path):
153
+ raise HTTPException(status_code=404, detail="Result not found")
154
+
155
+ return FileResponse(path, media_type="image/jpeg")
156
+
157
+ # -------------------------------------------------
158
+ # 🌐 Public Result File
159
+ # -------------------------------------------------
160
  @app.get("/results/{filename}")
161
  def get_result(filename: str):
162
  path = os.path.join(RESULTS_DIR, filename)
163
  if not os.path.exists(path):
164
+ raise HTTPException(status_code=404, detail="Result not found")
165
+ return FileResponse(path, media_type="image/jpeg")
166
 
167
+ # -------------------------------------------------
168
+ # 🌐 Public Uploaded File
169
+ # -------------------------------------------------
170
  @app.get("/uploads/{filename}")
171
  def get_upload(filename: str):
172
  path = os.path.join(UPLOAD_DIR, filename)
173
  if not os.path.exists(path):
174
+ raise HTTPException(status_code=404, detail="File not found")
175
+ return FileResponse(path, media_type="image/jpeg")