Spaces:
Sleeping
Sleeping
File size: 7,557 Bytes
76d2426 a4b6d10 76d2426 a4b6d10 76d2426 a4b6d10 76d2426 a4b6d10 76d2426 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from PIL import Image
import gradio as gr
import re
import pandas as pd
import joblib
import datetime
import matplotlib.pyplot as plt
from io import BytesIO
from nltk.tokenize import TreebankWordTokenizer
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
import os
import time
import zipfile
import nltk
# Download wordnet resource to avoid LookupError
nltk.download('wordnet')
# Load models and label mapping
lda = joblib.load("lda_model.joblib")
vectorizer = joblib.load("vectorizer.joblib")
auto_labels = joblib.load("topic_labels.joblib")
# Optional topic summaries
topic_summaries = {
"Politics & Gun Rights": "Discussions about government policies, laws, gun control, and rights.",
"Computing & Hardware": "Technical issues and terms related to computer hardware and drivers.",
"Programming & Software": "Programming terms, file handling, software output.",
"Sports & Games": "Topics related to teams, players, seasons, and matches.",
"Health & Medicine": "Diseases, treatment, healthcare, and medical facilities.",
"Religion & Philosophy": "Talks involving faith, belief systems, philosophical views.",
"Space & NASA": "Space exploration, NASA missions, satellites, and astronomy.",
"Cryptography & Security": "Discussions on encryption, digital security, and data protection.",
"Internet & Networking": "Terms around internet use, FTP, web versions, and networks.",
"Middle East Politics & Conflicts": "Topics involving Israel, Armenia, conflict regions."
}
# Tokenizer and lemmatizer
tokenizer = TreebankWordTokenizer()
lemmatizer = WordNetLemmatizer()
# --- Utility Functions ---
def preprocess(text):
text = re.sub(r'\W+', ' ', text.lower())
tokens = tokenizer.tokenize(text)
tokens = [lemmatizer.lemmatize(w) for w in tokens if w not in ENGLISH_STOP_WORDS and len(w) > 2 and w.isalpha()]
return ' '.join(tokens)
def get_topic_keywords(model, vectorizer, topic_idx, top_n=10):
feature_names = vectorizer.get_feature_names_out()
topic = model.components_[topic_idx]
top_indices = topic.argsort()[:-top_n - 1:-1]
return [feature_names[i] for i in top_indices]
def plot_topic_distribution(distribution, labels):
plt.figure(figsize=(8, 4))
plt.bar(range(len(distribution)), distribution, tick_label=labels)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Probability")
plt.title("Topic Distribution")
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
return Image.open(buf)
def save_prediction_file(text):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"lda_prediction_{timestamp}.txt"
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
return filename
def cleanup_old_predictions(directory=".", extension=".txt", max_age_minutes=10):
now = time.time()
max_age = max_age_minutes * 60
for fname in os.listdir(directory):
if fname.endswith(extension) and fname.startswith("lda_prediction_"):
full_path = os.path.join(directory, fname)
if os.path.isfile(full_path) and (now - os.path.getmtime(full_path)) > max_age:
try:
os.remove(full_path)
except Exception as e:
print(f"Failed to delete {fname}: {e}")
def download_log():
zip_filename = "lda_predictions_log.zip"
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
if os.path.exists("lda_predictions_log.csv"):
zipf.write("lda_predictions_log.csv")
return zip_filename
def save_feedback(text, feedback):
timestamp = datetime.datetime.now().isoformat()
log_entry = pd.DataFrame([{
"timestamp": timestamp,
"feedback": feedback,
"text_excerpt": text[:300].replace('\n', ' ') + "..."
}])
feedback_log = "lda_feedback_log.csv"
log_entry.to_csv(feedback_log, mode='a', header=not os.path.exists(feedback_log), index=False)
return " Feedback recorded. Thank you!"
# --- Main Prediction Function ---
def predict_topic(text_input, file_input):
cleanup_old_predictions()
if file_input is not None:
text = file_input.read().decode("utf-8")
elif text_input.strip():
text = text_input
else:
return "Please provide input", None, None
cleaned = preprocess(text)
bow = vectorizer.transform([cleaned])
topic_distribution = lda.transform(bow)[0]
dominant_topic = topic_distribution.argmax()
label = auto_labels.get(dominant_topic, f"Topic {dominant_topic+1}")
top_words = get_topic_keywords(lda, vectorizer, dominant_topic)
summary = topic_summaries.get(label, "No summary available.")
# Confidence threshold warning
confidence_threshold = 0.4
if topic_distribution[dominant_topic] < confidence_threshold:
label += " ( Low confidence)"
summary = " The model is uncertain. Try providing more context or a longer input."
# Log entry
timestamp = datetime.datetime.now().isoformat()
log_entry = pd.DataFrame([{
"timestamp": timestamp,
"predicted_topic": label,
"dominant_topic_index": dominant_topic,
"top_words": ", ".join(top_words),
"text_excerpt": text[:300].replace('\n', ' ') + "..."
}])
log_path = "lda_predictions_log.csv"
log_entry.to_csv(log_path, mode='a', header=not os.path.exists(log_path), index=False)
chart = plot_topic_distribution(topic_distribution, [auto_labels.get(i, f"Topic {i+1}") for i in range(len(topic_distribution))])
result = f" **Predicted Topic:** {label}\n\n"
result += f" **Summary:** {summary}\n\n"
result += f" **Top Words:** {', '.join(top_words)}\n\n"
result += " **Topic Distribution:**\n"
for idx, prob in enumerate(topic_distribution):
tlabel = auto_labels.get(idx, f"Topic {idx+1}")
result += f"{tlabel}: {prob:.3f}\n"
prediction_file = save_prediction_file(result)
return result, chart, prediction_file
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("## Topic Modeling with LDA")
gr.Markdown("Upload a `.txt` file or paste in text. See predicted topic, keywords, and a chart.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(lines=10, label=" Paste Text")
file_input = gr.File(label=" Or Upload a .txt File", file_types=[".txt"])
predict_btn = gr.Button(" Predict Topic")
download_btn = gr.Button("⬇ Download All Logs")
feedback_input = gr.Radio(
choices=["Accurate", " Inaccurate", "Unclear"],
label=" Was this prediction useful?",
interactive=True
)
feedback_btn = gr.Button("Submit Feedback")
feedback_output = gr.Textbox(visible=False)
with gr.Column():
output_text = gr.Textbox(label=" Prediction Result")
output_chart = gr.Image(type="pil", label=" Topic Distribution")
download_prediction = gr.File(label="⬇ Download This Prediction")
predict_btn.click(
fn=predict_topic,
inputs=[text_input, file_input],
outputs=[output_text, output_chart, download_prediction]
)
download_btn.click(fn=download_log, outputs=[gr.File()])
feedback_btn.click(
fn=save_feedback,
inputs=[text_input, feedback_input],
outputs=[feedback_output]
)
demo.launch()
|