LogicGoInfotechSpaces commited on
Commit
ae29148
·
1 Parent(s): c9d2859

Improve weight loading verification and skip configs with too many missing keys

Browse files
Files changed (1) hide show
  1. app/pytorch_colorizer.py +18 -9
app/pytorch_colorizer.py CHANGED
@@ -78,10 +78,11 @@ class ResNetGenerator(nn.Module):
78
  model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
79
  model += [nn.Tanh()]
80
 
81
- self.model = nn.Sequential(*model)
 
82
 
83
  def forward(self, input):
84
- return self.model(input)
85
 
86
 
87
  class UNetGenerator(nn.Module):
@@ -211,15 +212,18 @@ class PyTorchColorizer:
211
 
212
  # Log state dict keys to understand model structure
213
  if isinstance(state_dict, dict):
214
- keys = list(state_dict.keys())[:20] # First 20 keys
215
  logger.info(f"Model state_dict keys (sample): {keys}")
216
  logger.info(f"Total state_dict keys: {len(state_dict.keys())}")
217
 
218
  # Try to infer architecture from key names
 
 
219
  if any('down' in k.lower() or 'up' in k.lower() for k in keys):
220
  logger.info("Detected U-Net style architecture")
221
  if any('resnet' in k.lower() for k in keys):
222
  logger.info("Detected ResNet style architecture")
 
223
 
224
  except Exception as e:
225
  logger.error(f"Failed to load model file: {e}")
@@ -250,12 +254,17 @@ class PyTorchColorizer:
250
 
251
  # Try strict loading first
252
  try:
253
- model.load_state_dict(state_dict, strict=True)
254
- logger.info(f"✅ Successfully loaded {model_type} model with strict matching: {config_copy}")
255
- except:
256
- # If strict fails, try non-strict
257
- model.load_state_dict(state_dict, strict=False)
258
- logger.info(f"✅ Successfully loaded {model_type} model with non-strict matching: {config_copy}")
 
 
 
 
 
259
 
260
  model.eval()
261
  model.to(self.device)
 
78
  model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
79
  model += [nn.Tanh()]
80
 
81
+ # Wrap in Sequential with 'layers' attribute to match state_dict structure
82
+ self.layers = nn.Sequential(*model)
83
 
84
  def forward(self, input):
85
+ return self.layers(input)
86
 
87
 
88
  class UNetGenerator(nn.Module):
 
212
 
213
  # Log state dict keys to understand model structure
214
  if isinstance(state_dict, dict):
215
+ keys = list(state_dict.keys())[:30] # First 30 keys
216
  logger.info(f"Model state_dict keys (sample): {keys}")
217
  logger.info(f"Total state_dict keys: {len(state_dict.keys())}")
218
 
219
  # Try to infer architecture from key names
220
+ if any('layers' in k.lower() for k in keys):
221
+ logger.info("Detected sequential 'layers' structure")
222
  if any('down' in k.lower() or 'up' in k.lower() for k in keys):
223
  logger.info("Detected U-Net style architecture")
224
  if any('resnet' in k.lower() for k in keys):
225
  logger.info("Detected ResNet style architecture")
226
+
227
 
228
  except Exception as e:
229
  logger.error(f"Failed to load model file: {e}")
 
254
 
255
  # Try strict loading first
256
  try:
257
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
258
+ if not missing_keys and not unexpected_keys:
259
+ logger.info(f"✅ Successfully loaded {model_type} model with perfect matching: {config_copy}")
260
+ else:
261
+ logger.warning(f"⚠️ Loaded {model_type} model with mismatches - Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
262
+ if len(missing_keys) > len(state_dict) * 0.5: # If more than 50% missing, skip
263
+ logger.warning(f"Skipping this config - too many missing keys ({len(missing_keys)}/{len(state_dict)})")
264
+ continue
265
+ except Exception as e:
266
+ logger.debug(f"Failed to load state_dict: {e}")
267
+ continue
268
 
269
  model.eval()
270
  model.to(self.device)