Spaces:
Running
on
Zero
Running
on
Zero
| """Template Demo for IBM Granite Hugging Face spaces.""" | |
| from collections.abc import Iterator | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import torchaudio | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, TextIteratorStreamer | |
| import langid | |
| from punctuators.models import PunctCapSegModelONNX | |
| pc_model = PunctCapSegModelONNX.from_pretrained("pcs_en") | |
| from themes.research_monochrome import theme | |
| today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 | |
| MODEL_ID = "ibm-granite/granite-speech-3.3-2b" | |
| SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. | |
| Today's Date: {today_date}. | |
| You are Granite, developed by IBM. You are a helpful AI assistant""" | |
| TITLE = "IBM Granite Speech 3.3 2B ASR Demo" | |
| DESCRIPTION = "Record or upload an audio file and try one of the prompts. Press the play button to transcribe the file." | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| tokenizer = processor.tokenizer | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, offload_folder="offload/" | |
| ) | |
| def transcribe(audio_file: str, user_prompt: str) -> Iterator[str]: | |
| """transcribe function for ASR demo. | |
| Args: | |
| audio_file (str): Name of audio file from the user. | |
| user_prompt (str): Instruction from the user (transcription or translation). | |
| Returns: | |
| str: The generated transcription/translation of the audio file. | |
| """ | |
| # load wav file | |
| wav, sr = torchaudio.load(audio_file, normalize=True) | |
| if wav.shape[0] != 1 or sr != 16000: | |
| # resample + convert to mono if needed | |
| wav = torch.mean(wav, dim=0, keepdim=True) # mono | |
| wav = torchaudio.functional.resample(wav, sr, 16000) | |
| sr = 16000 | |
| # Build messages | |
| chat = [ | |
| dict(role="system", content=SYS_PROMPT), | |
| dict(role="user", content=f"<|audio|>{user_prompt}"), | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| chat, tokenize=False, add_generation_prompt=True) | |
| # run model | |
| model_inputs = processor( | |
| prompt, | |
| wav, | |
| device=model.device, | |
| return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True) | |
| kwargs = dict( | |
| **model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| num_beams=1 | |
| ) | |
| t = Thread(target=model.generate, kwargs=kwargs) | |
| t.start() | |
| text = "" | |
| for chunk in streamer: | |
| text += chunk | |
| yield text | |
| # Apply cap+punct for English-only | |
| if langid.classify(text)[0] == 'en': | |
| text = pc_model.infer([text]) | |
| yield " ".join(text[0]) | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", | |
| label="Upload Audio (16kHz mono preferred)") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Transcription", lines=5) | |
| choices = [ | |
| "Transcribe the speech to text", | |
| "Translate the speech to French", | |
| "Translate the speech to German", | |
| "Translate the speech to Spanish", | |
| "Translate the speech to Portuguese" | |
| ] | |
| user_prompt = gr.Dropdown(label="Prompt", choices=choices, interactive=True, allow_custom_value=True, value=choices[0]) | |
| audio_input.play( | |
| transcribe, | |
| inputs=[ | |
| audio_input, | |
| user_prompt], | |
| outputs=output_text) | |
| if __name__ == "__main__": | |
| demo.launch() | |