File size: 17,044 Bytes
e4599d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f40725d
 
e4599d1
 
 
 
 
 
0a1a3e1
87f9058
e4599d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a1a3e1
e4599d1
0a1a3e1
e4599d1
 
 
87f9058
0a1a3e1
87f9058
 
 
 
a7ddf76
87f9058
 
 
 
0a1a3e1
 
 
e4599d1
0a1a3e1
e4599d1
0a1a3e1
 
e4599d1
0a1a3e1
e4599d1
 
0a1a3e1
 
 
 
 
 
 
 
 
 
 
 
e4599d1
0a1a3e1
e4599d1
 
 
 
 
0a1a3e1
e4599d1
 
0a1a3e1
 
87f9058
e4599d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87f9058
e4599d1
87f9058
e4599d1
 
 
 
 
 
87f9058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4599d1
 
87f9058
e4599d1
115b125
e4599d1
 
115b125
0a1a3e1
115b125
 
e4599d1
 
 
115b125
 
 
 
 
87f9058
 
 
 
 
 
 
 
 
 
e4599d1
 
f40725d
 
a315115
f40725d
a315115
f40725d
 
 
 
 
 
 
a315115
f40725d
 
 
 
a315115
f40725d
 
a315115
 
 
f40725d
a315115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f40725d
 
a315115
f40725d
 
a315115
 
 
 
 
f40725d
 
 
e4599d1
 
0a1a3e1
 
 
 
 
 
 
 
e4599d1
0a1a3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4599d1
0a1a3e1
 
 
e4599d1
0a1a3e1
f40725d
115b125
f40725d
e4599d1
 
 
87f9058
e4599d1
 
 
 
 
 
 
87f9058
 
 
 
 
 
 
 
 
f40725d
 
 
e4599d1
 
87f9058
 
 
 
 
 
 
 
e4599d1
 
 
 
 
 
 
 
 
87f9058
 
e4599d1
 
 
 
 
 
87f9058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4599d1
 
 
 
 
 
 
87f9058
 
 
 
 
 
 
 
 
 
 
e4599d1
 
 
 
 
 
 
 
 
f40725d
e4599d1
 
 
 
 
115b125
 
