LogicGoInfotechSpaces commited on
Commit
e4599d1
·
1 Parent(s): 0454a91

Integrate FastAI colorization with Firebase auth and Gradio UI - Replace main.py with FastAI implementation - Add Gradio interface for Space UI - Add Firebase authentication to /colorize endpoint - Add curl examples documentation - Update test.py with User-Agent headers

Browse files
Files changed (6) hide show
  1. CURL_EXAMPLES.md +83 -0
  2. Dockerfile +1 -1
  3. app/main.py +138 -191
  4. app/main_fastai.py +301 -0
  5. requirements.txt +1 -0
  6. test.py +5 -4
CURL_EXAMPLES.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cURL Examples for Colorization API
2
+
3
+ ## Base URL
4
+ ```
5
+ https://logicgoinfotechspaces-text-guided-image-colorization.hf.space
6
+ ```
7
+
8
+ ## 1. Health Check (No Auth Required)
9
+ ```bash
10
+ curl -X GET "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/health" \
11
+ -H "User-Agent: Mozilla/5.0"
12
+ ```
13
+
14
+ ## 2. Colorize Image (With Firebase Auth)
15
+
16
+ ### Using Firebase ID Token
17
+ ```bash
18
+ curl -X POST "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/colorize" \
19
+ -H "Authorization: Bearer YOUR_FIREBASE_ID_TOKEN" \
20
+ -H "User-Agent: Mozilla/5.0" \
21
+ -F "file=@/path/to/your/image.jpg"
22
+ ```
23
+
24
+ ### Using App Check Token
25
+ ```bash
26
+ curl -X POST "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/colorize" \
27
+ -H "X-Firebase-AppCheck: YOUR_APP_CHECK_TOKEN" \
28
+ -H "User-Agent: Mozilla/5.0" \
29
+ -F "file=@/path/to/your/image.jpg"
30
+ ```
31
+
32
+ ### Without Auth (if DISABLE_AUTH=true)
33
+ ```bash
34
+ curl -X POST "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/colorize" \
35
+ -H "User-Agent: Mozilla/5.0" \
36
+ -F "file=@/path/to/your/image.jpg" \
37
+ -o colorized_result.png
38
+ ```
39
+
40
+ ## 3. Windows PowerShell Example
41
+ ```powershell
42
+ $imagePath = "C:\projects\colorize_text\pexels-andrey-grushnikov-223358-707676.jpg"
43
+ $headers = @{
44
+ "User-Agent" = "Mozilla/5.0"
45
+ }
46
+ $form = @{
47
+ file = Get-Item $imagePath
48
+ }
49
+ Invoke-RestMethod -Uri "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/colorize" `
50
+ -Method Post `
51
+ -Headers $headers `
52
+ -Form $form `
53
+ -OutFile "colorized_result.png"
54
+ ```
55
+
56
+ ## 4. Python requests Example
57
+ ```python
58
+ import requests
59
+
60
+ url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space/colorize"
61
+ headers = {
62
+ "User-Agent": "Mozilla/5.0",
63
+ # "Authorization": "Bearer YOUR_FIREBASE_ID_TOKEN", # Uncomment if auth enabled
64
+ }
65
+ files = {
66
+ "file": ("image.jpg", open("path/to/image.jpg", "rb"), "image/jpeg")
67
+ }
68
+
69
+ response = requests.post(url, headers=headers, files=files)
70
+ if response.status_code == 200:
71
+ with open("colorized_result.png", "wb") as f:
72
+ f.write(response.content)
73
+ print("Colorized image saved!")
74
+ else:
75
+ print(f"Error: {response.status_code} - {response.text}")
76
+ ```
77
+
78
+ ## Response
79
+ - **Success (200)**: Returns PNG image file
80
+ - **Error (400)**: Bad request (invalid file type)
81
+ - **Error (401)**: Unauthorized (missing/invalid auth token)
82
+ - **Error (503)**: Service unavailable (model not loaded)
83
+
Dockerfile CHANGED
@@ -63,4 +63,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
63
  ENTRYPOINT ["/entrypoint.sh"]
64
 
65
  # Run the application (port will be set via environment variable)
66
- CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860}"]
 
63
  ENTRYPOINT ["/entrypoint.sh"]
64
 
65
  # Run the application (port will be set via environment variable)
66
+ CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860}"]
app/main.py CHANGED
@@ -1,43 +1,39 @@
1
  """
2
- FastAPI application for image colorization using ColorizeNet model
3
- with Firebase App Check integration
4
  """
5
  import os
6
- # Set environment variables BEFORE any imports to ensure they're used
7
- # Set OMP_NUM_THREADS before any torch imports to avoid libgomp warnings
8
  os.environ["OMP_NUM_THREADS"] = "1"
9
- # Set HF cache directories to writable /tmp location BEFORE any HF imports
10
  os.environ["HF_HOME"] = "/tmp/hf_cache"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
