RTE Build
Deployment
a297a6e
raw
history blame
3.97 kB
"""Template Demo for IBM Granite Hugging Face spaces."""
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
import gradio as gr
import spaces
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, TextIteratorStreamer
import langid
from punctuators.models import PunctCapSegModelONNX
pc_model = PunctCapSegModelONNX.from_pretrained("pcs_en")
from themes.research_monochrome import theme
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
MODEL_ID = "ibm-granite/granite-speech-3.3-2b"
SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite Speech 3.3 2B ASR Demo"
DESCRIPTION = "Record or upload an audio file and try one of the prompts. Press the play button to transcribe the file."
processor = AutoProcessor.from_pretrained(MODEL_ID)
tokenizer = processor.tokenizer
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, offload_folder="offload/"
)
@spaces.GPU
def transcribe(audio_file: str, user_prompt: str) -> Iterator[str]:
"""transcribe function for ASR demo.
Args:
audio_file (str): Name of audio file from the user.
user_prompt (str): Instruction from the user (transcription or translation).
Returns:
str: The generated transcription/translation of the audio file.
"""
# load wav file
wav, sr = torchaudio.load(audio_file, normalize=True)
if wav.shape[0] != 1 or sr != 16000:
# resample + convert to mono if needed
wav = torch.mean(wav, dim=0, keepdim=True) # mono
wav = torchaudio.functional.resample(wav, sr, 16000)
sr = 16000
# Build messages
chat = [
dict(role="system", content=SYS_PROMPT),
dict(role="user", content=f"<|audio|>{user_prompt}"),
]
prompt = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True)
# run model
model_inputs = processor(
prompt,
wav,
device=model.device,
return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=False,
num_beams=1
)
t = Thread(target=model.generate, kwargs=kwargs)
t.start()
text = ""
for chunk in streamer:
text += chunk
yield text
# Apply cap+punct for English-only
if langid.classify(text)[0] == 'en':
text = pc_model.infer([text])
yield " ".join(text[0])
css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Row():
audio_input = gr.Audio(type="filepath",
label="Upload Audio (16kHz mono preferred)")
with gr.Column():
output_text = gr.Textbox(label="Transcription", lines=5)
choices = [
"Transcribe the speech to text",
"Translate the speech to French",
"Translate the speech to German",
"Translate the speech to Spanish",
"Translate the speech to Portuguese"
]
user_prompt = gr.Dropdown(label="Prompt", choices=choices, interactive=True, allow_custom_value=True, value=choices[0])
audio_input.play(
transcribe,
inputs=[
audio_input,
user_prompt],
outputs=output_text)
if __name__ == "__main__":
demo.launch()