Lucas Gagneten commited on
Commit
dbeb758
·
1 Parent(s): 376f3e3

first version

Browse files
Files changed (5) hide show
  1. .gitignore +4 -0
  2. README.md +3 -3
  3. app.py +366 -0
  4. layoutlmv3_state_dict.pth +3 -0
  5. requirements.txt +20 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ /venv/
3
+ /__pycache__/
4
+ *.bat
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Layoutlmv3 Facturas Extractor
3
- emoji: 🚀
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
 
1
  ---
2
  title: Layoutlmv3 Facturas Extractor
3
+ emoji: 🏃
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import torch
5
+ from transformers import LayoutLMv3ImageProcessor, AutoProcessor, LayoutLMv3ForTokenClassification
6
+ from doctr.models import ocr_predictor
7
+ from doctr.io import DocumentFile
8
+ from doctr.utils.visualization import visualize_page
9
+ import os
10
+ import warnings
11
+ from io import BytesIO
12
+ warnings.filterwarnings('ignore')
13
+
14
+ # --- 1. Carga de Modelo y Procesador (CPU Habilitada) ---
15
+
16
+ # --- CONFIGURACIÓN DE ARCHIVOS ---
17
+ STATE_DICT_PATH = "./layoutlmv3_state_dict.pth"
18
+ BASE_MODEL = "microsoft/layoutlmv3-base" # Usamos este para la arquitectura base
19
+
20
+ # Define el dispositivo como CPU
21
+ device = torch.device("cpu")
22
+ print(f"Inferencia forzada al dispositivo: {device}")
23
+
24
+ # Definir las etiquetas utilizadas durante el entrenamiento
25
+ label_list = [
26
+ 'B-ALICUOTA', 'B-COMPROBANTE_NUMERO', 'B-CONCEPTO_GASTO', 'B-FECHA', 'B-INGRESOS_BRUTOS', 'B-IVA', 'B-JURISDICCION_GASTO', 'B-NETO', 'B-PROVEEDOR_CUIT', 'B-PROVEEDOR_RAZON_SOCIAL', 'B-TIPO', 'B-TOTAL', 'I-COMPROBANTE_NUMERO', 'I-CONCEPTO_GASTO', 'I-INGRESOS_BRUTOS', 'I-JURISDICCION_GASTO', 'I-PROVEEDOR_CUIT', 'I-PROVEEDOR_RAZON_SOCIAL', 'I-TOTAL', 'O'
27
+ ]
28
+ id2label = {i: label for i, label in enumerate(label_list)}
29
+ label2id = {label: i for i, label in enumerate(label_list)}
30
+ # 1. Definir una paleta de colores robusta
31
+ color_palette = [
32
+ 'red', 'blue', 'green', 'purple', 'orange', 'brown',
33
+ 'pink', 'cyan', 'lime', 'olive', 'teal', 'magenta',
34
+ 'navy', 'maroon', 'gold', 'silver', 'indigo', 'turquoise'
35
+ ]
36
+
37
+ # 2. Extraer las etiquetas raíz únicas
38
+ # La etiqueta 'O' (Outside) se ignora ya que no es una entidad
39
+ root_labels = set()
40
+ for label in label_list:
41
+ if label != 'O':
42
+ # Split solo por el primer '-' para manejar etiquetas tipo 'B-ETIQUETA'
43
+ root_label = label.split('-', 1)[-1]
44
+ root_labels.add(root_label)
45
+
46
+ # 3. Crear el diccionario de asignación de color
47
+ label2color = {}
48
+ for i, root_label in enumerate(sorted(list(root_labels))): # Ordenar para consistencia
49
+ # Asigna un color de la paleta usando el operador módulo (%) para reciclar colores
50
+ label2color[root_label] = color_palette[i % len(color_palette)]
51
+
52
+ # Cargar el modelo/procesador
53
+ try:
54
+ # 1. Cargar la configuración de procesamiento de imagen, FORZANDO apply_ocr=False
55
+ image_processor = LayoutLMv3ImageProcessor.from_pretrained(BASE_MODEL, apply_ocr=False)
56
+
57
+ # 2. Inicializar AutoProcessor con el procesador de imagen ya configurado
58
+ loaded_processor = AutoProcessor.from_pretrained(
59
+ BASE_MODEL, image_processor=image_processor
60
+ )
61
+
62
+ # 2. Cargar la arquitectura base de LayoutLMv3 (sin los pesos)
63
+ # Se añade la configuración de las etiquetas personalizadas
64
+ loaded_model = LayoutLMv3ForTokenClassification.from_pretrained(
65
+ BASE_MODEL,
66
+ num_labels=len(label_list),
67
+ id2label=id2label,
68
+ label2id=label2id
69
+ ).to(device)
70
+
71
+ # 3. Cargar los pesos fine-tuneados desde el archivo .pth
72
+ if os.path.exists(STATE_DICT_PATH):
73
+ # Mapear a la CPU para asegurar la compatibilidad
74
+ state_dict = torch.load(STATE_DICT_PATH, map_location=device)
75
+
76
+ # Inyectar los pesos en el modelo
77
+ loaded_model.load_state_dict(state_dict)
78
+ print(f"Modelo fine-tuneado cargado exitosamente desde {STATE_DICT_PATH} en CPU.")
79
+ else:
80
+ print(f"Advertencia: No se encontró el archivo de pesos: {STATE_DICT_PATH}. Usando pesos iniciales del modelo base.")
81
+
82
+ except Exception as e:
83
+ print(f"Error fatal al cargar el modelo o procesador: {e}")
84
+ # En un entorno de producción, puedes optar por salir o cargar el modelo base como fallback.
85
+ # Por simplicidad, el código anterior se salta el fallback del modelo base,
86
+ # ya que la arquitectura base ya fue cargada, solo falló la inyección de pesos.
87
+
88
+ # Cargar el predictor OCR de DocTR
89
+ doctr_model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
90
+
91
+ # --- 2. Función Principal de Inferencia y Visualización ---
92
+
93
+ def process_invoice(image: Image.Image):
94
+ """
95
+ Realiza OCR con DocTR, NER con LayoutLMv3 y genera los resultados en tabla y imagen.
96
+ """
97
+ if image is None:
98
+ return None, "Por favor, carga una imagen de factura.", None, None
99
+
100
+ # 1. OCR con DocTR (obtener texto y bboxes)
101
+ try:
102
+ # 1. Asegurar el formato RGB
103
+ rgb_image = image.convert("RGB")
104
+
105
+ # 2. Guardar la imagen en un buffer de memoria como si fuera un archivo JPG
106
+ img_byte_arr = BytesIO()
107
+ # Nota: Asegúrate de que PIL pueda guardar como 'jpeg' o 'png'
108
+ rgb_image.save(img_byte_arr, format='JPEG')
109
+
110
+ # 3. Mover el puntero al inicio del buffer y obtener los bytes
111
+ img_byte_arr.seek(0)
112
+ image_bytes = img_byte_arr.read()
113
+
114
+ # 4. DocTR soporta la carga de una lista de bytes de imágenes
115
+ # NOTA: Usamos from_images y le pasamos los bytes de UNA imagen
116
+ doctr_doc = DocumentFile.from_images([image_bytes])
117
+
118
+ except Exception as e:
119
+ # Imprime el error completo en tu consola para depuración
120
+ print(f"Error detallado al cargar imagen en DocTR: {e}")
121
+ return None, f"Error al procesar la imagen con DocTR (conversión): {e}", None, None
122
+
123
+ doctr_result = doctr_model(doctr_doc)
124
+
125
+ if not doctr_result.pages:
126
+ return None, "DocTR no pudo extraer ninguna página de la imagen.", None, None
127
+
128
+ page = doctr_result.pages[0]
129
+
130
+ # Extraer texto, bboxes normalizados y fusionar a nivel de palabra
131
+ words_data = []
132
+ # La geometría de DocTR es [x_min, y_min] y [x_max, y_max] normalizada a [0, 1]
133
+ for block in page.blocks:
134
+ for line in block.lines:
135
+ for word in line.words:
136
+ text = word.value
137
+ # Coordenadas normalizadas a [0, 1000]
138
+ geom = np.array(word.geometry) * 1000
139
+ xmin, ymin = map(int, geom[0])
140
+ xmax, ymax = map(int, geom[1])
141
+ words_data.append({"text": text, "box": [xmin, ymin, xmax, ymax]})
142
+
143
+ words = [wd["text"] for wd in words_data]
144
+ boxes = [wd["box"] for wd in words_data]
145
+ image_width, image_height = image.size
146
+
147
+ # 2. Preprocesamiento para LayoutLMv3 (usando los resultados del OCR)
148
+ encoding = loaded_processor(
149
+ image,
150
+ words,
151
+ boxes=boxes,
152
+ max_length=512,
153
+ truncation=True,
154
+ padding="max_length",
155
+ return_tensors="pt"
156
+ )
157
+
158
+ # Mover los tensores de entrada a la CPU antes de la inferencia
159
+ input_ids = encoding["input_ids"].to(device)
160
+ attention_mask = encoding["attention_mask"].to(device)
161
+ bbox = encoding["bbox"].to(device)
162
+ pixel_values = encoding["pixel_values"].to(device) # LayoutLMv3 usa 'pixel_values'
163
+
164
+ # 3. Inferencia del Modelo LayoutLMv3
165
+ # Asegúrate de poner el modelo en modo de evaluación
166
+ loaded_model.eval()
167
+ with torch.no_grad():
168
+ outputs = loaded_model(
169
+ input_ids=input_ids,
170
+ attention_mask=attention_mask,
171
+ bbox=bbox,
172
+ pixel_values=pixel_values
173
+ )
174
+
175
+ predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()
176
+
177
+ # --- Mapeo Correcto de Predicciones a Palabras del OCR (Similar al Colab) ---
178
+ # Esto asegura que haya una predicción limpia por cada palabra extraída por DocTR.
179
+ word_ids = encoding.word_ids()
180
+ predictions_final = []
181
+ current_word_index = None
182
+
183
+ for idx, pred_id in enumerate(predictions):
184
+ word_idx = word_ids[idx]
185
+
186
+ # Solo procesar tokens que se mapean a palabras (no CLS, SEP, etc.)
187
+ if word_idx is not None:
188
+ # Solo tomar la predicción del primer sub-token de cada palabra
189
+ if word_idx != current_word_index:
190
+ if len(predictions_final) < len(words):
191
+ predictions_final.append(id2label[pred_id])
192
+
193
+ current_word_index = word_idx
194
+
195
+ # --- Fin del Mapeo ---
196
+
197
+
198
+ # 4. Agrupación de Resultados BIO (Recolecta todos los candidatos, incluidos duplicados)
199
+ # ner_candidates almacenará una lista de entidades para cada etiqueta raíz.
200
+ # Structure: {'ETIQUETA': [{'valor': '...', 'bbox_entity': [...]}, {...}]}
201
+ ner_candidates = {}
202
+
203
+ current_entity = []
204
+ current_label = None # Almacena la etiqueta raíz (ej. 'TOTAL')
205
+ current_bbox_group = []
206
+
207
+ # Función auxiliar para guardar la entidad actual
208
+ def save_current_entity(entity_list, label, bbox_list):
209
+ if not entity_list or not label:
210
+ return
211
+
212
+ # 1. Calcular el BBox final de la entidad (min/max de todos los bboxes de las palabras)
213
+ all_x = [b[0] for b in bbox_list] + [b[2] for b in bbox_list]
214
+ all_y = [b[1] for b in bbox_list] + [b[3] for b in bbox_list]
215
+ bbox_normalized = [min(all_x), min(all_y), max(all_x), max(all_y)]
216
+
217
+ # 2. Guardar en ner_candidates (permite duplicados)
218
+ if label not in ner_candidates:
219
+ ner_candidates[label] = []
220
+
221
+ ner_candidates[label].append({
222
+ 'valor': " ".join(entity_list),
223
+ 'bbox_entity': bbox_normalized
224
+ })
225
+
226
+ # Iterar sobre palabras y sus predicciones finales
227
+ for word_data, pred_label in zip(words_data, predictions_final):
228
+ word_text = word_data["text"]
229
+ word_box = word_data["box"]
230
+ tag_parts = pred_label.split('-', 1)
231
+ tag_type = tag_parts[0]
232
+ root_label = tag_parts[1] if len(tag_parts) > 1 else None
233
+
234
+ if tag_type == 'B':
235
+ # 1. Si hay una entidad previa, guardarla.
236
+ save_current_entity(current_entity, current_label, current_bbox_group)
237
+
238
+ # 2. Iniciar la nueva entidad.
239
+ current_label = root_label
240
+ current_entity = [word_text]
241
+ current_bbox_group = [word_box]
242
+
243
+ elif tag_type == 'I':
244
+ # Continuar solo si el I- tag corresponde a la entidad B- tag actual
245
+ if current_label == root_label:
246
+ current_entity.append(word_text)
247
+ current_bbox_group.append(word_box)
248
+ else:
249
+ # Si no coincide (error BIO), guardar la entidad previa (si existe) y
250
+ # tratar el I- tag desalineado como el inicio de una nueva entidad.
251
+ save_current_entity(current_entity, current_label, current_bbox_group)
252
+
253
+ current_label = root_label
254
+ current_entity = [word_text]
255
+ current_bbox_group = [word_box]
256
+
257
+ elif tag_type == 'O':
258
+ # Si se encuentra 'O', finalizar la entidad actual si existe.
259
+ save_current_entity(current_entity, current_label, current_bbox_group)
260
+
261
+ # Resetear
262
+ current_entity = []
263
+ current_label = None
264
+ current_bbox_group = []
265
+
266
+ # Añadir la última entidad después del bucle
267
+ save_current_entity(current_entity, current_label, current_bbox_group)
268
+
269
+
270
+ # --- 5: DESDUPLICACIÓN (Seleccionar el valor más largo) ---
271
+ final_ner_results = []
272
+
273
+ for label, candidates in ner_candidates.items():
274
+ if not candidates:
275
+ continue
276
+
277
+ # Ordenar por longitud de la cadena de valor (mayor a menor)
278
+ sorted_candidates = sorted(candidates, key=lambda x: len(x['valor']), reverse=True)
279
+
280
+ # El mejor candidato es el primero (el más largo)
281
+ best_candidate = sorted_candidates[0]
282
+
283
+ # Agregar al resultado final (ya desduplicado)
284
+ final_ner_results.append({
285
+ 'etiqueta': label,
286
+ 'valor': best_candidate['valor'],
287
+ 'bbox_entity': best_candidate['bbox_entity']
288
+ })
289
+
290
+
291
+ # Preparar tabla de resultados (Usando final_ner_results)
292
+ table_data = [[res['etiqueta'], res['valor']] for res in final_ner_results]
293
+
294
+ # 6. Dibujar Bounding Boxes en la Imagen (para visualización)
295
+ annotated_image = image.copy()
296
+ draw = ImageDraw.Draw(annotated_image)
297
+
298
+ try:
299
+ font = ImageFont.truetype("arial.ttf", 20)
300
+ except IOError:
301
+ font = ImageFont.load_default()
302
+
303
+ for res in final_ner_results: # Usar final_ner_results
304
+ label = res['etiqueta']
305
+ min_x_norm, min_y_norm, max_x_norm, max_y_norm = res['bbox_entity']
306
+
307
+ # Desnormalizar el bbox [0-1000] a píxeles
308
+ min_x = int(min_x_norm * image_width / 1000)
309
+ min_y = int(min_y_norm * image_height / 1000)
310
+ max_x = int(max_x_norm * image_width / 1000)
311
+ max_y = int(max_y_norm * image_height / 1000)
312
+
313
+ color = label2color.get(label, 'yellow')
314
+
315
+ draw.rectangle([min_x, min_y, max_x, max_y], outline=color, width=3)
316
+ draw.text((min_x, min_y - 20), label, fill=color, font=font)
317
+
318
+
319
+ # 7. Devolver resultados
320
+ return annotated_image, "Extracción de Entidades Nombradas completada.", table_data, [
321
+ {'etiqueta': r['etiqueta'], 'valor': r['valor'], 'bbox_entity': r['bbox_entity']}
322
+ for r in final_ner_results # Usar final_ner_results
323
+ ]
324
+
325
+ # --- 3. Interfaz Gradio ---
326
+
327
+ # Elementos de entrada y salida
328
+ image_input = gr.Image(type="pil", label="Cargar Imagen de Factura", interactive=True)
329
+ image_output = gr.Image(type="pil", label="Factura con Entidades Resaltadas")
330
+ status_output = gr.Textbox(label="Estado", value="Carga una imagen y haz clic en 'Procesar'")
331
+ table_output = gr.Dataframe(
332
+ headers=["Etiqueta", "Valor"],
333
+ label="Resultados de NER",
334
+ interactive=False,
335
+ col_count=(2, "fixed")
336
+ )
337
+ json_output = gr.JSON(label="Datos JSON Crudos (Incluye BBox Normalizados)", visible=True)
338
+
339
+ # Interfaz
340
+ with gr.Blocks(title="NER de Facturas Argentinas con LayoutLMv3 y DocTR") as demo:
341
+ gr.Markdown(
342
+ f"""
343
+ # 🇦🇷 Extracción de Datos de Facturas Argentinas (LayoutLMv3 + DocTR)
344
+ Carga una imagen de factura para realizar OCR (DocTR) y Reconocimiento de Entidades Nombradas (NER)
345
+ con un modelo **LayoutLMv3 fine-tuneado** cargado desde **`{STATE_DICT_PATH}`**, forzando la **ejecución en CPU**.
346
+ """
347
+ )
348
+
349
+ with gr.Row():
350
+ with gr.Column(scale=1):
351
+ image_input.render()
352
+ process_button = gr.Button("🚀 Procesar Factura", variant="primary")
353
+ status_output.render()
354
+ with gr.Column(scale=2):
355
+ image_output.render()
356
+ table_output.render()
357
+ json_output.render()
358
+
359
+ process_button.click(
360
+ fn=process_invoice,
361
+ inputs=[image_input],
362
+ outputs=[image_output, status_output, table_output, json_output]
363
+ )
364
+
365
+ # Lanzar la aplicación
366
+ demo.launch()
layoutlmv3_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:807818c88ce85767b337f03ce6ca7fd89ea14ce559c2981ea404cafc13557025
3
+ size 503825075
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Requerimientos del Frameworks y Utilidades ---
2
+ gradio>=4.0.0 # Interfaz de usuario
3
+ pillow # Manipulación de imágenes (PIL)
4
+ numpy # Operaciones numéricas
5
+
6
+ # --- Requerimientos de OCR (DocTR) y NER (Transformers) ---
7
+ # Usamos una versión más moderna de DocTR para asegurar compatibilidad
8
+ python-doctr[viz,html]>=1.0.0 # Librería DocTR (incluye dependencias de CPU como Pillow)
9
+ transformers>=4.30.0 # Librería principal para LayoutLMv3
10
+ torch
11
+ matplotlib # Añadir esta línea
12
+
13
+ # --- Requerimientos de PyTorch ---
14
+ # El archivo .pth requiere torch. Si lo instalas manualmente, puedes omitirlo.
15
+ # Si quieres que pip lo instale (incluso la versión CPU), descomenta:
16
+ # torch>=2.0.0
17
+
18
+ # --- Requerimientos CRÍTICOS (ya incluidos o buena práctica) ---
19
+ # protobuf se maneja internamente en transformers/torch.
20
+ # Solo añadir si hay problemas específicos.