Commit
·
845bd8d
1
Parent(s):
ae29148
Fix: Try ResNet first, improve size validation, and skip configs with too many mismatches
Browse files- app/pytorch_colorizer.py +27 -16
app/pytorch_colorizer.py
CHANGED
|
@@ -231,15 +231,16 @@ class PyTorchColorizer:
|
|
| 231 |
|
| 232 |
# Try different model architectures with state_dict
|
| 233 |
# Based on state_dict keys showing "layers" structure, try ResNet first
|
|
|
|
| 234 |
model_configs = [
|
| 235 |
-
# ResNet Generator (matches "layers" structure)
|
| 236 |
{"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 9},
|
| 237 |
-
{"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 32, "n_blocks": 6},
|
| 238 |
{"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 6},
|
| 239 |
-
|
|
|
|
|
|
|
| 240 |
{"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 64},
|
| 241 |
{"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 7, "ngf": 64},
|
| 242 |
-
{"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 32},
|
| 243 |
]
|
| 244 |
|
| 245 |
loaded = False
|
|
@@ -259,8 +260,9 @@ class PyTorchColorizer:
|
|
| 259 |
logger.info(f"✅ Successfully loaded {model_type} model with perfect matching: {config_copy}")
|
| 260 |
else:
|
| 261 |
logger.warning(f"⚠️ Loaded {model_type} model with mismatches - Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
|
| 262 |
-
|
| 263 |
-
|
|
|
|
| 264 |
continue
|
| 265 |
except Exception as e:
|
| 266 |
logger.debug(f"Failed to load state_dict: {e}")
|
|
@@ -318,24 +320,33 @@ class PyTorchColorizer:
|
|
| 318 |
|
| 319 |
# Ensure minimum size - models need at least 64x64, preferably 256x256
|
| 320 |
# Many GAN models work better with 256x256
|
| 321 |
-
min_size = 64 # Minimum size to avoid kernel errors
|
| 322 |
target_size = 256 # Preferred size for GAN models
|
| 323 |
|
| 324 |
# Calculate new size maintaining aspect ratio
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
# If image is too small, scale it up
|
| 327 |
-
scale = min_size /
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
| 330 |
# If image is too large, scale it down
|
| 331 |
-
scale = target_size /
|
| 332 |
-
|
|
|
|
|
|
|
| 333 |
else:
|
| 334 |
-
# Use original size
|
| 335 |
-
new_size =
|
| 336 |
|
| 337 |
-
#
|
| 338 |
new_size = (max(new_size[0], min_size), max(new_size[1], min_size))
|
|
|
|
| 339 |
|
| 340 |
# Transform to tensor
|
| 341 |
# GAN colorization models typically expect normalized input
|
|
|
|
| 231 |
|
| 232 |
# Try different model architectures with state_dict
|
| 233 |
# Based on state_dict keys showing "layers" structure, try ResNet first
|
| 234 |
+
# The keys like 'layers.0.4.0.conv1.weight' suggest ResNet blocks in a Sequential
|
| 235 |
model_configs = [
|
| 236 |
+
# ResNet Generator (matches "layers" structure) - try these first
|
| 237 |
{"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 9},
|
|
|
|
| 238 |
{"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 6},
|
| 239 |
+
{"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 32, "n_blocks": 9},
|
| 240 |
+
{"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 32, "n_blocks": 6},
|
| 241 |
+
# U-Net Generator (fallback only if ResNet fails)
|
| 242 |
{"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 64},
|
| 243 |
{"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 7, "ngf": 64},
|
|
|
|
| 244 |
]
|
| 245 |
|
| 246 |
loaded = False
|
|
|
|
| 260 |
logger.info(f"✅ Successfully loaded {model_type} model with perfect matching: {config_copy}")
|
| 261 |
else:
|
| 262 |
logger.warning(f"⚠️ Loaded {model_type} model with mismatches - Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
|
| 263 |
+
# If more than 30% missing or if unexpected keys > 50% of state_dict, skip
|
| 264 |
+
if len(missing_keys) > len(state_dict) * 0.3 or len(unexpected_keys) > len(state_dict) * 0.5:
|
| 265 |
+
logger.warning(f"Skipping this config - too many mismatches (Missing: {len(missing_keys)}/{len(state_dict)}, Unexpected: {len(unexpected_keys)}/{len(state_dict)})")
|
| 266 |
continue
|
| 267 |
except Exception as e:
|
| 268 |
logger.debug(f"Failed to load state_dict: {e}")
|
|
|
|
| 320 |
|
| 321 |
# Ensure minimum size - models need at least 64x64, preferably 256x256
|
| 322 |
# Many GAN models work better with 256x256
|
| 323 |
+
min_size = 64 # Minimum size to avoid kernel errors (must be >= 4 for kernel size)
|
| 324 |
target_size = 256 # Preferred size for GAN models
|
| 325 |
|
| 326 |
# Calculate new size maintaining aspect ratio
|
| 327 |
+
width, height = original_size
|
| 328 |
+
max_dim = max(width, height)
|
| 329 |
+
min_dim = min(width, height)
|
| 330 |
+
|
| 331 |
+
if max_dim < min_size:
|
| 332 |
# If image is too small, scale it up
|
| 333 |
+
scale = min_size / max_dim
|
| 334 |
+
new_width = max(int(width * scale), min_size)
|
| 335 |
+
new_height = max(int(height * scale), min_size)
|
| 336 |
+
new_size = (new_width, new_height)
|
| 337 |
+
elif max_dim > 512:
|
| 338 |
# If image is too large, scale it down
|
| 339 |
+
scale = target_size / max_dim
|
| 340 |
+
new_width = max(int(width * scale), min_size)
|
| 341 |
+
new_height = max(int(height * scale), min_size)
|
| 342 |
+
new_size = (new_width, new_height)
|
| 343 |
else:
|
| 344 |
+
# Use original size but ensure minimum dimensions
|
| 345 |
+
new_size = (max(width, min_size), max(height, min_size))
|
| 346 |
|
| 347 |
+
# Double-check minimum dimensions are met
|
| 348 |
new_size = (max(new_size[0], min_size), max(new_size[1], min_size))
|
| 349 |
+
logger.debug(f"Resizing image from {original_size} to {new_size}")
|
| 350 |
|
| 351 |
# Transform to tensor
|
| 352 |
# GAN colorization models typically expect normalized input
|