Fix cache directory permissions and OMP_NUM_THREADS warnings - Set HF cache env vars before imports - Set MPLCONFIGDIR for matplotlib - Fix OMP_NUM_THREADS in Dockerfile
6108abf
| """ | |
| Colorize model wrapper using FastAI GAN Colorization Model | |
| Hammad712/GAN-Colorization-Model | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from typing import Tuple | |
| # Ensure cache directory is set before any HF imports | |
| # (main.py should have set these, but ensure they're set here too) | |
| cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir | |
| os.environ["HF_HUB_CACHE"] = cache_dir | |
| os.environ["XDG_CACHE_HOME"] = cache_dir | |
| import torch | |
| from PIL import Image | |
| from fastai.vision.all import * | |
| from huggingface_hub import from_pretrained_fastai | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| def _ensure_cache_dir() -> str: | |
| cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| except Exception as exc: | |
| logger.warning("Could not create cache directory %s: %s", cache_dir, exc) | |
| # Ensure all cache env vars point to this directory | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir | |
| os.environ["HF_HUB_CACHE"] = cache_dir | |
| os.environ["XDG_CACHE_HOME"] = cache_dir | |
| return cache_dir | |
| class ColorizeModel: | |
| """Colorization model using FastAI GAN model.""" | |
| def __init__(self, model_id: str | None = None) -> None: | |
| self.cache_dir = _ensure_cache_dir() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| # Use FastAI model ID from config or default | |
| self.model_id = model_id or settings.MODEL_ID | |
| self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model") | |
| logger.info("Loading FastAI GAN Colorization model: %s", self.model_id) | |
| try: | |
| self.learn = from_pretrained_fastai(self.model_id) | |
| logger.info("FastAI GAN Colorization model loaded successfully") | |
| except Exception as e: | |
| error_msg = ( | |
| f"Failed to load FastAI model '{self.model_id}'. " | |
| f"Error: {str(e)}\n" | |
| f"Please check the MODEL_ID environment variable. " | |
| f"Default model: 'Hammad712/GAN-Colorization-Model'" | |
| ) | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]: | |
| """ | |
| Colorize a grayscale or color image using FastAI GAN model. | |
| Args: | |
| image: PIL Image (grayscale or color) | |
| num_inference_steps: Ignored for FastAI model (kept for API compatibility) | |
| Returns: | |
| Tuple of (colorized PIL Image, caption string) | |
| """ | |
| try: | |
| original_size = image.size | |
| # Ensure image is RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # FastAI predict expects a PIL Image | |
| logger.info("Running FastAI GAN colorization...") | |
| # Use the model's predict method | |
| # FastAI predict for image models typically returns the output image directly | |
| # or as the first element of a tuple | |
| prediction = self.learn.predict(image) | |
| # Extract the colorized image from prediction | |
| # Handle different return types from FastAI | |
| if isinstance(prediction, (list, tuple)): | |
| # If tuple/list, first element is usually the prediction | |
| colorized = prediction[0] if len(prediction) > 0 else image | |
| else: | |
| # Direct return | |
| colorized = prediction | |
| # Ensure we have a PIL Image | |
| if not isinstance(colorized, Image.Image): | |
| # If it's a tensor, convert to PIL | |
| if isinstance(colorized, torch.Tensor): | |
| # Handle tensor conversion | |
| if colorized.dim() == 4: | |
| colorized = colorized[0] # Remove batch dimension | |
| if colorized.dim() == 3: | |
| # Convert CHW to HWC and denormalize if needed | |
| colorized = colorized.permute(1, 2, 0).cpu() | |
| # Clamp values to [0, 1] if float, or [0, 255] if uint8 | |
| if colorized.dtype == torch.float32 or colorized.dtype == torch.float16: | |
| colorized = torch.clamp(colorized, 0, 1) | |
| colorized = (colorized * 255).byte() | |
| colorized = Image.fromarray(colorized.numpy(), 'RGB') | |
| else: | |
| raise ValueError(f"Unexpected tensor shape: {colorized.shape}") | |
| else: | |
| raise ValueError(f"Unexpected prediction type: {type(colorized)}") | |
| # Ensure RGB mode | |
| if colorized.mode != "RGB": | |
| colorized = colorized.convert("RGB") | |
| # Resize back to original size if needed | |
| if colorized.size != original_size: | |
| colorized = colorized.resize(original_size, Image.Resampling.LANCZOS) | |
| logger.info("Colorization completed successfully") | |
| return colorized, self.output_caption | |
| except Exception as e: | |
| logger.error("Error during colorization: %s", str(e)) | |
| raise RuntimeError(f"Colorization failed: {str(e)}") from e | |