13
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
14
  os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
15
- # Set matplotlib config directory to writable location
16
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
17
 
 
18
  import uuid
19
  import logging
20
  from pathlib import Path
21
  from typing import Optional
22
- from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Request
 
23
  from fastapi.responses import FileResponse, JSONResponse
24
  from fastapi.middleware.cors import CORSMiddleware
25
  from fastapi.staticfiles import StaticFiles
26
  import firebase_admin
27
  from firebase_admin import credentials, app_check, auth as firebase_auth
28
- import numpy as np
29
- import torch
30
  from PIL import Image
31
- import io
 
 
32
 
33
- from app.colorize_model import ColorizeModel
34
- from app.config import settings
 
35
 
36
- # Create writable directories (env vars already set above)
37
- Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True)
38
- Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True)
39
- Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True)
40
- Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True)
41
 
42
  # Configure logging
43
  logging.basicConfig(
@@ -46,17 +42,23 @@ logging.basicConfig(
46
  )
47
  logger = logging.getLogger(__name__)
48
 
 
 
 
 
 
 
49
  # Initialize FastAPI app
50
  app = FastAPI(
51
- title="Colorize API",
52
- description="Image colorization API using ColorizeNet model",
53
  version="1.0.0"
54
  )
55
 
56
  # CORS middleware
57
  app.add_middleware(
58
  CORSMiddleware,
59
- allow_origins=["*"], # Configure appropriately for production
60
  allow_credentials=True,
61
  allow_methods=["*"],
62
  allow_headers=["*"],
@@ -71,7 +73,10 @@ if os.path.exists(firebase_cred_path):
71
  logger.info("Firebase Admin SDK initialized")
72
  except Exception as e:
73
  logger.warning("Failed to initialize Firebase: %s", str(e))
74
- firebase_admin.initialize_app()
 
 
 
75
  else:
76
  logger.warning("Firebase credentials file not found. App Check will be disabled.")
77
  try:
@@ -79,74 +84,40 @@ else:
79
  except:
80
  pass
81
 
82
- # Always use /data in Spaces (writable). Allow override via env.
83
- DATA_ROOT = Path(os.getenv("DATA_DIR", "/data"))
 
84
 
85
-
86
- def _ensure_dir(preferred: Path, fallback: Path) -> Path:
87
- try:
88
- preferred.mkdir(parents=True, exist_ok=True)
89
- return preferred
90
- except Exception as exc:
91
- logger.warning("Failed to create directory %s: %s. Falling back to %s", preferred, exc, fallback)
92
- try:
93
- fallback.mkdir(parents=True, exist_ok=True)
94
- return fallback
95
- except Exception as fallback_exc:
96
- logger.error("Could not create fallback directory %s: %s", fallback, fallback_exc)
97
- raise
98
-
99
-
100
- UPLOAD_DIR = _ensure_dir(
101
- Path(os.getenv("UPLOAD_DIR", str(DATA_ROOT / "uploads"))),
102
- Path("/tmp/colorize_uploads")
103
- )
104
- RESULT_DIR = _ensure_dir(
105
- Path(os.getenv("RESULT_DIR", str(DATA_ROOT / "results"))),
106
- Path("/tmp/colorize_results")
107
- )
108
- logger.info("Storage directories -> uploads: %s, results: %s", str(UPLOAD_DIR), str(RESULT_DIR))
109
-
110
- # Mount static files for serving results from resolved directories
111
  app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
112
  app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
113
 
114
- # Initialize ColorizeNet model
115
- colorize_model = None
116
  model_load_error: Optional[str] = None
117
 
118
- @app.get("/")
119
- async def root():
120
- return {
121
- "app": "Colorize API",
122
- "version": "1.0.0",
123
- "health": "/health",
124
- "upload": "/upload",
125
- "colorize": "/colorize"
126
- }
127
-
128
  @app.on_event("startup")
129
  async def startup_event():
130
- """Initialize the colorization model on startup"""
131
- global colorize_model, model_load_error
132
  try:
133
- logger.info("Loading ColorizeNet model with MODEL_ID: %s", settings.MODEL_ID)
134
- logger.info("HF_TOKEN present: %s", "Yes" if os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") else "No")
135
- colorize_model = ColorizeModel(settings.MODEL_ID)
136
- logger.info("ColorizeNet model loaded successfully")
137
  model_load_error = None
138
  except Exception as e:
139
  error_msg = str(e)
140
- logger.error("Failed to load ColorizeNet model: %s", error_msg)
141
  model_load_error = error_msg
142
- # Don't raise - allow health check to work even if model fails
143
 
144
  @app.on_event("shutdown")
145
  async def shutdown_event():
146
  """Cleanup on shutdown"""
147
- global colorize_model
148
- if colorize_model:
149
- del colorize_model
150
  logger.info("Application shutdown")
151
 
152
  def _extract_bearer_token(authorization_header: str | None) -> str | None:
@@ -157,9 +128,9 @@ def _extract_bearer_token(authorization_header: str | None) -> str | None:
157
  return parts[1].strip()
158
  return None
159
 
160
-
161
  async def verify_request(request: Request):
162
  """
 
163
  Accept either:
164
  - Firebase Auth id_token via Authorization: Bearer <id_token>
165
  - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
@@ -173,12 +144,11 @@ async def verify_request(request: Request):
173
  if bearer:
174
  try:
175
  decoded = firebase_auth.verify_id_token(bearer)
176
- request.state.user = decoded # make claims available if needed
177
  logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
178
  return True
179
  except Exception as e:
180
  logger.warning("Auth token verification failed: %s", str(e))
181
- # fall through to App Check if enabled
182
 
183
  # If App Check is enabled, require valid App Check token
184
  if settings.ENABLE_APP_CHECK:
@@ -196,159 +166,136 @@ async def verify_request(request: Request):
196
  # Neither token required nor provided → allow (App Check disabled)
197
  return True
198
 
 
 
 
 
 
 
 
 
 
 
 
199
  @app.get("/health")
200
  async def health_check():
201
  """Health check endpoint"""
202
  response = {
203
  "status": "healthy",
204
- "model_loaded": colorize_model is not None,
205
- "model_id": settings.MODEL_ID
206
  }
207
  if model_load_error:
208
  response["model_error"] = model_load_error
209
  return response
210
 
211
- @app.post("/upload")
212
- async def upload_image(
213
- file: UploadFile = File(...),
214
- verified: bool = Depends(verify_request)
215
- ):
216
- """
217
- Upload an image and return the uploaded image URL
218
- """
219
- if not file.content_type or not file.content_type.startswith("image/"):
220
- raise HTTPException(status_code=400, detail="File must be an image")
 
 
221
 
222
- # Generate unique filename
223
- file_id = str(uuid.uuid4())
224
- file_extension = Path(file.filename).suffix or ".jpg"
225
- filename = f"{file_id}{file_extension}"
226
- filepath = UPLOAD_DIR / filename
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- # Save uploaded file
229
- try:
230
- contents = await file.read()
231
- with open(filepath, "wb") as f:
232
- f.write(contents)
233
- logger.info("Image uploaded: %s", filename)
234
-
235
- # Return the URL to access the uploaded image
236
- base_url = os.getenv("BASE_URL", os.getenv("SPACE_HOST", "http://localhost:7860"))
237
- image_url = f"{base_url}/uploads/{filename}"
238
-
239
- return {
240
- "success": True,
241
- "image_id": file_id,
242
- "image_url": image_url,
243
- "filename": filename
244
- }
245
- except Exception as e:
246
- logger.error("Error uploading image: %s", str(e))
247
- raise HTTPException(status_code=500, detail=f"Error uploading image: {str(e)}")
248
 
249
  @app.post("/colorize")
250
- async def colorize_image(
251
  file: UploadFile = File(...),
252
  verified: bool = Depends(verify_request)
253
  ):
254
  """
255
- Colorize an uploaded grayscale image using ColorizeNet
256
- Returns the colorized image URL
257
  """
258
- if colorize_model is None:
259
  raise HTTPException(status_code=503, detail="Colorization model not loaded")
260
 
261
  if not file.content_type or not file.content_type.startswith("image/"):
262
  raise HTTPException(status_code=400, detail="File must be an image")
263
 
264
  try:
265
- # Read image
266
- contents = await file.read()
267
- image = Image.open(io.BytesIO(contents))
268
-
269
- # Convert to RGB if needed
270
- if image.mode != "RGB":
271
- image = image.convert("RGB")
272
 
273
- # Colorize the image
274
  logger.info("Colorizing image...")
275
- colorized_image, caption = colorize_model.colorize(image)
276
 
277
- # Save colorized image
278
- file_id = str(uuid.uuid4())
279
- result_filename = f"{file_id}.jpg"
280
- result_filepath = RESULT_DIR / result_filename
281
 
282
- colorized_image.save(result_filepath, "JPEG", quality=95)
283
- logger.info("Colorized image saved: %s", result_filename)
284
 
285
- # Return URLs
286
- base_url = os.getenv("BASE_URL", os.getenv("SPACE_HOST", "http://localhost:7860"))
287
- download_url = f"{base_url}/results/{result_filename}"
288
- api_download_url = f"{base_url}/download/{file_id}"
289
-
290
- return {
291
- "success": True,
292
- "result_id": file_id,
293
- "download_url": download_url,
294
- "api_download_url": api_download_url,
295
- "filename": result_filename,
296
- "caption": caption
297
- }
298
  except Exception as e:
299
  logger.error("Error colorizing image: %s", str(e))
300
  raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}")
301
 
302
- @app.get("/download/{file_id}")
303
- async def download_result(
304
- file_id: str,
305
- verified: bool = Depends(verify_request)
306
- ):
307
- """
308
- Download the colorized image by file ID
309
- """
310
- result_filepath = RESULT_DIR / f"{file_id}.jpg"
311
-
312
- if not result_filepath.exists():
313
- raise HTTPException(status_code=404, detail="Result not found")
314
-
315
- return FileResponse(
316
- result_filepath,
317
- media_type="image/jpeg",
318
- filename=f"colorized_{file_id}.jpg"
319
- )
320
 
321
- @app.get("/results/{filename}")
322
- async def get_result_file(filename: str):
323
- """
324
- Serve result files directly (public endpoint for browser access)
325
- """
326
- result_filepath = RESULT_DIR / filename
327
-
328
- if not result_filepath.exists():
329
- raise HTTPException(status_code=404, detail="File not found")
330
-
331
- return FileResponse(
332
- result_filepath,
333
- media_type="image/jpeg"
334
- )
335
 
336
- @app.get("/uploads/{filename}")
337
- async def get_upload_file(filename: str):
338
- """
339
- Serve uploaded files directly
340
- """
341
- upload_filepath = UPLOAD_DIR / filename
342
-
343
- if not upload_filepath.exists():
344
- raise HTTPException(status_code=404, detail="File not found")
345
-
346
- return FileResponse(
347
- upload_filepath,
348
- media_type="image/jpeg"
349
- )
350
 
 
 
 
 
 
 
 
351
  if __name__ == "__main__":
352
- import uvicorn
353
  port = int(os.getenv("PORT", "7860"))
354
  uvicorn.run(app, host="0.0.0.0", port=port)
 
 
1
  """
2
+ FastAPI application for FastAI GAN Image Colorization
3
+ with Firebase Authentication and Gradio UI
4
  """
5
  import os
6
+ # Set environment variables BEFORE any imports
 
7
  os.environ["OMP_NUM_THREADS"] = "1"
 
8
  os.environ["HF_HOME"] = "/tmp/hf_cache"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
11
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
12
  os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
 
13
  os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
14
 
15
+ import io
16
  import uuid
17
  import logging
18
  from pathlib import Path
19
  from typing import Optional
20
+
21
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request
22
  from fastapi.responses import FileResponse, JSONResponse
23
  from fastapi.middleware.cors import CORSMiddleware
24
  from fastapi.staticfiles import StaticFiles
25
  import firebase_admin
26
  from firebase_admin import credentials, app_check, auth as firebase_auth
 
 
27
  from PIL import Image
28
+ import torch
29
+ import uvicorn
30
+ import gradio as gr
31
 
32
+ # FastAI imports
33
+ from fastai.vision.all import *
34
+ from huggingface_hub import from_pretrained_fastai
35
 
36
+ from app.config import settings
 
 
 
 
37
 
38
  # Configure logging
39
  logging.basicConfig(
 
42
  )
43
  logger = logging.getLogger(__name__)
44
 
45
+ # Create writable directories
46
+ Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True)
47
+ Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True)
48
+ Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True)
49
+ Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True)
50
+
51
  # Initialize FastAPI app
52
  app = FastAPI(
53
+ title="FastAI Image Colorizer API",
54
+ description="Image colorization using FastAI GAN model with Firebase authentication",
55
  version="1.0.0"
56
  )
57
 
58
  # CORS middleware
59
  app.add_middleware(
60
  CORSMiddleware,
61
+ allow_origins=["*"],
62
  allow_credentials=True,
63
  allow_methods=["*"],
64
  allow_headers=["*"],
 
73
  logger.info("Firebase Admin SDK initialized")
74
  except Exception as e:
75
  logger.warning("Failed to initialize Firebase: %s", str(e))
76
+ try:
77
+ firebase_admin.initialize_app()
78
+ except:
79
+ pass
80
  else:
81
  logger.warning("Firebase credentials file not found. App Check will be disabled.")
82
  try:
 
84
  except:
85
  pass
86
 
87
+ # Storage directories
88
+ UPLOAD_DIR = Path("/tmp/colorize_uploads")
89
+ RESULT_DIR = Path("/tmp/colorize_results")
90
 
91
+ # Mount static files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
93
  app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
94
 
95
+ # Initialize FastAI model
96
+ learn = None
97
  model_load_error: Optional[str] = None
98
 
 
 
 
 
 
 
 
 
 
 
99
  @app.on_event("startup")
100
  async def startup_event():
101
+ """Load FastAI model on startup"""
102
+ global learn, model_load_error
103
  try:
104
+ model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
105
+ logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
106
+ learn = from_pretrained_fastai(model_id)
107
+ logger.info(" Model loaded successfully!")
108
  model_load_error = None
109
  except Exception as e:
110
  error_msg = str(e)
111
+ logger.error("Failed to load model: %s", error_msg)
112
  model_load_error = error_msg
113
+ # Don't raise - allow health check to work
114
 
115
  @app.on_event("shutdown")
116
  async def shutdown_event():
117
  """Cleanup on shutdown"""
118
+ global learn
119
+ if learn:
120
+ del learn
121
  logger.info("Application shutdown")
122
 
123
  def _extract_bearer_token(authorization_header: str | None) -> str | None:
 
128
  return parts[1].strip()
129
  return None
130
 
 
131
  async def verify_request(request: Request):
132
  """
133
+ Verify Firebase authentication
134
  Accept either:
135
  - Firebase Auth id_token via Authorization: Bearer <id_token>
136
  - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
 
144
  if bearer:
145
  try:
146
  decoded = firebase_auth.verify_id_token(bearer)
147
+ request.state.user = decoded
148
  logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
149
  return True
150
  except Exception as e:
151
  logger.warning("Auth token verification failed: %s", str(e))
 
152
 
153
  # If App Check is enabled, require valid App Check token
154
  if settings.ENABLE_APP_CHECK:
 
166
  # Neither token required nor provided → allow (App Check disabled)
167
  return True
168
 
169
+ @app.get("/api")
170
+ async def api_info():
171
+ """API info endpoint"""
172
+ return {
173
+ "app": "FastAI Image Colorizer API",
174
+ "version": "1.0.0",
175
+ "health": "/health",
176
+ "colorize": "/colorize",
177
+ "gradio": "/"
178
+ }
179
+
180
  @app.get("/health")
181
  async def health_check():
182
  """Health check endpoint"""
183
  response = {
184
  "status": "healthy",
185
+ "model_loaded": learn is not None,
186
+ "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
187
  }
188
  if model_load_error:
189
  response["model_error"] = model_load_error
190
  return response
191
 
192
+ def colorize_pil(image: Image.Image) -> Image.Image:
193
+ """Run model prediction and return colorized image"""
194
+ if learn is None:
195
+ raise RuntimeError("Model not loaded")
196
+ if image.mode != "RGB":
197
+ image = image.convert("RGB")
198
+ pred = learn.predict(image)
199
+ # Handle different return types from FastAI
200
+ if isinstance(pred, (list, tuple)):
201
+ colorized = pred[0] if len(pred) > 0 else image
202
+ else:
203
+ colorized = pred
204
 
205
+ # Ensure we have a PIL Image
206
+ if not isinstance(colorized, Image.Image):
207
+ if isinstance(colorized, torch.Tensor):
208
+ # Convert tensor to PIL
209
+ if colorized.dim() == 4:
210
+ colorized = colorized[0]
211
+ if colorized.dim() == 3:
212
+ colorized = colorized.permute(1, 2, 0).cpu()
213
+ if colorized.dtype in (torch.float32, torch.float16):
214
+ colorized = torch.clamp(colorized, 0, 1)
215
+ colorized = (colorized * 255).byte()
216
+ colorized = Image.fromarray(colorized.numpy(), 'RGB')
217
+ else:
218
+ raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
219
+ else:
220
+ raise ValueError(f"Unexpected prediction type: {type(colorized)}")
221
 
222
+ if colorized.mode != "RGB":
223
+ colorized = colorized.convert("RGB")
224
+
225
+ return colorized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  @app.post("/colorize")
228
+ async def colorize_api(
229
  file: UploadFile = File(...),
230
  verified: bool = Depends(verify_request)
231
  ):
232
  """
233
+ Upload a black & white image -> returns colorized image.
234
+ Requires Firebase authentication unless DISABLE_AUTH=true
235
  """
236
+ if learn is None:
237
  raise HTTPException(status_code=503, detail="Colorization model not loaded")
238
 
239
  if not file.content_type or not file.content_type.startswith("image/"):
240
  raise HTTPException(status_code=400, detail="File must be an image")
241
 
242
  try:
243
+ img_bytes = await file.read()
244
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
 
 
 
 
245
 
 
246
  logger.info("Colorizing image...")
247
+ colorized = colorize_pil(image)
248
 
249
+ output_filename = f"{uuid.uuid4()}.png"
250
+ output_path = RESULT_DIR / output_filename
251
+ colorized.save(output_path, "PNG")
 
252
 
253
+ logger.info("Colorized image saved: %s", output_filename)
 
254
 
255
+ # Return the image file
256
+ return FileResponse(
257
+ output_path,
258
+ media_type="image/png",
259
+ filename=f"colorized_{output_filename}"
260
+ )
 
 
 
 
 
 
 
261
  except Exception as e:
262
  logger.error("Error colorizing image: %s", str(e))
263
  raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}")
264
 
265
+ # ==========================================================
266
+ # Gradio Interface (for Space UI)
267
+ # ==========================================================
268
+ def gradio_colorize(image):
269
+ """Gradio colorization function"""
270
+ if image is None:
271
+ return None
272
+ try:
273
+ if learn is None:
274
+ return None
275
+ return colorize_pil(image)
276
+ except Exception as e:
277
+ logger.error("Gradio colorization error: %s", str(e))
278
+ return None
 
 
 
 
279
 
280
+ title = "🎨 FastAI GAN Image Colorizer"
281
+ description = "Upload a black & white photo to generate a colorized version using the FastAI GAN model."
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ iface = gr.Interface(
284
+ fn=gradio_colorize,
285
+ inputs=gr.Image(type="pil", label="Upload B&W Image"),
286
+ outputs=gr.Image(type="pil", label="Colorized Image"),
287
+ title=title,
288
+ description=description,
289
+ )
 
 
 
 
 
 
 
290
 
291
+ # Mount Gradio app at root (this will be the Space UI)
292
+ # Note: This will override the root endpoint, so use /api for API info
293
+ app = gr.mount_gradio_app(app, iface, path="/")
294
+
295
+ # ==========================================================
296
+ # Run Server
297
+ # ==========================================================
298
  if __name__ == "__main__":
 
299
  port = int(os.getenv("PORT", "7860"))
300
  uvicorn.run(app, host="0.0.0.0", port=port)
301
+
app/main_fastai.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for FastAI GAN Image Colorization
3
+ with Firebase Authentication and Gradio UI
4
+ """
5
+ import os
6
+ # Set environment variables BEFORE any imports
7
+ os.environ["OMP_NUM_THREADS"] = "1"
8
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
+ os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
11
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
12
+ os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
13
+ os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
14
+
15
+ import io
16
+ import uuid
17
+ import logging
18
+ from pathlib import Path
19
+ from typing import Optional
20
+
21
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request
22
+ from fastapi.responses import FileResponse, JSONResponse
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ from fastapi.staticfiles import StaticFiles
25
+ import firebase_admin
26
+ from firebase_admin import credentials, app_check, auth as firebase_auth
27
+ from PIL import Image
28
+ import torch
29
+ import uvicorn
30
+ import gradio as gr
31
+
32
+ # FastAI imports
33
+ from fastai.vision.all import *
34
+ from huggingface_hub import from_pretrained_fastai
35
+
36
+ from app.config import settings
37
+
38
+ # Configure logging
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
42
+ )
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # Create writable directories
46
+ Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True)
47
+ Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True)
48
+ Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True)
49
+ Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True)
50
+
51
+ # Initialize FastAPI app
52
+ app = FastAPI(
53
+ title="FastAI Image Colorizer API",
54
+ description="Image colorization using FastAI GAN model with Firebase authentication",
55
+ version="1.0.0"
56
+ )
57
+
58
+ # CORS middleware
59
+ app.add_middleware(
60
+ CORSMiddleware,
61
+ allow_origins=["*"],
62
+ allow_credentials=True,
63
+ allow_methods=["*"],
64
+ allow_headers=["*"],
65
+ )
66
+
67
+ # Initialize Firebase Admin SDK
68
+ firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json")
69
+ if os.path.exists(firebase_cred_path):
70
+ try:
71
+ cred = credentials.Certificate(firebase_cred_path)
72
+ firebase_admin.initialize_app(cred)
73
+ logger.info("Firebase Admin SDK initialized")
74
+ except Exception as e:
75
+ logger.warning("Failed to initialize Firebase: %s", str(e))
76
+ try:
77
+ firebase_admin.initialize_app()
78
+ except:
79
+ pass
80
+ else:
81
+ logger.warning("Firebase credentials file not found. App Check will be disabled.")
82
+ try:
83
+ firebase_admin.initialize_app()
84
+ except:
85
+ pass
86
+
87
+ # Storage directories
88
+ UPLOAD_DIR = Path("/tmp/colorize_uploads")
89
+ RESULT_DIR = Path("/tmp/colorize_results")
90
+
91
+ # Mount static files
92
+ app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
93
+ app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
94
+
95
+ # Initialize FastAI model
96
+ learn = None
97
+ model_load_error: Optional[str] = None
98
+
99
+ @app.on_event("startup")
100
+ async def startup_event():
101
+ """Load FastAI model on startup"""
102
+ global learn, model_load_error
103
+ try:
104
+ model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
105
+ logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
106
+ learn = from_pretrained_fastai(model_id)
107
+ logger.info("✅ Model loaded successfully!")
108
+ model_load_error = None
109
+ except Exception as e:
110
+ error_msg = str(e)
111
+ logger.error("❌ Failed to load model: %s", error_msg)
112
+ model_load_error = error_msg
113
+ # Don't raise - allow health check to work
114
+
115
+ @app.on_event("shutdown")
116
+ async def shutdown_event():
117
+ """Cleanup on shutdown"""
118
+ global learn
119
+ if learn:
120
+ del learn
121
+ logger.info("Application shutdown")
122
+
123
+ def _extract_bearer_token(authorization_header: str | None) -> str | None:
124
+ if not authorization_header:
125
+ return None
126
+ parts = authorization_header.split(" ", 1)
127
+ if len(parts) == 2 and parts[0].lower() == "bearer":
128
+ return parts[1].strip()
129
+ return None
130
+
131
+ async def verify_request(request: Request):
132
+ """
133
+ Verify Firebase authentication
134
+ Accept either:
135
+ - Firebase Auth id_token via Authorization: Bearer <id_token>
136
+ - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
137
+ """
138
+ # If Firebase is not initialized or auth is explicitly disabled, allow
139
+ if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true":
140
+ return True
141
+
142
+ # Try Firebase Auth id_token first if present
143
+ bearer = _extract_bearer_token(request.headers.get("Authorization"))
144
+ if bearer:
145
+ try:
146
+ decoded = firebase_auth.verify_id_token(bearer)
147
+ request.state.user = decoded
148
+ logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
149
+ return True
150
+ except Exception as e:
151
+ logger.warning("Auth token verification failed: %s", str(e))
152
+
153
+ # If App Check is enabled, require valid App Check token
154
+ if settings.ENABLE_APP_CHECK:
155
+ app_check_token = request.headers.get("X-Firebase-AppCheck")
156
+ if not app_check_token:
157
+ raise HTTPException(status_code=401, detail="Missing App Check token")
158
+ try:
159
+ app_check_claims = app_check.verify_token(app_check_token)
160
+ logger.info("App Check token verified for: %s", app_check_claims.get("app_id"))
161
+ return True
162
+ except Exception as e:
163
+ logger.warning("App Check token verification failed: %s", str(e))
164
+ raise HTTPException(status_code=401, detail="Invalid App Check token")
165
+
166
+ # Neither token required nor provided → allow (App Check disabled)
167
+ return True
168
+
169
+ @app.get("/api")
170
+ async def api_info():
171
+ """API info endpoint"""
172
+ return {
173
+ "app": "FastAI Image Colorizer API",
174
+ "version": "1.0.0",
175
+ "health": "/health",
176
+ "colorize": "/colorize",
177
+ "gradio": "/"
178
+ }
179
+
180
+ @app.get("/health")
181
+ async def health_check():
182
+ """Health check endpoint"""
183
+ response = {
184
+ "status": "healthy",
185
+ "model_loaded": learn is not None,
186
+ "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
187
+ }
188
+ if model_load_error:
189
+ response["model_error"] = model_load_error
190
+ return response
191
+
192
+ def colorize_pil(image: Image.Image) -> Image.Image:
193
+ """Run model prediction and return colorized image"""
194
+ if learn is None:
195
+ raise RuntimeError("Model not loaded")
196
+ if image.mode != "RGB":
197
+ image = image.convert("RGB")
198
+ pred = learn.predict(image)
199
+ # Handle different return types from FastAI
200
+ if isinstance(pred, (list, tuple)):
201
+ colorized = pred[0] if len(pred) > 0 else image
202
+ else:
203
+ colorized = pred
204
+
205
+ # Ensure we have a PIL Image
206
+ if not isinstance(colorized, Image.Image):
207
+ if isinstance(colorized, torch.Tensor):
208
+ # Convert tensor to PIL
209
+ if colorized.dim() == 4:
210
+ colorized = colorized[0]
211
+ if colorized.dim() == 3:
212
+ colorized = colorized.permute(1, 2, 0).cpu()
213
+ if colorized.dtype in (torch.float32, torch.float16):
214
+ colorized = torch.clamp(colorized, 0, 1)
215
+ colorized = (colorized * 255).byte()
216
+ colorized = Image.fromarray(colorized.numpy(), 'RGB')
217
+ else:
218
+ raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
219
+ else:
220
+ raise ValueError(f"Unexpected prediction type: {type(colorized)}")
221
+
222
+ if colorized.mode != "RGB":
223
+ colorized = colorized.convert("RGB")
224
+
225
+ return colorized
226
+
227
+ @app.post("/colorize")
228
+ async def colorize_api(
229
+ file: UploadFile = File(...),
230
+ verified: bool = Depends(verify_request)
231
+ ):
232
+ """
233
+ Upload a black & white image -> returns colorized image.
234
+ Requires Firebase authentication unless DISABLE_AUTH=true
235
+ """
236
+ if learn is None:
237
+ raise HTTPException(status_code=503, detail="Colorization model not loaded")
238
+
239
+ if not file.content_type or not file.content_type.startswith("image/"):
240
+ raise HTTPException(status_code=400, detail="File must be an image")
241
+
242
+ try:
243
+ img_bytes = await file.read()
244
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
245
+
246
+ logger.info("Colorizing image...")
247
+ colorized = colorize_pil(image)
248
+
249
+ output_filename = f"{uuid.uuid4()}.png"
250
+ output_path = RESULT_DIR / output_filename
251
+ colorized.save(output_path, "PNG")
252
+
253
+ logger.info("Colorized image saved: %s", output_filename)
254
+
255
+ # Return the image file
256
+ return FileResponse(
257
+ output_path,
258
+ media_type="image/png",
259
+ filename=f"colorized_{output_filename}"
260
+ )
261
+ except Exception as e:
262
+ logger.error("Error colorizing image: %s", str(e))
263
+ raise HTTPException(status_code=500, detail=f"Error colorizing image: {str(e)}")
264
+
265
+ # ==========================================================
266
+ # Gradio Interface (for Space UI)
267
+ # ==========================================================
268
+ def gradio_colorize(image):
269
+ """Gradio colorization function"""
270
+ if image is None:
271
+ return None
272
+ try:
273
+ if learn is None:
274
+ return None
275
+ return colorize_pil(image)
276
+ except Exception as e:
277
+ logger.error("Gradio colorization error: %s", str(e))
278
+ return None
279
+
280
+ title = "🎨 FastAI GAN Image Colorizer"
281
+ description = "Upload a black & white photo to generate a colorized version using the FastAI GAN model."
282
+
283
+ iface = gr.Interface(
284
+ fn=gradio_colorize,
285
+ inputs=gr.Image(type="pil", label="Upload B&W Image"),
286
+ outputs=gr.Image(type="pil", label="Colorized Image"),
287
+ title=title,
288
+ description=description,
289
+ )
290
+
291
+ # Mount Gradio app at root (this will be the Space UI)
292
+ # Note: This will override the root endpoint, so use /api for API info
293
+ app = gr.mount_gradio_app(app, iface, path="/")
294
+
295
+ # ==========================================================
296
+ # Run Server
297
+ # ==========================================================
298
+ if __name__ == "__main__":
299
+ port = int(os.getenv("PORT", "7860"))
300
+ uvicorn.run(app, host="0.0.0.0", port=port)
301
+
requirements.txt CHANGED
@@ -16,4 +16,5 @@ huggingface-hub>=0.16.0
16
  safetensors>=0.3.0
17
  fastai>=2.7.13
18
  toml>=0.10.2
 
19
 
 
16
  safetensors>=0.3.0
17
  fastai>=2.7.13
18
  toml>=0.10.2
19
+ gradio>=4.0.0
20
 
test.py CHANGED
@@ -33,9 +33,10 @@ def wait_for_model(base_url: str, timeout_seconds: int = 300) -> None:
33
  health_url = f"{base_url}/health"
34
  logging.info("Waiting for model to load at %s", health_url)
35
  last_status = None
 
36
  while time.time() < deadline:
37
  try:
38
- resp = requests.get(health_url, timeout=15)
39
  if resp.ok:
40
  data = resp.json()
41
  last_status = data
@@ -53,7 +54,7 @@ def wait_for_model(base_url: str, timeout_seconds: int = 300) -> None:
53
 
54
  def upload_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict:
55
  url = f"{base_url}/upload"
56
- headers = {}
57
  if auth_bearer:
58
  headers["Authorization"] = f"Bearer {auth_bearer}"
59
  if app_check:
@@ -70,7 +71,7 @@ def upload_image(base_url: str, image_path: str, auth_bearer: Optional[str], app
70
 
71
  def colorize_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict:
72
  url = f"{base_url}/colorize"
73
- headers = {}
74
  if auth_bearer:
75
  headers["Authorization"] = f"Bearer {auth_bearer}"
76
  if app_check:
@@ -87,7 +88,7 @@ def colorize_image(base_url: str, image_path: str, auth_bearer: Optional[str], a
87
 
88
  def download_result(base_url: str, result_id: str, output_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> None:
89
  url = f"{base_url}/download/{result_id}"
90
- headers = {}
91
  if auth_bearer:
92
  headers["Authorization"] = f"Bearer {auth_bearer}"
93
  if app_check:
 
33
  health_url = f"{base_url}/health"
34
  logging.info("Waiting for model to load at %s", health_url)
35
  last_status = None
36
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
37
  while time.time() < deadline:
38
  try:
39
+ resp = requests.get(health_url, headers=headers, timeout=15)
40
  if resp.ok:
41
  data = resp.json()
42
  last_status = data
 
54
 
55
  def upload_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict:
56
  url = f"{base_url}/upload"
57
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
58
  if auth_bearer:
59
  headers["Authorization"] = f"Bearer {auth_bearer}"
60
  if app_check:
 
71
 
72
  def colorize_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict:
73
  url = f"{base_url}/colorize"
74
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
75
  if auth_bearer:
76
  headers["Authorization"] = f"Bearer {auth_bearer}"
77
  if app_check:
 
88
 
89
  def download_result(base_url: str, result_id: str, output_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> None:
90
  url = f"{base_url}/download/{result_id}"
91
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
92
  if auth_bearer:
93
  headers["Authorization"] = f"Bearer {auth_bearer}"
94
  if app_check: