# model_loader.py """ Carga y gestión de modelos (LayoutLMv3 y DocTR) """ import torch from transformers import AutoProcessor, LayoutLMv3ForTokenClassification from doctr.models import ocr_predictor import warnings from config import HUGGINGFACE_MODEL, ID2LABEL, LABEL2ID warnings.filterwarnings('ignore') class ModelManager: """Clase para gestionar la carga y acceso a los modelos.""" def __init__(self, force_cpu=True): """ Inicializa y carga los modelos necesarios. Args: force_cpu (bool): Si True, fuerza el uso de CPU para inferencia """ self.device = torch.device("cpu" if force_cpu else "cuda" if torch.cuda.is_available() else "cpu") print(f"Inferencia en dispositivo: {self.device}") # Cargar LayoutLMv3 self.processor, self.model = self._load_layoutlmv3() # Cargar DocTR self.ocr_model = self._load_doctr() def _load_layoutlmv3(self): """Carga el modelo LayoutLMv3 y su procesador.""" try: processor = AutoProcessor.from_pretrained(HUGGINGFACE_MODEL, apply_ocr=False) model = LayoutLMv3ForTokenClassification.from_pretrained(HUGGINGFACE_MODEL).to(self.device) model.config.id2label = ID2LABEL model.config.label2id = LABEL2ID print(f"✓ Modelo LayoutLMv3 cargado: {HUGGINGFACE_MODEL}") return processor, model except Exception as e: print(f"✗ Error al cargar LayoutLMv3: {e}") raise def _load_doctr(self): """Carga el modelo OCR de DocTR.""" try: ocr_model = ocr_predictor( det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True ) print("✓ Modelo DocTR cargado") return ocr_model except Exception as e: print(f"✗ Error al cargar DocTR: {e}") raise def get_processor(self): """Retorna el procesador de LayoutLMv3.""" return self.processor def get_model(self): """Retorna el modelo de LayoutLMv3.""" return self.model def get_ocr_model(self): """Retorna el modelo OCR de DocTR.""" return self.ocr_model def get_device(self): """Retorna el dispositivo utilizado.""" return self.device