karesaeedff commited on
Commit
fc9c607
·
verified ·
1 Parent(s): cfb6b7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -30
app.py CHANGED
@@ -2,33 +2,39 @@ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
5
- from transformers import pipeline, AutoModelForAudioClassification, AutoFeatureExtractor
6
- from tqdm import tqdm
7
  import tempfile
8
- import json
9
  import soundfile as sf
 
10
 
11
- # ==== 参数 ====
12
  SAMPLE_RATE = 16000
13
- WINDOW = 5
14
- STEP = 2
15
  MUSIC_THRESHOLD = 0.4
16
  VOICE_THRESHOLD = 0.3
17
  MIN_SING_DURATION = 8
18
 
19
- # ==== 模型加载 ====
20
  music_model_id = "AI-Music-Detection/ai_music_detection_large_60s"
21
  music_feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
22
  music_model = AutoModelForAudioClassification.from_pretrained(music_model_id)
23
- music_pipe = pipeline(
24
- task="audio-classification",
25
- model=music_model,
26
- feature_extractor=music_feature_extractor
27
- )
28
- voice_pipe = pipeline(
29
- "audio-classification",
30
- model="superb/hubert-large-superb-sid"
31
- )
 
 
 
 
 
 
 
 
32
 
33
  def detect_singing(audio_path):
34
  wav, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
@@ -39,27 +45,22 @@ def detect_singing(audio_path):
39
  end = start + WINDOW
40
  snippet = wav[int(start * SAMPLE_RATE):int(end * SAMPLE_RATE)]
41
 
42
- # === 修复:AST模型要求固定60秒输入 ===
43
  max_len = SAMPLE_RATE * 60
44
- if len(snippet) < max_len:
45
- pad = np.zeros(max_len)
46
- pad[:len(snippet)] = snippet
47
- snippet = pad
48
- elif len(snippet) > max_len:
49
  snippet = snippet[:max_len]
50
 
51
- # === 音乐检测 ===
52
- music_pred = music_pipe(snippet, sampling_rate=SAMPLE_RATE)
53
- music_score = max([p['score'] for p in music_pred if 'music' in p['label'].lower()] or [0])
54
 
55
- # === 人声检测 ===
56
  voice_pred = voice_pipe(snippet, sampling_rate=SAMPLE_RATE)
57
  voice_score = max([p['score'] for p in voice_pred if 'speech' in p['label'].lower()] or [0])
58
 
59
  if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD:
60
  results.append((float(start), float(end)))
61
 
62
- # === 合并连续片段 ===
63
  merged = []
64
  for seg in results:
65
  if not merged or seg[0] > merged[-1][1]:
@@ -74,8 +75,7 @@ def analyze_audio(file):
74
  if file is None:
75
  return "请上传音频文件", None
76
 
77
- audio_path = file # type="filepath" 返回的是路径字符串
78
-
79
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
80
  data, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
81
  sf.write(tmp.name, data, sr)
@@ -91,7 +91,6 @@ def analyze_audio(file):
91
  return f"检测到 {len(segments)} 段唱歌片段", json_output
92
 
93
 
94
- # ==== Gradio UI ====
95
  with gr.Blocks(title="🎵 Singing Segment Detector") as demo:
96
  gr.Markdown("# 🎤 自动识别唱歌片段\n上传音频文件(从视频提取后),返回检测到的唱歌时间段 JSON。")
97
  audio_in = gr.Audio(type="filepath", label="上传音频文件(WAV)")
 
2
  import librosa
3
  import numpy as np
4
  import torch
5
+ from transformers import AutoModelForAudioClassification, AutoFeatureExtractor, pipeline
 
6
  import tempfile
 
7
  import soundfile as sf
8
+ import json
9
 
 
10
  SAMPLE_RATE = 16000
11
+ WINDOW = 10
12
+ STEP = 5
13
  MUSIC_THRESHOLD = 0.4
14
  VOICE_THRESHOLD = 0.3
15
  MIN_SING_DURATION = 8
16
 
17
+ # === 模型加载 ===
18
  music_model_id = "AI-Music-Detection/ai_music_detection_large_60s"
19
  music_feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
20
  music_model = AutoModelForAudioClassification.from_pretrained(music_model_id)
21
+ voice_pipe = pipeline("audio-classification", model="superb/hubert-large-superb-sid")
22
+
23
+ def predict_music_score(snippet):
24
+ """
25
+ 直接手动跑 feature_extractor + model
26
+ 避免 pipeline 自动切片问题
27
+ """
28
+ inputs = music_feature_extractor(snippet, sampling_rate=SAMPLE_RATE, return_tensors="pt", truncation=True, padding="max_length")
29
+ with torch.no_grad():
30
+ outputs = music_model(**inputs)
31
+ scores = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
32
+ labels = music_model.config.id2label
33
+ label_scores = {labels[i].lower(): float(scores[i]) for i in range(len(scores))}
34
+ # 找 music 或 singing 相关标签
35
+ music_score = max([v for k, v in label_scores.items() if "music" in k or "sing" in k] or [0])
36
+ return music_score
37
+
38
 
39
  def detect_singing(audio_path):
40
  wav, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
 
45
  end = start + WINDOW
46
  snippet = wav[int(start * SAMPLE_RATE):int(end * SAMPLE_RATE)]
47
 
48
+ # 模型输入安全长度
49
  max_len = SAMPLE_RATE * 60
50
+ if len(snippet) < SAMPLE_RATE * 3: # 过短片段跳过
51
+ continue
52
+ if len(snippet) > max_len:
 
 
53
  snippet = snippet[:max_len]
54
 
55
+ music_score = predict_music_score(snippet)
 
 
56
 
 
57
  voice_pred = voice_pipe(snippet, sampling_rate=SAMPLE_RATE)
58
  voice_score = max([p['score'] for p in voice_pred if 'speech' in p['label'].lower()] or [0])
59
 
60
  if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD:
61
  results.append((float(start), float(end)))
62
 
63
+ # 合并连续窗口
64
  merged = []
65
  for seg in results:
66
  if not merged or seg[0] > merged[-1][1]:
 
75
  if file is None:
76
  return "请上传音频文件", None
77
 
78
+ audio_path = file
 
79
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
80
  data, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
81
  sf.write(tmp.name, data, sr)
 
91
  return f"检测到 {len(segments)} 段唱歌片段", json_output
92
 
93
 
 
94
  with gr.Blocks(title="🎵 Singing Segment Detector") as demo:
95
  gr.Markdown("# 🎤 自动识别唱歌片段\n上传音频文件(从视频提取后),返回检测到的唱歌时间段 JSON。")
96
  audio_in = gr.Audio(type="filepath", label="上传音频文件(WAV)")