Lucas Gagneten commited on
Commit
31755b3
1 Parent(s): 083efaa

LayoutLMv3 fine-tuneado cargado directamente de Hugging Face: lucasgagneten/layoutlmv3-argentine-invoices

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. app.py +34 -69
.gitignore CHANGED
@@ -1,4 +1,6 @@
1
  .env
2
  /venv/
3
  /__pycache__/
4
- *.bat
 
 
 
1
  .env
2
  /venv/
3
  /__pycache__/
4
+ *.bat
5
+ app_with_state_dict.py
6
+ layoutlmv3_state_dict.pth
app.py CHANGED
@@ -12,16 +12,16 @@ from io import BytesIO
12
  warnings.filterwarnings('ignore')
13
 
14
  # --- 1. Carga de Modelo y Procesador (CPU Habilitada) ---
15
-
16
- # --- CONFIGURACI脫N DE ARCHIVOS ---
17
- STATE_DICT_PATH = "./layoutlmv3_state_dict.pth"
18
- BASE_MODEL = "microsoft/layoutlmv3-base" # Usamos este para la arquitectura base
19
 
20
  # Define el dispositivo como CPU
21
  device = torch.device("cpu")
22
  print(f"Inferencia forzada al dispositivo: {device}")
23
 
24
  # Definir las etiquetas utilizadas durante el entrenamiento
 
 
