File size: 10,386 Bytes
60c56d7
5e6062c
 
60c56d7
8f6f449
 
 
60c56d7
7471c96
8f6f449
 
6108abf
 
 
 
 
 
 
 
 
60c56d7
 
5e6062c
a2d6cd7
8f6f449
2ae242d
60c56d7
 
 
8f6f449
 
6108abf
8d0a1ae
 
5e6062c
8d0a1ae
6108abf
8d0a1ae
 
 
 
6108abf
8d0a1ae
 
 
60c56d7
5e6062c
8f6f449
 
8d0a1ae
f79a7fe
7471c96
 
5e6062c
 
 
 
 
 
80080e1
 
 
 
 
 
 
a2d6cd7
 
 
 
 
0454a91
a2d6cd7
0454a91
 
 
a2d6cd7
 
0454a91
 
 
 
 
 
 
 
 
a2d6cd7
 
 
0454a91
 
80080e1
a2d6cd7
80080e1
 
 
 
 
 
a2d6cd7
80080e1
 
0454a91
 
 
 
 
80080e1
a2d6cd7
 
80080e1
 
 
0454a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80080e1
a2d6cd7
80080e1
 
 
 
 
5e6062c
 
 
 
 
 
 
 
 
f79a7fe
8d0a1ae
5e6062c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""

from __future__ import annotations

import logging
import os
from typing import Tuple

# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir

import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files

from app.config import settings

logger = logging.getLogger(__name__)


def _ensure_cache_dir() -> str:
    cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
    try:
        os.makedirs(cache_dir, exist_ok=True)
    except Exception as exc:
        logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
    # Ensure all cache env vars point to this directory
    os.environ["HF_HOME"] = cache_dir
    os.environ["TRANSFORMERS_CACHE"] = cache_dir
    os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
    os.environ["HF_HUB_CACHE"] = cache_dir
    os.environ["XDG_CACHE_HOME"] = cache_dir
    return cache_dir


class ColorizeModel:
    """Colorization model using FastAI GAN model."""

    def __init__(self, model_id: str | None = None) -> None:
        self.cache_dir = _ensure_cache_dir()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        os.environ.setdefault("OMP_NUM_THREADS", "1")

        # Use FastAI model ID from config or default
        self.model_id = model_id or settings.MODEL_ID
        self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")

        logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
        try:
            # Try using from_pretrained_fastai first
            try:
                self.learn = from_pretrained_fastai(self.model_id)
                logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai")
            except Exception as e1:
                logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
                # Fallback: manually download and load the model file
                # First, list files in the repository to find the actual model file
                hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
                try:
                    repo_files = list_repo_files(repo_id=self.model_id, token=hf_token)
                    logger.info("Repository files: %s", repo_files)
                    # Look for .pkl files (FastAI) or .pt files (PyTorch)
                    pkl_files = [f for f in repo_files if f.endswith('.pkl')]
                    pt_files = [f for f in repo_files if f.endswith('.pt')]
                    
                    if pkl_files:
                        model_filenames = pkl_files
                        logger.info("Found .pkl files in repository: %s", pkl_files)
                        model_type = "fastai"
                    elif pt_files:
                        model_filenames = pt_files
                        logger.info("Found .pt files in repository: %s", pt_files)
                        model_type = "pytorch"
                    else:
                        # Fallback to common filenames
                        model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
                        model_type = "fastai"  # Default assumption
                except Exception as list_err:
                    logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
                    # Fallback to common filenames
                    model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
                    model_type = "fastai"
                
                model_path = None
                for filename in model_filenames:
                    try:
                        model_path = hf_hub_download(
                            repo_id=self.model_id,
                            filename=filename,
                            cache_dir=self.cache_dir,
                            token=hf_token
                        )
                        logger.info("Found model file: %s", filename)
                        # Determine model type from extension
                        if filename.endswith('.pt'):
                            model_type = "pytorch"
                        elif filename.endswith('.pkl'):
                            model_type = "fastai"
                        break
                    except Exception as dl_err:
                        logger.debug("Failed to download %s: %s", filename, str(dl_err))
                        continue
                
                if model_path and os.path.exists(model_path):
                    if model_type == "pytorch":
                        # Load PyTorch model - this is a GAN generator
                        logger.info("Loading PyTorch model from: %s", model_path)
                        # Note: This requires knowing the model architecture
                        # For now, we'll try to load it and see if it works
                        logger.warning("PyTorch model loading not fully implemented. This model may not work correctly.")
                        raise RuntimeError(
                            f"Repository '{self.model_id}' contains a PyTorch model (generator.pt), "
                            f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. "
                            f"Please use a FastAI-compatible colorization model, or switch to a different model backend."
                        )
                    else:
                        # Load the model using FastAI's load_learner
                        logger.info("Loading FastAI model from: %s", model_path)
                        self.learn = load_learner(model_path)
                        logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
                else:
                    # If no model file found, raise error with more details
                    raise RuntimeError(
                        f"Could not find model file in repository '{self.model_id}'. "
                        f"Tried: {', '.join(model_filenames)}. "
                        f"Original error: {str(e1)}"
                    )
        except Exception as e:
            error_msg = (
                f"Failed to load FastAI model '{self.model_id}'. "
                f"Error: {str(e)}\n"
                f"Please check the MODEL_ID environment variable. "
                f"Default model: 'Hammad712/GAN-Colorization-Model'"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg) from e

    def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
        """
        Colorize a grayscale or color image using FastAI GAN model.
        
        Args:
            image: PIL Image (grayscale or color)
            num_inference_steps: Ignored for FastAI model (kept for API compatibility)
            
        Returns:
            Tuple of (colorized PIL Image, caption string)
        """
        try:
            original_size = image.size
            
            # Ensure image is RGB
            if image.mode != "RGB":
                image = image.convert("RGB")
            
            # FastAI predict expects a PIL Image
            logger.info("Running FastAI GAN colorization...")
            
            # Use the model's predict method
            # FastAI predict for image models typically returns the output image directly
            # or as the first element of a tuple
            prediction = self.learn.predict(image)
            
            # Extract the colorized image from prediction
            # Handle different return types from FastAI
            if isinstance(prediction, (list, tuple)):
                # If tuple/list, first element is usually the prediction
                colorized = prediction[0] if len(prediction) > 0 else image
            else:
                # Direct return
                colorized = prediction
            
            # Ensure we have a PIL Image
            if not isinstance(colorized, Image.Image):
                # If it's a tensor, convert to PIL
                if isinstance(colorized, torch.Tensor):
                    # Handle tensor conversion
                    if colorized.dim() == 4:
                        colorized = colorized[0]  # Remove batch dimension
                    if colorized.dim() == 3:
                        # Convert CHW to HWC and denormalize if needed
                        colorized = colorized.permute(1, 2, 0).cpu()
                        # Clamp values to [0, 1] if float, or [0, 255] if uint8
                        if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
                            colorized = torch.clamp(colorized, 0, 1)
                            colorized = (colorized * 255).byte()
                        colorized = Image.fromarray(colorized.numpy(), 'RGB')
                    else:
                        raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
                else:
                    raise ValueError(f"Unexpected prediction type: {type(colorized)}")
            
            # Ensure RGB mode
            if colorized.mode != "RGB":
                colorized = colorized.convert("RGB")
            
            # Resize back to original size if needed
            if colorized.size != original_size:
                colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
            
            logger.info("Colorization completed successfully")
            return colorized, self.output_caption
            
        except Exception as e:
            logger.error("Error during colorization: %s", str(e))
            raise RuntimeError(f"Colorization failed: {str(e)}") from e