RTE Build
Deployment
1ea7f1e
"""Template Demo for IBM Granite Hugging Face spaces."""
import os
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
import gradio as gr
import langid
import spaces
import torch
import torchaudio
from punctuators.models import PunctCapSegModelONNX
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, TextIteratorStreamer
pc_model = PunctCapSegModelONNX.from_pretrained("pcs_en")
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
MODEL_ID = "ibm-granite/granite-speech-3.3-2b"
SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite Speech 3.3 2B ASR Demo"
DESCRIPTION = "Record or upload an audio file and try one of the prompts. Press the play button to transcribe the file."
processor = AutoProcessor.from_pretrained(MODEL_ID)
tokenizer = processor.tokenizer
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, offload_folder="offload/"
)
def delete_file(path: str) -> None:
"""Delete a file if it exists.
Args:
path (str): Path to the file to delete.
Returns:
None
"""
if path and os.path.exists(path):
try:
os.remove(path)
print(f"Deleted old audio file: {path}")
except Exception as e:
print(f"Warning: could not delete {path}: {e}")
@spaces.GPU
def transcribe(audio_file: str, user_prompt: str, prev_file: str) -> Iterator[str]:
"""Transcribe function for ASR demo.
Args:
audio_file (str): Name of audio file from the user.
user_prompt (str): Instruction from the user (transcription or translation).
prev_file (str): Previously uploaded audio file.
Returns:
str: The generated transcription/translation of the audio file.
"""
# load wav file
wav, sr = torchaudio.load(audio_file, normalize=True)
if wav.shape[0] != 1 or sr != 16000:
# resample + convert to mono if needed
wav = torch.mean(wav, dim=0, keepdim=True) # mono
wav = torchaudio.functional.resample(wav, sr, 16000)
sr = 16000
# SAFE POINT: new audio is good → delete old audio if different
if prev_file != "" and prev_file != audio_file:
delete_file(prev_file)
# Update prev_file to the *current* file
prev_file = audio_file
# Build messages
chat = [
{"role": "system", "content": SYS_PROMPT},
{"role": "user", "content": f"<|audio|>{user_prompt}"},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# run model
model_inputs = processor(prompt, wav, device=model.device, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(**model_inputs, streamer=streamer, max_new_tokens=512, do_sample=False, num_beams=1)
t = Thread(target=model.generate, kwargs=kwargs)
t.start()
text = ""
for chunk in streamer:
text += chunk
yield text, prev_file
# Apply cap+punct for English-only
if langid.classify(text)[0] == "en":
text = pc_model.infer([text])
text = " ".join(text[0]).replace("<unk>", " ").replace("<Unk>", " ") # map <unk> to space
yield text, prev_file
css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
# State to store the previously uploaded audio file
prev_audio = gr.State(value="")
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)")
with gr.Column():
output_text = gr.Textbox(label="Transcription", lines=5)
choices = [
"Transcribe the speech to text",
"Translate the speech to French",
"Translate the speech to German",
"Translate the speech to Spanish",
"Translate the speech to Portuguese",
]
user_prompt = gr.Dropdown(
label="Prompt", choices=choices, interactive=True, allow_custom_value=True, value=choices[0]
)
audio_input.play(transcribe, inputs=[audio_input, user_prompt, prev_audio], outputs=[output_text, prev_audio])
if __name__ == "__main__":
demo.launch()