File size: 5,746 Bytes
60c56d7 5e6062c 60c56d7 8f6f449 60c56d7 7471c96 8f6f449 6108abf 60c56d7 5e6062c 8f6f449 2ae242d 60c56d7 8f6f449 6108abf 8d0a1ae 5e6062c 8d0a1ae 6108abf 8d0a1ae 6108abf 8d0a1ae 60c56d7 5e6062c 8f6f449 8d0a1ae f79a7fe 7471c96 5e6062c f79a7fe 8d0a1ae 5e6062c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""
from __future__ import annotations
import logging
import os
from typing import Tuple
# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as exc:
logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
# Ensure all cache env vars point to this directory
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
return cache_dir
class ColorizeModel:
"""Colorization model using FastAI GAN model."""
def __init__(self, model_id: str | None = None) -> None:
self.cache_dir = _ensure_cache_dir()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ.setdefault("OMP_NUM_THREADS", "1")
# Use FastAI model ID from config or default
self.model_id = model_id or settings.MODEL_ID
self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")
logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
try:
self.learn = from_pretrained_fastai(self.model_id)
logger.info("FastAI GAN Colorization model loaded successfully")
except Exception as e:
error_msg = (
f"Failed to load FastAI model '{self.model_id}'. "
f"Error: {str(e)}\n"
f"Please check the MODEL_ID environment variable. "
f"Default model: 'Hammad712/GAN-Colorization-Model'"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
"""
Colorize a grayscale or color image using FastAI GAN model.
Args:
image: PIL Image (grayscale or color)
num_inference_steps: Ignored for FastAI model (kept for API compatibility)
Returns:
Tuple of (colorized PIL Image, caption string)
"""
try:
original_size = image.size
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# FastAI predict expects a PIL Image
logger.info("Running FastAI GAN colorization...")
# Use the model's predict method
# FastAI predict for image models typically returns the output image directly
# or as the first element of a tuple
prediction = self.learn.predict(image)
# Extract the colorized image from prediction
# Handle different return types from FastAI
if isinstance(prediction, (list, tuple)):
# If tuple/list, first element is usually the prediction
colorized = prediction[0] if len(prediction) > 0 else image
else:
# Direct return
colorized = prediction
# Ensure we have a PIL Image
if not isinstance(colorized, Image.Image):
# If it's a tensor, convert to PIL
if isinstance(colorized, torch.Tensor):
# Handle tensor conversion
if colorized.dim() == 4:
colorized = colorized[0] # Remove batch dimension
if colorized.dim() == 3:
# Convert CHW to HWC and denormalize if needed
colorized = colorized.permute(1, 2, 0).cpu()
# Clamp values to [0, 1] if float, or [0, 255] if uint8
if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
colorized = torch.clamp(colorized, 0, 1)
colorized = (colorized * 255).byte()
colorized = Image.fromarray(colorized.numpy(), 'RGB')
else:
raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
else:
raise ValueError(f"Unexpected prediction type: {type(colorized)}")
# Ensure RGB mode
if colorized.mode != "RGB":
colorized = colorized.convert("RGB")
# Resize back to original size if needed
if colorized.size != original_size:
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
logger.info("Colorization completed successfully")
return colorized, self.output_caption
except Exception as e:
logger.error("Error during colorization: %s", str(e))
raise RuntimeError(f"Colorization failed: {str(e)}") from e
|