e4599d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""
FastAPI application for FastAI GAN Image Colorization
with Firebase Authentication and Gradio UI
"""
import os
# Set environment variables BEFORE any imports
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"

import io
import uuid
import logging
from pathlib import Path
from typing import Optional

from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import firebase_admin
from firebase_admin import credentials, app_check, auth as firebase_auth
from PIL import Image
import torch
import uvicorn
import gradio as gr
import numpy as np
import cv2

# FastAI imports
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai

from app.config import settings
from app.pytorch_colorizer import PyTorchColorizer
from app.database import get_database, log_api_call, log_image_upload, log_colorization, close_connection

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Create writable directories
Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True)
Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True)
Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True)
Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True)

# Initialize FastAPI app
app = FastAPI(
    title="FastAI Image Colorizer API",
    description="Image colorization using FastAI GAN model with Firebase authentication",
    version="1.0.0"
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize Firebase Admin SDK
firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json")
if os.path.exists(firebase_cred_path):
    try:
        cred = credentials.Certificate(firebase_cred_path)
        firebase_admin.initialize_app(cred)
        logger.info("Firebase Admin SDK initialized")
    except Exception as e:
        logger.warning("Failed to initialize Firebase: %s", str(e))
        try:
            firebase_admin.initialize_app()
        except:
            pass
else:
    logger.warning("Firebase credentials file not found. App Check will be disabled.")
    try:
        firebase_admin.initialize_app()
    except:
        pass

# Storage directories
UPLOAD_DIR = Path("/tmp/colorize_uploads")
RESULT_DIR = Path("/tmp/colorize_results")

# Mount static files
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")

# Initialize FastAI model
learn = None
pytorch_colorizer = None
model_load_error: Optional[str] = None
model_type: str = "none"  # "fastai", "pytorch", or "none"

@app.on_event("startup")
async def startup_event():
    """Load FastAI or PyTorch model on startup and initialize MongoDB"""
    global learn, pytorch_colorizer, model_load_error, model_type
    
    # Initialize MongoDB
    try:
        db = get_database()
        if db is not None:
            logger.info("✅ MongoDB initialized successfully!")
    except Exception as e:
        logger.warning("⚠️ MongoDB initialization failed: %s", str(e))
    
    model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
    
    # Try FastAI first
    try:
        logger.info("🔄 Attempting to load FastAI GAN Colorization Model: %s", model_id)
        learn = from_pretrained_fastai(model_id)
        logger.info("✅ FastAI model loaded successfully!")
        model_type = "fastai"
        model_load_error = None
        return
    except Exception as e:
        error_msg = str(e)
        logger.warning("⚠️ FastAI model loading failed: %s. Trying PyTorch fallback...", error_msg)
    
    # Fallback to PyTorch
    try:
        logger.info("🔄 Attempting to load PyTorch GAN Colorization Model: %s", model_id)
        pytorch_colorizer = PyTorchColorizer(model_id=model_id, model_filename="generator.pt")
        logger.info("✅ PyTorch model loaded successfully!")
        model_type = "pytorch"
        model_load_error = None
    except Exception as e:
        error_msg = str(e)
        logger.error("❌ Failed to load both FastAI and PyTorch models: %s", error_msg)
        model_load_error = error_msg
        model_type = "none"
        # Don't raise - allow health check to work

@app.on_event("shutdown")
async def shutdown_event():
    """Cleanup on shutdown"""
    global learn, pytorch_colorizer
    if learn:
        del learn
    if pytorch_colorizer:
        del pytorch_colorizer
    close_connection()
    logger.info("Application shutdown")

def _extract_bearer_token(authorization_header: str | None) -> str | None:
    if not authorization_header:
        return None
    parts = authorization_header.split(" ", 1)
    if len(parts) == 2 and parts[0].lower() == "bearer":
        return parts[1].strip()
    return None

async def verify_request(request: Request):
    """
    Verify Firebase authentication
    Accept either:
      - Firebase Auth id_token via Authorization: Bearer <id_token>
      - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
    """
    # If Firebase is not initialized or auth is explicitly disabled, allow
    if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true":
        return True

    # Try Firebase Auth id_token first if present
    bearer = _extract_bearer_token(request.headers.get("Authorization"))
    if bearer:
        try:
            decoded = firebase_auth.verify_id_token(bearer)
            request.state.user = decoded
            logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
            return True
        except Exception as e:
            logger.warning("Auth token verification failed: %s", str(e))

    # If App Check is enabled, require valid App Check token
    if settings.ENABLE_APP_CHECK:
        app_check_token = request.headers.get("X-Firebase-AppCheck")
        if not app_check_token:
            raise HTTPException(status_code=401, detail="Missing App Check token")
        try:
            app_check_claims = app_check.verify_token(app_check_token)
            logger.info("App Check token verified for: %s", app_check_claims.get("app_id"))
            return True
        except Exception as e:
            logger.warning("App Check token verification failed: %s", str(e))
            raise HTTPException(status_code=401, detail="Invalid App Check token")

    # Neither token required nor provided → allow (App Check disabled)
    return True

@app.get("/api")
async def api_info(request: Request):
    """API info endpoint"""
    response_data = {
        "app": "FastAI Image Colorizer API",
        "version": "1.0.0",
        "health": "/health",
        "colorize": "/colorize",
        "gradio": "/"
    }
    
    # Log API call
    user_id = None
    if hasattr(request, 'state') and hasattr(request.state, 'user'):
        user_id = request.state.user.get("uid")
    
    log_api_call(
        endpoint="/api",
        method="GET",
        status_code=200,
        response_data=response_data,
        user_id=user_id,
        ip_address=request.client.host if request.client else None
    )
    
    return response_data

@app.get("/health")
async def health_check(request: Request):
    """Health check endpoint"""
    model_loaded = (learn is not None) or (pytorch_colorizer is not None)
    response = {
        "status": "healthy",
        "model_loaded": model_loaded,
        "model_type": model_type,
        "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model"),
        "using_fallback": not model_loaded
    }
    if model_load_error:
        response["model_error"] = model_load_error
        response["message"] = "Model failed to load. Using fallback colorization method."
    elif not model_loaded:
        response["message"] = "No model loaded. Using fallback colorization method."
    else:
        response["message"] = f"Model loaded successfully ({model_type})"
    
    # Log API call
    log_api_call(
        endpoint="/health",
        method="GET",
        status_code=200,
        response_data=response,
        ip_address=request.client.host if request.client else None
    )
    
    return response

def simple_colorize_fallback(image: Image.Image) -> Image.Image:
    """
    Enhanced fallback colorization using LAB color space with better color hints
    This provides basic colorization when the model doesn't load
    Note: This is a simple heuristic-based approach and won't match trained models
    """
    # Convert to LAB color space
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    # Convert to numpy array
    img_array = np.array(image)
    original_shape = img_array.shape
    
    # Convert RGB to LAB
    lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
    
    # Split channels
    l, a, b = cv2.split(lab)
    
    # Enhance lightness with CLAHE (Contrast Limited Adaptive Histogram Equalization)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    l_enhanced = clahe.apply(l)
    
    # Add intelligent color hints based on image characteristics
    # Analyze the grayscale image to determine color hints
    l_normalized = l.astype(np.float32) / 255.0
    
    # Create color hints: warmer tones for mid-brightness areas
    # a channel: green-red axis (positive = red, negative = green)
    # b channel: blue-yellow axis (positive = yellow, negative = blue)
    
    # Add warm tones (slight red and yellow bias) based on brightness
    # Darker areas get cooler tones, mid-brightness gets warmer
    brightness_mask = np.clip((l_normalized - 0.3) * 2, 0, 1)  # Emphasize mid-brightness
    
    # Add color hints: warm tones for skin/faces, cooler for shadows
    a_hint = np.clip(a.astype(np.float32) + brightness_mask * 8 + (1 - brightness_mask) * 2, 0, 255).astype(np.uint8)
    b_hint = np.clip(b.astype(np.float32) + brightness_mask * 12 + (1 - brightness_mask) * 3, 0, 255).astype(np.uint8)
    
    # Merge channels and convert back to RGB
    lab_colored = cv2.merge([l_enhanced, a_hint, b_hint])
    colored_rgb = cv2.cvtColor(lab_colored, cv2.COLOR_LAB2RGB)
    
    # Apply slight saturation boost
    hsv = cv2.cvtColor(colored_rgb, cv2.COLOR_RGB2HSV)
    hsv[:, :, 1] = np.clip(hsv[:, :, 1].astype(np.float32) * 1.2, 0, 255).astype(np.uint8)
    colored_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    
    return Image.fromarray(colored_rgb)


def colorize_pil(image: Image.Image) -> Image.Image:
    """Run model prediction and return colorized image"""
    # Try FastAI first
    if learn is not None:
        if image.mode != "RGB":
            image = image.convert("RGB")
        pred = learn.predict(image)
        # Handle different return types from FastAI
        if isinstance(pred, (list, tuple)):
            colorized = pred[0] if len(pred) > 0 else image
        else:
            colorized = pred
        
        # Ensure we have a PIL Image
        if not isinstance(colorized, Image.Image):
            if isinstance(colorized, torch.Tensor):
                # Convert tensor to PIL
                if colorized.dim() == 4:
                    colorized = colorized[0]
                if colorized.dim() == 3:
                    colorized = colorized.permute(1, 2, 0).cpu()
                    if colorized.dtype in (torch.float32, torch.float16):
                        colorized = torch.clamp(colorized, 0, 1)
                        colorized = (colorized * 255).byte()
                    colorized = Image.fromarray(colorized.numpy(), 'RGB')
                else:
                    raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
            else:
                raise ValueError(f"Unexpected prediction type: {type(colorized)}")
        
        if colorized.mode != "RGB":
            colorized = colorized.convert("RGB")
        
        return colorized
    
    # Fallback to PyTorch
    elif pytorch_colorizer is not None:
        return pytorch_colorizer.colorize(image)
    
    else:
        # Final fallback: simple colorization
        logger.info("No model loaded, using enhanced colorization fallback (LAB color space method)")
        return simple_colorize_fallback(image)

@app.post("/colorize")
async def colorize_api(
    request: Request,
    file: UploadFile = File(...),
    verified: bool = Depends(verify_request)
):
    """
    Upload a black & white image -> returns colorized image.
    Requires Firebase authentication unless DISABLE_AUTH=true
    """
    import time
    start_time = time.time()
    
    user_id = None
    if hasattr(request, 'state') and hasattr(request.state, 'user'):
        user_id = request.state.user.get("uid")
    
    ip_address = request.client.host if request.client else None
    
    # Allow fallback colorization even if model isn't loaded
    # if learn is None and pytorch_colorizer is None:
    #     raise HTTPException(status_code=503, detail="Colorization model not loaded")
    
    if not file.content_type or not file.content_type.startswith("image/"):
        log_api_call(
            endpoint="/colorize",
            method="POST",
            status_code=400,
            error="File must be an image",
            user_id=user_id,
            ip_address=ip_address
        )
        raise HTTPException(status_code=400, detail="File must be an image")
    
    try:
        img_bytes = await file.read()
        image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        
        logger.info("Colorizing image...")
        colorized = colorize_pil(image)
        
        processing_time = time.time() - start_time
        
        output_filename = f"{uuid.uuid4()}.png"
        output_path = RESULT_DIR / output_filename
        colorized.save(output_path, "PNG")
        
        logger.info("Colorized image saved: %s", output_filename)
        
        result_id = output_filename.replace(".png", "")
        
        # Log to MongoDB
        log_colorization(
            result_id=result_id,
            model_type=model_type,
            processing_time=processing_time,
            user_id=user_id,
            ip_address=ip_address
        )
        
        log_api_call(
            endpoint="/colorize",
            method="POST",
            status_code=200,
            request_data={"filename": file.filename, "content_type": file.content_type},
            response_data={"result_id": result_id, "filename": output_filename},
            user_id=user_id,
            ip_address=ip_address
        )
        
        # Return the image file
        return FileResponse(
            output_path,
            media_type="image/png",
            filename=f"colorized_{output_filename}"
        )
    except Exception as e:
        error_msg = str(e)
        logger.error("Error colorizing image: %s", error_msg)
        log_api_call(
            endpoint="/colorize",
            method="POST",
            status_code=500,
            error=error_msg,
            user_id=user_id,
            ip_address=ip_address
        )
        raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}")

# ==========================================================
# Gradio Interface (for Space UI)
# ==========================================================
def gradio_colorize(image):
    """Gradio colorization function"""
    if image is None:
        return None
    try:
        # Always try to colorize, even with fallback
        return colorize_pil(image)
    except Exception as e:
        logger.error("Gradio colorization error: %s", str(e))
        return None

title = "🎨 Image Colorizer"
description = "Upload a black & white photo to generate a colorized version. Uses AI model when available, otherwise uses enhanced colorization fallback."

iface = gr.Interface(
    fn=gradio_colorize,
    inputs=gr.Image(type="pil", label="Upload B&W Image"),
    outputs=gr.Image(type="pil", label="Colorized Image"),
    title=title,
    description=description,
)

# Mount Gradio app at root (this will be the Space UI)
# Note: This will override the root endpoint, so use /api for API info
app = gr.mount_gradio_app(app, iface, path="/")

# ==========================================================
# Run Server
# ==========================================================
if __name__ == "__main__":
    port = int(os.getenv("PORT", "7860"))
    uvicorn.run(app, host="0.0.0.0", port=port)