LogicGoInfotechSpaces commited on
Commit
c9d2859
·
1 Parent(s): ec7bfd1

Fix colorization: Add ResNet generator architecture and fix minimum image size to prevent kernel errors

Browse files
Files changed (1) hide show
  1. app/pytorch_colorizer.py +116 -20
app/pytorch_colorizer.py CHANGED
@@ -16,6 +16,74 @@ from huggingface_hub import hf_hub_download
16
  logger = logging.getLogger(__name__)
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class UNetGenerator(nn.Module):
20
  """
21
  U-Net Generator for Image Colorization
@@ -143,34 +211,51 @@ class PyTorchColorizer:
143
 
144
  # Log state dict keys to understand model structure
145
  if isinstance(state_dict, dict):
146
- keys = list(state_dict.keys())[:10] # First 10 keys
147
  logger.info(f"Model state_dict keys (sample): {keys}")
148
  logger.info(f"Total state_dict keys: {len(state_dict.keys())}")
 
 
 
 
 
 
149
 
150
  except Exception as e:
151
  logger.error(f"Failed to load model file: {e}")
152
  raise
153
 
154
  # Try different model architectures with state_dict
 
155
  model_configs = [
156
- {"input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 64},
157
- {"input_nc": 1, "output_nc": 3, "num_downs": 7, "ngf": 64},
158
- {"input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 32},
159
- {"input_nc": 1, "output_nc": 3, "num_downs": 6, "ngf": 64},
 
 
 
 
160
  ]
161
 
162
  loaded = False
163
  for config in model_configs:
164
  try:
165
- model = UNetGenerator(**config)
 
 
 
 
 
 
166
  # Try strict loading first
167
  try:
168
  model.load_state_dict(state_dict, strict=True)
169
- logger.info(f"✅ Successfully loaded model with strict matching: {config}")
170
  except:
171
  # If strict fails, try non-strict
172
  model.load_state_dict(state_dict, strict=False)
173
- logger.info(f"✅ Successfully loaded model with non-strict matching: {config}")
174
 
175
  model.eval()
176
  model.to(self.device)
@@ -178,25 +263,25 @@ class PyTorchColorizer:
178
  loaded = True
179
  break
180
  except Exception as e:
181
- logger.debug(f"Failed to load with config {config}: {e}")
182
  continue
183
 
184
  if not loaded:
185
- # Last resort: try with default config and non-strict loading
186
  try:
187
- logger.warning("Attempting to load model with default config and non-strict matching")
188
- model = UNetGenerator(input_nc=1, output_nc=3, num_downs=8, ngf=64)
189
  model.load_state_dict(state_dict, strict=False)
190
  model.eval()
191
  model.to(self.device)
192
  self.model = model
193
- logger.info("✅ Model loaded with fallback method")
194
  except Exception as e:
195
  logger.error(f"Failed to load model: {e}")
196
  raise RuntimeError(
197
- f"Could not load PyTorch model. Tried multiple architectures. "
198
  f"Last error: {e}. "
199
- f"The model architecture may not match the expected U-Net structure."
200
  )
201
 
202
  except Exception as e:
@@ -222,16 +307,27 @@ class PyTorchColorizer:
222
  if image.mode != "L":
223
  image = image.convert("L")
224
 
225
- # Try to maintain aspect ratio and use a better resize
226
- # Many GAN models work better with 256x256 or 512x512
227
- target_size = 256
228
- if max(original_size) > 512:
229
- # Scale down proportionally but keep max dimension reasonable
 
 
 
 
 
 
 
230
  scale = target_size / max(original_size)
231
  new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
232
  else:
 
233
  new_size = original_size
234
 
 
 
 
235
  # Transform to tensor
236
  # GAN colorization models typically expect normalized input
237
  transform = transforms.Compose([
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
19
+ class ResNetBlock(nn.Module):
20
+ """ResNet block with skip connection"""
21
+ def __init__(self, dim):
22
+ super(ResNetBlock, self).__init__()
23
+ self.conv_block = self.build_conv_block(dim)
24
+
25
+ def build_conv_block(self, dim):
26
+ conv_block = []
27
+ conv_block += [nn.ReflectionPad2d(1)]
28
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True)]
29
+ conv_block += [nn.InstanceNorm2d(dim)]
30
+ conv_block += [nn.ReLU(True)]
31
+ conv_block += [nn.ReflectionPad2d(1)]
32
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True)]
33
+ conv_block += [nn.InstanceNorm2d(dim)]
34
+ return nn.Sequential(*conv_block)
35
+
36
+ def forward(self, x):
37
+ out = x + self.conv_block(x)
38
+ return out
39
+
40
+
41
+ class ResNetGenerator(nn.Module):
42
+ """
43
+ ResNet Generator for Image Colorization
44
+ Architecture with sequential layers (matches 'layers.X.X' structure)
45
+ """
46
+ def __init__(self, input_nc=1, output_nc=3, ngf=64, n_blocks=9):
47
+ super(ResNetGenerator, self).__init__()
48
+
49
+ model = []
50
+ # Initial convolution block
51
+ model += [nn.ReflectionPad2d(3)]
52
+ model += [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True)]
53
+ model += [nn.InstanceNorm2d(ngf)]
54
+ model += [nn.ReLU(True)]
55
+
56
+ # Downsampling
57
+ n_downsampling = 2
58
+ for i in range(n_downsampling):
59
+ mult = 2 ** i
60
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True)]
61
+ model += [nn.InstanceNorm2d(ngf * mult * 2)]
62
+ model += [nn.ReLU(True)]
63
+
64
+ # ResNet blocks
65
+ mult = 2 ** n_downsampling
66
+ for i in range(n_blocks):
67
+ model += [ResNetBlock(ngf * mult)]
68
+
69
+ # Upsampling
70
+ for i in range(n_downsampling):
71
+ mult = 2 ** (n_downsampling - i)
72
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)]
73
+ model += [nn.InstanceNorm2d(int(ngf * mult / 2))]
74
+ model += [nn.ReLU(True)]
75
+
76
+ # Output layer
77
+ model += [nn.ReflectionPad2d(3)]
78
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
79
+ model += [nn.Tanh()]
80
+
81
+ self.model = nn.Sequential(*model)
82
+
83
+ def forward(self, input):
84
+ return self.model(input)
85
+
86
+
87
  class UNetGenerator(nn.Module):
88
  """
