File size: 5,746 Bytes
60c56d7
5e6062c
 
60c56d7
8f6f449
 
 
60c56d7
7471c96
8f6f449
 
6108abf
 
 
 
 
 
 
 
 
60c56d7
 
5e6062c
 
8f6f449
2ae242d
60c56d7
 
 
8f6f449
 
6108abf
8d0a1ae
 
5e6062c
8d0a1ae
6108abf
8d0a1ae
 
 
 
6108abf
8d0a1ae
 
 
60c56d7
5e6062c
8f6f449
 
8d0a1ae
f79a7fe
7471c96
 
5e6062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79a7fe
8d0a1ae
5e6062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""

from __future__ import annotations

import logging
import os
from typing import Tuple

# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir

import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai

from app.config import settings

logger = logging.getLogger(__name__)


def _ensure_cache_dir() -> str:
    cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
    try:
        os.makedirs(cache_dir, exist_ok=True)
    except Exception as exc:
        logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
    # Ensure all cache env vars point to this directory
    os.environ["HF_HOME"] = cache_dir
    os.environ["TRANSFORMERS_CACHE"] = cache_dir
    os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
    os.environ["HF_HUB_CACHE"] = cache_dir
    os.environ["XDG_CACHE_HOME"] = cache_dir
    return cache_dir


class ColorizeModel:
    """Colorization model using FastAI GAN model."""

    def __init__(self, model_id: str | None = None) -> None:
        self.cache_dir = _ensure_cache_dir()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        os.environ.setdefault("OMP_NUM_THREADS", "1")

        # Use FastAI model ID from config or default
        self.model_id = model_id or settings.MODEL_ID
        self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")

        logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
        try:
            self.learn = from_pretrained_fastai(self.model_id)
            logger.info("FastAI GAN Colorization model loaded successfully")
        except Exception as e:
            error_msg = (
                f"Failed to load FastAI model '{self.model_id}'. "
                f"Error: {str(e)}\n"
                f"Please check the MODEL_ID environment variable. "
                f"Default model: 'Hammad712/GAN-Colorization-Model'"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg) from e

    def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
        """
        Colorize a grayscale or color image using FastAI GAN model.
        
        Args:
            image: PIL Image (grayscale or color)
            num_inference_steps: Ignored for FastAI model (kept for API compatibility)
            
        Returns:
            Tuple of (colorized PIL Image, caption string)
        """
        try:
            original_size = image.size
            
            # Ensure image is RGB
            if image.mode != "RGB":
                image = image.convert("RGB")
            
            # FastAI predict expects a PIL Image
            logger.info("Running FastAI GAN colorization...")
            
            # Use the model's predict method
            # FastAI predict for image models typically returns the output image directly
            # or as the first element of a tuple
            prediction = self.learn.predict(image)
            
            # Extract the colorized image from prediction
            # Handle different return types from FastAI
            if isinstance(prediction, (list, tuple)):
                # If tuple/list, first element is usually the prediction
                colorized = prediction[0] if len(prediction) > 0 else image
            else:
                # Direct return
                colorized = prediction
            
            # Ensure we have a PIL Image
            if not isinstance(colorized, Image.Image):
                # If it's a tensor, convert to PIL
                if isinstance(colorized, torch.Tensor):
                    # Handle tensor conversion
                    if colorized.dim() == 4:
                        colorized = colorized[0]  # Remove batch dimension
                    if colorized.dim() == 3:
                        # Convert CHW to HWC and denormalize if needed
                        colorized = colorized.permute(1, 2, 0).cpu()
                        # Clamp values to [0, 1] if float, or [0, 255] if uint8
                        if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
                            colorized = torch.clamp(colorized, 0, 1)
                            colorized = (colorized * 255).byte()
                        colorized = Image.fromarray(colorized.numpy(), 'RGB')
                    else:
                        raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
                else:
                    raise ValueError(f"Unexpected prediction type: {type(colorized)}")
            
            # Ensure RGB mode
            if colorized.mode != "RGB":
                colorized = colorized.convert("RGB")
            
            # Resize back to original size if needed
            if colorized.size != original_size:
                colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
            
            logger.info("Colorization completed successfully")
            return colorized, self.output_caption
            
        except Exception as e:
            logger.error("Error during colorization: %s", str(e))
            raise RuntimeError(f"Colorization failed: {str(e)}") from e