File size: 7,591 Bytes
60c56d7 5e6062c 60c56d7 8f6f449 60c56d7 7471c96 8f6f449 6108abf 60c56d7 5e6062c 80080e1 8f6f449 2ae242d 60c56d7 8f6f449 6108abf 8d0a1ae 5e6062c 8d0a1ae 6108abf 8d0a1ae 6108abf 8d0a1ae 60c56d7 5e6062c 8f6f449 8d0a1ae f79a7fe 7471c96 5e6062c 80080e1 5e6062c f79a7fe 8d0a1ae 5e6062c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""
from __future__ import annotations
import logging
import os
from typing import Tuple
# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai, hf_hub_download
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as exc:
logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
# Ensure all cache env vars point to this directory
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
return cache_dir
class ColorizeModel:
"""Colorization model using FastAI GAN model."""
def __init__(self, model_id: str | None = None) -> None:
self.cache_dir = _ensure_cache_dir()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ.setdefault("OMP_NUM_THREADS", "1")
# Use FastAI model ID from config or default
self.model_id = model_id or settings.MODEL_ID
self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")
logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
try:
# Try using from_pretrained_fastai first
try:
self.learn = from_pretrained_fastai(self.model_id)
logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai")
except Exception as e1:
logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
# Fallback: manually download and load the model file
# Try common FastAI model file names
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
model_path = None
for filename in model_filenames:
try:
model_path = hf_hub_download(
repo_id=self.model_id,
filename=filename,
cache_dir=self.cache_dir,
token=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
)
logger.info("Found model file: %s", filename)
break
except Exception:
continue
if model_path and os.path.exists(model_path):
# Load the model using FastAI's load_learner
logger.info("Loading model from: %s", model_path)
self.learn = load_learner(model_path)
logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
else:
# If no model file found, try listing repository files
raise RuntimeError(
f"Could not find model file in repository '{self.model_id}'. "
f"Tried: {', '.join(model_filenames)}. "
f"Original error: {str(e1)}"
)
except Exception as e:
error_msg = (
f"Failed to load FastAI model '{self.model_id}'. "
f"Error: {str(e)}\n"
f"Please check the MODEL_ID environment variable. "
f"Default model: 'Hammad712/GAN-Colorization-Model'"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
"""
Colorize a grayscale or color image using FastAI GAN model.
Args:
image: PIL Image (grayscale or color)
num_inference_steps: Ignored for FastAI model (kept for API compatibility)
Returns:
Tuple of (colorized PIL Image, caption string)
"""
try:
original_size = image.size
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# FastAI predict expects a PIL Image
logger.info("Running FastAI GAN colorization...")
# Use the model's predict method
# FastAI predict for image models typically returns the output image directly
# or as the first element of a tuple
prediction = self.learn.predict(image)
# Extract the colorized image from prediction
# Handle different return types from FastAI
if isinstance(prediction, (list, tuple)):
# If tuple/list, first element is usually the prediction
colorized = prediction[0] if len(prediction) > 0 else image
else:
# Direct return
colorized = prediction
# Ensure we have a PIL Image
if not isinstance(colorized, Image.Image):
# If it's a tensor, convert to PIL
if isinstance(colorized, torch.Tensor):
# Handle tensor conversion
if colorized.dim() == 4:
colorized = colorized[0] # Remove batch dimension
if colorized.dim() == 3:
# Convert CHW to HWC and denormalize if needed
colorized = colorized.permute(1, 2, 0).cpu()
# Clamp values to [0, 1] if float, or [0, 255] if uint8
if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
colorized = torch.clamp(colorized, 0, 1)
colorized = (colorized * 255).byte()
colorized = Image.fromarray(colorized.numpy(), 'RGB')
else:
raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
else:
raise ValueError(f"Unexpected prediction type: {type(colorized)}")
# Ensure RGB mode
if colorized.mode != "RGB":
colorized = colorized.convert("RGB")
# Resize back to original size if needed
if colorized.size != original_size:
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
logger.info("Colorization completed successfully")
return colorized, self.output_caption
except Exception as e:
logger.error("Error during colorization: %s", str(e))
raise RuntimeError(f"Colorization failed: {str(e)}") from e
|