Files changed (2) hide show
  1. .pre-commit-config.yaml +0 -1
  2. src/app.py +41 -57
.pre-commit-config.yaml CHANGED
@@ -49,4 +49,3 @@ repos:
49
  - id: poetry-export
50
  name: poetry export for base requirements
51
  args: [-f, requirements.txt, -o, requirements.txt, -n, --only=main, --without-hashes]
52
- stages: [manual]
 
49
  - id: poetry-export
50
  name: poetry export for base requirements
51
  args: [-f, requirements.txt, -o, requirements.txt, -n, --only=main, --without-hashes]
 
src/app.py CHANGED
@@ -1,21 +1,21 @@
1
  """Template Demo for IBM Granite Hugging Face spaces."""
2
 
3
- import os
4
  from collections.abc import Iterator
5
  from datetime import datetime
6
  from pathlib import Path
7
  from threading import Thread
8
 
9
  import gradio as gr
10
- import langid
11
  import spaces
12
  import torch
13
  import torchaudio
 
 
14
  from punctuators.models import PunctCapSegModelONNX
15
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, TextIteratorStreamer
16
-
17
  pc_model = PunctCapSegModelONNX.from_pretrained("pcs_en")
18
 
 
 
19
  today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
20
 
21
  MODEL_ID = "ibm-granite/granite-speech-3.3-2b"
@@ -31,36 +31,18 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(
31
  MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, offload_folder="offload/"
32
  )
33
 
34
-
35
- def delete_file(path: str) -> None:
36
- """Delete a file if it exists.
37
-
38
- Args:
39
- path (str): Path to the file to delete.
40
-
41
- Returns:
42
- None
43
- """
44
- if path and os.path.exists(path):
45
- try:
46
- os.remove(path)
47
- print(f"Deleted old audio file: {path}")
48
- except Exception as e:
49
- print(f"Warning: could not delete {path}: {e}")
50
-
51
-
52
  @spaces.GPU
53
- def transcribe(audio_file: str, user_prompt: str, prev_file: str) -> Iterator[str]:
54
- """Transcribe function for ASR demo.
55
 
56
  Args:
57
  audio_file (str): Name of audio file from the user.
58
  user_prompt (str): Instruction from the user (transcription or translation).
59
- prev_file (str): Previously uploaded audio file.
60
 
61
  Returns:
62
  str: The generated transcription/translation of the audio file.
63
  """
 
64
  # load wav file
65
  wav, sr = torchaudio.load(audio_file, normalize=True)
66
  if wav.shape[0] != 1 or sr != 16000:
@@ -68,40 +50,42 @@ def transcribe(audio_file: str, user_prompt: str, prev_file: str) -> Iterator[st
68
  wav = torch.mean(wav, dim=0, keepdim=True) # mono
69
  wav = torchaudio.functional.resample(wav, sr, 16000)
70
  sr = 16000
71
-
72
- # SAFE POINT: new audio is good → delete old audio if different
73
- if prev_file != "" and prev_file != audio_file:
74
- delete_file(prev_file)
75
-
76
- # Update prev_file to the *current* file
77
- prev_file = audio_file
78
-
79
  # Build messages
80
  chat = [
81
- {"role": "system", "content": SYS_PROMPT},
82
- {"role": "user", "content": f"<|audio|>{user_prompt}"},
83
  ]
84
- prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
 
85
 
86
  # run model
87
- model_inputs = processor(prompt, wav, device=model.device, return_tensors="pt").to(model.device)
 
 
 
 
88
  streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
89
-
90
- kwargs = dict(**model_inputs, streamer=streamer, max_new_tokens=512, do_sample=False, num_beams=1)
 
 
 
 
 
 
91
  t = Thread(target=model.generate, kwargs=kwargs)
92
  t.start()
93
 
94
  text = ""
95
  for chunk in streamer:
96
  text += chunk
97
- yield text, prev_file
98
 
99
  # Apply cap+punct for English-only
100
- if langid.classify(text)[0] == "en":
101
  text = pc_model.infer([text])
102
- text = " ".join(text[0]).replace("<unk>", " ").replace("<Unk>", " ") # map <unk> to space
103
- yield text, prev_file
104
-
105
 
106
  css_file_path = Path(Path(__file__).parent / "app.css")
107
  head_file_path = Path(Path(__file__).parent / "app_head.html")
@@ -109,25 +93,25 @@ head_file_path = Path(Path(__file__).parent / "app_head.html")
109
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, title=TITLE) as demo:
110
  gr.Markdown(f"# {TITLE}")
111
  gr.Markdown(DESCRIPTION)
112
-
113
- # State to store the previously uploaded audio file
114
- prev_audio = gr.State(value="")
115
-
116
  with gr.Row():