25
  label_list = [
26
  'B-ALICUOTA',
27
  'B-COMPROBANTE_NUMERO',
@@ -44,6 +44,7 @@ label_list = [
44
  ]
45
  id2label = {i: label for i, label in enumerate(label_list)}
46
  label2id = {label: i for i, label in enumerate(label_list)}
 
47
  # 1. Definir una paleta de colores robusta
48
  color_palette = [
49
  'red', 'blue', 'green', 'purple', 'orange', 'brown',
@@ -52,55 +53,43 @@ color_palette = [
52
  ]
53
 
54
  # 2. Extraer las etiquetas ra铆z 煤nicas
55
- # La etiqueta 'O' (Outside) se ignora ya que no es una entidad
56
  root_labels = set()
57
  for label in label_list:
58
  if label != 'O':
59
- # Split solo por el primer '-' para manejar etiquetas tipo 'B-ETIQUETA'
60
  root_label = label.split('-', 1)[-1]
61
  root_labels.add(root_label)
62
 
63
  # 3. Crear el diccionario de asignaci贸n de color
64
  label2color = {}
65
- for i, root_label in enumerate(sorted(list(root_labels))): # Ordenar para consistencia
66
- # Asigna un color de la paleta usando el operador m贸dulo (%) para reciclar colores
67
  label2color[root_label] = color_palette[i % len(color_palette)]
68
 
69
  # Cargar el modelo/procesador
70
  try:
71
- # 1. Cargar la configuraci贸n de procesamiento de imagen, FORZANDO apply_ocr=False
72
- image_processor = LayoutLMv3ImageProcessor.from_pretrained(BASE_MODEL, apply_ocr=False)
73
-
74
- # 2. Inicializar AutoProcessor con el procesador de imagen ya configurado
75
  loaded_processor = AutoProcessor.from_pretrained(
76
- BASE_MODEL, image_processor=image_processor
 
77
  )
78
 
79
- # 2. Cargar la arquitectura base de LayoutLMv3 (sin los pesos)
80
- # Se a帽ade la configuraci贸n de las etiquetas personalizadas
81
  loaded_model = LayoutLMv3ForTokenClassification.from_pretrained(
82
- BASE_MODEL,
83
- num_labels=len(label_list),
84
- id2label=id2label,
85
- label2id=label2id
86
- ).to(device)
87
-
88
- # 3. Cargar los pesos fine-tuneados desde el archivo .pth
89
- if os.path.exists(STATE_DICT_PATH):
90
- # Mapear a la CPU para asegurar la compatibilidad
91
- state_dict = torch.load(STATE_DICT_PATH, map_location=device)
92
-
93
- # Inyectar los pesos en el modelo
94
- loaded_model.load_state_dict(state_dict)
95
- print(f"Modelo fine-tuneado cargado exitosamente desde {STATE_DICT_PATH} en CPU.")
96
- else:
97
- print(f"Advertencia: No se encontr贸 el archivo de pesos: {STATE_DICT_PATH}. Usando pesos iniciales del modelo base.")
98
 
99
  except Exception as e:
100
- print(f"Error fatal al cargar el modelo o procesador: {e}")
101
- # En un entorno de producci贸n, puedes optar por salir o cargar el modelo base como fallback.
102
- # Por simplicidad, el c贸digo anterior se salta el fallback del modelo base,
103
- # ya que la arquitectura base ya fue cargada, solo fall贸 la inyecci贸n de pesos.
104
 
105
  # Cargar el predictor OCR de DocTR
106
  doctr_model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
@@ -121,7 +110,6 @@ def process_invoice(image: Image.Image):
121
 
122
  # 2. Guardar la imagen en un buffer de memoria como si fuera un archivo JPG
123
  img_byte_arr = BytesIO()
124
- # Nota: Aseg煤rate de que PIL pueda guardar como 'jpeg' o 'png'
125
  rgb_image.save(img_byte_arr, format='JPEG')
126
 
127
  # 3. Mover el puntero al inicio del buffer y obtener los bytes
@@ -129,11 +117,9 @@ def process_invoice(image: Image.Image):
129
  image_bytes = img_byte_arr.read()
130
 
131
  # 4. DocTR soporta la carga de una lista de bytes de im谩genes
132
- # NOTA: Usamos from_images y le pasamos los bytes de UNA imagen
133
  doctr_doc = DocumentFile.from_images([image_bytes])
134
 
135
  except Exception as e:
136
- # Imprime el error completo en tu consola para depuraci贸n
137
  print(f"Error detallado al cargar imagen en DocTR: {e}")
138
  return None, f"Error al procesar la imagen con DocTR (conversi贸n): {e}", None, None
139
 
@@ -179,7 +165,6 @@ def process_invoice(image: Image.Image):
179
  pixel_values = encoding["pixel_values"].to(device) # LayoutLMv3 usa 'pixel_values'
180
 
181
  # 3. Inferencia del Modelo LayoutLMv3
182
- # Aseg煤rate de poner el modelo en modo de evaluaci贸n
183
  loaded_model.eval()
184
  with torch.no_grad():
185
  outputs = loaded_model(
@@ -191,8 +176,7 @@ def process_invoice(image: Image.Image):
191
 
192
  predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()
193
 
194
- # --- Mapeo Correcto de Predicciones a Palabras del OCR (Similar al Colab) ---
195
- # Esto asegura que haya una predicci贸n limpia por cada palabra extra铆da por DocTR.
196
  word_ids = encoding.word_ids()
197
  predictions_final = []
198
  current_word_index = None
@@ -200,38 +184,32 @@ def process_invoice(image: Image.Image):
200
  for idx, pred_id in enumerate(predictions):
201
  word_idx = word_ids[idx]
202
 
203
- # Solo procesar tokens que se mapean a palabras (no CLS, SEP, etc.)
204
  if word_idx is not None:
205
- # Solo tomar la predicci贸n del primer sub-token de cada palabra
206
  if word_idx != current_word_index:
207
  if len(predictions_final) < len(words):
208
- predictions_final.append(id2label[pred_id])
 
209
 
210
  current_word_index = word_idx
211
 
212
  # --- Fin del Mapeo ---
213
 
214
 
215
- # 4. Agrupaci贸n de Resultados BIO (Recolecta todos los candidatos, incluidos duplicados)
216
- # ner_candidates almacenar谩 una lista de entidades para cada etiqueta ra铆z.
217
- # Structure: {'ETIQUETA': [{'valor': '...', 'bbox_entity': [...]}, {...}]}
218
  ner_candidates = {}
219
 
220
  current_entity = []
221
- current_label = None # Almacena la etiqueta ra铆z (ej. 'TOTAL')
222
  current_bbox_group = []
223
 
224
- # Funci贸n auxiliar para guardar la entidad actual
225
  def save_current_entity(entity_list, label, bbox_list):
226
  if not entity_list or not label:
227
  return
228
 
229
- # 1. Calcular el BBox final de la entidad (min/max de todos los bboxes de las palabras)
230
  all_x = [b[0] for b in bbox_list] + [b[2] for b in bbox_list]
231
  all_y = [b[1] for b in bbox_list] + [b[3] for b in bbox_list]
232
  bbox_normalized = [min(all_x), min(all_y), max(all_x), max(all_y)]
233
 
234
- # 2. Guardar en ner_candidates (permite duplicados)
235
  if label not in ner_candidates:
236
  ner_candidates[label] = []
237
 
@@ -240,7 +218,6 @@ def process_invoice(image: Image.Image):
240
  'bbox_entity': bbox_normalized
241
  })
242
 
243
- # Iterar sobre palabras y sus predicciones finales
244
  for word_data, pred_label in zip(words_data, predictions_final):
245
  word_text = word_data["text"]
246
  word_box = word_data["box"]
@@ -249,22 +226,17 @@ def process_invoice(image: Image.Image):
249
  root_label = tag_parts[1] if len(tag_parts) > 1 else None
250
 
251
  if tag_type == 'B':
252
- # 1. Si hay una entidad previa, guardarla.
253
  save_current_entity(current_entity, current_label, current_bbox_group)
254
 
255
- # 2. Iniciar la nueva entidad.
256
  current_label = root_label
257
  current_entity = [word_text]
258
  current_bbox_group = [word_box]
259
 
260
  elif tag_type == 'I':
261
- # Continuar solo si el I- tag corresponde a la entidad B- tag actual
262
  if current_label == root_label:
263
  current_entity.append(word_text)
264
  current_bbox_group.append(word_box)
265
  else:
266
- # Si no coincide (error BIO), guardar la entidad previa (si existe) y
267
- # tratar el I- tag desalineado como el inicio de una nueva entidad.
268
  save_current_entity(current_entity, current_label, current_bbox_group)
269
 
270
  current_label = root_label
@@ -272,15 +244,12 @@ def process_invoice(image: Image.Image):
272
  current_bbox_group = [word_box]
273
 
274
  elif tag_type == 'O':
275
- # Si se encuentra 'O', finalizar la entidad actual si existe.
276
  save_current_entity(current_entity, current_label, current_bbox_group)
277
 
278
- # Resetear
279
  current_entity = []
280
  current_label = None
281
  current_bbox_group = []
282
 
283
- # A帽adir la 煤ltima entidad despu茅s del bucle
284
  save_current_entity(current_entity, current_label, current_bbox_group)
285
 
286
 
@@ -291,13 +260,9 @@ def process_invoice(image: Image.Image):
291
  if not candidates:
292
  continue
293
 
294
- # Ordenar por longitud de la cadena de valor (mayor a menor)
295
  sorted_candidates = sorted(candidates, key=lambda x: len(x['valor']), reverse=True)
296
-
297
- # El mejor candidato es el primero (el m谩s largo)
298
  best_candidate = sorted_candidates[0]
299
 
300
- # Agregar al resultado final (ya desduplicado)
301
  final_ner_results.append({
302
  'etiqueta': label,
303
  'valor': best_candidate['valor'],
@@ -305,10 +270,10 @@ def process_invoice(image: Image.Image):
305
  })
306
 
307
 
308
- # Preparar tabla de resultados (Usando final_ner_results)
309
  table_data = [[res['etiqueta'], res['valor']] for res in final_ner_results]
310
 
311
- # 6. Dibujar Bounding Boxes en la Imagen (para visualizaci贸n)
312
  annotated_image = image.copy()
313
  draw = ImageDraw.Draw(annotated_image)
314
 
@@ -317,7 +282,7 @@ def process_invoice(image: Image.Image):
317
  except IOError:
318
  font = ImageFont.load_default()
319
 
320
- for res in final_ner_results: # Usar final_ner_results
321
  label = res['etiqueta']
322
  min_x_norm, min_y_norm, max_x_norm, max_y_norm = res['bbox_entity']
323
 
@@ -336,7 +301,7 @@ def process_invoice(image: Image.Image):
336
  # 7. Devolver resultados
337
  return annotated_image, "Extracci贸n de Entidades Nombradas completada.", table_data, [
338
  {'etiqueta': r['etiqueta'], 'valor': r['valor'], 'bbox_entity': r['bbox_entity']}
339
- for r in final_ner_results # Usar final_ner_results
340
  ]
341
 
342
  # --- 3. Interfaz Gradio ---
@@ -359,7 +324,7 @@ with gr.Blocks(title="NER de Facturas Argentinas con LayoutLMv3 y DocTR") as dem
359
  f"""
360
  # 馃嚘馃嚪 Extracci贸n de Datos de Facturas Argentinas (LayoutLMv3 + DocTR)
361
  Carga una imagen de factura para realizar OCR (DocTR) y Reconocimiento de Entidades Nombradas (NER)
362
- con un modelo **LayoutLMv3 fine-tuneado** cargado desde **`{STATE_DICT_PATH}`**, forzando la **ejecuci贸n en CPU**.
363
  """
364
  )
