import torch import os import sys from pathlib import Path from huggingface_hub import snapshot_download # Ensure local detree package is importable # This allows the script to find the 'detree' package if it sits in the same directory current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.append(current_dir) try: from detree.inference import Detector except ImportError as e: # Fallback if detree is not found (e.g. during initial setup check) print(f"Warning: 'detree' package not found. Error: {e}") Detector = None # ── 1) Configuration ──────────────────────────────────────────────────────────── REPO_ID = "MAS-AI-0000/Authentica" TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo EMBEDDING_FILE = "priori1_center10k.pt" MAX_LEN = 512 MODEL_DIR = None try: # download a local snapshot of just the Text folder and point MODEL_DIR at it print(f"Downloading/Checking model from {REPO_ID}...") _snapshot_dir = snapshot_download( repo_id=REPO_ID, allow_patterns=[f"{TEXT_SUBFOLDER}/*"] ) MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER) print(f"Model directory set to: {MODEL_DIR}") except Exception as e: print(f"Error downloading model from Hugging Face: {e}") # ── 2) Load model & tokenizer ────────────────────────────────────────────────── device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Text prediction device: {device}") detector = None try: if Detector: database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE) if not os.path.exists(MODEL_DIR): print(f"Warning: Model directory not found at {MODEL_DIR}") if not os.path.exists(database_path): print(f"Warning: Embedding file not found at {database_path}") # Initialize DETree Detector # This loads the model from MODEL_DIR and the embeddings from database_path detector = Detector( database_path=database_path, model_name_or_path=MODEL_DIR, device=device, max_length=MAX_LEN, pooling="max" # Default pooling ) print(f"Text classification model (DETree) loaded successfully") else: print("DETree detector could not be initialized due to missing package.") except Exception as e: print(f"Error loading text model: {e}") print("Text prediction will return fallback responses") # ── 3) Inference function ────────────────────────────────────────────────────── def predict_text(text: str, max_length: int = None): """ Predict whether the given text is human-written or AI-generated using DETree. Args: text (str): The text to classify max_length (int): Ignored in this implementation as DETree handles it globally, but kept for compatibility. Returns: dict: Contains predicted_class and confidence """ if detector is None: return { "predicted_class": "Human", "confidence_ai": -100.0, "confidence_human": -100.0 } try: # detector.predict expects a list of strings predictions = detector.predict([text]) print(f"DETree prediction output: {predictions}") if not predictions: return { "predicted_class": "Human", "confidence_ai": -100.0, "confidence_human": -100.0 } pred = predictions[0] # Determine predicted_class based on higher confidence predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human" return { "predicted_class": predicted_class, "confidence_ai": float(pred.probability_ai), "confidence_human": float(pred.probability_human) } except Exception as e: print(f"Error during text prediction: {e}") return { "predicted_class": "Human", "confidence_ai": -100.0, "confidence_human": -100.0 } # ── 4) Batch prediction ──────────────────────────────────────────────────────── def predict_batch(texts, batch_size=16): """ Predict multiple texts in batches. Args: texts (list): List of text strings to classify batch_size (int): Batch size for processing if detector is None: return [{ "predicted_class": "Human", "confidence_ai": -100.0, "confidence_human": -100.0 } for _ in texts] list: List of prediction dictionaries """ if detector is None: return [{"predicted_class": "Human", "confidence": 0} for _ in texts] # Temporarily update batch size if needed, or just use the detector's default # We'll update it to respect the argument original_batch_size = detector.batch_size detector.batch_size = batch_size try: predictions = detector.predict(texts) results = [] for text, pred in zip(texts, predictions): # Determine predicted_class based on higher confidence predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human" results.append({ "text": text, "predicted_class": predicted_class, "confidence_ai": float(pred.probability_ai), "confidence_human": float(pred.probability_human) }) return results except Exception as e: print(f"Error during batch prediction: {e}") return [{ "predicted_class": "Human", "confidence_ai": -100.0, "confidence_human": -100.0 } for _ in texts] finally: detector.batch_size = original_batch_size