Spaces:
Running
on
Zero
Running
on
Zero
| """Template Demo for IBM Granite Hugging Face spaces.""" | |
| import os | |
| from collections.abc import Iterator | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import langid | |
| import spaces | |
| import torch | |
| import torchaudio | |
| from punctuators.models import PunctCapSegModelONNX | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, TextIteratorStreamer | |
| pc_model = PunctCapSegModelONNX.from_pretrained("pcs_en") | |
| today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 | |
| MODEL_ID = "ibm-granite/granite-speech-3.3-2b" | |
| SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. | |
| Today's Date: {today_date}. | |
| You are Granite, developed by IBM. You are a helpful AI assistant""" | |
| TITLE = "IBM Granite Speech 3.3 2B ASR Demo" | |
| DESCRIPTION = "Record or upload an audio file and try one of the prompts. Press the play button to transcribe the file." | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| tokenizer = processor.tokenizer | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, offload_folder="offload/" | |
| ) | |
| def delete_file(path: str) -> None: | |
| """Delete a file if it exists. | |
| Args: | |
| path (str): Path to the file to delete. | |
| Returns: | |
| None | |
| """ | |
| if path and os.path.exists(path): | |
| try: | |
| os.remove(path) | |
| print(f"Deleted old audio file: {path}") | |
| except Exception as e: | |
| print(f"Warning: could not delete {path}: {e}") | |
| def transcribe(audio_file: str, user_prompt: str, prev_file: str) -> Iterator[str]: | |
| """Transcribe function for ASR demo. | |
| Args: | |
| audio_file (str): Name of audio file from the user. | |
| user_prompt (str): Instruction from the user (transcription or translation). | |
| prev_file (str): Previously uploaded audio file. | |
| Returns: | |
| str: The generated transcription/translation of the audio file. | |
| """ | |
| # load wav file | |
| wav, sr = torchaudio.load(audio_file, normalize=True) | |
| if wav.shape[0] != 1 or sr != 16000: | |
| # resample + convert to mono if needed | |
| wav = torch.mean(wav, dim=0, keepdim=True) # mono | |
| wav = torchaudio.functional.resample(wav, sr, 16000) | |
| sr = 16000 | |
| # SAFE POINT: new audio is good → delete old audio if different | |
| if prev_file != "" and prev_file != audio_file: | |
| delete_file(prev_file) | |
| # Update prev_file to the *current* file | |
| prev_file = audio_file | |
| # Build messages | |
| chat = [ | |
| {"role": "system", "content": SYS_PROMPT}, | |
| {"role": "user", "content": f"<|audio|>{user_prompt}"}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
| # run model | |
| model_inputs = processor(prompt, wav, device=model.device, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True) | |
| kwargs = dict(**model_inputs, streamer=streamer, max_new_tokens=512, do_sample=False, num_beams=1) | |
| t = Thread(target=model.generate, kwargs=kwargs) | |
| t.start() | |
| text = "" | |
| for chunk in streamer: | |
| text += chunk | |
| yield text, prev_file | |
| # Apply cap+punct for English-only | |
| if langid.classify(text)[0] == "en": | |
| text = pc_model.infer([text]) | |
| text = " ".join(text[0]).replace("<unk>", " ").replace("<Unk>", " ") # map <unk> to space | |
| yield text, prev_file | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| # State to store the previously uploaded audio file | |
| prev_audio = gr.State(value="") | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Transcription", lines=5) | |
| choices = [ | |
| "Transcribe the speech to text", | |
| "Translate the speech to French", | |
| "Translate the speech to German", | |
| "Translate the speech to Spanish", | |
| "Translate the speech to Portuguese", | |
| ] | |
| user_prompt = gr.Dropdown( | |
| label="Prompt", choices=choices, interactive=True, allow_custom_value=True, value=choices[0] | |
| ) | |
| audio_input.play(transcribe, inputs=[audio_input, user_prompt, prev_audio], outputs=[output_text, prev_audio]) | |
| if __name__ == "__main__": | |
| demo.launch() | |