LogicGoInfotechSpaces commited on
Commit
6960c4e
·
1 Parent(s): 0517019

Add fallback logic to handle PyTorch models and find actual model files - Add manual download fallback when from_pretrained_fastai fails - List repository files to find .pkl or .pt files - Provide clear error message for PyTorch models

Browse files
Files changed (1) hide show
  1. app/main.py +86 -5
app/main.py CHANGED
@@ -31,7 +31,7 @@ import gradio as gr
31
 
32
  # FastAI imports
33
  from fastai.vision.all import *
34
- from huggingface_hub import from_pretrained_fastai
35
 
36
  from app.config import settings
37
 
@@ -103,13 +103,94 @@ async def startup_event():
103
  try:
104
  model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
105
  logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
106
- learn = from_pretrained_fastai(model_id)
107
- logger.info("✅ Model loaded successfully!")
108
- model_load_error = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  except Exception as e:
110
  error_msg = str(e)
 
 
111
  logger.error("❌ Failed to load model: %s", error_msg)
112
- model_load_error = error_msg
113
  # Don't raise - allow health check to work
114
 
115
  @app.on_event("shutdown")
 
31
 
32
  # FastAI imports
33
  from fastai.vision.all import *
34
+ from huggingface_hub import from_pretrained_fastai, hf_hub_download, list_repo_files
35
 
36
  from app.config import settings
37
 
 
103
  try:
104
  model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
105
  logger.info("🔄 Loading FastAI GAN Colorization Model: %s", model_id)
106
+
107
+ # Try using from_pretrained_fastai first
108
+ try:
109
+ learn = from_pretrained_fastai(model_id)
110
+ logger.info("✅ Model loaded successfully via from_pretrained_fastai!")
111
+ model_load_error = None
112
+ except Exception as e1:
113
+ logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
114
+ # Fallback: manually download and load the model file
115
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
116
+
117
+ # List repository files to find the actual model file
118
+ model_filenames = []
119
+ model_type = "fastai"
120
+
121
+ try:
122
+ repo_files = list_repo_files(repo_id=model_id, token=hf_token)
123
+ logger.info("Repository files: %s", repo_files)
124
+ pkl_files = [f for f in repo_files if f.endswith('.pkl')]
125
+ pt_files = [f for f in repo_files if f.endswith('.pt')]
126
+
127
+ if pkl_files:
128
+ model_filenames = pkl_files
129
+ logger.info("Found .pkl files in repository: %s", pkl_files)
130
+ model_type = "fastai"
131
+ elif pt_files:
132
+ model_filenames = pt_files
133
+ logger.info("Found .pt files in repository: %s", pt_files)
134
+ model_type = "pytorch"
135
+ else:
136
+ model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
137
+ model_type = "fastai"
138
+ except Exception as list_err:
139
+ logger.warning("Could not list repository files: %s. Trying common filenames...", str(list_err))
140
+ model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl", "generator.pt"]
141
+ model_type = "fastai"
142
+
143
+ # Try to download and load the model file
144
+ cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
145
+ model_path = None
146
+ for filename in model_filenames:
147
+ try:
148
+ model_path = hf_hub_download(
149
+ repo_id=model_id,
150
+ filename=filename,
151
+ cache_dir=cache_dir,
152
+ token=hf_token
153
+ )
154
+ logger.info("Found model file: %s", filename)
155
+ if filename.endswith('.pt'):
156
+ model_type = "pytorch"
157
+ elif filename.endswith('.pkl'):
158
+ model_type = "fastai"
159
+ break
160
+ except Exception as dl_err:
161
+ logger.debug("Failed to download %s: %s", filename, str(dl_err))
162
+ continue
163
+
164
+ if model_path and os.path.exists(model_path):
165
+ if model_type == "pytorch":
166
+ error_msg = (
167
+ f"Repository '{model_id}' contains a PyTorch model (.pt file), "
168
+ f"not a FastAI model. FastAI models must be .pkl files created with FastAI's export. "
169
+ f"Please use a FastAI-compatible colorization model, or switch to a different model backend."
170
+ )
171
+ logger.error(error_msg)
172
+ model_load_error = error_msg
173
+ raise RuntimeError(error_msg)
174
+ else:
175
+ logger.info("Loading FastAI model from: %s", model_path)
176
+ learn = load_learner(model_path)
177
+ logger.info("✅ Model loaded successfully from %s", model_path)
178
+ model_load_error = None
179
+ else:
180
+ error_msg = (
181
+ f"Could not find model file in repository '{model_id}'. "
182
+ f"Tried: {', '.join(model_filenames)}. "
183
+ f"Original error: {str(e1)}"
184
+ )
185
+ logger.error(error_msg)
186
+ model_load_error = error_msg
187
+ raise RuntimeError(error_msg)
188
+
189
  except Exception as e:
190
  error_msg = str(e)
191
+ if not model_load_error:
192
+ model_load_error = error_msg
193
  logger.error("❌ Failed to load model: %s", error_msg)
 
194
  # Don't raise - allow health check to work
195
 
196
  @app.on_event("shutdown")