Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import gradio as gr | |
| import re | |
| import pandas as pd | |
| import joblib | |
| import datetime | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| from nltk.tokenize import TreebankWordTokenizer | |
| from nltk.stem import WordNetLemmatizer | |
| from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS | |
| import os | |
| import time | |
| import zipfile | |
| import nltk | |
| # Download wordnet resource to avoid LookupError | |
| nltk.download('wordnet') | |
| # Load models and label mapping | |
| lda = joblib.load("lda_model.joblib") | |
| vectorizer = joblib.load("vectorizer.joblib") | |
| auto_labels = joblib.load("topic_labels.joblib") | |
| # Optional topic summaries | |
| topic_summaries = { | |
| "Politics & Gun Rights": "Discussions about government policies, laws, gun control, and rights.", | |
| "Computing & Hardware": "Technical issues and terms related to computer hardware and drivers.", | |
| "Programming & Software": "Programming terms, file handling, software output.", | |
| "Sports & Games": "Topics related to teams, players, seasons, and matches.", | |
| "Health & Medicine": "Diseases, treatment, healthcare, and medical facilities.", | |
| "Religion & Philosophy": "Talks involving faith, belief systems, philosophical views.", | |
| "Space & NASA": "Space exploration, NASA missions, satellites, and astronomy.", | |
| "Cryptography & Security": "Discussions on encryption, digital security, and data protection.", | |
| "Internet & Networking": "Terms around internet use, FTP, web versions, and networks.", | |
| "Middle East Politics & Conflicts": "Topics involving Israel, Armenia, conflict regions." | |
| } | |
| # Tokenizer and lemmatizer | |
| tokenizer = TreebankWordTokenizer() | |
| lemmatizer = WordNetLemmatizer() | |
| # --- Utility Functions --- | |
| def preprocess(text): | |
| text = re.sub(r'\W+', ' ', text.lower()) | |
| tokens = tokenizer.tokenize(text) | |
| tokens = [lemmatizer.lemmatize(w) for w in tokens if w not in ENGLISH_STOP_WORDS and len(w) > 2 and w.isalpha()] | |
| return ' '.join(tokens) | |
| def get_topic_keywords(model, vectorizer, topic_idx, top_n=10): | |
| feature_names = vectorizer.get_feature_names_out() | |
| topic = model.components_[topic_idx] | |
| top_indices = topic.argsort()[:-top_n - 1:-1] | |
| return [feature_names[i] for i in top_indices] | |
| def plot_topic_distribution(distribution, labels): | |
| plt.figure(figsize=(8, 4)) | |
| plt.bar(range(len(distribution)), distribution, tick_label=labels) | |
| plt.xticks(rotation=45, ha="right") | |
| plt.ylabel("Probability") | |
| plt.title("Topic Distribution") | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def save_prediction_file(text): | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"lda_prediction_{timestamp}.txt" | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| return filename | |
| def cleanup_old_predictions(directory=".", extension=".txt", max_age_minutes=10): | |
| now = time.time() | |
| max_age = max_age_minutes * 60 | |
| for fname in os.listdir(directory): | |
| if fname.endswith(extension) and fname.startswith("lda_prediction_"): | |
| full_path = os.path.join(directory, fname) | |
| if os.path.isfile(full_path) and (now - os.path.getmtime(full_path)) > max_age: | |
| try: | |
| os.remove(full_path) | |
| except Exception as e: | |
| print(f"Failed to delete {fname}: {e}") | |
| def download_log(): | |
| zip_filename = "lda_predictions_log.zip" | |
| with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| if os.path.exists("lda_predictions_log.csv"): | |
| zipf.write("lda_predictions_log.csv") | |
| return zip_filename | |
| def save_feedback(text, feedback): | |
| timestamp = datetime.datetime.now().isoformat() | |
| log_entry = pd.DataFrame([{ | |
| "timestamp": timestamp, | |
| "feedback": feedback, | |
| "text_excerpt": text[:300].replace('\n', ' ') + "..." | |
| }]) | |
| feedback_log = "lda_feedback_log.csv" | |
| log_entry.to_csv(feedback_log, mode='a', header=not os.path.exists(feedback_log), index=False) | |
| return " Feedback recorded. Thank you!" | |
| # --- Main Prediction Function --- | |
| def predict_topic(text_input, file_input): | |
| cleanup_old_predictions() | |
| if file_input is not None: | |
| text = file_input.read().decode("utf-8") | |
| elif text_input.strip(): | |
| text = text_input | |
| else: | |
| return "Please provide input", None, None | |
| cleaned = preprocess(text) | |
| bow = vectorizer.transform([cleaned]) | |
| topic_distribution = lda.transform(bow)[0] | |
| dominant_topic = topic_distribution.argmax() | |
| label = auto_labels.get(dominant_topic, f"Topic {dominant_topic+1}") | |
| top_words = get_topic_keywords(lda, vectorizer, dominant_topic) | |
| summary = topic_summaries.get(label, "No summary available.") | |
| # Confidence threshold warning | |
| confidence_threshold = 0.4 | |
| if topic_distribution[dominant_topic] < confidence_threshold: | |
| label += " ( Low confidence)" | |
| summary = " The model is uncertain. Try providing more context or a longer input." | |
| # Log entry | |
| timestamp = datetime.datetime.now().isoformat() | |
| log_entry = pd.DataFrame([{ | |
| "timestamp": timestamp, | |
| "predicted_topic": label, | |
| "dominant_topic_index": dominant_topic, | |
| "top_words": ", ".join(top_words), | |
| "text_excerpt": text[:300].replace('\n', ' ') + "..." | |
| }]) | |
| log_path = "lda_predictions_log.csv" | |
| log_entry.to_csv(log_path, mode='a', header=not os.path.exists(log_path), index=False) | |
| chart = plot_topic_distribution(topic_distribution, [auto_labels.get(i, f"Topic {i+1}") for i in range(len(topic_distribution))]) | |
| result = f" **Predicted Topic:** {label}\n\n" | |
| result += f" **Summary:** {summary}\n\n" | |
| result += f" **Top Words:** {', '.join(top_words)}\n\n" | |
| result += " **Topic Distribution:**\n" | |
| for idx, prob in enumerate(topic_distribution): | |
| tlabel = auto_labels.get(idx, f"Topic {idx+1}") | |
| result += f"{tlabel}: {prob:.3f}\n" | |
| prediction_file = save_prediction_file(result) | |
| return result, chart, prediction_file | |
| # --- Gradio Interface --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Topic Modeling with LDA") | |
| gr.Markdown("Upload a `.txt` file or paste in text. See predicted topic, keywords, and a chart.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox(lines=10, label=" Paste Text") | |
| file_input = gr.File(label=" Or Upload a .txt File", file_types=[".txt"]) | |
| predict_btn = gr.Button(" Predict Topic") | |
| download_btn = gr.Button("⬇ Download All Logs") | |
| feedback_input = gr.Radio( | |
| choices=["Accurate", " Inaccurate", "Unclear"], | |
| label=" Was this prediction useful?", | |
| interactive=True | |
| ) | |
| feedback_btn = gr.Button("Submit Feedback") | |
| feedback_output = gr.Textbox(visible=False) | |
| with gr.Column(): | |
| output_text = gr.Textbox(label=" Prediction Result") | |
| output_chart = gr.Image(type="pil", label=" Topic Distribution") | |
| download_prediction = gr.File(label="⬇ Download This Prediction") | |
| predict_btn.click( | |
| fn=predict_topic, | |
| inputs=[text_input, file_input], | |
| outputs=[output_text, output_chart, download_prediction] | |
| ) | |
| download_btn.click(fn=download_log, outputs=[gr.File()]) | |
| feedback_btn.click( | |
| fn=save_feedback, | |
| inputs=[text_input, feedback_input], | |
| outputs=[feedback_output] | |
| ) | |
| demo.launch() | |