File size: 12,320 Bytes
60c56d7
 
 
 
7471c96
60c56d7
 
 
7471c96
60c56d7
 
 
2ae242d
60c56d7
 
 
 
 
 
2ae242d
60c56d7
 
 
 
 
 
2ae242d
 
60c56d7
 
 
 
2f136a8
 
7471c96
 
67d9b2f
 
 
 
2f136a8
 
 
 
 
 
 
f1c1f42
 
2f136a8
 
67d9b2f
 
 
 
f1c1f42
 
 
2f136a8
 
67d9b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7471c96
67d9b2f
 
 
 
 
7471c96
 
60c56d7
 
7471c96
 
 
 
 
 
60c56d7
7471c96
 
 
60c56d7
7471c96
 
60c56d7
7471c96
 
60c56d7
7471c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c56d7
7471c96
 
 
 
 
 
 
 
 
60c56d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62db9e6
60c56d7
 
 
 
 
62db9e6
60c56d7
 
 
 
 
62db9e6
 
 
 
 
60c56d7
 
 
 
 
 
 
 
62db9e6
 
 
60c56d7
 
 
 
 
 
 
 
62db9e6
60c56d7
 
 
 
 
 
 
 
 
 
 
 
 
 
bfba916
 
 
eb42092
 
60c56d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""
ColorizeNet model wrapper for image colorization
"""
import logging
import os
import torch
import numpy as np
from PIL import Image
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
from transformers import pipeline
from huggingface_hub import hf_hub_download
from app.config import settings

logger = logging.getLogger(__name__)

class ColorizeModel:
    """Wrapper for ColorizeNet model"""
    
    def __init__(self, model_id: str | None = None):
        """
        Initialize the ColorizeNet model
        
        Args:
            model_id: Hugging Face model ID for ColorizeNet
        """
        if model_id is None:
            model_id = settings.MODEL_ID
        self.model_id = model_id
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info("Using device: %s", self.device)
        self.dtype = torch.float16 if self.device == "cuda" else torch.float32
        # Check for Hugging Face token (try both environment variable names)
        self.hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") or None

        # Configure writable cache to avoid permission issues on Spaces
        # Prefer DATA_DIR if available, otherwise fallback to /tmp
        data_dir = os.getenv("DATA_DIR")
        if not data_dir:
            data_dir = "/tmp"
        hf_cache_dir = os.path.join(data_dir, "hf_cache")
        
        # Set cache environment variables
        os.environ["HF_HOME"] = hf_cache_dir
        os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
        os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
        
        try:
            os.makedirs(hf_cache_dir, exist_ok=True)
            logger.info("HF cache directory: %s", hf_cache_dir)
        except Exception as e:
            # Fallback to /tmp/hf_cache if DATA_DIR was set but not writable
            tmp_cache_dir = os.path.join("/tmp", "hf_cache")
            logger.warning("Failed to create cache in %s: %s, trying %s", data_dir, str(e), tmp_cache_dir)
            hf_cache_dir = tmp_cache_dir
            os.environ["HF_HOME"] = hf_cache_dir
            os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
            os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
            try:
                os.makedirs(hf_cache_dir, exist_ok=True)
                logger.info("HF cache directory (tmp): %s", hf_cache_dir)
            except Exception as e_tmp:
                # Final fallback to user home (local dev)
                logger.warning("Failed to create cache in /tmp: %s, trying user home", str(e_tmp))
                default_home_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
                hf_cache_dir = default_home_cache
                os.environ["HF_HOME"] = hf_cache_dir
                os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
                os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
                try:
                    os.makedirs(hf_cache_dir, exist_ok=True)
                    logger.info("HF cache directory (home): %s", hf_cache_dir)
                except Exception as e2:
                    logger.error("Failed to create cache directory: %s", str(e2))
                    raise RuntimeError(f"Cannot create Hugging Face cache directory: {str(e2)}")

        else:
            # Ensure environment variables reflect the final cache dir
            os.environ["HF_HOME"] = hf_cache_dir
            os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_dir
            os.environ["TRANSFORMERS_CACHE"] = hf_cache_dir
        # Avoid libgomp warning by setting a valid integer
        os.environ.setdefault("OMP_NUM_THREADS", "1")
        
        try:
            # Decide whether to use ControlNet based on model_id
            wants_controlnet = "control" in self.model_id.lower()

            if wants_controlnet:
                # Try loading as ControlNet with Stable Diffusion
                logger.info("Attempting to load model as ControlNet: %s", self.model_id)
                try:
                    # Load ControlNet model
                    self.controlnet = ControlNetModel.from_pretrained(
                        self.model_id,
                        torch_dtype=self.dtype,
                        token=self.hf_token,
                        cache_dir=hf_cache_dir
                    )
                    
                    # Try SDXL first, fallback to SD 1.5
                    try:
                        self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
                            "stabilityai/stable-diffusion-xl-base-1.0",
                            controlnet=self.controlnet,
                            torch_dtype=self.dtype,
                            safety_checker=None,
                            requires_safety_checker=False,
                            token=self.hf_token,
                            cache_dir=hf_cache_dir
                        )
                        logger.info("Loaded with SDXL base model")
                    except Exception:
                        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
                            "runwayml/stable-diffusion-v1-5",
                            controlnet=self.controlnet,
                            torch_dtype=self.dtype,
                            safety_checker=None,
                            requires_safety_checker=False,
                            token=self.hf_token,
                            cache_dir=hf_cache_dir
                        )
                        logger.info("Loaded with SD 1.5 base model")
                    
                    self.pipe.to(self.device)
                    
                    # Enable memory efficient attention if available
                    if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
                        try:
                            self.pipe.enable_xformers_memory_efficient_attention()
                            logger.info("XFormers memory efficient attention enabled")
                        except Exception as e:
                            logger.warning("Could not enable XFormers: %s", str(e))
                    
                    logger.info("ColorizeNet model loaded successfully as ControlNet")
                    self.model_type = "controlnet"
                except Exception as e:
                    logger.warning("Failed to load as ControlNet: %s", str(e))
                    wants_controlnet = False  # fall through to pipeline

            if not wants_controlnet:
                # Load as image-to-image pipeline
                logger.info("Trying to load as image-to-image pipeline...")
                self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
                    self.model_id,
                    torch_dtype=self.dtype,
                    safety_checker=None,
                    requires_safety_checker=False,
                    use_safetensors=True,
                    cache_dir=hf_cache_dir,
                    token=self.hf_token
                ).to(self.device)
                logger.info("ColorizeNet model loaded using image-to-image pipeline")
                self.model_type = "pipeline"
                
        except Exception as e:
            logger.error("Failed to load ColorizeNet model: %s", str(e))
            raise RuntimeError(f"Could not load ColorizeNet model: {str(e)}")
    
    def preprocess_image(self, image: Image.Image) -> Image.Image:
        """
        Preprocess image for colorization
        
        Args:
            image: PIL Image
            
        Returns:
            Preprocessed PIL Image
        """
        # Convert to grayscale if needed
        if image.mode != "L":
            # Convert to grayscale
            image = image.convert("L")
        
        # Convert back to RGB (grayscale image with 3 channels)
        image = image.convert("RGB")
        
        # Resize to standard size (512x512 for SD models)
        image = image.resize((512, 512), Image.Resampling.LANCZOS)
        
        return image
    
    def colorize(self, image: Image.Image, num_inference_steps: int = None) -> Image.Image:
        """
        Colorize a grayscale image
        
        Args:
            image: PIL Image (grayscale or color)
            num_inference_steps: Number of inference steps (auto-adjusted for CPU/GPU)
            
        Returns:
            Colorized PIL Image
        """
        try:
            # Optimize inference steps based on device
            if num_inference_steps is None:
                # Use fewer steps on CPU for faster processing
                num_inference_steps = 8 if self.device == "cpu" else 20
            
            # Preprocess image
            control_image = self.preprocess_image(image)
            original_size = image.size
            
            # Prepare prompt for colorization
            prompt = "colorize this black and white image, high quality, detailed, vibrant colors, natural colors"
            negative_prompt = "black and white, grayscale, monochrome, low quality, blurry, desaturated"
            
            # Adjust guidance scale for CPU (lower = faster)
            guidance_scale = 5.0 if self.device == "cpu" else 7.5
            
            # Generate colorized image based on model type
            if self.model_type == "controlnet":
                # Use ControlNet pipeline
                result = self.pipe(
                    prompt=prompt,
                    image=control_image,
                    negative_prompt=negative_prompt,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    controlnet_conditioning_scale=1.0,
                    generator=torch.Generator(device=self.device).manual_seed(42)
                )
                
                if isinstance(result, dict) and "images" in result:
                    colorized = result["images"][0]
                elif isinstance(result, list) and len(result) > 0:
                    colorized = result[0]
                else:
                    colorized = result
            else:
                # Use pipeline directly
                result = self.pipe(
                    prompt=prompt,
                    image=control_image,
                    negative_prompt=negative_prompt,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    strength=1.0
                )
                
                if isinstance(result, dict) and "images" in result:
                    colorized = result["images"][0]
                elif isinstance(result, list) and len(result) > 0:
                    colorized = result[0]
                else:
                    colorized = result
            
            # Ensure we have a PIL Image
            if not isinstance(colorized, Image.Image):
                if isinstance(colorized, np.ndarray):
                    # Handle numpy array
                    if colorized.dtype != np.uint8:
                        colorized = (colorized * 255).astype(np.uint8)
                    if len(colorized.shape) == 3 and colorized.shape[2] == 3:
                        colorized = Image.fromarray(colorized, 'RGB')
                    else:
                        colorized = Image.fromarray(colorized)
                elif torch.is_tensor(colorized):
                    # Handle torch tensor
                    colorized = colorized.cpu().permute(1, 2, 0).numpy()
                    colorized = (colorized * 255).astype(np.uint8)
                    colorized = Image.fromarray(colorized, 'RGB')
                else:
                    raise ValueError(f"Unexpected output type: {type(colorized)}")
            
            # Resize back to original size
            if original_size != (512, 512):
                colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
            
            return colorized
            
        except Exception as e:
            logger.error("Error during colorization: %s", str(e))
            raise