LogicGoInfotechSpaces commited on
Commit
d8c6239
Β·
verified Β·
1 Parent(s): b805ad0

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +64 -36
app/main.py CHANGED
@@ -1,43 +1,69 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
  from fastapi.responses import FileResponse
3
  from huggingface_hub import hf_hub_download
4
- from firebase_admin import credentials, initialize_app, app_check
5
  import uuid
6
  import os
 
 
7
  from PIL import Image
8
  import torch
9
- import io
10
  from torchvision import transforms
11
 
 
 
 
12
  app = FastAPI(title="Text-Guided Image Colorization API")
13
 
14
  # -------------------------------------------------
15
- # πŸ” Firebase App Check Initialization
16
  # -------------------------------------------------
17
- cred = credentials.Certificate("firebase-key.json") # Your service account key
18
- initialize_app(cred)
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
20
  UPLOAD_DIR = "uploads"
21
  RESULTS_DIR = "results"
22
  os.makedirs(UPLOAD_DIR, exist_ok=True)
23
  os.makedirs(RESULTS_DIR, exist_ok=True)
24
 
25
  # -------------------------------------------------
26
- # 🧠 Load ColorizeNet Model
27
  # -------------------------------------------------
28
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
29
  MODEL_FILENAME = "generator.pt"
 
 
30
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
 
 
31
  state_dict = torch.load(model_path, map_location="cpu")
32
 
33
- # (Example model structure – replace with your actual ColorizeNet)
34
- # from your_model import ColorizeNet
 
 
35
  # model = ColorizeNet()
36
  # model.load_state_dict(state_dict)
37
  # model.eval()
38
 
39
- # Dummy colorization function
40
  def colorize_image(img: Image.Image):
 
41
  transform = transforms.ToTensor()
42
  tensor = transform(img.convert("L")).unsqueeze(0)
43
  tensor = tensor.repeat(1, 3, 1, 1)
@@ -45,23 +71,22 @@ def colorize_image(img: Image.Image):
45
  return output_img
46
 
47
  # -------------------------------------------------
48
- # 🩺 1. Health Check
49
  # -------------------------------------------------
50
  @app.get("/health")
51
  def health_check():
52
  return {"status": "healthy", "model_loaded": True}
53
 
54
  # -------------------------------------------------
55
- # βœ… Firebase App Check Token Validation
56
  # -------------------------------------------------
57
  def verify_app_check_token(token: str):
58
- # In production, verify token with Firebase REST API or Admin SDK.
59
  if not token or len(token) < 20:
60
- raise HTTPException(status_code=401, detail="Missing or invalid Firebase App Check token")
61
  return True
62
 
63
  # -------------------------------------------------
64
- # πŸ“€ 2. Upload Image
65
  # -------------------------------------------------
66
  @app.post("/upload")
