from PIL import Image import gradio as gr import re import pandas as pd import joblib import datetime import matplotlib.pyplot as plt from io import BytesIO from nltk.tokenize import TreebankWordTokenizer from nltk.stem import WordNetLemmatizer from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS import os import time import zipfile import nltk # Download wordnet resource to avoid LookupError nltk.download('wordnet') # Load models and label mapping lda = joblib.load("lda_model.joblib") vectorizer = joblib.load("vectorizer.joblib") auto_labels = joblib.load("topic_labels.joblib") # Optional topic summaries topic_summaries = { "Politics & Gun Rights": "Discussions about government policies, laws, gun control, and rights.", "Computing & Hardware": "Technical issues and terms related to computer hardware and drivers.", "Programming & Software": "Programming terms, file handling, software output.", "Sports & Games": "Topics related to teams, players, seasons, and matches.", "Health & Medicine": "Diseases, treatment, healthcare, and medical facilities.", "Religion & Philosophy": "Talks involving faith, belief systems, philosophical views.", "Space & NASA": "Space exploration, NASA missions, satellites, and astronomy.", "Cryptography & Security": "Discussions on encryption, digital security, and data protection.", "Internet & Networking": "Terms around internet use, FTP, web versions, and networks.", "Middle East Politics & Conflicts": "Topics involving Israel, Armenia, conflict regions." } # Tokenizer and lemmatizer tokenizer = TreebankWordTokenizer() lemmatizer = WordNetLemmatizer() # --- Utility Functions --- def preprocess(text): text = re.sub(r'\W+', ' ', text.lower()) tokens = tokenizer.tokenize(text) tokens = [lemmatizer.lemmatize(w) for w in tokens if w not in ENGLISH_STOP_WORDS and len(w) > 2 and w.isalpha()] return ' '.join(tokens) def get_topic_keywords(model, vectorizer, topic_idx, top_n=10): feature_names = vectorizer.get_feature_names_out() topic = model.components_[topic_idx] top_indices = topic.argsort()[:-top_n - 1:-1] return [feature_names[i] for i in top_indices] def plot_topic_distribution(distribution, labels): plt.figure(figsize=(8, 4)) plt.bar(range(len(distribution)), distribution, tick_label=labels) plt.xticks(rotation=45, ha="right") plt.ylabel("Probability") plt.title("Topic Distribution") plt.tight_layout() buf = BytesIO() plt.savefig(buf, format="png") plt.close() buf.seek(0) return Image.open(buf) def save_prediction_file(text): timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"lda_prediction_{timestamp}.txt" with open(filename, "w", encoding="utf-8") as f: f.write(text) return filename def cleanup_old_predictions(directory=".", extension=".txt", max_age_minutes=10): now = time.time() max_age = max_age_minutes * 60 for fname in os.listdir(directory): if fname.endswith(extension) and fname.startswith("lda_prediction_"): full_path = os.path.join(directory, fname) if os.path.isfile(full_path) and (now - os.path.getmtime(full_path)) > max_age: try: os.remove(full_path) except Exception as e: print(f"Failed to delete {fname}: {e}") def download_log(): zip_filename = "lda_predictions_log.zip" with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf: if os.path.exists("lda_predictions_log.csv"): zipf.write("lda_predictions_log.csv") return zip_filename def save_feedback(text, feedback): timestamp = datetime.datetime.now().isoformat() log_entry = pd.DataFrame([{ "timestamp": timestamp, "feedback": feedback, "text_excerpt": text[:300].replace('\n', ' ') + "..." }]) feedback_log = "lda_feedback_log.csv" log_entry.to_csv(feedback_log, mode='a', header=not os.path.exists(feedback_log), index=False) return " Feedback recorded. Thank you!" # --- Main Prediction Function --- def predict_topic(text_input, file_input): cleanup_old_predictions() if file_input is not None: text = file_input.read().decode("utf-8") elif text_input.strip(): text = text_input else: return "Please provide input", None, None cleaned = preprocess(text) bow = vectorizer.transform([cleaned]) topic_distribution = lda.transform(bow)[0] dominant_topic = topic_distribution.argmax() label = auto_labels.get(dominant_topic, f"Topic {dominant_topic+1}") top_words = get_topic_keywords(lda, vectorizer, dominant_topic) summary = topic_summaries.get(label, "No summary available.") # Confidence threshold warning confidence_threshold = 0.4 if topic_distribution[dominant_topic] < confidence_threshold: label += " ( Low confidence)" summary = " The model is uncertain. Try providing more context or a longer input." # Log entry timestamp = datetime.datetime.now().isoformat() log_entry = pd.DataFrame([{ "timestamp": timestamp, "predicted_topic": label, "dominant_topic_index": dominant_topic, "top_words": ", ".join(top_words), "text_excerpt": text[:300].replace('\n', ' ') + "..." }]) log_path = "lda_predictions_log.csv" log_entry.to_csv(log_path, mode='a', header=not os.path.exists(log_path), index=False) chart = plot_topic_distribution(topic_distribution, [auto_labels.get(i, f"Topic {i+1}") for i in range(len(topic_distribution))]) result = f" **Predicted Topic:** {label}\n\n" result += f" **Summary:** {summary}\n\n" result += f" **Top Words:** {', '.join(top_words)}\n\n" result += " **Topic Distribution:**\n" for idx, prob in enumerate(topic_distribution): tlabel = auto_labels.get(idx, f"Topic {idx+1}") result += f"{tlabel}: {prob:.3f}\n" prediction_file = save_prediction_file(result) return result, chart, prediction_file # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("## Topic Modeling with LDA") gr.Markdown("Upload a `.txt` file or paste in text. See predicted topic, keywords, and a chart.") with gr.Row(): with gr.Column(): text_input = gr.Textbox(lines=10, label=" Paste Text") file_input = gr.File(label=" Or Upload a .txt File", file_types=[".txt"]) predict_btn = gr.Button(" Predict Topic") download_btn = gr.Button("⬇ Download All Logs") feedback_input = gr.Radio( choices=["Accurate", " Inaccurate", "Unclear"], label=" Was this prediction useful?", interactive=True ) feedback_btn = gr.Button("Submit Feedback") feedback_output = gr.Textbox(visible=False) with gr.Column(): output_text = gr.Textbox(label=" Prediction Result") output_chart = gr.Image(type="pil", label=" Topic Distribution") download_prediction = gr.File(label="⬇ Download This Prediction") predict_btn.click( fn=predict_topic, inputs=[text_input, file_input], outputs=[output_text, output_chart, download_prediction] ) download_btn.click(fn=download_log, outputs=[gr.File()]) feedback_btn.click( fn=save_feedback, inputs=[text_input, feedback_input], outputs=[feedback_output] ) demo.launch()