"""Template Demo for IBM Granite Hugging Face spaces.""" import os from collections.abc import Iterator from datetime import datetime from pathlib import Path from threading import Thread import gradio as gr import langid import spaces import torch import torchaudio from punctuators.models import PunctCapSegModelONNX from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, TextIteratorStreamer pc_model = PunctCapSegModelONNX.from_pretrained("pcs_en") today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 MODEL_ID = "ibm-granite/granite-speech-3.3-2b" SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. Today's Date: {today_date}. You are Granite, developed by IBM. You are a helpful AI assistant""" TITLE = "IBM Granite Speech 3.3 2B ASR Demo" DESCRIPTION = "Record or upload an audio file and try one of the prompts. Press the play button to transcribe the file." processor = AutoProcessor.from_pretrained(MODEL_ID) tokenizer = processor.tokenizer model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, offload_folder="offload/" ) def delete_file(path: str) -> None: """Delete a file if it exists. Args: path (str): Path to the file to delete. Returns: None """ if path and os.path.exists(path): try: os.remove(path) print(f"Deleted old audio file: {path}") except Exception as e: print(f"Warning: could not delete {path}: {e}") @spaces.GPU def transcribe(audio_file: str, user_prompt: str, prev_file: str) -> Iterator[str]: """Transcribe function for ASR demo. Args: audio_file (str): Name of audio file from the user. user_prompt (str): Instruction from the user (transcription or translation). prev_file (str): Previously uploaded audio file. Returns: str: The generated transcription/translation of the audio file. """ # load wav file wav, sr = torchaudio.load(audio_file, normalize=True) if wav.shape[0] != 1 or sr != 16000: # resample + convert to mono if needed wav = torch.mean(wav, dim=0, keepdim=True) # mono wav = torchaudio.functional.resample(wav, sr, 16000) sr = 16000 # SAFE POINT: new audio is good → delete old audio if different if prev_file != "" and prev_file != audio_file: delete_file(prev_file) # Update prev_file to the *current* file prev_file = audio_file # Build messages chat = [ {"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": f"<|audio|>{user_prompt}"}, ] prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) # run model model_inputs = processor(prompt, wav, device=model.device, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True) kwargs = dict(**model_inputs, streamer=streamer, max_new_tokens=512, do_sample=False, num_beams=1) t = Thread(target=model.generate, kwargs=kwargs) t.start() text = "" for chunk in streamer: text += chunk yield text, prev_file # Apply cap+punct for English-only if langid.classify(text)[0] == "en": text = pc_model.infer([text]) text = " ".join(text[0]).replace("", " ").replace("", " ") # map to space yield text, prev_file css_file_path = Path(Path(__file__).parent / "app.css") head_file_path = Path(Path(__file__).parent / "app_head.html") with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, title=TITLE) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) # State to store the previously uploaded audio file prev_audio = gr.State(value="") with gr.Row(): audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)") with gr.Column(): output_text = gr.Textbox(label="Transcription", lines=5) choices = [ "Transcribe the speech to text", "Translate the speech to French", "Translate the speech to German", "Translate the speech to Spanish", "Translate the speech to Portuguese", ] user_prompt = gr.Dropdown( label="Prompt", choices=choices, interactive=True, allow_custom_value=True, value=choices[0] ) audio_input.play(transcribe, inputs=[audio_input, user_prompt, prev_audio], outputs=[output_text, prev_audio]) if __name__ == "__main__": demo.launch()