|
|
|
|
|
""" |
|
|
Carga y gestión de modelos (LayoutLMv3 y DocTR) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoProcessor, LayoutLMv3ForTokenClassification |
|
|
from doctr.models import ocr_predictor |
|
|
import warnings |
|
|
from config import HUGGINGFACE_MODEL, ID2LABEL, LABEL2ID |
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
"""Clase para gestionar la carga y acceso a los modelos.""" |
|
|
|
|
|
def __init__(self, force_cpu=True): |
|
|
""" |
|
|
Inicializa y carga los modelos necesarios. |
|
|
|
|
|
Args: |
|
|
force_cpu (bool): Si True, fuerza el uso de CPU para inferencia |
|
|
""" |
|
|
self.device = torch.device("cpu" if force_cpu else "cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Inferencia en dispositivo: {self.device}") |
|
|
|
|
|
|
|
|
self.processor, self.model = self._load_layoutlmv3() |
|
|
|
|
|
|
|
|
self.ocr_model = self._load_doctr() |
|
|
|
|
|
def _load_layoutlmv3(self): |
|
|
"""Carga el modelo LayoutLMv3 y su procesador.""" |
|
|
try: |
|
|
processor = AutoProcessor.from_pretrained(HUGGINGFACE_MODEL, apply_ocr=False) |
|
|
model = LayoutLMv3ForTokenClassification.from_pretrained(HUGGINGFACE_MODEL).to(self.device) |
|
|
model.config.id2label = ID2LABEL |
|
|
model.config.label2id = LABEL2ID |
|
|
print(f"✓ Modelo LayoutLMv3 cargado: {HUGGINGFACE_MODEL}") |
|
|
return processor, model |
|
|
except Exception as e: |
|
|
print(f"✗ Error al cargar LayoutLMv3: {e}") |
|
|
raise |
|
|
|
|
|
def _load_doctr(self): |
|
|
"""Carga el modelo OCR de DocTR.""" |
|
|
try: |
|
|
ocr_model = ocr_predictor( |
|
|
det_arch='db_resnet50', |
|
|
reco_arch='crnn_vgg16_bn', |
|
|
pretrained=True |
|
|
) |
|
|
print("✓ Modelo DocTR cargado") |
|
|
return ocr_model |
|
|
except Exception as e: |
|
|
print(f"✗ Error al cargar DocTR: {e}") |
|
|
raise |
|
|
|
|
|
def get_processor(self): |
|
|
"""Retorna el procesador de LayoutLMv3.""" |
|
|
return self.processor |
|
|
|
|
|
def get_model(self): |
|
|
"""Retorna el modelo de LayoutLMv3.""" |
|
|
return self.model |
|
|
|
|
|
def get_ocr_model(self): |
|
|
"""Retorna el modelo OCR de DocTR.""" |
|
|
return self.ocr_model |
|
|
|
|
|
def get_device(self): |
|
|
"""Retorna el dispositivo utilizado.""" |
|
|
return self.device |