Authentica / textPreprocess.py
MAS-AI-0000's picture
Update textPreprocess.py
13464bf verified
raw
history blame
4.69 kB
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import os
# ── 1) Configuration ────────────────────────────────────────────────────────────
BASE_DIR = "MAS-AI-0000/Authentica"
MODEL_DIR = os.path.join(BASE_DIR, "Lib/Models/Text") # Update this path to your model location
MAX_LEN = 512
# ── 2) Load model & tokenizer ──────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Text prediction device: {device}")
# Global variables for model and tokenizer
tokenizer = None
model = None
ID2LABEL = {0: "human", 1: "ai"}
try:
# Config carries id2label/label2id if you saved them
config = AutoConfig.from_pretrained(MODEL_DIR)
# Loads tokenizer.json + special_tokens_map.json automatically
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
# Loads model.safetensors automatically (no extra flags needed)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, config=config)
model.eval().to(device)
# Update label mapping from config if available
ID2LABEL = model.config.id2label if getattr(model.config, "id2label", None) else {0: "human", 1: "ai"}
print(f"Text classification model loaded successfully")
print("Labels:", ID2LABEL)
except Exception as e:
print(f"Error loading text model: {e}")
print("Text prediction will return fallback responses")
# ── 3) Inference function ──────────────────────────────────────────────────────
@torch.inference_mode()
def predict_text(text: str, max_length: int = None):
"""
Predict whether the given text is human-written or AI-generated.
Args:
text (str): The text to classify
max_length (int): Maximum sequence length for tokenization (defaults to MAX_LEN)
Returns:
dict: Contains predicted_class and confidence
"""
if model is None or tokenizer is None:
return {"predicted_class": "Human", "confidence": 0}
if max_length is None:
max_length = MAX_LEN
try:
# Tokenize input
enc = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
)
enc = {k: v.to(device) for k, v in enc.items()}
# Get predictions
logits = model(**enc).logits
probs = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu().numpy()
pred_id = int(probs.argmax(-1))
# Get label (capitalize first letter for consistency)
label = ID2LABEL.get(pred_id, str(pred_id))
label = label.capitalize() # "human" -> "Human", "ai" -> "Ai"
return {
"predicted_class": label,
"confidence": float(probs[pred_id])
}
except Exception as e:
print(f"Error during text prediction: {e}")
return {"predicted_class": "Human", "confidence": 0}
# ── 4) Batch prediction (optional, for future use) ─────────────────────────────
@torch.inference_mode()
def predict_batch(texts, batch_size=16):
"""
Predict multiple texts in batches.
Args:
texts (list): List of text strings to classify
batch_size (int): Batch size for processing
Returns:
list: List of prediction dictionaries
"""
if model is None or tokenizer is None:
return [{"predicted_class": "Human", "confidence": 0} for _ in texts]
results = []
for i in range(0, len(texts), batch_size):
chunk = texts[i:i+batch_size]
enc = tokenizer(
chunk,
return_tensors="pt",
truncation=True,
max_length=MAX_LEN,
padding=True,
)
enc = {k: v.to(device) for k, v in enc.items()}
logits = model(**enc).logits
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()
ids = probs.argmax(-1)
for t, pid, p in zip(chunk, ids, probs):
label = ID2LABEL.get(int(pid), str(int(pid))).capitalize()
results.append({
"text": t,
"predicted_class": label,
"confidence": float(p[int(pid)])
})
return results