LogicGoInfotechSpaces commited on
Commit
80080e1
·
1 Parent(s): 93277d4

Add fallback mechanism to manually download and load FastAI model if from_pretrained_fastai fails

Browse files
Files changed (1) hide show
  1. app/colorize_model.py +37 -3
app/colorize_model.py CHANGED
@@ -21,7 +21,7 @@ os.environ["XDG_CACHE_HOME"] = cache_dir
21
  import torch
22
  from PIL import Image
23
  from fastai.vision.all import *
24
- from huggingface_hub import from_pretrained_fastai
25
 
26
  from app.config import settings
27
 
@@ -57,8 +57,42 @@ class ColorizeModel:
57
 
58
  logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
59
  try:
60
- self.learn = from_pretrained_fastai(self.model_id)
61
- logger.info("FastAI GAN Colorization model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
  error_msg = (
64
  f"Failed to load FastAI model '{self.model_id}'. "
 
21
  import torch
22
  from PIL import Image
23
  from fastai.vision.all import *
24
+ from huggingface_hub import from_pretrained_fastai, hf_hub_download
25
 
26
  from app.config import settings
27
 
 
57
 
58
  logger.info("Loading FastAI GAN Colorization model: %s", self.model_id)
59
  try:
60
+ # Try using from_pretrained_fastai first
61
+ try:
62
+ self.learn = from_pretrained_fastai(self.model_id)
63
+ logger.info("FastAI GAN Colorization model loaded successfully via from_pretrained_fastai")
64
+ except Exception as e1:
65
+ logger.warning("from_pretrained_fastai failed: %s. Trying manual download...", str(e1))
66
+ # Fallback: manually download and load the model file
67
+ # Try common FastAI model file names
68
+ model_filenames = ["model.pkl", "export.pkl", "learner.pkl", "model_export.pkl"]
69
+ model_path = None
70
+
71
+ for filename in model_filenames:
72
+ try:
73
+ model_path = hf_hub_download(
74
+ repo_id=self.model_id,
75
+ filename=filename,
76
+ cache_dir=self.cache_dir,
77
+ token=os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
78
+ )
79
+ logger.info("Found model file: %s", filename)
80
+ break
81
+ except Exception:
82
+ continue
83
+
84
+ if model_path and os.path.exists(model_path):
85
+ # Load the model using FastAI's load_learner
86
+ logger.info("Loading model from: %s", model_path)
87
+ self.learn = load_learner(model_path)
88
+ logger.info("FastAI GAN Colorization model loaded successfully from %s", model_path)
89
+ else:
90
+ # If no model file found, try listing repository files
91
+ raise RuntimeError(
92
+ f"Could not find model file in repository '{self.model_id}'. "
93
+ f"Tried: {', '.join(model_filenames)}. "
94
+ f"Original error: {str(e1)}"
95
+ )
96
  except Exception as e:
97
  error_msg = (
98
  f"Failed to load FastAI model '{self.model_id}'. "