# invoice_processor.py """ Procesamiento de facturas: OCR, NER y visualización """ import numpy as np from PIL import Image, ImageDraw, ImageFont import torch from doctr.io import DocumentFile from io import BytesIO from config import LABEL2COLOR, MAX_LENGTH, NORMALIZATION_FACTOR from validator import InvoiceValidator class InvoiceProcessor: """Clase para procesar facturas y extraer entidades.""" def __init__(self, model_manager): """ Inicializa el procesador de facturas. Args: model_manager: Instancia de ModelManager con los modelos cargados """ self.model_manager = model_manager self.processor = model_manager.get_processor() self.model = model_manager.get_model() self.ocr_model = model_manager.get_ocr_model() self.device = model_manager.get_device() self.validator = InvoiceValidator() # ✅ AGREGADO def extract_ocr_data(self, image: Image.Image): """ Extrae texto y bounding boxes usando DocTR. Args: image: Imagen PIL de la factura Returns: tuple: (words_data, image_width, image_height) o (None, None, None) en caso de error """ try: rgb_image = image.convert("RGB") img_byte_arr = BytesIO() rgb_image.save(img_byte_arr, format='JPEG') img_byte_arr.seek(0) image_bytes = img_byte_arr.read() doctr_doc = DocumentFile.from_images([image_bytes]) doctr_result = self.ocr_model(doctr_doc) if not doctr_result.pages: return None, None, None page = doctr_result.pages[0] words_data = [] for block in page.blocks: for line in block.lines: for word in line.words: text = word.value geom = np.array(word.geometry) * NORMALIZATION_FACTOR xmin, ymin = map(int, geom[0]) xmax, ymax = map(int, geom[1]) words_data.append({"text": text, "box": [xmin, ymin, xmax, ymax]}) image_width, image_height = image.size return words_data, image_width, image_height except Exception as e: print(f"Error en OCR: {e}") return None, None, None def perform_ner(self, image: Image.Image, words_data: list): """ Realiza NER sobre las palabras extraídas. Args: image: Imagen PIL words_data: Lista de diccionarios con 'text' y 'box' Returns: list: Predicciones para cada palabra """ words = [wd["text"] for wd in words_data] boxes = [wd["box"] for wd in words_data] # Preprocesamiento encoding = self.processor( image, words, boxes=boxes, max_length=MAX_LENGTH, truncation=True, padding="max_length", return_tensors="pt" ) input_ids = encoding["input_ids"].to(self.device) attention_mask = encoding["attention_mask"].to(self.device) bbox = encoding["bbox"].to(self.device) pixel_values = encoding["pixel_values"].to(self.device) # Inferencia self.model.eval() with torch.no_grad(): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values ) predictions = outputs.logits.argmax(dim=-1).squeeze().tolist() # Mapeo de predicciones a palabras word_ids = encoding.word_ids() predictions_final = [] current_word_index = None for idx, pred_id in enumerate(predictions): word_idx = word_ids[idx] if word_idx is not None: if word_idx != current_word_index: if len(predictions_final) < len(words): predictions_final.append(self.model.config.id2label[pred_id]) current_word_index = word_idx return predictions_final def group_entities(self, words_data: list, predictions: list): """ Agrupa entidades usando el esquema BIO y desduplicación. Args: words_data: Lista de palabras con sus bboxes predictions: Predicciones NER para cada palabra Returns: list: Lista de entidades finales con etiqueta, valor y bbox """ ner_candidates = {} current_entity = [] current_label = None current_bbox_group = [] def save_current_entity(entity_list, label, bbox_list): if not entity_list or not label: return all_x = [b[0] for b in bbox_list] + [b[2] for b in bbox_list] all_y = [b[1] for b in bbox_list] + [b[3] for b in bbox_list] bbox_normalized = [min(all_x), min(all_y), max(all_x), max(all_y)] if label not in ner_candidates: ner_candidates[label] = [] ner_candidates[label].append({ 'valor': " ".join(entity_list), 'bbox_entity': bbox_normalized }) for word_data, pred_label in zip(words_data, predictions): word_text = word_data["text"] word_box = word_data["box"] tag_parts = pred_label.split('-', 1) tag_type = tag_parts[0] root_label = tag_parts[1] if len(tag_parts) > 1 else None if tag_type == 'B': save_current_entity(current_entity, current_label, current_bbox_group) current_label = root_label current_entity = [word_text] current_bbox_group = [word_box] elif tag_type == 'I': if current_label == root_label: current_entity.append(word_text) current_bbox_group.append(word_box) else: save_current_entity(current_entity, current_label, current_bbox_group) current_label = root_label current_entity = [word_text] current_bbox_group = [word_box] elif tag_type == 'O': save_current_entity(current_entity, current_label, current_bbox_group) current_entity = [] current_label = None current_bbox_group = [] save_current_entity(current_entity, current_label, current_bbox_group) # Desduplicación: seleccionar el valor más largo final_ner_results = [] for label, candidates in ner_candidates.items(): if not candidates: continue sorted_candidates = sorted(candidates, key=lambda x: len(x['valor']), reverse=True) best_candidate = sorted_candidates[0] final_ner_results.append({ 'etiqueta': label, 'valor': best_candidate['valor'], 'bbox_entity': best_candidate['bbox_entity'] }) return final_ner_results def draw_annotations(self, image: Image.Image, entities: list): """ Dibuja bounding boxes y etiquetas en la imagen. Args: image: Imagen PIL original entities: Lista de entidades con bbox Returns: Image: Imagen anotada """ annotated_image = image.copy() draw = ImageDraw.Draw(annotated_image) image_width, image_height = image.size try: font = ImageFont.truetype("arial.ttf", 20) except IOError: font = ImageFont.load_default() for entity in entities: label = entity['etiqueta'] min_x_norm, min_y_norm, max_x_norm, max_y_norm = entity['bbox_entity'] # Desnormalizar coordenadas min_x = int(min_x_norm * image_width / NORMALIZATION_FACTOR) min_y = int(min_y_norm * image_height / NORMALIZATION_FACTOR) max_x = int(max_x_norm * image_width / NORMALIZATION_FACTOR) max_y = int(max_y_norm * image_height / NORMALIZATION_FACTOR) color = LABEL2COLOR.get(label, 'yellow') draw.rectangle([min_x, min_y, max_x, max_y], outline=color, width=3) draw.text((min_x, min_y - 20), label, fill=color, font=font) return annotated_image def process_invoice(self, image: Image.Image, filename: str): """ Procesa una factura completa: OCR + NER + visualización + validación. Args: image: Imagen PIL de la factura filename: Nombre del archivo Returns: tuple: (filename, annotated_image, table_data, json_data) """ # 1. OCR words_data, image_width, image_height = self.extract_ocr_data(image) if words_data is None: return filename, None, [["ERROR", "No se pudo realizar OCR"]], [] if not words_data: return filename, None, [["ERROR", "No se encontró texto en la imagen"]], [] # Extraer lista de palabras para el validador ocr_words = [wd["text"] for wd in words_data] # 2. NER try: predictions = self.perform_ner(image, words_data) except Exception as e: return filename, None, [["ERROR", f"Error en NER: {e}"]], [] # 3. Agrupar entidades entities = self.group_entities(words_data, predictions) # 4. VALIDAR Y CORREGIR ENTIDADES validated_table, validation_errors = self.validator.validate_and_correct(entities, ocr_words) # 5. Dibujar anotaciones (solo las entidades detectadas originalmente) annotated_image = self.draw_annotations(image, entities) # 6. Preparar resultados # validated_table ya viene como [etiqueta, valor] (sin columna de validación) json_data = [ { 'etiqueta': row[0], 'valor': row[1] } for row in validated_table ] return filename, annotated_image, validated_table, json_data