Authentica / textPreprocess.py
MAS-AI-0000's picture
Update textPreprocess.py
387e9f1 verified
import torch
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
# Ensure local detree package is importable
# This allows the script to find the 'detree' package if it sits in the same directory
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.append(current_dir)
try:
from detree.inference import Detector
except ImportError as e:
# Fallback if detree is not found (e.g. during initial setup check)
print(f"Warning: 'detree' package not found. Error: {e}")
Detector = None
# ── 1) Configuration ────────────────────────────────────────────────────────────
REPO_ID = "MAS-AI-0000/Authentica"
TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
EMBEDDING_FILE = "priori1_center10k.pt"
MAX_LEN = 512
MODEL_DIR = None
try:
# download a local snapshot of just the Text folder and point MODEL_DIR at it
print(f"Downloading/Checking model from {REPO_ID}...")
_snapshot_dir = snapshot_download(
repo_id=REPO_ID,
allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
)
MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
print(f"Model directory set to: {MODEL_DIR}")
except Exception as e:
print(f"Error downloading model from Hugging Face: {e}")
# ── 2) Load model & tokenizer ──────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Text prediction device: {device}")
detector = None
try:
if Detector:
database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE)
if not os.path.exists(MODEL_DIR):
print(f"Warning: Model directory not found at {MODEL_DIR}")
if not os.path.exists(database_path):
print(f"Warning: Embedding file not found at {database_path}")
# Initialize DETree Detector
# This loads the model from MODEL_DIR and the embeddings from database_path
detector = Detector(
database_path=database_path,
model_name_or_path=MODEL_DIR,
device=device,
max_length=MAX_LEN,
pooling="max" # Default pooling
)
print(f"Text classification model (DETree) loaded successfully")
else:
print("DETree detector could not be initialized due to missing package.")
except Exception as e:
print(f"Error loading text model: {e}")
print("Text prediction will return fallback responses")
# ── 3) Inference function ──────────────────────────────────────────────────────
def predict_text(text: str, max_length: int = None):
"""
Predict whether the given text is human-written or AI-generated using DETree.
Args:
text (str): The text to classify
max_length (int): Ignored in this implementation as DETree handles it globally,
but kept for compatibility.
Returns:
dict: Contains predicted_class and confidence
"""
if detector is None:
return {
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
}
try:
# detector.predict expects a list of strings
predictions = detector.predict([text])
print(f"DETree prediction output: {predictions}")
if not predictions:
return {
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
}
pred = predictions[0]
# Determine predicted_class based on higher confidence
predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human"
return {
"predicted_class": predicted_class,
"confidence_ai": float(pred.probability_ai),
"confidence_human": float(pred.probability_human)
}
except Exception as e:
print(f"Error during text prediction: {e}")
return {
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
}
# ── 4) Batch prediction ────────────────────────────────────────────────────────
def predict_batch(texts, batch_size=16):
"""
Predict multiple texts in batches.
Args:
texts (list): List of text strings to classify
batch_size (int): Batch size for processing
if detector is None:
return [{
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
} for _ in texts]
list: List of prediction dictionaries
"""
if detector is None:
return [{"predicted_class": "Human", "confidence": 0} for _ in texts]
# Temporarily update batch size if needed, or just use the detector's default
# We'll update it to respect the argument
original_batch_size = detector.batch_size
detector.batch_size = batch_size
try:
predictions = detector.predict(texts)
results = []
for text, pred in zip(texts, predictions):
# Determine predicted_class based on higher confidence
predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human"
results.append({
"text": text,
"predicted_class": predicted_class,
"confidence_ai": float(pred.probability_ai),
"confidence_human": float(pred.probability_human)
})
return results
except Exception as e:
print(f"Error during batch prediction: {e}")
return [{
"predicted_class": "Human",
"confidence_ai": -100.0,
"confidence_human": -100.0
} for _ in texts]
finally:
detector.batch_size = original_batch_size