JamesKingsley's picture
Add nltk.download for wordnet resource
a4b6d10
from PIL import Image
import gradio as gr
import re
import pandas as pd
import joblib
import datetime
import matplotlib.pyplot as plt
from io import BytesIO
from nltk.tokenize import TreebankWordTokenizer
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
import os
import time
import zipfile
import nltk
# Download wordnet resource to avoid LookupError
nltk.download('wordnet')
# Load models and label mapping
lda = joblib.load("lda_model.joblib")
vectorizer = joblib.load("vectorizer.joblib")
auto_labels = joblib.load("topic_labels.joblib")
# Optional topic summaries
topic_summaries = {
"Politics & Gun Rights": "Discussions about government policies, laws, gun control, and rights.",
"Computing & Hardware": "Technical issues and terms related to computer hardware and drivers.",
"Programming & Software": "Programming terms, file handling, software output.",
"Sports & Games": "Topics related to teams, players, seasons, and matches.",
"Health & Medicine": "Diseases, treatment, healthcare, and medical facilities.",
"Religion & Philosophy": "Talks involving faith, belief systems, philosophical views.",
"Space & NASA": "Space exploration, NASA missions, satellites, and astronomy.",
"Cryptography & Security": "Discussions on encryption, digital security, and data protection.",
"Internet & Networking": "Terms around internet use, FTP, web versions, and networks.",
"Middle East Politics & Conflicts": "Topics involving Israel, Armenia, conflict regions."
}
# Tokenizer and lemmatizer
tokenizer = TreebankWordTokenizer()
lemmatizer = WordNetLemmatizer()
# --- Utility Functions ---
def preprocess(text):
text = re.sub(r'\W+', ' ', text.lower())
tokens = tokenizer.tokenize(text)
tokens = [lemmatizer.lemmatize(w) for w in tokens if w not in ENGLISH_STOP_WORDS and len(w) > 2 and w.isalpha()]
return ' '.join(tokens)
def get_topic_keywords(model, vectorizer, topic_idx, top_n=10):
feature_names = vectorizer.get_feature_names_out()
topic = model.components_[topic_idx]
top_indices = topic.argsort()[:-top_n - 1:-1]
return [feature_names[i] for i in top_indices]
def plot_topic_distribution(distribution, labels):
plt.figure(figsize=(8, 4))
plt.bar(range(len(distribution)), distribution, tick_label=labels)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Probability")
plt.title("Topic Distribution")
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
return Image.open(buf)
def save_prediction_file(text):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"lda_prediction_{timestamp}.txt"
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
return filename
def cleanup_old_predictions(directory=".", extension=".txt", max_age_minutes=10):
now = time.time()
max_age = max_age_minutes * 60
for fname in os.listdir(directory):
if fname.endswith(extension) and fname.startswith("lda_prediction_"):
full_path = os.path.join(directory, fname)
if os.path.isfile(full_path) and (now - os.path.getmtime(full_path)) > max_age:
try:
os.remove(full_path)
except Exception as e:
print(f"Failed to delete {fname}: {e}")
def download_log():
zip_filename = "lda_predictions_log.zip"
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
if os.path.exists("lda_predictions_log.csv"):
zipf.write("lda_predictions_log.csv")
return zip_filename
def save_feedback(text, feedback):
timestamp = datetime.datetime.now().isoformat()
log_entry = pd.DataFrame([{
"timestamp": timestamp,
"feedback": feedback,
"text_excerpt": text[:300].replace('\n', ' ') + "..."
}])
feedback_log = "lda_feedback_log.csv"
log_entry.to_csv(feedback_log, mode='a', header=not os.path.exists(feedback_log), index=False)
return " Feedback recorded. Thank you!"
# --- Main Prediction Function ---
def predict_topic(text_input, file_input):
cleanup_old_predictions()
if file_input is not None:
text = file_input.read().decode("utf-8")
elif text_input.strip():
text = text_input
else:
return "Please provide input", None, None
cleaned = preprocess(text)
bow = vectorizer.transform([cleaned])
topic_distribution = lda.transform(bow)[0]
dominant_topic = topic_distribution.argmax()
label = auto_labels.get(dominant_topic, f"Topic {dominant_topic+1}")
top_words = get_topic_keywords(lda, vectorizer, dominant_topic)
summary = topic_summaries.get(label, "No summary available.")
# Confidence threshold warning
confidence_threshold = 0.4
if topic_distribution[dominant_topic] < confidence_threshold:
label += " ( Low confidence)"
summary = " The model is uncertain. Try providing more context or a longer input."
# Log entry
timestamp = datetime.datetime.now().isoformat()
log_entry = pd.DataFrame([{
"timestamp": timestamp,
"predicted_topic": label,
"dominant_topic_index": dominant_topic,
"top_words": ", ".join(top_words),
"text_excerpt": text[:300].replace('\n', ' ') + "..."
}])
log_path = "lda_predictions_log.csv"
log_entry.to_csv(log_path, mode='a', header=not os.path.exists(log_path), index=False)
chart = plot_topic_distribution(topic_distribution, [auto_labels.get(i, f"Topic {i+1}") for i in range(len(topic_distribution))])
result = f" **Predicted Topic:** {label}\n\n"
result += f" **Summary:** {summary}\n\n"
result += f" **Top Words:** {', '.join(top_words)}\n\n"
result += " **Topic Distribution:**\n"
for idx, prob in enumerate(topic_distribution):
tlabel = auto_labels.get(idx, f"Topic {idx+1}")
result += f"{tlabel}: {prob:.3f}\n"
prediction_file = save_prediction_file(result)
return result, chart, prediction_file
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("## Topic Modeling with LDA")
gr.Markdown("Upload a `.txt` file or paste in text. See predicted topic, keywords, and a chart.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(lines=10, label=" Paste Text")
file_input = gr.File(label=" Or Upload a .txt File", file_types=[".txt"])
predict_btn = gr.Button(" Predict Topic")
download_btn = gr.Button("⬇ Download All Logs")
feedback_input = gr.Radio(
choices=["Accurate", " Inaccurate", "Unclear"],
label=" Was this prediction useful?",
interactive=True
)
feedback_btn = gr.Button("Submit Feedback")
feedback_output = gr.Textbox(visible=False)
with gr.Column():
output_text = gr.Textbox(label=" Prediction Result")
output_chart = gr.Image(type="pil", label=" Topic Distribution")
download_prediction = gr.File(label="⬇ Download This Prediction")
predict_btn.click(
fn=predict_topic,
inputs=[text_input, file_input],
outputs=[output_text, output_chart, download_prediction]
)
download_btn.click(fn=download_log, outputs=[gr.File()])
feedback_btn.click(
fn=save_feedback,
inputs=[text_input, feedback_input],
outputs=[feedback_output]
)
demo.launch()