karesaeedff's picture
Update app.py
71f519c verified
import gradio as gr
import librosa
import numpy as np
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import tempfile
import soundfile as sf
import json
# === 参数设置 ===
SAMPLE_RATE = 16000
CHUNK_SIZE = 60 # 模型输入60秒
STEP = 10 # 滑动步长
MUSIC_THRESHOLD = 0.5
VOICE_THRESHOLD = 0.3
MIN_SEG_DURATION = 8 # 最小唱段长度(秒)
# === 模型加载 ===
print("Loading models...")
# 🎵 音乐检测模型(AST架构)
music_model_id = "AI-Music-Detection/ai_music_detection_large_60s"
music_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
music_model = AutoModelForAudioClassification.from_pretrained(music_model_id)
# 🗣️ 语音活动检测模型(HuBERT)
voice_model_id = "superb/hubert-large-superb-sid"
voice_extractor = AutoFeatureExtractor.from_pretrained(voice_model_id)
voice_model = AutoModelForAudioClassification.from_pretrained(voice_model_id)
print("✅ Models loaded successfully.")
# === 模型推理函数 ===
def predict_music_score(wav):
"""预测音乐片段概率"""
wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
inputs = music_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = music_model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).squeeze()
score = float(probs[-1]) if probs.numel() > 1 else float(probs[0])
return score
def predict_voice_score(wav):
"""预测语音片段概率"""
wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
inputs = voice_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = voice_model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).squeeze()
score = float(probs.mean()) # 平均各类别概率
return score
def detect_singing(audio_path):
wav, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
total_len = len(wav)
chunk_len = SAMPLE_RATE * CHUNK_SIZE
step_len = SAMPLE_RATE * STEP
results = []
for start_idx in range(0, max(1, total_len - chunk_len), step_len):
end_idx = start_idx + chunk_len
snippet = wav[start_idx:end_idx]
# --- 修正:强制长度一致 ---
if len(snippet) < chunk_len:
snippet = np.pad(snippet, (0, chunk_len - len(snippet)))
# --- 模型推理 ---
music_score = predict_music_score(snippet)
voice_score = predict_voice_score(snippet)
start_t = start_idx / SAMPLE_RATE
end_t = start_t + CHUNK_SIZE
if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD:
results.append((float(start_t), float(end_t)))
# 合并相邻片段
merged = []
for seg in results:
if not merged or seg[0] > merged[-1][1]:
merged.append(list(seg))
else:
merged[-1][1] = seg[1]
merged = [(s, e) for s, e in merged if e - s >= MIN_SEG_DURATION]
return merged
# === 主推理函数 ===
def analyze_audio(file_path):
if file_path is None:
return "⚠️ 请上传音频文件", None
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
data, sr = librosa.load(file_path, sr=SAMPLE_RATE)
sf.write(tmp.name, data, sr)
segments = detect_singing(tmp.name)
if not segments:
return "未检测到唱歌片段", json.dumps([], indent=2)
json_output = json.dumps(
[{"start": s, "end": e, "duration": round(e - s, 2)} for s, e in segments],
indent=2
)
return f"检测到 {len(segments)} 段唱歌片段", json_output
# === Gradio UI ===
with gr.Blocks(title="🎵 Singing Segment Detector (Final)") as demo:
gr.Markdown(
"""
# 🎤 唱歌片段自动检测器(AI-Music + HuBERT)
- 自动检测视频中的演唱时间段
- 采用 `AI-Music-Detection/ai_music_detection_large_60s` + `HuBERT` 双模型融合
- 输出每段的开始、结束时间与时长
"""
)
audio_input = gr.Audio(type="filepath", label="上传音频(从视频提取)")
run_btn = gr.Button("🚀 开始分析")
status_box = gr.Textbox(label="分析状态", interactive=False)
json_output = gr.Code(label="唱歌片段时间戳(JSON)", language="json")
run_btn.click(fn=analyze_audio, inputs=[audio_input], outputs=[status_box, json_output])
demo.launch()