67
  async def upload_image(
@@ -73,22 +98,22 @@ async def upload_image(
73
  if not file.content_type.startswith("image/"):
74
  raise HTTPException(status_code=400, detail="Invalid file type")
75
 
76
- image_id = str(uuid.uuid4())
77
- filename = f"{image_id}.jpg"
78
- path = os.path.join(UPLOAD_DIR, filename)
79
 
80
- with open(path, "wb") as f:
81
  f.write(await file.read())
82
 
 
 
83
  return {
84
  "success": True,
85
- "image_id": image_id,
86
- "image_url": f"https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/uploads/{filename}",
87
- "filename": filename
88
  }
89
 
90
  # -------------------------------------------------
91
- # 🎨 3. Colorize Image
92
  # -------------------------------------------------
93
  @app.post("/colorize")
94
  async def colorize(
@@ -103,43 +128,46 @@ async def colorize(
103
  img = Image.open(io.BytesIO(await file.read()))
104
  output_img = colorize_image(img)
105
 
106
- result_id = str(uuid.uuid4())
107
- filename = f"{result_id}.jpg"
108
- path = os.path.join(RESULTS_DIR, filename)
109
- output_img.save(path)
110
 
111
  base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
 
112
  return {
113
  "success": True,
114
- "result_id": result_id,
115
- "download_url": f"{base_url}/results/{filename}",
116
- "api_download_url": f"{base_url}/download/{result_id}",
117
- "filename": filename
118
  }
119
 
120
  # -------------------------------------------------
121
- # ⬇️ 4. Download Processed Image
122
  # -------------------------------------------------
123
  @app.get("/download/{file_id}")
124
  def download_result(file_id: str, x_firebase_appcheck: str = Header(None)):
125
  verify_app_check_token(x_firebase_appcheck)
126
- path = os.path.join(RESULTS_DIR, f"{file_id}.jpg")
 
 
 
127
  if not os.path.exists(path):
128
- raise HTTPException(status_code=404, detail="File not found")
 
129
  return FileResponse(path, media_type="image/jpeg")
130
 
131
  # -------------------------------------------------
132
- # 🌈 5. Get Result (Public URL)
133
  # -------------------------------------------------
134
  @app.get("/results/{filename}")
135
  def get_result(filename: str):
136
  path = os.path.join(RESULTS_DIR, filename)
137
  if not os.path.exists(path):
138
- raise HTTPException(status_code=404, detail="File not found")
139
  return FileResponse(path, media_type="image/jpeg")
140
 
141
  # -------------------------------------------------
142
- # πŸ–ΌοΈ 6. Get Uploaded Image (Public URL)
143
  # -------------------------------------------------
144
  @app.get("/uploads/{filename}")
145
  def get_upload(filename: str):
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
  from fastapi.responses import FileResponse
3
  from huggingface_hub import hf_hub_download
 
4
  import uuid
5
  import os
6
+ import io
7
+ import json
8
  from PIL import Image
9
  import torch
 
10
  from torchvision import transforms
11
 
12
+ # -------------------------------------------------
13
+ # πŸš€ FastAPI App
14
+ # -------------------------------------------------
15
  app = FastAPI(title="Text-Guided Image Colorization API")
16
 
17
  # -------------------------------------------------
18
+ # πŸ” Firebase Initialization (ENV-based)
19
  # -------------------------------------------------
20
+ try:
21
+ import firebase_admin
22
+ from firebase_admin import credentials, app_check
23
+
24
+ firebase_json = os.getenv("FIREBASE_CREDENTIALS")
25
+
26
+ if firebase_json:
27
+ print("πŸ”₯ Loading Firebase credentials from ENV...")
28
+ firebase_dict = json.loads(firebase_json)
29
+ cred = credentials.Certificate(firebase_dict)
30
+ firebase_admin.initialize_app(cred)
31
+ else:
32
+ print("⚠️ No Firebase credentials found. Firebase disabled.")
33
 
34
+ except Exception as e:
35
+ print("❌ Firebase initialization failed:", e)
36
+
37
+ # -------------------------------------------------
38
+ # πŸ“ Directories
39
+ # -------------------------------------------------
40
  UPLOAD_DIR = "uploads"
41
  RESULTS_DIR = "results"
42
  os.makedirs(UPLOAD_DIR, exist_ok=True)
43
  os.makedirs(RESULTS_DIR, exist_ok=True)
44
 
45
  # -------------------------------------------------
46
+ # 🧠 Load GAN Colorization Model
47
  # -------------------------------------------------
48
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
49
  MODEL_FILENAME = "generator.pt"
50
+
51
+ print("⬇️ Downloading model...")
52
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
53
+
54
+ print("πŸ“¦ Loading model weights...")
55
  state_dict = torch.load(model_path, map_location="cpu")
56
 
57
+ # NOTE:
58
+ # Replace this with your actual model class.
59
+ # This is just example placeholder.
60
+ # from model import ColorizeNet
61
  # model = ColorizeNet()
62
  # model.load_state_dict(state_dict)
63
  # model.eval()
64
 
 
65
  def colorize_image(img: Image.Image):
66
+ """ Dummy colorizer (convert grayscale β†’ fake color) """
67
  transform = transforms.ToTensor()
68
  tensor = transform(img.convert("L")).unsqueeze(0)
69
  tensor = tensor.repeat(1, 3, 1, 1)
 
71
  return output_img
72
 
73
  # -------------------------------------------------
74
+ # 🩺 Health Check
75
  # -------------------------------------------------
76
  @app.get("/health")
77
  def health_check():
78
  return {"status": "healthy", "model_loaded": True}
79
 
80
  # -------------------------------------------------
81
+ # πŸ” Firebase Token Validator
82
  # -------------------------------------------------
83
  def verify_app_check_token(token: str):
 
84
  if not token or len(token) < 20:
85
+ raise HTTPException(status_code=401, detail="Invalid Firebase App Check token")
86
  return True
87
 
88
  # -------------------------------------------------
89
+ # πŸ“€ Upload Image
90
  # -------------------------------------------------
91
  @app.post("/upload")
92
  async def upload_image(
 
98
  if not file.content_type.startswith("image/"):
99
  raise HTTPException(status_code=400, detail="Invalid file type")
100
 
101
+ image_id = str(uuid.uuid4()) + ".jpg"
102
+ file_path = os.path.join(UPLOAD_DIR, image_id)
 
103
 
104
+ with open(file_path, "wb") as f:
105
  f.write(await file.read())
106
 
107
+ base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
108
+
109
  return {
110
  "success": True,
111
+ "image_id": image_id.replace(".jpg", ""),
112
+ "file_url": f"{base_url}/uploads/{image_id}"
 
113
  }
114
 
115
  # -------------------------------------------------
116
+ # 🎨 Colorize Image
117
  # -------------------------------------------------
118
  @app.post("/colorize")
119
  async def colorize(
 
128
  img = Image.open(io.BytesIO(await file.read()))
129
  output_img = colorize_image(img)
130
 
131
+ result_id = str(uuid.uuid4()) + ".jpg"
132
+ output_path = os.path.join(RESULTS_DIR, result_id)
133
+ output_img.save(output_path)
 
134
 
135
  base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
136
+
137
  return {
138
  "success": True,
139
+ "result_id": result_id.replace(".jpg", ""),
140
+ "download_url": f"{base_url}/results/{result_id}",
141
+ "api_download": f"{base_url}/download/{result_id.replace('.jpg','')}"
 
142
  }
143
 
144
  # -------------------------------------------------
145
+ # ⬇️ Download via API (Secure)
146
  # -------------------------------------------------
147
  @app.get("/download/{file_id}")
148
  def download_result(file_id: str, x_firebase_appcheck: str = Header(None)):
149
  verify_app_check_token(x_firebase_appcheck)
150
+
151
+ filename = f"{file_id}.jpg"
152
+ path = os.path.join(RESULTS_DIR, filename)
153
+
154
  if not os.path.exists(path):
155
+ raise HTTPException(status_code=404, detail="Result not found")
156
+
157
  return FileResponse(path, media_type="image/jpeg")
158
 
159
  # -------------------------------------------------
160
+ # 🌐 Public Result File
161
  # -------------------------------------------------
162
  @app.get("/results/{filename}")
163
  def get_result(filename: str):
164
  path = os.path.join(RESULTS_DIR, filename)
165
  if not os.path.exists(path):
166
+ raise HTTPException(status_code=404, detail="Result not found")
167
  return FileResponse(path, media_type="image/jpeg")
168
 
169
  # -------------------------------------------------
170
+ # 🌐 Public Uploaded File
171
  # -------------------------------------------------
172
  @app.get("/uploads/{filename}")
173
  def get_upload(filename: str):