File size: 2,431 Bytes
809b92e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# model_loader.py
"""
Carga y gestión de modelos (LayoutLMv3 y DocTR)
"""

import torch
from transformers import AutoProcessor, LayoutLMv3ForTokenClassification
from doctr.models import ocr_predictor
import warnings
from config import HUGGINGFACE_MODEL, ID2LABEL, LABEL2ID

warnings.filterwarnings('ignore')


class ModelManager:
    """Clase para gestionar la carga y acceso a los modelos."""
    
    def __init__(self, force_cpu=True):
        """
        Inicializa y carga los modelos necesarios.
        
        Args:
            force_cpu (bool): Si True, fuerza el uso de CPU para inferencia
        """
        self.device = torch.device("cpu" if force_cpu else "cuda" if torch.cuda.is_available() else "cpu")
        print(f"Inferencia en dispositivo: {self.device}")
        
        # Cargar LayoutLMv3
        self.processor, self.model = self._load_layoutlmv3()
        
        # Cargar DocTR
        self.ocr_model = self._load_doctr()
    
    def _load_layoutlmv3(self):
        """Carga el modelo LayoutLMv3 y su procesador."""
        try:
            processor = AutoProcessor.from_pretrained(HUGGINGFACE_MODEL, apply_ocr=False)
            model = LayoutLMv3ForTokenClassification.from_pretrained(HUGGINGFACE_MODEL).to(self.device)
            model.config.id2label = ID2LABEL
            model.config.label2id = LABEL2ID
            print(f"✓ Modelo LayoutLMv3 cargado: {HUGGINGFACE_MODEL}")
            return processor, model
        except Exception as e:
            print(f"✗ Error al cargar LayoutLMv3: {e}")
            raise
    
    def _load_doctr(self):
        """Carga el modelo OCR de DocTR."""
        try:
            ocr_model = ocr_predictor(
                det_arch='db_resnet50', 
                reco_arch='crnn_vgg16_bn', 
                pretrained=True
            )
            print("✓ Modelo DocTR cargado")
            return ocr_model
        except Exception as e:
            print(f"✗ Error al cargar DocTR: {e}")
            raise
    
    def get_processor(self):
        """Retorna el procesador de LayoutLMv3."""
        return self.processor
    
    def get_model(self):
        """Retorna el modelo de LayoutLMv3."""
        return self.model
    
    def get_ocr_model(self):
        """Retorna el modelo OCR de DocTR."""
        return self.ocr_model
    
    def get_device(self):
        """Retorna el dispositivo utilizado."""
        return self.device