LogicGoInfotechSpaces's picture
Add fallback mechanism to manually download and load FastAI model if from_pretrained_fastai fails
80080e1
raw
history blame
7.59 kB
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""
from __future__ import annotations
import logging
import os
from typing import Tuple
# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai, hf_hub_download
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as exc:
logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
# Ensure all cache env vars point to this directory
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
return cache_dir
class ColorizeModel:
"""Colorization model using FastAI GAN model."""
def __init__(self, model_id: str | None = None) -> None:
self.cache_dir = _ensure_cache_dir()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ.setdefault("OMP_NUM_THREADS", "1")
# Use FastAI model ID from config or default
self.model_id = model_id or settings.MODEL_ID
self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")
logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
try:
# Try using from_pretrained_fastai first
try:
self.learn = from_pretrained_fastai(self.model_id)
logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai")
except Exception as e1:
logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
# Fallback: manually download and load the model file
# Try common FastAI model file names
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
model_path = None
for filename in model_filenames:
try:
model_path = hf_hub_download(
repo_id=self.model_id,
filename=filename,
cache_dir=self.cache_dir,
token=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
)
logger.info("Found model file: %s", filename)
break
except Exception:
continue
if model_path and os.path.exists(model_path):
# Load the model using FastAI's load_learner
logger.info("Loading model from: %s", model_path)
self.learn = load_learner(model_path)
logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
else:
# If no model file found, try listing repository files
raise RuntimeError(
f"Could not find model file in repository '{self.model_id}'. "
f"Tried: {', '.join(model_filenames)}. "
f"Original error: {str(e1)}"
)
except Exception as e:
error_msg = (
f"Failed to load FastAI model '{self.model_id}'. "
f"Error: {str(e)}\n"
f"Please check the MODEL_ID environment variable. "
f"Default model: 'Hammad712/GAN-Colorization-Model'"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
"""
Colorize a grayscale or color image using FastAI GAN model.
Args:
image: PIL Image (grayscale or color)
num_inference_steps: Ignored for FastAI model (kept for API compatibility)
Returns:
Tuple of (colorized PIL Image, caption string)
"""
try:
original_size = image.size
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# FastAI predict expects a PIL Image
logger.info("Running FastAI GAN colorization...")
# Use the model's predict method
# FastAI predict for image models typically returns the output image directly
# or as the first element of a tuple
prediction = self.learn.predict(image)
# Extract the colorized image from prediction
# Handle different return types from FastAI
if isinstance(prediction, (list, tuple)):
# If tuple/list, first element is usually the prediction
colorized = prediction[0] if len(prediction) > 0 else image
else:
# Direct return
colorized = prediction
# Ensure we have a PIL Image
if not isinstance(colorized, Image.Image):
# If it's a tensor, convert to PIL
if isinstance(colorized, torch.Tensor):
# Handle tensor conversion
if colorized.dim() == 4:
colorized = colorized[0] # Remove batch dimension
if colorized.dim() == 3:
# Convert CHW to HWC and denormalize if needed
colorized = colorized.permute(1, 2, 0).cpu()
# Clamp values to [0, 1] if float, or [0, 255] if uint8
if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
colorized = torch.clamp(colorized, 0, 1)
colorized = (colorized * 255).byte()
colorized = Image.fromarray(colorized.numpy(), 'RGB')
else:
raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
else:
raise ValueError(f"Unexpected prediction type: {type(colorized)}")
# Ensure RGB mode
if colorized.mode != "RGB":
colorized = colorized.convert("RGB")
# Resize back to original size if needed
if colorized.size != original_size:
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
logger.info("Colorization completed successfully")
return colorized, self.output_caption
except Exception as e:
logger.error("Error during colorization: %s", str(e))
raise RuntimeError(f"Colorization failed: {str(e)}") from e