Lucas Gagneten
Interfaz mejorada
809b92e
# model_loader.py
"""
Carga y gestión de modelos (LayoutLMv3 y DocTR)
"""
import torch
from transformers import AutoProcessor, LayoutLMv3ForTokenClassification
from doctr.models import ocr_predictor
import warnings
from config import HUGGINGFACE_MODEL, ID2LABEL, LABEL2ID
warnings.filterwarnings('ignore')
class ModelManager:
"""Clase para gestionar la carga y acceso a los modelos."""
def __init__(self, force_cpu=True):
"""
Inicializa y carga los modelos necesarios.
Args:
force_cpu (bool): Si True, fuerza el uso de CPU para inferencia
"""
self.device = torch.device("cpu" if force_cpu else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Inferencia en dispositivo: {self.device}")
# Cargar LayoutLMv3
self.processor, self.model = self._load_layoutlmv3()
# Cargar DocTR
self.ocr_model = self._load_doctr()
def _load_layoutlmv3(self):
"""Carga el modelo LayoutLMv3 y su procesador."""
try:
processor = AutoProcessor.from_pretrained(HUGGINGFACE_MODEL, apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained(HUGGINGFACE_MODEL).to(self.device)
model.config.id2label = ID2LABEL
model.config.label2id = LABEL2ID
print(f"✓ Modelo LayoutLMv3 cargado: {HUGGINGFACE_MODEL}")
return processor, model
except Exception as e:
print(f"✗ Error al cargar LayoutLMv3: {e}")
raise
def _load_doctr(self):
"""Carga el modelo OCR de DocTR."""
try:
ocr_model = ocr_predictor(
det_arch='db_resnet50',
reco_arch='crnn_vgg16_bn',
pretrained=True
)
print("✓ Modelo DocTR cargado")
return ocr_model
except Exception as e:
print(f"✗ Error al cargar DocTR: {e}")
raise
def get_processor(self):
"""Retorna el procesador de LayoutLMv3."""
return self.processor
def get_model(self):
"""Retorna el modelo de LayoutLMv3."""
return self.model
def get_ocr_model(self):
"""Retorna el modelo OCR de DocTR."""
return self.ocr_model
def get_device(self):
"""Retorna el dispositivo utilizado."""
return self.device