Commit
·
6960c4e
1
Parent(s):
0517019
Add fallback logic to handle PyTorch models and find actual model files - Add manual download fallback when from_pretrained_fastai fails - List repository files to find .pkl or .pt files - Provide clear error message for PyTorch models
Browse files- app/main.py +86 -5
app/main.py
CHANGED
|
@@ -31,7 +31,7 @@ import gradio as gr
|
|
| 31 |
|
| 32 |
# FastAI imports
|
| 33 |
from fastai.vision.all import *
|
| 34 |
-
from huggingface_hub import from_pretrained_fastai
|
| 35 |
|
| 36 |
from app.config import settings
|
| 37 |
|
|
@@ -103,13 +103,94 @@ async def startup_event():
|
|
| 103 |
try:
|
| 104 |
model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
|
| 105 |
logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
except Exception as e:
|
| 110 |
error_msg = str(e)
|
|
|
|
|
|
|
| 111 |
logger.error("❌ Failed to load model: %s", error_msg)
|
| 112 |
-
model_load_error = error_msg
|
| 113 |
# Don't raise - allow health check to work
|
| 114 |
|
| 115 |
@app.on_event("shutdown")
|
|
|
|
| 31 |
|
| 32 |
# FastAI imports
|
| 33 |
from fastai.vision.all import *
|
| 34 |
+
from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files
|
| 35 |
|
| 36 |
from app.config import settings
|
| 37 |
|
|
|
|
| 103 |
try:
|
| 104 |
model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
|
| 105 |
logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
|
| 106 |
+
|
| 107 |
+
# Try using from_pretrained_fastai first
|
| 108 |
+
try:
|
| 109 |
+
learn = from_pretrained_fastai(model_id)
|
| 110 |
+
logger.info("✅ Model loaded successfully via from_pretrained_fastai!")
|
| 111 |
+
model_load_error = None
|
| 112 |
+
except Exception as e1:
|
| 113 |
+
logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
|
| 114 |
+
# Fallback: manually download and load the model file
|
| 115 |
+
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 116 |
+
|
| 117 |
+
# List repository files to find the actual model file
|
| 118 |
+
model_filenames = []
|
| 119 |
+
model_type = "fastai"
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
repo_files = list_repo_files(repo_id=model_id, token=hf_token)
|
| 123 |
+
logger.info("Repository files: %s", repo_files)
|
| 124 |
+
pkl_files = [f for f in repo_files if f.endswith('.pkl')]
|
| 125 |
+
pt_files = [f for f in repo_files if f.endswith('.pt')]
|
| 126 |
+
|
| 127 |
+
if pkl_files:
|
| 128 |
+
model_filenames = pkl_files
|
| 129 |
+
logger.info("Found .pkl files in repository: %s", pkl_files)
|
| 130 |
+
model_type = "fastai"
|
| 131 |
+
elif pt_files:
|
| 132 |
+
model_filenames = pt_files
|
| 133 |
+
logger.info("Found .pt files in repository: %s", pt_files)
|
| 134 |
+
model_type = "pytorch"
|
| 135 |
+
else:
|
| 136 |
+
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
|
| 137 |
+
model_type = "fastai"
|
| 138 |
+
except Exception as list_err:
|
| 139 |
+
logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
|
| 140 |
+
model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
|
| 141 |
+
model_type = "fastai"
|
| 142 |
+
|
| 143 |
+
# Try to download and load the model file
|
| 144 |
+
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
|
| 145 |
+
model_path = None
|
| 146 |
+
for filename in model_filenames:
|
| 147 |
+
try:
|
| 148 |
+
model_path = hf_hub_download(
|
| 149 |
+
repo_id=model_id,
|
| 150 |
+
filename=filename,
|
| 151 |
+
cache_dir=cache_dir,
|
| 152 |
+
token=hf_token
|
| 153 |
+
)
|
| 154 |
+
logger.info("Found model file: %s", filename)
|
| 155 |
+
if filename.endswith('.pt'):
|
| 156 |
+
model_type = "pytorch"
|
| 157 |
+
elif filename.endswith('.pkl'):
|
| 158 |
+
model_type = "fastai"
|
| 159 |
+
break
|
| 160 |
+
except Exception as dl_err:
|
| 161 |
+
logger.debug("Failed to download %s: %s", filename, str(dl_err))
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
if model_path and os.path.exists(model_path):
|
| 165 |
+
if model_type == "pytorch":
|
| 166 |
+
error_msg = (
|
| 167 |
+
f"Repository '{model_id}' contains a PyTorch model (.pt file), "
|
| 168 |
+
f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. "
|
| 169 |
+
f"Please use a FastAI-compatible colorization model, or switch to a different model backend."
|
| 170 |
+
)
|
| 171 |
+
logger.error(error_msg)
|
| 172 |
+
model_load_error = error_msg
|
| 173 |
+
raise RuntimeError(error_msg)
|
| 174 |
+
else:
|
| 175 |
+
logger.info("Loading FastAI model from: %s", model_path)
|
| 176 |
+
learn = load_learner(model_path)
|
| 177 |
+
logger.info("✅ Model loaded successfully from %s", model_path)
|
| 178 |
+
model_load_error = None
|
| 179 |
+
else:
|
| 180 |
+
error_msg = (
|
| 181 |
+
f"Could not find model file in repository '{model_id}'. "
|
| 182 |
+
f"Tried: {', '.join(model_filenames)}. "
|
| 183 |
+
f"Original error: {str(e1)}"
|
| 184 |
+
)
|
| 185 |
+
logger.error(error_msg)
|
| 186 |
+
model_load_error = error_msg
|
| 187 |
+
raise RuntimeError(error_msg)
|
| 188 |
+
|
| 189 |
except Exception as e:
|
| 190 |
error_msg = str(e)
|
| 191 |
+
if not model_load_error:
|
| 192 |
+
model_load_error = error_msg
|
| 193 |
logger.error("❌ Failed to load model: %s", error_msg)
|
|
|
|
| 194 |
# Don't raise - allow health check to work
|
| 195 |
|
| 196 |
@app.on_event("shutdown")
|