File size: 6,370 Bytes
13464bf
8eaaeee
 
 
 
 
 
 
 
 
 
 
 
 
21abb82
8eaaeee
21abb82
8eaaeee
b4746b6
21abb82
8eaaeee
b4746b6
 
8eaaeee
 
b4746b6
8eaaeee
13464bf
8eaaeee
 
 
 
 
 
 
 
 
 
 
13464bf
677a506
13464bf
 
 
 
8eaaeee
13464bf
 
677a506
8eaaeee
 
 
 
 
 
b4746b6
8eaaeee
 
 
 
 
 
 
 
 
 
 
677a506
b4746b6
13464bf
 
 
 
8eaaeee
 
 
 
 
 
 
 
 
 
 
 
 
 
677a506
 
 
 
 
8eaaeee
13464bf
8eaaeee
 
677a506
 
 
8eaaeee
677a506
 
 
 
 
8eaaeee
 
677a506
 
 
8eaaeee
 
677a506
 
 
8eaaeee
677a506
13464bf
 
677a506
 
 
 
 
13464bf
8eaaeee
13464bf
8eaaeee
 
 
 
 
 
677a506
 
 
 
 
 
8eaaeee
 
 
677a506
8eaaeee
 
 
 
 
 
 
 
 
 
677a506
 
8eaaeee
 
 
677a506
 
 
8eaaeee
 
677a506
8eaaeee
 
677a506
 
387e9f1
 
677a506
8eaaeee
677a506
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import torch
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download

# Ensure local detree package is importable
# This allows the script to find the 'detree' package if it sits in the same directory
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.append(current_dir)

try:
    from detree.inference import Detector
except ImportError as e:
    # Fallback if detree is not found (e.g. during initial setup check)
    print(f"Warning: 'detree' package not found. Error: {e}")
    Detector = None


# ── 1) Configuration ────────────────────────────────────────────────────────────
REPO_ID = "MAS-AI-0000/Authentica"
TEXT_SUBFOLDER = "Lib/Models/Text"   # where config.json/model.safetensors live in the repo
EMBEDDING_FILE = "priori1_center10k.pt" 
MAX_LEN = 512

MODEL_DIR = None

try:
    # download a local snapshot of just the Text folder and point MODEL_DIR at it
    print(f"Downloading/Checking model from {REPO_ID}...")
    _snapshot_dir = snapshot_download(
        repo_id=REPO_ID,
        allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
    )
    MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
    print(f"Model directory set to: {MODEL_DIR}")
except Exception as e:
    print(f"Error downloading model from Hugging Face: {e}")


# ── 2) Load model & tokenizer ──────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Text prediction device: {device}")

detector = None

try:
    if Detector:
        database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE)
        
        if not os.path.exists(MODEL_DIR):
            print(f"Warning: Model directory not found at {MODEL_DIR}")
        if not os.path.exists(database_path):
            print(f"Warning: Embedding file not found at {database_path}")

        # Initialize DETree Detector
        # This loads the model from MODEL_DIR and the embeddings from database_path
        detector = Detector(
            database_path=database_path,
            model_name_or_path=MODEL_DIR,
            device=device,
            max_length=MAX_LEN,
            pooling="max" # Default pooling
        )
        print(f"Text classification model (DETree) loaded successfully")
    else:
        print("DETree detector could not be initialized due to missing package.")

except Exception as e:
    print(f"Error loading text model: {e}")
    print("Text prediction will return fallback responses")

# ── 3) Inference function ──────────────────────────────────────────────────────
def predict_text(text: str, max_length: int = None):
    """
    Predict whether the given text is human-written or AI-generated using DETree.
    
    Args:
        text (str): The text to classify
        max_length (int): Ignored in this implementation as DETree handles it globally, 
                          but kept for compatibility.
        
    Returns:
        dict: Contains predicted_class and confidence
    """
    if detector is None:
        return {
            "predicted_class": "Human",
            "confidence_ai": -100.0,
            "confidence_human": -100.0
        }
    
    try:
        # detector.predict expects a list of strings
        predictions = detector.predict([text])

        print(f"DETree prediction output: {predictions}")

        if not predictions:
             return {
                "predicted_class": "Human",
                "confidence_ai": -100.0,
                "confidence_human": -100.0
            }
        
        pred = predictions[0]

        # Determine predicted_class based on higher confidence
        predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human"
        
        return {
            "predicted_class": predicted_class,
            "confidence_ai": float(pred.probability_ai),
            "confidence_human": float(pred.probability_human)
        }
        
    except Exception as e:
        print(f"Error during text prediction: {e}")
        return {
            "predicted_class": "Human",
            "confidence_ai": -100.0,
            "confidence_human": -100.0
        }

# ── 4) Batch prediction ────────────────────────────────────────────────────────
def predict_batch(texts, batch_size=16):
    """
    Predict multiple texts in batches.
    
    Args:
        texts (list): List of text strings to classify
        batch_size (int): Batch size for processing
    if detector is None:
        return [{
            "predicted_class": "Human",
            "confidence_ai": -100.0,
            "confidence_human": -100.0
        } for _ in texts]
        list: List of prediction dictionaries
    """
    if detector is None:
        return [{"predicted_class": "Human", "confidence": 0} for _ in texts]
    
    # Temporarily update batch size if needed, or just use the detector's default
    # We'll update it to respect the argument
    original_batch_size = detector.batch_size
    detector.batch_size = batch_size
    
    try:
        predictions = detector.predict(texts)
        results = []
        for text, pred in zip(texts, predictions):
            # Determine predicted_class based on higher confidence
            predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human"
            
            results.append({
                "text": text,
                "predicted_class": predicted_class,
                "confidence_ai": float(pred.probability_ai),
                "confidence_human": float(pred.probability_human)
            })
        return results
    
    except Exception as e:
        print(f"Error during batch prediction: {e}")
        return [{
            "predicted_class": "Human",
            "confidence_ai": -100.0,
            "confidence_human": -100.0
        } for _ in texts]
    finally:
        detector.batch_size = original_batch_size