Spaces:
Running
Running
File size: 6,370 Bytes
13464bf 8eaaeee 21abb82 8eaaeee 21abb82 8eaaeee b4746b6 21abb82 8eaaeee b4746b6 8eaaeee b4746b6 8eaaeee 13464bf 8eaaeee 13464bf 677a506 13464bf 8eaaeee 13464bf 677a506 8eaaeee b4746b6 8eaaeee 677a506 b4746b6 13464bf 8eaaeee 677a506 8eaaeee 13464bf 8eaaeee 677a506 8eaaeee 677a506 8eaaeee 677a506 8eaaeee 677a506 8eaaeee 677a506 13464bf 677a506 13464bf 8eaaeee 13464bf 8eaaeee 677a506 8eaaeee 677a506 8eaaeee 677a506 8eaaeee 677a506 8eaaeee 677a506 8eaaeee 677a506 387e9f1 677a506 8eaaeee 677a506 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import torch
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
# Ensure local detree package is importable
# This allows the script to find the 'detree' package if it sits in the same directory
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
try:
from detree.inference import Detector
except ImportError as e:
# Fallback if detree is not found (e.g. during initial setup check)
print(f"Warning: 'detree' package not found. Error: {e}")
Detector = None
# ββ 1) Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
REPO_ID = "MAS-AI-0000/Authentica"
TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
EMBEDDING_FILE = "priori1_center10k.pt"
MAX_LEN = 512
MODEL_DIR = None
try:
# download a local snapshot of just the Text folder and point MODEL_DIR at it
print(f"Downloading/Checking model from {REPO_ID}...")
_snapshot_dir = snapshot_download(
repo_id=REPO_ID,
allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
)
MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
print(f"Model directory set to: {MODEL_DIR}")
except Exception as e:
print(f"Error downloading model from Hugging Face: {e}")
# ββ 2) Load model & tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββ
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Text prediction device: {device}")
detector = None
try:
if Detector:
database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE)
if not os.path.exists(MODEL_DIR):
print(f"Warning: Model directory not found at {MODEL_DIR}")
if not os.path.exists(database_path):
print(f"Warning: Embedding file not found at {database_path}")
# Initialize DETree Detector
# This loads the model from MODEL_DIR and the embeddings from database_path
detector = Detector(
database_path=database_path,
model_name_or_path=MODEL_DIR,
device=device,
max_length=MAX_LEN,
pooling="max" # Default pooling
)
print(f"Text classification model (DETree) loaded successfully")
else:
print("DETree detector could not be initialized due to missing package.")
except Exception as e:
print(f"Error loading text model: {e}")
print("Text prediction will return fallback responses")
# ββ 3) Inference function ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_text(text: str, max_length: int = None):
"""
Predict whether the given text is human-written or AI-generated using DETree.
Args:
text (str): The text to classify
max_length (int): Ignored in this implementation as DETree handles it globally,
but kept for compatibility.
Returns:
dict: Contains predicted_class and confidence
"""
if detector is None:
return {
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
}
try:
# detector.predict expects a list of strings
predictions = detector.predict([text])
print(f"DETree prediction output: {predictions}")
if not predictions:
return {
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
}
pred = predictions[0]
# Determine predicted_class based on higher confidence
predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human"
return {
"predicted_class": predicted_class,
"confidence_ai": float(pred.probability_ai),
"confidence_human": float(pred.probability_human)
}
except Exception as e:
print(f"Error during text prediction: {e}")
return {
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
}
# ββ 4) Batch prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_batch(texts, batch_size=16):
"""
Predict multiple texts in batches.
Args:
texts (list): List of text strings to classify
batch_size (int): Batch size for processing
if detector is None:
return [{
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
} for _ in texts]
list: List of prediction dictionaries
"""
if detector is None:
return [{"predicted_class": "Human", "confidence": 0} for _ in texts]
# Temporarily update batch size if needed, or just use the detector's default
# We'll update it to respect the argument
original_batch_size = detector.batch_size
detector.batch_size = batch_size
try:
predictions = detector.predict(texts)
results = []
for text, pred in zip(texts, predictions):
# Determine predicted_class based on higher confidence
predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human"
results.append({
"text": text,
"predicted_class": predicted_class,
"confidence_ai": float(pred.probability_ai),
"confidence_human": float(pred.probability_human)
})
return results
except Exception as e:
print(f"Error during batch prediction: {e}")
return [{
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
} for _ in texts]
finally:
detector.batch_size = original_batch_size |