MAS-AI-0000 commited on
Commit
8eaaeee
Β·
verified Β·
1 Parent(s): 2e29cb0

Update textPreprocess.py

Browse files
Files changed (1) hide show
  1. textPreprocess.py +130 -66
textPreprocess.py CHANGED
@@ -1,93 +1,157 @@
1
- import os
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
4
- from huggingface_hub import snapshot_download # <-- needed to pull the folder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # ── 1) PATHS / VARS ────────────────────────────────────────────────────────────
7
  REPO_ID = "MAS-AI-0000/Authentica"
8
  TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
 
 
9
 
10
- # download a local snapshot of just the Text folder and point MODEL_DIR at it
11
- _snapshot_dir = snapshot_download(
12
- repo_id=REPO_ID,
13
- allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
14
- )
15
- MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
16
-
17
- # individual file paths (in case you need them elsewhere)
18
- CONFIG_PATH = os.path.join(MODEL_DIR, "config.json")
19
- MODEL_SAFETENSORS_PATH = os.path.join(MODEL_DIR, "model.safetensors")
20
- TOKENIZER_JSON_PATH = os.path.join(MODEL_DIR, "tokenizer.json")
21
- TOKENIZER_CONFIG_PATH = os.path.join(MODEL_DIR, "tokenizer_config.json")
22
- SPECIAL_TOKENS_MAP_PATH = os.path.join(MODEL_DIR, "special_tokens_map.json")
23
- TRAINING_ARGS_BIN_PATH = os.path.join(MODEL_DIR, "training_args.bin") # optional
24
- TEXT_TXT_PATH = os.path.join(MODEL_DIR, "text.txt") # optional
25
 
26
- MAX_LEN = 512
 
 
 
 
 
 
 
 
 
 
27
 
28
  # ── 2) Load model & tokenizer ──────────────────────────────────────────────────
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  print(f"Text prediction device: {device}")
31
 
32
- tokenizer = None
33
- model = None
34
- ID2LABEL = {0: "human", 1: "ai"}
35
 
36
  try:
37
- # load directly from the local MODEL_DIR
38
- config = AutoConfig.from_pretrained(MODEL_DIR)
39
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
40
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, config=config)
41
- model.eval().to(device)
 
 
 
42
 