365
 
 
12
  warnings.filterwarnings('ignore')
13
 
14
  # --- 1. Carga de Modelo y Procesador (CPU Habilitada) ---
15
+ # MODELO DE HUGGING FACE FINE-TUNEADO
16
+ HUGGINGFACE_MODEL = "lucasgagneten/layoutlmv3-argentine-invoices"
 
 
17
 
18
  # Define el dispositivo como CPU
19
  device = torch.device("cpu")
20
  print(f"Inferencia forzada al dispositivo: {device}")
21
 
22
  # Definir las etiquetas utilizadas durante el entrenamiento
23
+ # Estas son necesarias para la l贸gica de visualizaci贸n y la deduplicaci贸n,
24
+ # aunque el modelo cargado ya contendr谩 esta informaci贸n en su configuraci贸n.
25
  label_list = [
26
  'B-ALICUOTA',
27
  'B-COMPROBANTE_NUMERO',
 
44
  ]
45
  id2label = {i: label for i, label in enumerate(label_list)}
46
  label2id = {label: i for i, label in enumerate(label_list)}
47
+
48
  # 1. Definir una paleta de colores robusta
49
  color_palette = [
50
  'red', 'blue', 'green', 'purple', 'orange', 'brown',
 
53
  ]
54
 
55
  # 2. Extraer las etiquetas ra铆z 煤nicas
 