89
  U-Net Generator for Image Colorization
 
211
 
212
  # Log state dict keys to understand model structure
213
  if isinstance(state_dict, dict):
214
+ keys = list(state_dict.keys())[:20] # First 20 keys
215
  logger.info(f"Model state_dict keys (sample): {keys}")
216
  logger.info(f"Total state_dict keys: {len(state_dict.keys())}")
217
+
218
+ # Try to infer architecture from key names
219
+ if any('down' in k.lower() or 'up' in k.lower() for k in keys):
220
+ logger.info("Detected U-Net style architecture")
221
+ if any('resnet' in k.lower() for k in keys):
222
+ logger.info("Detected ResNet style architecture")
223
 
224
  except Exception as e:
225
  logger.error(f"Failed to load model file: {e}")
226
  raise
227
 
228
  # Try different model architectures with state_dict
229
+ # Based on state_dict keys showing "layers" structure, try ResNet first
230
  model_configs = [
231
+ # ResNet Generator (matches "layers" structure)
232
+ {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 9},
233
+ {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 32, "n_blocks": 6},
234
+ {"type": "resnet", "input_nc": 1, "output_nc": 3, "ngf": 64, "n_blocks": 6},
235
+ # U-Net Generator (fallback)
236
+ {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 64},
237
+ {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 7, "ngf": 64},
238
+ {"type": "unet", "input_nc": 1, "output_nc": 3, "num_downs": 8, "ngf": 32},
239
  ]
240
 
241
  loaded = False
242
  for config in model_configs:
243
  try:
244
+ config_copy = config.copy() # Don't modify original
245
+ model_type = config_copy.pop("type")
246
+ if model_type == "resnet":
247
+ model = ResNetGenerator(**config_copy)
248
+ else:
249
+ model = UNetGenerator(**config_copy)
250
+
251
  # Try strict loading first
252
  try:
253
  model.load_state_dict(state_dict, strict=True)
254
+ logger.info(f"✅ Successfully loaded {model_type} model with strict matching: {config_copy}")
255
  except:
256
  # If strict fails, try non-strict
257
  model.load_state_dict(state_dict, strict=False)
258
+ logger.info(f"✅ Successfully loaded {model_type} model with non-strict matching: {config_copy}")
259
 
260
  model.eval()
261
  model.to(self.device)
 
263
  loaded = True
264
  break
265
  except Exception as e:
266
+ logger.debug(f"Failed to load {config.get('type', 'unknown')} model with config {config}: {e}")
267
  continue
268
 
269
  if not loaded:
270
+ # Last resort: try with default ResNet config and non-strict loading
271
  try:
272
+ logger.warning("Attempting to load model with default ResNet config and non-strict matching")
273
+ model = ResNetGenerator(input_nc=1, output_nc=3, ngf=64, n_blocks=9)
274
  model.load_state_dict(state_dict, strict=False)
275
  model.eval()
276
  model.to(self.device)
277
  self.model = model
278
+ logger.info("✅ Model loaded with fallback ResNet method")
279
  except Exception as e:
280
  logger.error(f"Failed to load model: {e}")
281
  raise RuntimeError(
282
+ f"Could not load PyTorch model. Tried multiple architectures (ResNet and U-Net). "
283
  f"Last error: {e}. "
284
+ f"The model architecture may not match the expected structures."
285
  )
286
 
287
  except Exception as e:
 
307
  if image.mode != "L":
308
  image = image.convert("L")
309
 
310
+ # Ensure minimum size - models need at least 64x64, preferably 256x256
311
+ # Many GAN models work better with 256x256
312
+ min_size = 64 # Minimum size to avoid kernel errors
313
+ target_size = 256 # Preferred size for GAN models
314
+
315
+ # Calculate new size maintaining aspect ratio
316
+ if max(original_size) < min_size:
317
+ # If image is too small, scale it up
318
+ scale = min_size / max(original_size)
319
+ new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
320
+ elif max(original_size) > 512:
321
+ # If image is too large, scale it down
322
  scale = target_size / max(original_size)
323
  new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
324
  else:
325
+ # Use original size if it's in a reasonable range
326
  new_size = original_size
327
 
328
+ # Ensure minimum dimensions
329
+ new_size = (max(new_size[0], min_size), max(new_size[1], min_size))
330
+
331
  # Transform to tensor
332
  # GAN colorization models typically expect normalized input
333
  transform = transforms.Compose([