File size: 7,557 Bytes
76d2426
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b6d10
 
 
 
76d2426
 
 
 
 
 
a4b6d10
76d2426
 
 
 
 
 
 
 
 
 
 
 
 
a4b6d10
76d2426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b6d10
 
76d2426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from PIL import Image
import gradio as gr
import re
import pandas as pd
import joblib
import datetime
import matplotlib.pyplot as plt
from io import BytesIO
from nltk.tokenize import TreebankWordTokenizer
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
import os
import time
import zipfile
import nltk

# Download wordnet resource to avoid LookupError
nltk.download('wordnet')

# Load models and label mapping
lda = joblib.load("lda_model.joblib")
vectorizer = joblib.load("vectorizer.joblib")
auto_labels = joblib.load("topic_labels.joblib")

# Optional topic summaries
topic_summaries = {
    "Politics & Gun Rights": "Discussions about government policies, laws, gun control, and rights.",
    "Computing & Hardware": "Technical issues and terms related to computer hardware and drivers.",
    "Programming & Software": "Programming terms, file handling, software output.",
    "Sports & Games": "Topics related to teams, players, seasons, and matches.",
    "Health & Medicine": "Diseases, treatment, healthcare, and medical facilities.",
    "Religion & Philosophy": "Talks involving faith, belief systems, philosophical views.",
    "Space & NASA": "Space exploration, NASA missions, satellites, and astronomy.",
    "Cryptography & Security": "Discussions on encryption, digital security, and data protection.",
    "Internet & Networking": "Terms around internet use, FTP, web versions, and networks.",
    "Middle East Politics & Conflicts": "Topics involving Israel, Armenia, conflict regions."
}

# Tokenizer and lemmatizer
tokenizer = TreebankWordTokenizer()
lemmatizer = WordNetLemmatizer()

# --- Utility Functions ---

def preprocess(text):
    text = re.sub(r'\W+', ' ', text.lower())
    tokens = tokenizer.tokenize(text)
    tokens = [lemmatizer.lemmatize(w) for w in tokens if w not in ENGLISH_STOP_WORDS and len(w) > 2 and w.isalpha()]
    return ' '.join(tokens)

def get_topic_keywords(model, vectorizer, topic_idx, top_n=10):
    feature_names = vectorizer.get_feature_names_out()
    topic = model.components_[topic_idx]
    top_indices = topic.argsort()[:-top_n - 1:-1]
    return [feature_names[i] for i in top_indices]

def plot_topic_distribution(distribution, labels):
    plt.figure(figsize=(8, 4))
    plt.bar(range(len(distribution)), distribution, tick_label=labels)
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Probability")
    plt.title("Topic Distribution")
    plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format="png")
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def save_prediction_file(text):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"lda_prediction_{timestamp}.txt"
    with open(filename, "w", encoding="utf-8") as f:
        f.write(text)
    return filename

def cleanup_old_predictions(directory=".", extension=".txt", max_age_minutes=10):
    now = time.time()
    max_age = max_age_minutes * 60
    for fname in os.listdir(directory):
        if fname.endswith(extension) and fname.startswith("lda_prediction_"):
            full_path = os.path.join(directory, fname)
            if os.path.isfile(full_path) and (now - os.path.getmtime(full_path)) > max_age:
                try:
                    os.remove(full_path)
                except Exception as e:
                    print(f"Failed to delete {fname}: {e}")

def download_log():
    zip_filename = "lda_predictions_log.zip"
    with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
        if os.path.exists("lda_predictions_log.csv"):
            zipf.write("lda_predictions_log.csv")
    return zip_filename

def save_feedback(text, feedback):
    timestamp = datetime.datetime.now().isoformat()
    log_entry = pd.DataFrame([{
        "timestamp": timestamp,
        "feedback": feedback,
        "text_excerpt": text[:300].replace('\n', ' ') + "..."
    }])
    feedback_log = "lda_feedback_log.csv"
    log_entry.to_csv(feedback_log, mode='a', header=not os.path.exists(feedback_log), index=False)
    return " Feedback recorded. Thank you!"

# --- Main Prediction Function ---

def predict_topic(text_input, file_input):
    cleanup_old_predictions()

    if file_input is not None:
        text = file_input.read().decode("utf-8")
    elif text_input.strip():
        text = text_input
    else:
        return "Please provide input", None, None

    cleaned = preprocess(text)
    bow = vectorizer.transform([cleaned])
    topic_distribution = lda.transform(bow)[0]
    dominant_topic = topic_distribution.argmax()
    label = auto_labels.get(dominant_topic, f"Topic {dominant_topic+1}")
    top_words = get_topic_keywords(lda, vectorizer, dominant_topic)
    summary = topic_summaries.get(label, "No summary available.")

    # Confidence threshold warning
    confidence_threshold = 0.4
    if topic_distribution[dominant_topic] < confidence_threshold:
        label += " ( Low confidence)"
        summary = " The model is uncertain. Try providing more context or a longer input."

    # Log entry
    timestamp = datetime.datetime.now().isoformat()
    log_entry = pd.DataFrame([{
        "timestamp": timestamp,
        "predicted_topic": label,
        "dominant_topic_index": dominant_topic,
        "top_words": ", ".join(top_words),
        "text_excerpt": text[:300].replace('\n', ' ') + "..."
    }])
    log_path = "lda_predictions_log.csv"
    log_entry.to_csv(log_path, mode='a', header=not os.path.exists(log_path), index=False)

    chart = plot_topic_distribution(topic_distribution, [auto_labels.get(i, f"Topic {i+1}") for i in range(len(topic_distribution))])

    result = f" **Predicted Topic:** {label}\n\n"
    result += f" **Summary:** {summary}\n\n"
    result += f" **Top Words:** {', '.join(top_words)}\n\n"
    result += " **Topic Distribution:**\n"
    for idx, prob in enumerate(topic_distribution):
        tlabel = auto_labels.get(idx, f"Topic {idx+1}")
        result += f"{tlabel}: {prob:.3f}\n"

    prediction_file = save_prediction_file(result)
    return result, chart, prediction_file

# --- Gradio Interface ---

with gr.Blocks() as demo:
    gr.Markdown("## Topic Modeling with LDA")
    gr.Markdown("Upload a `.txt` file or paste in text. See predicted topic, keywords, and a chart.")

    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(lines=10, label=" Paste Text")
            file_input = gr.File(label=" Or Upload a .txt File", file_types=[".txt"])
            predict_btn = gr.Button(" Predict Topic")
            download_btn = gr.Button("⬇ Download All Logs")

            feedback_input = gr.Radio(
                choices=["Accurate", " Inaccurate", "Unclear"],
                label=" Was this prediction useful?",
                interactive=True
            )
            feedback_btn = gr.Button("Submit Feedback")
            feedback_output = gr.Textbox(visible=False)

        with gr.Column():
            output_text = gr.Textbox(label=" Prediction Result")
            output_chart = gr.Image(type="pil", label=" Topic Distribution")
            download_prediction = gr.File(label="⬇ Download This Prediction")

    predict_btn.click(
        fn=predict_topic,
        inputs=[text_input, file_input],
        outputs=[output_text, output_chart, download_prediction]
    )

    download_btn.click(fn=download_log, outputs=[gr.File()])

    feedback_btn.click(
        fn=save_feedback,
        inputs=[text_input, feedback_input],
        outputs=[feedback_output]
    )

demo.launch()