Spaces:
Running
Running
| import torch | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| # Ensure local detree package is importable | |
| # This allows the script to find the 'detree' package if it sits in the same directory | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.append(current_dir) | |
| try: | |
| from detree.inference import Detector | |
| except ImportError as e: | |
| # Fallback if detree is not found (e.g. during initial setup check) | |
| print(f"Warning: 'detree' package not found. Error: {e}") | |
| Detector = None | |
| # ββ 1) Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO_ID = "MAS-AI-0000/Authentica" | |
| TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo | |
| EMBEDDING_FILE = "priori1_center10k.pt" | |
| MAX_LEN = 512 | |
| MODEL_DIR = None | |
| try: | |
| # download a local snapshot of just the Text folder and point MODEL_DIR at it | |
| print(f"Downloading/Checking model from {REPO_ID}...") | |
| _snapshot_dir = snapshot_download( | |
| repo_id=REPO_ID, | |
| allow_patterns=[f"{TEXT_SUBFOLDER}/*"] | |
| ) | |
| MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER) | |
| print(f"Model directory set to: {MODEL_DIR}") | |
| except Exception as e: | |
| print(f"Error downloading model from Hugging Face: {e}") | |
| # ββ 2) Load model & tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Text prediction device: {device}") | |
| detector = None | |
| try: | |
| if Detector: | |
| database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE) | |
| if not os.path.exists(MODEL_DIR): | |
| print(f"Warning: Model directory not found at {MODEL_DIR}") | |
| if not os.path.exists(database_path): | |
| print(f"Warning: Embedding file not found at {database_path}") | |
| # Initialize DETree Detector | |
| # This loads the model from MODEL_DIR and the embeddings from database_path | |
| detector = Detector( | |
| database_path=database_path, | |
| model_name_or_path=MODEL_DIR, | |
| device=device, | |
| max_length=MAX_LEN, | |
| pooling="max" # Default pooling | |
| ) | |
| print(f"Text classification model (DETree) loaded successfully") | |
| else: | |
| print("DETree detector could not be initialized due to missing package.") | |
| except Exception as e: | |
| print(f"Error loading text model: {e}") | |
| print("Text prediction will return fallback responses") | |
| # ββ 3) Inference function ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_text(text: str, max_length: int = None): | |
| """ | |
| Predict whether the given text is human-written or AI-generated using DETree. | |
| Args: | |
| text (str): The text to classify | |
| max_length (int): Ignored in this implementation as DETree handles it globally, | |
| but kept for compatibility. | |
| Returns: | |
| dict: Contains predicted_class and confidence | |
| """ | |
| if detector is None: | |
| return { | |
| "predicted_class": "Human", | |
| "confidence_ai": -100.0, | |
| "confidence_human": -100.0 | |
| } | |
| try: | |
| # detector.predict expects a list of strings | |
| predictions = detector.predict([text]) | |
| print(f"DETree prediction output: {predictions}") | |
| if not predictions: | |
| return { | |
| "predicted_class": "Human", | |
| "confidence_ai": -100.0, | |
| "confidence_human": -100.0 | |
| } | |
| pred = predictions[0] | |
| # Determine predicted_class based on higher confidence | |
| predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human" | |
| return { | |
| "predicted_class": predicted_class, | |
| "confidence_ai": float(pred.probability_ai), | |
| "confidence_human": float(pred.probability_human) | |
| } | |
| except Exception as e: | |
| print(f"Error during text prediction: {e}") | |
| return { | |
| "predicted_class": "Human", | |
| "confidence_ai": -100.0, | |
| "confidence_human": -100.0 | |
| } | |
| # ββ 4) Batch prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_batch(texts, batch_size=16): | |
| """ | |
| Predict multiple texts in batches. | |
| Args: | |
| texts (list): List of text strings to classify | |
| batch_size (int): Batch size for processing | |
| if detector is None: | |
| return [{ | |
| "predicted_class": "Human", | |
| "confidence_ai": -100.0, | |
| "confidence_human": -100.0 | |
| } for _ in texts] | |
| list: List of prediction dictionaries | |
| """ | |
| if detector is None: | |
| return [{"predicted_class": "Human", "confidence": 0} for _ in texts] | |
| # Temporarily update batch size if needed, or just use the detector's default | |
| # We'll update it to respect the argument | |
| original_batch_size = detector.batch_size | |
| detector.batch_size = batch_size | |
| try: | |
| predictions = detector.predict(texts) | |
| results = [] | |
| for text, pred in zip(texts, predictions): | |
| # Determine predicted_class based on higher confidence | |
| predicted_class = "AI" if pred.probability_ai > pred.probability_human else "Human" | |
| results.append({ | |
| "text": text, | |
| "predicted_class": predicted_class, | |
| "confidence_ai": float(pred.probability_ai), | |
| "confidence_human": float(pred.probability_human) | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"Error during batch prediction: {e}") | |
| return [{ | |
| "predicted_class": "Human", | |
| "confidence_ai": -100.0, | |
| "confidence_human": -100.0 | |
| } for _ in texts] | |
| finally: | |
| detector.batch_size = original_batch_size |