117
- audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)")
 
118
  with gr.Column():
119
  output_text = gr.Textbox(label="Transcription", lines=5)
120
- choices = [
121
  "Transcribe the speech to text",
122
- "Translate the speech to French",
123
- "Translate the speech to German",
124
  "Translate the speech to Spanish",
125
- "Translate the speech to Portuguese",
126
  ]
127
- user_prompt = gr.Dropdown(
128
- label="Prompt", choices=choices, interactive=True, allow_custom_value=True, value=choices[0]
129
- )
130
- audio_input.play(transcribe, inputs=[audio_input, user_prompt, prev_audio], outputs=[output_text, prev_audio])
 
 
 
131
 
132
  if __name__ == "__main__":
133
  demo.launch()
 
1
  """Template Demo for IBM Granite Hugging Face spaces."""
2
 
 
3
  from collections.abc import Iterator
4
  from datetime import datetime
5
  from pathlib import Path
6
  from threading import Thread
7
 
8
  import gradio as gr
 
9
  import spaces
10
  import torch
11
  import torchaudio
12
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, TextIteratorStreamer
13
+ import langid
14
  from punctuators.models import PunctCapSegModelONNX
 
 
15
  pc_model = PunctCapSegModelONNX.from_pretrained("pcs_en")
16
 
17
+ from themes.research_monochrome import theme
18
+
19
  today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
20
 
21
  MODEL_ID = "ibm-granite/granite-speech-3.3-2b"
 
31
  MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, offload_folder="offload/"
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @spaces.GPU
35
+ def transcribe(audio_file: str, user_prompt: str) -> Iterator[str]:
36
+ """transcribe function for ASR demo.
37
 
38
  Args:
39
  audio_file (str): Name of audio file from the user.
40
  user_prompt (str): Instruction from the user (transcription or translation).
 
41
 
42
  Returns:
43
  str: The generated transcription/translation of the audio file.
44
  """
45
+
46
  # load wav file
47
  wav, sr = torchaudio.load(audio_file, normalize=True)
48
  if wav.shape[0] != 1 or sr != 16000:
 
50
  wav = torch.mean(wav, dim=0, keepdim=True) # mono
51
  wav = torchaudio.functional.resample(wav, sr, 16000)
52
  sr = 16000
53
+
 
 
 
 
 
 
 
54
  # Build messages
55
  chat = [
56
+ dict(role="system", content=SYS_PROMPT),
57
+ dict(role="user", content=f"<|audio|>{user_prompt}"),
58
  ]
59
+ prompt = tokenizer.apply_chat_template(
60
+ chat, tokenize=False, add_generation_prompt=True)
61
 
62
  # run model
63
+ model_inputs = processor(
64
+ prompt,
65
+ wav,
66
+ device=model.device,
67
+ return_tensors="pt").to(model.device)
68
  streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
69
+
70
+ kwargs = dict(
71
+ **model_inputs,
72
+ streamer=streamer,
73
+ max_new_tokens=512,
74
+ do_sample=False,
75
+ num_beams=1
76
+ )
77
  t = Thread(target=model.generate, kwargs=kwargs)
78
  t.start()
79
 
80
  text = ""
81
  for chunk in streamer:
82
  text += chunk
83
+ yield text
84
 
85
  # Apply cap+punct for English-only
86
+ if langid.classify(text)[0] == 'en':
87
  text = pc_model.infer([text])
88
+ yield " ".join(text[0])
 
 
89
 
90
  css_file_path = Path(Path(__file__).parent / "app.css")
91
  head_file_path = Path(Path(__file__).parent / "app_head.html")
 
93
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, title=TITLE) as demo:
94
  gr.Markdown(f"# {TITLE}")
95
  gr.Markdown(DESCRIPTION)
 
 
 
 
96
  with gr.Row():
97
+ audio_input = gr.Audio(type="filepath",
98
+ label="Upload Audio (16kHz mono preferred)")
99
  with gr.Column():
100
  output_text = gr.Textbox(label="Transcription", lines=5)
101
+ choices = [
102
  "Transcribe the speech to text",
103
+ "Translate the speech to French",
104
+ "Translate the speech to German",
105
  "Translate the speech to Spanish",
106
+ "Translate the speech to Portuguese"
107
  ]
108
+ user_prompt = gr.Dropdown(label="Prompt", choices=choices, interactive=True, allow_custom_value=True, value=choices[0])
109
+ audio_input.play(
110
+ transcribe,
111
+ inputs=[
112
+ audio_input,
113
+ user_prompt],
114
+ outputs=output_text)
115
 
116
  if __name__ == "__main__":
117
  demo.launch()