File size: 7,591 Bytes
60c56d7
5e6062c
 
60c56d7
8f6f449
 
 
60c56d7
7471c96
8f6f449
 
6108abf
 
 
 
 
 
 
 
 
60c56d7
 
5e6062c
80080e1
8f6f449
2ae242d
60c56d7
 
 
8f6f449
 
6108abf
8d0a1ae
 
5e6062c
8d0a1ae
6108abf
8d0a1ae
 
 
 
6108abf
8d0a1ae
 
 
60c56d7
5e6062c
8f6f449
 
8d0a1ae
f79a7fe
7471c96
 
5e6062c
 
 
 
 
 
80080e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e6062c
 
 
 
 
 
 
 
 
f79a7fe
8d0a1ae
5e6062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""

from __future__ import annotations

import logging
import os
from typing import Tuple

# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir

import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai, hf_hub_download

from app.config import settings

logger = logging.getLogger(__name__)


def _ensure_cache_dir() -> str:
    cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
    try:
        os.makedirs(cache_dir, exist_ok=True)
    except Exception as exc:
        logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
    # Ensure all cache env vars point to this directory
    os.environ["HF_HOME"] = cache_dir
    os.environ["TRANSFORMERS_CACHE"] = cache_dir
    os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
    os.environ["HF_HUB_CACHE"] = cache_dir
    os.environ["XDG_CACHE_HOME"] = cache_dir
    return cache_dir


class ColorizeModel:
    """Colorization model using FastAI GAN model."""

    def __init__(self, model_id: str | None = None) -> None:
        self.cache_dir = _ensure_cache_dir()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        os.environ.setdefault("OMP_NUM_THREADS", "1")

        # Use FastAI model ID from config or default
        self.model_id = model_id or settings.MODEL_ID
        self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")

        logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
        try:
            # Try using from_pretrained_fastai first
            try:
                self.learn = from_pretrained_fastai(self.model_id)
                logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai")
            except Exception as e1:
                logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
                # Fallback: manually download and load the model file
                # Try common FastAI model file names
                model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
                model_path = None
                
                for filename in model_filenames:
                    try:
                        model_path = hf_hub_download(
                            repo_id=self.model_id,
                            filename=filename,
                            cache_dir=self.cache_dir,
                            token=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
                        )
                        logger.info("Found model file: %s", filename)
                        break
                    except Exception:
                        continue
                
                if model_path and os.path.exists(model_path):
                    # Load the model using FastAI's load_learner
                    logger.info("Loading model from: %s", model_path)
                    self.learn = load_learner(model_path)
                    logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
                else:
                    # If no model file found, try listing repository files
                    raise RuntimeError(
                        f"Could not find model file in repository '{self.model_id}'. "
                        f"Tried: {', '.join(model_filenames)}. "
                        f"Original error: {str(e1)}"
                    )
        except Exception as e:
            error_msg = (
                f"Failed to load FastAI model '{self.model_id}'. "
                f"Error: {str(e)}\n"
                f"Please check the MODEL_ID environment variable. "
                f"Default model: 'Hammad712/GAN-Colorization-Model'"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg) from e

    def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
        """
        Colorize a grayscale or color image using FastAI GAN model.
        
        Args:
            image: PIL Image (grayscale or color)
            num_inference_steps: Ignored for FastAI model (kept for API compatibility)
            
        Returns:
            Tuple of (colorized PIL Image, caption string)
        """
        try:
            original_size = image.size
            
            # Ensure image is RGB
            if image.mode != "RGB":
                image = image.convert("RGB")
            
            # FastAI predict expects a PIL Image
            logger.info("Running FastAI GAN colorization...")
            
            # Use the model's predict method
            # FastAI predict for image models typically returns the output image directly
            # or as the first element of a tuple
            prediction = self.learn.predict(image)
            
            # Extract the colorized image from prediction
            # Handle different return types from FastAI
            if isinstance(prediction, (list, tuple)):
                # If tuple/list, first element is usually the prediction
                colorized = prediction[0] if len(prediction) > 0 else image
            else:
                # Direct return
                colorized = prediction
            
            # Ensure we have a PIL Image
            if not isinstance(colorized, Image.Image):
                # If it's a tensor, convert to PIL
                if isinstance(colorized, torch.Tensor):
                    # Handle tensor conversion
                    if colorized.dim() == 4:
                        colorized = colorized[0]  # Remove batch dimension
                    if colorized.dim() == 3:
                        # Convert CHW to HWC and denormalize if needed
                        colorized = colorized.permute(1, 2, 0).cpu()
                        # Clamp values to [0, 1] if float, or [0, 255] if uint8
                        if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
                            colorized = torch.clamp(colorized, 0, 1)
                            colorized = (colorized * 255).byte()
                        colorized = Image.fromarray(colorized.numpy(), 'RGB')
                    else:
                        raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
                else:
                    raise ValueError(f"Unexpected prediction type: {type(colorized)}")
            
            # Ensure RGB mode
            if colorized.mode != "RGB":
                colorized = colorized.convert("RGB")
            
            # Resize back to original size if needed
            if colorized.size != original_size:
                colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
            
            logger.info("Colorization completed successfully")
            return colorized, self.output_caption
            
        except Exception as e:
            logger.error("Error during colorization: %s", str(e))
            raise RuntimeError(f"Colorization failed: {str(e)}") from e