43
- # override labels from config if present
44
- if getattr(model.config, "id2label", None):
45
- ID2LABEL = {int(k): v for k, v in model.config.id2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- print("Text classification model loaded successfully")
48
- print("MODEL_DIR:", MODEL_DIR)
49
- print("Labels:", ID2LABEL)
50
  except Exception as e:
51
  print(f"Error loading text model: {e}")
52
  print("Text prediction will return fallback responses")
53
 
54
- # ── 3) Inference ───────────────────────────────────────────────────────────────
55
- @torch.inference_mode()
56
- def predict_text(text: str, max_length: int | None = None):
57
- if model is None or tokenizer is None:
58
- print("Issue 1")
 
 
 
 
 
 
 
 
 
59
  return {"predicted_class": "Human", "confidence": -100.0}
60
-
61
- if max_length is None:
62
- max_length = MAX_LEN
63
-
64
  try:
65
- enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
66
- enc = {k: v.to(device) for k, v in enc.items()}
67
- logits = model(**enc).logits
68
- probs = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu().numpy()
69
- pred_id = int(probs.argmax(-1))
70
- label = ID2LABEL.get(pred_id, str(pred_id)).capitalize()
71
- return {"predicted_class": label, "confidence": float(probs[pred_id])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
  print(f"Error during text prediction: {e}")
74
  return {"predicted_class": "Human", "confidence": -100.0}
75
 
76
- # ── 4) Batch (optional) ────────────────────────────────────────────────────────
77
- @torch.inference_mode()
78
  def predict_batch(texts, batch_size=16):
79
- if model is None or tokenizer is None:
80
- print("Issue 2")
 
 
 
 
 
 
 
 
 
81
  return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
82
-
83
- results = []
84
- for i in range(0, len(texts), batch_size):
85
- chunk = texts[i:i+batch_size]
86
- enc = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_LEN, padding=True)
87
- enc = {k: v.to(device) for k, v in enc.items()}
88
- probs = torch.softmax(model(**enc).logits, dim=-1).detach().cpu().numpy()
89
- ids = probs.argmax(-1)
90
- for t, pid, p in zip(chunk, ids, probs):
91
- label = ID2LABEL.get(int(pid), str(int(pid))).capitalize()
92
- results.append({"text": t, "predicted_class": label, "confidence": float(p[int(pid)])})
93
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import os
3
+ import sys
4
+ from pathlib import Path
5
+ from huggingface_hub import snapshot_download
6
+
7
+ # Ensure local detree package is importable
8
+ # This allows the script to find the 'detree' package if it sits in the same directory
9
+ current_dir = os.path.dirname(os.path.abspath(__file__))
10
+ if current_dir not in sys.path:
11
+ sys.path.append(current_dir)
12
+
13
+ try:
14
+ from detree.inference import Detector
15
+ except ImportError:
16
+ # Fallback if detree is not found (e.g. during initial setup check)
17
+ print("Warning: 'detree' package not found. Please ensure the 'detree' folder is in the same directory.")
18
+ Detector = None
19
 
20
+ # ── 1) Configuration ────────────────────────────────────────────────────────────
21
  REPO_ID = "MAS-AI-0000/Authentica"
22
  TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
23
+ EMBEDDING_FILE = "priori1_center10k.pt"
24
+ MAX_LEN = 512
25
 
26
+ MODEL_DIR = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ try:
29
+ # download a local snapshot of just the Text folder and point MODEL_DIR at it
30
+ print(f"Downloading/Checking model from {REPO_ID}...")
31
+ _snapshot_dir = snapshot_download(
32
+ repo_id=REPO_ID,
33
+ allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
34
+ )
35
+ MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
36
+ print(f"Model directory set to: {MODEL_DIR}")
37
+ except Exception as e:
38
+ print(f"Error downloading model from Hugging Face: {e}")
39
 
40
  # ── 2) Load model & tokenizer ──────────────────────────────────────────────────
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  print(f"Text prediction device: {device}")
43
 
44
+ detector = None
 
 
45
 
46
  try:
47
+ if Detector and MODEL_DIR:
48
+ database_path = os.path.join(MODEL_DIR, EMBEDDING_FILE)
49
+
50
+ if not os.path.exists(MODEL_DIR):
51
+ print(f"Warning: Model directory not found at {MODEL_DIR}")
52
+ if not os.path.exists(database_path):
53
+ print(f"Warning: Embedding file not found at {database_path}")
54
+ print(f"Please ensure '{EMBEDDING_FILE}' is present in '{TEXT_SUBFOLDER}' of the Hugging Face repo.")
55
 
56
+ # Initialize DETree Detector
57
+ # This loads the model from MODEL_DIR and the embeddings from database_path
58
+ detector = Detector(
59
+ database_path=database_path,
60
+ model_name_or_path=MODEL_DIR,
61
+ device=device,
62
+ max_length=MAX_LEN,
63
+ pooling="max" # Default pooling
64
+ )
65
+ print(f"Text classification model (DETree) loaded successfully")
66
+ else:
67
+ if not Detector:
68
+ print("DETree detector could not be initialized due to missing package.")
69
+ if not MODEL_DIR:
70
+ print("DETree detector could not be initialized due to missing model directory.")
71
 
 
 
 
72
  except Exception as e:
73
  print(f"Error loading text model: {e}")
74
  print("Text prediction will return fallback responses")
75
 
76
+ # ── 3) Inference function ──────────────────────────────────────────────────────
77
+ def predict_text(text: str, max_length: int = None):
78
+ """
79
+ Predict whether the given text is human-written or AI-generated using DETree.
80
+
81
+ Args:
82
+ text (str): The text to classify
83
+ max_length (int): Ignored in this implementation as DETree handles it globally,
84
+ but kept for compatibility.
85
+
86
+ Returns:
87
+ dict: Contains predicted_class and confidence
88
+ """
89
+ if detector is None:
90
  return {"predicted_class": "Human", "confidence": -100.0}
91
+
 
 
 
92
  try:
93
+ # detector.predict expects a list of strings
94
+ predictions = detector.predict([text])
95
+ if not predictions:
96
+ return {"predicted_class": "Human", "confidence": -100.0}
97
+
98
+ pred = predictions[0]
99
+ # pred.label is "Human" or "AI"
100
+ # Map to "Human" or "Ai" to match previous API
101
+ label = pred.label
102
+ if label == "AI":
103
+ label = "Ai"
104
+
105
+ # Confidence logic:
106
+ # If label is Human, use probability_human
107
+ # If label is Ai, use probability_ai
108
+ confidence = pred.probability_human if label == "Human" else pred.probability_ai
109
+
110
+ return {
111
+ "predicted_class": label,
112
+ "confidence": float(confidence)
113
+ }
114
  except Exception as e:
115
  print(f"Error during text prediction: {e}")
116
  return {"predicted_class": "Human", "confidence": -100.0}
117
 
118
+ # ── 4) Batch prediction ────────────────────────────────────────────────────────
 
119
  def predict_batch(texts, batch_size=16):
120
+ """
121
+ Predict multiple texts in batches.
122
+
123
+ Args:
124
+ texts (list): List of text strings to classify
125
+ batch_size (int): Batch size for processing
126
+
127
+ Returns:
128
+ list: List of prediction dictionaries
129
+ """
130
+ if detector is None:
131
  return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
132
+
133
+ # Temporarily update batch size if needed, or just use the detector's default
134
+ # We'll update it to respect the argument
135
+ original_batch_size = detector.batch_size
136
+ detector.batch_size = batch_size
137
+
138
+ try:
139
+ predictions = detector.predict(texts)
140
+ results = []
141
+ for text, pred in zip(texts, predictions):
142
+ label = pred.label
143
+ if label == "AI":
144
+ label = "Ai"
145
+ confidence = pred.probability_human if label == "Human" else pred.probability_ai
146
+
147
+ results.append({
148
+ "text": text,
149
+ "predicted_class": label,
150
+ "confidence": float(confidence)
151
+ })
152
+ return results
153
+ except Exception as e:
154
+ print(f"Error during batch prediction: {e}")
155
+ return [{"predicted_class": "Human", "confidence": -100.0} for _ in texts]
156
+ finally:
157
+ detector.batch_size = original_batch_size