File size: 10,386 Bytes
60c56d7 5e6062c 60c56d7 8f6f449 60c56d7 7471c96 8f6f449 6108abf 60c56d7 5e6062c a2d6cd7 8f6f449 2ae242d 60c56d7 8f6f449 6108abf 8d0a1ae 5e6062c 8d0a1ae 6108abf 8d0a1ae 6108abf 8d0a1ae 60c56d7 5e6062c 8f6f449 8d0a1ae f79a7fe 7471c96 5e6062c 80080e1 a2d6cd7 0454a91 a2d6cd7 0454a91 a2d6cd7 0454a91 a2d6cd7 0454a91 80080e1 a2d6cd7 80080e1 a2d6cd7 80080e1 0454a91 80080e1 a2d6cd7 80080e1 0454a91 80080e1 a2d6cd7 80080e1 5e6062c f79a7fe 8d0a1ae 5e6062c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
"""
Colorize model wrapper using FastAI GAN Colorization Model
Hammad712/GAN-Colorization-Model
"""
from __future__ import annotations
import logging
import os
from typing import Tuple
# Ensure cache directory is set before any HF imports
# (main.py should have set these, but ensure they're set here too)
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
import torch
from PIL import Image
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files
from app.config import settings
logger = logging.getLogger(__name__)
def _ensure_cache_dir() -> str:
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
try:
os.makedirs(cache_dir, exist_ok=True)
except Exception as exc:
logger.warning("Could not create cache directory %s: %s", cache_dir, exc)
# Ensure all cache env vars point to this directory
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["XDG_CACHE_HOME"] = cache_dir
return cache_dir
class ColorizeModel:
"""Colorization model using FastAI GAN model."""
def __init__(self, model_id: str | None = None) -> None:
self.cache_dir = _ensure_cache_dir()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ.setdefault("OMP_NUM_THREADS", "1")
# Use FastAI model ID from config or default
self.model_id = model_id or settings.MODEL_ID
self.output_caption = getattr(settings, "FASTAI_OUTPUT_CAPTION", "Colorized using GAN-Colorization-Model")
logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
try:
# Try using from_pretrained_fastai first
try:
self.learn = from_pretrained_fastai(self.model_id)
logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai")
except Exception as e1:
logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
# Fallback: manually download and load the model file
# First, list files in the repository to find the actual model file
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
try:
repo_files = list_repo_files(repo_id=self.model_id, token=hf_token)
logger.info("Repository files: %s", repo_files)
# Look for .pkl files (FastAI) or .pt files (PyTorch)
pkl_files = [f for f in repo_files if f.endswith('.pkl')]
pt_files = [f for f in repo_files if f.endswith('.pt')]
if pkl_files:
model_filenames = pkl_files
logger.info("Found .pkl files in repository: %s", pkl_files)
model_type = "fastai"
elif pt_files:
model_filenames = pt_files
logger.info("Found .pt files in repository: %s", pt_files)
model_type = "pytorch"
else:
# Fallback to common filenames
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
model_type = "fastai" # Default assumption
except Exception as list_err:
logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
# Fallback to common filenames
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
model_type = "fastai"
model_path = None
for filename in model_filenames:
try:
model_path = hf_hub_download(
repo_id=self.model_id,
filename=filename,
cache_dir=self.cache_dir,
token=hf_token
)
logger.info("Found model file: %s", filename)
# Determine model type from extension
if filename.endswith('.pt'):
model_type = "pytorch"
elif filename.endswith('.pkl'):
model_type = "fastai"
break
except Exception as dl_err:
logger.debug("Failed to download %s: %s", filename, str(dl_err))
continue
if model_path and os.path.exists(model_path):
if model_type == "pytorch":
# Load PyTorch model - this is a GAN generator
logger.info("Loading PyTorch model from: %s", model_path)
# Note: This requires knowing the model architecture
# For now, we'll try to load it and see if it works
logger.warning("PyTorch model loading not fully implemented. This model may not work correctly.")
raise RuntimeError(
f"Repository '{self.model_id}' contains a PyTorch model (generator.pt), "
f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. "
f"Please use a FastAI-compatible colorization model, or switch to a different model backend."
)
else:
# Load the model using FastAI's load_learner
logger.info("Loading FastAI model from: %s", model_path)
self.learn = load_learner(model_path)
logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
else:
# If no model file found, raise error with more details
raise RuntimeError(
f"Could not find model file in repository '{self.model_id}'. "
f"Tried: {', '.join(model_filenames)}. "
f"Original error: {str(e1)}"
)
except Exception as e:
error_msg = (
f"Failed to load FastAI model '{self.model_id}'. "
f"Error: {str(e)}\n"
f"Please check the MODEL_ID environment variable. "
f"Default model: 'Hammad712/GAN-Colorization-Model'"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from e
def colorize(self, image: Image.Image, num_inference_steps: int | None = None) -> Tuple[Image.Image, str]:
"""
Colorize a grayscale or color image using FastAI GAN model.
Args:
image: PIL Image (grayscale or color)
num_inference_steps: Ignored for FastAI model (kept for API compatibility)
Returns:
Tuple of (colorized PIL Image, caption string)
"""
try:
original_size = image.size
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# FastAI predict expects a PIL Image
logger.info("Running FastAI GAN colorization...")
# Use the model's predict method
# FastAI predict for image models typically returns the output image directly
# or as the first element of a tuple
prediction = self.learn.predict(image)
# Extract the colorized image from prediction
# Handle different return types from FastAI
if isinstance(prediction, (list, tuple)):
# If tuple/list, first element is usually the prediction
colorized = prediction[0] if len(prediction) > 0 else image
else:
# Direct return
colorized = prediction
# Ensure we have a PIL Image
if not isinstance(colorized, Image.Image):
# If it's a tensor, convert to PIL
if isinstance(colorized, torch.Tensor):
# Handle tensor conversion
if colorized.dim() == 4:
colorized = colorized[0] # Remove batch dimension
if colorized.dim() == 3:
# Convert CHW to HWC and denormalize if needed
colorized = colorized.permute(1, 2, 0).cpu()
# Clamp values to [0, 1] if float, or [0, 255] if uint8
if colorized.dtype == torch.float32 or colorized.dtype == torch.float16:
colorized = torch.clamp(colorized, 0, 1)
colorized = (colorized * 255).byte()
colorized = Image.fromarray(colorized.numpy(), 'RGB')
else:
raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
else:
raise ValueError(f"Unexpected prediction type: {type(colorized)}")
# Ensure RGB mode
if colorized.mode != "RGB":
colorized = colorized.convert("RGB")
# Resize back to original size if needed
if colorized.size != original_size:
colorized = colorized.resize(original_size, Image.Resampling.LANCZOS)
logger.info("Colorization completed successfully")
return colorized, self.output_caption
except Exception as e:
logger.error("Error during colorization: %s", str(e))
raise RuntimeError(f"Colorization failed: {str(e)}") from e
|