56
  root_labels = set()
57
  for label in label_list:
58
  if label != 'O':
 
59
  root_label = label.split('-', 1)[-1]
60
  root_labels.add(root_label)
61
 
62
  # 3. Crear el diccionario de asignaci贸n de color
63
  label2color = {}
64
+ for i, root_label in enumerate(sorted(list(root_labels))):
 
65
  label2color[root_label] = color_palette[i % len(color_palette)]
66
 
67
  # Cargar el modelo/procesador
68
  try:
69
+ # 1. Cargar el procesador directamente desde el modelo de HF.
70
+ # El procesador de LayoutLMv3 siempre requiere que apply_ocr=False si se usa con OCR externo.
71
+ # AutoProcessor se encargar谩 de cargar ImageProcessor, Tokenizer, y FeatureExtractor.
 
72
  loaded_processor = AutoProcessor.from_pretrained(
73
+ HUGGINGFACE_MODEL,
74
+ apply_ocr=False # Importante para usar los resultados de DocTR
75
  )
76
 
77
+ # 2. Cargar el modelo de Clasificaci贸n de Tokens directamente desde el repositorio de HF.
78
+ # Esto carga tanto la arquitectura como los pesos fine-tuneados.
79
  loaded_model = LayoutLMv3ForTokenClassification.from_pretrained(
80
+ HUGGINGFACE_MODEL
81
+ ).to(device) # Mover a la CPU
82
+
83
+ # Sobrescribir id2label/label2id para consistencia, aunque ya deber铆an estar cargados
84
+ # en la configuraci贸n del modelo de HF. Esto es una precauci贸n.
85
+ loaded_model.config.id2label = id2label
86
+ loaded_model.config.label2id = label2id
87
+
88
+ print(f"Modelo fine-tuneado cargado exitosamente desde Hugging Face: {HUGGINGFACE_MODEL} en CPU.")
 
 
 
 
 
 
 
89
 
90
  except Exception as e:
91
+ print(f"Error fatal al cargar el modelo o procesador desde Hugging Face: {e}")
92
+ # Nota: Aqu铆 la aplicaci贸n fallar铆a si no puede descargar el modelo.
 
 
93
 
94
  # Cargar el predictor OCR de DocTR
95
  doctr_model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
 
110
 
111
  # 2. Guardar la imagen en un buffer de memoria como si fuera un archivo JPG
112
  img_byte_arr = BytesIO()
 
113
  rgb_image.save(img_byte_arr, format='JPEG')
114
 
115
  # 3. Mover el puntero al inicio del buffer y obtener los bytes
 
117
  image_bytes = img_byte_arr.read()
118
 
119
  # 4. DocTR soporta la carga de una lista de bytes de im谩genes
 
120
  doctr_doc = DocumentFile.from_images([image_bytes])
121
 
122
  except Exception as e:
 
123
  print(f"Error detallado al cargar imagen en DocTR: {e}")
124
  return None, f"Error al procesar la imagen con DocTR (conversi贸n): {e}", None, None
125
 
 
165
  pixel_values = encoding["pixel_values"].to(device) # LayoutLMv3 usa 'pixel_values'
166
 
167
  # 3. Inferencia del Modelo LayoutLMv3
 
168
  loaded_model.eval()
169
  with torch.no_grad():
