Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import whisper | |
| import PyPDF2 | |
| import docx | |
| from transformers import pipeline | |
| import io | |
| import tempfile | |
| import os | |
| import numpy as np | |
| class TextSummarizer: | |
| def __init__(self): | |
| self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
| # Ensure whisper uses a writable cache directory | |
| cache_dir = "/code/cache" | |
| self.whisper_model = whisper.load_model("base", download_root=cache_dir) | |
| def extract_text_from_pdf(self, pdf_file): | |
| """Extract text from a PDF file object""" | |
| try: | |
| reader = PyPDF2.PdfReader(pdf_file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() or "" | |
| return text | |
| except Exception as e: | |
| return f"Error reading PDF: {str(e)}" | |
| def extract_text_from_docx(self, docx_file): | |
| """Extract text from a DOCX file object""" | |
| try: | |
| doc = docx.Document(docx_file) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| except Exception as e: | |
| return f"Error reading DOCX: {str(e)}" | |
| def process_text_file(self, txt_file): | |
| """Extract text from a TXT file object""" | |
| try: | |
| # The file from Gradio is a temporary file, we can read it directly | |
| with open(txt_file.name, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| except Exception as e: | |
| return f"Error reading TXT file: {str(e)}" | |
| def transcribe_audio(self, audio_file): | |
| """Transcribe audio file to text using Whisper""" | |
| try: | |
| result = self.whisper_model.transcribe(audio_file) | |
| return result["text"] | |
| except Exception as e: | |
| return f"Error transcribing audio: {str(e)}" | |
| def summarize_text(self, text, max_length=150, min_length=50): | |
| """Summarize text using BART model""" | |
| try: | |
| if len(text.strip()) < 50: | |
| return "Text is too short to summarize." | |
| summary = self.summarizer(text, max_length=max_length, min_length=min_length, do_sample=False) | |
| return summary[0]['summary_text'] | |
| except Exception as e: | |
| return f"Error summarizing text: {str(e)}" | |
| def process_file(self, file, summary_length): | |
| """Process uploaded file and return summary""" | |
| if file is None: | |
| return "No file uploaded." | |
| file_path = file.name | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| max_length = {"Short": 100, "Medium": 150, "Long": 250}[summary_length] | |
| min_length = max_length // 3 | |
| text_extractors = { | |
| ".txt": self.process_text_file, | |
| ".pdf": self.extract_text_from_pdf, | |
| ".docx": self.extract_text_from_docx, | |
| } | |
| audio_transcribers = { | |
| ".mp3": self.transcribe_audio, | |
| ".wav": self.transcribe_audio, | |
| ".m4a": self.transcribe_audio, | |
| ".flac": self.transcribe_audio, | |
| } | |
| if file_extension in text_extractors: | |
| text = text_extractors[file_extension](file) | |
| elif file_extension in audio_transcribers: | |
| text = audio_transcribers[file_extension](file_path) | |
| else: | |
| return f"Unsupported file format: {file_extension}" | |
| if isinstance(text, str) and text.startswith("Error"): | |
| return text | |
| summary = self.summarize_text(text, max_length, min_length) | |
| return f"**Original Text Length:** {len(text)} characters\n\n**Summary:**\n{summary}" | |
| def transcribe_stream(self, audio_chunk, current_transcript): | |
| """Transcribe a stream of audio chunks and append to the transcript.""" | |
| if audio_chunk is None: | |
| return current_transcript, current_transcript | |
| try: | |
| sample_rate, data = audio_chunk | |
| # Convert from int16 to float32 | |
| data = data.astype(np.float32) / 32768.0 | |
| # Transcribe the audio chunk | |
| result = self.whisper_model.transcribe(data, fp16=False) | |
| new_text = result['text'] | |
| updated_transcript = current_transcript + new_text + " " | |
| return updated_transcript, updated_transcript | |
| except Exception as e: | |
| return f"Error during transcription: {str(e)}", current_transcript | |
| def convert_file_to_text(self, file): | |
| """Extract text from any supported file format.""" | |
| if file is None: | |
| return "No file uploaded for conversion." | |
| file_path = file.name | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| text_extractors = { | |
| ".txt": self.process_text_file, | |
| ".pdf": self.extract_text_from_pdf, | |
| ".docx": self.extract_text_from_docx, | |
| } | |
| audio_transcribers = { | |
| ".mp3": self.transcribe_audio, | |
| ".wav": self.transcribe_audio, | |
| ".m4a": self.transcribe_audio, | |
| ".flac": self.transcribe_audio, | |
| } | |
| if file_extension in text_extractors: | |
| return text_extractors[file_extension](file) | |
| elif file_extension in audio_transcribers: | |
| return audio_transcribers[file_extension](file_path) | |
| else: | |
| return f"Unsupported file format for conversion: {file_extension}" | |
| def create_interface(): | |
| summarizer = TextSummarizer() | |
| with gr.Blocks(title="Text Summarization Dashboard") as interface: | |
| gr.Markdown("Text Summarization Dashboard") | |
| gr.Markdown("Manage files, and interact with specialized AI agents for various tasks.") | |
| # State component to store the uploaded file | |
| uploaded_file_state = gr.State(None) | |
| with gr.Tabs(): | |
| with gr.TabItem("📄 File Management & Conversion"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Upload File") | |
| file_input = gr.File( | |
| label="Select a file", | |
| file_types=[".txt", ".pdf", ".docx", ".mp3", ".wav", ".m4a", ".flac"] | |
| ) | |
| uploaded_file_name = gr.Textbox(label="Current File", interactive=False) | |
| def store_file(file): | |
| if file: | |
| return file, file.name | |
| return None, "No file uploaded" | |
| file_input.upload( | |
| fn=store_file, | |
| inputs=[file_input], | |
| outputs=[uploaded_file_state, uploaded_file_name] | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Convert to TXT") | |
| gr.Markdown("Supported formats for conversion to .txt: `.pdf`, `.docx`, `.mp3`, `.wav`, `.m4a`, `.flac`") | |
| convert_btn = gr.Button("Convert to TXT", variant="secondary") | |
| conversion_output = gr.Textbox( | |
| label="Conversion Output", | |
| placeholder="Converted text will appear here...", | |
| lines=8, | |
| interactive=False | |
| ) | |
| convert_btn.click( | |
| fn=summarizer.convert_file_to_text, | |
| inputs=[uploaded_file_state], | |
| outputs=[conversion_output] | |
| ) | |
| with gr.TabItem("✍️ Meeting Summarization"): | |
| gr.Markdown("### Meeting Summarization") | |
| gr.Markdown("Generate summaries from your meeting transcripts and other documents.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| summary_length = gr.Dropdown( | |
| choices=["Short", "Medium", "Long"], | |
| value="Medium", | |
| label="Summary Length", | |
| info="Short: ~300 words, Medium: ~500+ words, Long: ~1000+ words" | |
| ) | |
| submit_btn = gr.Button("Generate Summary", variant="primary") | |
| with gr.Column(scale=2): | |
| output = gr.Textbox( | |
| label="Summary Output", | |
| lines=10, | |
| placeholder="Your summary will appear here..." | |
| ) | |
| with gr.Accordion("⚙️ Model Settings", open=False): | |
| gr.Markdown("### Model Selection & Fine-Tuning") | |
| gr.Markdown("Choose different models and configure their parameters.") | |
| with gr.Row(): | |
| gr.Dropdown( | |
| label="Select Summarization Model", | |
| choices=["facebook/bart-large-cnn", "t5-small", "google/pegasus-xsum"], | |
| value="facebook/bart-large-cnn" | |
| ) | |
| with gr.Accordion("Fine-Tuning Options", open=False): | |
| gr.Slider(label="Min Tokens", minimum=10, maximum=200, step=5, value=50) | |
| gr.Slider(label="Max Tokens", minimum=50, maximum=500, step=10, value=150) | |
| gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.1, value=0.7) | |
| gr.Slider(label="Top-K", minimum=0, maximum=100, step=1, value=50, info="0 to disable") | |
| gr.Slider(label="Top-P (Nucleus Sampling)", minimum=0.0, maximum=1.0, step=0.05, value=0.95, info="0 to disable") | |
| gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2) | |
| gr.Slider(label="Number of Beams", minimum=1, maximum=8, step=1, value=4) | |
| with gr.TabItem("🔴 Live Meeting Recording & Summarization"): | |
| gr.Markdown("### Live Meeting Transcription & Summarization") | |
| gr.Markdown("Record audio from your microphone, get a live transcript, and generate a summary.") | |
| live_transcript_state = gr.State("") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| label="Live Audio", | |
| sources="microphone", | |
| streaming=True, | |
| ) | |
| with gr.Column(scale=2): | |
| live_transcript_output = gr.Textbox( | |
| label="Live Transcript", | |
| placeholder="Transcript will appear here...", | |
| lines=15, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| live_summary_length = gr.Dropdown( | |
| choices=["Short", "Medium", "Long"], | |
| value="Medium", | |
| label="Summary Length" | |
| ) | |
| live_summary_btn = gr.Button("Generate Summary", variant="primary") | |
| with gr.Column(scale=2): | |
| live_summary_output = gr.Textbox( | |
| label="Meeting Summary", | |
| placeholder="Summary will appear here...", | |
| lines=5, | |
| ) | |
| audio_input.stream( | |
| fn=summarizer.transcribe_stream, | |
| inputs=[audio_input, live_transcript_state], | |
| outputs=[live_transcript_output, live_transcript_state], | |
| ) | |
| def generate_live_summary(transcript, length_option): | |
| max_len = {"Short": 100, "Medium": 150, "Long": 250}[length_option] | |
| min_len = max_len // 3 | |
| return summarizer.summarize_text(transcript, max_length=max_len, min_length=min_len) | |
| live_summary_btn.click( | |
| fn=generate_live_summary, | |
| inputs=[live_transcript_output, live_summary_length], | |
| outputs=[live_summary_output], | |
| ) | |
| submit_btn.click( | |
| fn=summarizer.process_file, | |
| inputs=[uploaded_file_state, summary_length], | |
| outputs=output | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch(server_name="0.0.0.0", server_port=7860, share=True) | |