LogicGoInfotechSpaces commited on
Commit
845bd8d
·
1 Parent(s): ae29148

Fix: Try ResNet first, improve size validation, and skip configs with too many mismatches

Browse files
Files changed (1) hide show
  1. app/pytorch_colorizer.py +27 -16
app/pytorch_colorizer.py CHANGED
@@ -231,15 +231,16 @@ class PyTorchColorizer:
231
 
232
  # Try different model architectures with state_dict
233
  # Based on state_dict keys showing "layers" structure, try ResNet first
 
234
  model_configs = [
235
- # ResNet Generator (matches "layers" structure)
236
  {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 9},
237
- {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 32, "n_blocks": 6},
238
  {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 6},
239
- # U-Net Generator (fallback)
 
 
240
  {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 64},
241
  {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 7, "ngf": 64},
242
- {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 32},
243
  ]
244
 
245
  loaded = False
@@ -259,8 +260,9 @@ class PyTorchColorizer:
259
  logger.info(f"✅ Successfully loaded {model_type} model with perfect matching: {config_copy}")
260
  else:
261
  logger.warning(f"⚠️ Loaded {model_type} model with mismatches - Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
262
- if len(missing_keys) > len(state_dict) * 0.5: # If more than 50% missing, skip
263
- logger.warning(f"Skipping this config - too many missing keys ({len(missing_keys)}/{len(state_dict)})")
 
264
  continue
265
  except Exception as e:
266
  logger.debug(f"Failed to load state_dict: {e}")
@@ -318,24 +320,33 @@ class PyTorchColorizer:
318
 
319
  # Ensure minimum size - models need at least 64x64, preferably 256x256
320
  # Many GAN models work better with 256x256
321
- min_size = 64 # Minimum size to avoid kernel errors
322
  target_size = 256 # Preferred size for GAN models
323
 
324
  # Calculate new size maintaining aspect ratio
325
- if max(original_size) < min_size:
 
 
 
 
326
  # If image is too small, scale it up
327
- scale = min_size / max(original_size)
328
- new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
329
- elif max(original_size) > 512:
 
 
330
  # If image is too large, scale it down
331
- scale = target_size / max(original_size)
332
- new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
 
 
333
  else:
334
- # Use original size if it's in a reasonable range
335
- new_size = original_size
336
 
337
- # Ensure minimum dimensions
338
  new_size = (max(new_size[0], min_size), max(new_size[1], min_size))
 
339
 
340
  # Transform to tensor
341
  # GAN colorization models typically expect normalized input
 
231
 
232
  # Try different model architectures with state_dict
233
  # Based on state_dict keys showing "layers" structure, try ResNet first
234
+ # The keys like 'layers.0.4.0.conv1.weight' suggest ResNet blocks in a Sequential
235
  model_configs = [
236
+ # ResNet Generator (matches "layers" structure) - try these first
237
  {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 9},
 
238
  {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 6},
239
+ {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 32, "n_blocks": 9},
240
+ {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 32, "n_blocks": 6},
241
+ # U-Net Generator (fallback only if ResNet fails)
242
  {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 64},
243
  {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 7, "ngf": 64},
 
244
  ]
245
 
246
  loaded = False
 
260
  logger.info(f"✅ Successfully loaded {model_type} model with perfect matching: {config_copy}")
261
  else:
262
  logger.warning(f"⚠️ Loaded {model_type} model with mismatches - Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
263
+ # If more than 30% missing or if unexpected keys > 50% of state_dict, skip
264
+ if len(missing_keys) > len(state_dict) * 0.3 or len(unexpected_keys) > len(state_dict) * 0.5:
265
+ logger.warning(f"Skipping this config - too many mismatches (Missing: {len(missing_keys)}/{len(state_dict)}, Unexpected: {len(unexpected_keys)}/{len(state_dict)})")
266
  continue
267
  except Exception as e:
268
  logger.debug(f"Failed to load state_dict: {e}")
 
320
 
321
  # Ensure minimum size - models need at least 64x64, preferably 256x256
322
  # Many GAN models work better with 256x256
323
+ min_size = 64 # Minimum size to avoid kernel errors (must be >= 4 for kernel size)
324
  target_size = 256 # Preferred size for GAN models
325
 
326
  # Calculate new size maintaining aspect ratio
327
+ width, height = original_size
328
+ max_dim = max(width, height)
329
+ min_dim = min(width, height)
330
+
331
+ if max_dim < min_size:
332
  # If image is too small, scale it up
333
+ scale = min_size / max_dim
334
+ new_width = max(int(width * scale), min_size)
335
+ new_height = max(int(height * scale), min_size)
336
+ new_size = (new_width, new_height)
337
+ elif max_dim > 512:
338
  # If image is too large, scale it down
339
+ scale = target_size / max_dim
340
+ new_width = max(int(width * scale), min_size)
341
+ new_height = max(int(height * scale), min_size)
342
+ new_size = (new_width, new_height)
343
  else:
344
+ # Use original size but ensure minimum dimensions
345
+ new_size = (max(width, min_size), max(height, min_size))
346
 
347
+ # Double-check minimum dimensions are met
348
  new_size = (max(new_size[0], min_size), max(new_size[1], min_size))
349
+ logger.debug(f"Resizing image from {original_size} to {new_size}")
350
 
351
  # Transform to tensor
352
  # GAN colorization models typically expect normalized input