170
  outputs = loaded_model(
 
176
 
177
  predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()
178
 
179
+ # --- Mapeo Correcto de Predicciones a Palabras del OCR ---
 
180
  word_ids = encoding.word_ids()
181
  predictions_final = []
182
  current_word_index = None
 
184
  for idx, pred_id in enumerate(predictions):
185
  word_idx = word_ids[idx]
186
 
 
187
  if word_idx is not None:
 
188
  if word_idx != current_word_index:
189
  if len(predictions_final) < len(words):
190
+ # Usar el id2label del modelo cargado, que ahora es la fuente de verdad
191
+ predictions_final.append(loaded_model.config.id2label[pred_id])
192
 
193
  current_word_index = word_idx
194
 
195
  # --- Fin del Mapeo ---
196
 
197
 
198
+ # 4. Agrupaci贸n de Resultados BIO
 
 
199
  ner_candidates = {}
200
 
201
  current_entity = []
202
+ current_label = None
203
  current_bbox_group = []
204
 
 
205
  def save_current_entity(entity_list, label, bbox_list):
206
  if not entity_list or not label:
207
  return
208
 
 
209
  all_x = [b[0] for b in bbox_list] + [b[2] for b in bbox_list]
210
  all_y = [b[1] for b in bbox_list] + [b[3] for b in bbox_list]
211
  bbox_normalized = [min(all_x), min(all_y), max(all_x), max(all_y)]
212
 
 
213
  if label not in ner_candidates:
214
  ner_candidates[label] = []
215
 
 
218
  'bbox_entity': bbox_normalized
219
  })
220
 
 
221
  for word_data, pred_label in zip(words_data, predictions_final):
222
  word_text = word_data["text"]
223
  word_box = word_data["box"]
 
226
  root_label = tag_parts[1] if len(tag_parts) > 1 else None
227
 
228
  if tag_type == 'B':
 
229
  save_current_entity(current_entity, current_label, current_bbox_group)
230
 
 
231
  current_label = root_label
232
  current_entity = [word_text]
233
  current_bbox_group = [word_box]
234
 
235
  elif tag_type == 'I':
 
236
  if current_label == root_label:
237
  current_entity.append(word_text)
238
  current_bbox_group.append(word_box)
239
  else:
 
 
240
  save_current_entity(current_entity, current_label, current_bbox_group)
241
 
242
  current_label = root_label
 
244
  current_bbox_group = [word_box]
245
 
246
  elif tag_type == 'O':
 
247
  save_current_entity(current_entity, current_label, current_bbox_group)
248
 
 
249
  current_entity = []
250
  current_label = None
251
  current_bbox_group = []
252
 
 
253
  save_current_entity(current_entity, current_label, current_bbox_group)
254
 
255
 
 
260
  if not candidates:
261
  continue
262
 
 
263
  sorted_candidates = sorted(candidates, key=lambda x: len(x['valor']), reverse=True)
 
 
264
  best_candidate = sorted_candidates[0]
265
 
 
266
  final_ner_results.append({
267
  'etiqueta': label,
268
  'valor': best_candidate['valor'],
 
270
  })
271
 
272
 
273
+ # Preparar tabla de resultados
274
  table_data = [[res['etiqueta'], res['valor']] for res in final_ner_results]
275
 
276
+ # 6. Dibujar Bounding Boxes en la Imagen
277
  annotated_image = image.copy()
278
  draw = ImageDraw.Draw(annotated_image)
279
 
 
282
  except IOError:
283
  font = ImageFont.load_default()
284
 
285
+ for res in final_ner_results:
286
  label = res['etiqueta']
287
  min_x_norm, min_y_norm, max_x_norm, max_y_norm = res['bbox_entity']
288
 
 
301
  # 7. Devolver resultados
302
  return annotated_image, "Extracci贸n de Entidades Nombradas completada.", table_data, [
303
  {'etiqueta': r['etiqueta'], 'valor': r['valor'], 'bbox_entity': r['bbox_entity']}
304
+ for r in final_ner_results
305
  ]
306
 
307
  # --- 3. Interfaz Gradio ---
 
324
  f"""
325
  # 馃嚘馃嚪 Extracci贸n de Datos de Facturas Argentinas (LayoutLMv3 + DocTR)
326
  Carga una imagen de factura para realizar OCR (DocTR) y Reconocimiento de Entidades Nombradas (NER)
327
+ con un modelo **LayoutLMv3 fine-tuneado** cargado directamente de Hugging Face: **`{HUGGINGFACE_MODEL}`**, forzando la **ejecuci贸n en CPU**.
328
  """
329
  )
330