|
|
import gradio as gr |
|
|
import librosa |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
|
|
import tempfile |
|
|
import soundfile as sf |
|
|
import json |
|
|
|
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
|
CHUNK_SIZE = 60 |
|
|
STEP = 10 |
|
|
MUSIC_THRESHOLD = 0.5 |
|
|
VOICE_THRESHOLD = 0.3 |
|
|
MIN_SEG_DURATION = 8 |
|
|
|
|
|
|
|
|
print("Loading models...") |
|
|
|
|
|
|
|
|
music_model_id = "AI-Music-Detection/ai_music_detection_large_60s" |
|
|
music_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") |
|
|
music_model = AutoModelForAudioClassification.from_pretrained(music_model_id) |
|
|
|
|
|
|
|
|
voice_model_id = "superb/hubert-large-superb-sid" |
|
|
voice_extractor = AutoFeatureExtractor.from_pretrained(voice_model_id) |
|
|
voice_model = AutoModelForAudioClassification.from_pretrained(voice_model_id) |
|
|
|
|
|
print("✅ Models loaded successfully.") |
|
|
|
|
|
|
|
|
|
|
|
def predict_music_score(wav): |
|
|
"""预测音乐片段概率""" |
|
|
wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE) |
|
|
inputs = music_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = music_model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=-1).squeeze() |
|
|
score = float(probs[-1]) if probs.numel() > 1 else float(probs[0]) |
|
|
return score |
|
|
|
|
|
|
|
|
def predict_voice_score(wav): |
|
|
"""预测语音片段概率""" |
|
|
wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE) |
|
|
inputs = voice_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = voice_model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=-1).squeeze() |
|
|
score = float(probs.mean()) |
|
|
return score |
|
|
|
|
|
|
|
|
def detect_singing(audio_path): |
|
|
wav, _ = librosa.load(audio_path, sr=SAMPLE_RATE) |
|
|
total_len = len(wav) |
|
|
chunk_len = SAMPLE_RATE * CHUNK_SIZE |
|
|
step_len = SAMPLE_RATE * STEP |
|
|
results = [] |
|
|
|
|
|
for start_idx in range(0, max(1, total_len - chunk_len), step_len): |
|
|
end_idx = start_idx + chunk_len |
|
|
snippet = wav[start_idx:end_idx] |
|
|
|
|
|
|
|
|
if len(snippet) < chunk_len: |
|
|
snippet = np.pad(snippet, (0, chunk_len - len(snippet))) |
|
|
|
|
|
|
|
|
music_score = predict_music_score(snippet) |
|
|
voice_score = predict_voice_score(snippet) |
|
|
|
|
|
start_t = start_idx / SAMPLE_RATE |
|
|
end_t = start_t + CHUNK_SIZE |
|
|
|
|
|
if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD: |
|
|
results.append((float(start_t), float(end_t))) |
|
|
|
|
|
|
|
|
merged = [] |
|
|
for seg in results: |
|
|
if not merged or seg[0] > merged[-1][1]: |
|
|
merged.append(list(seg)) |
|
|
else: |
|
|
merged[-1][1] = seg[1] |
|
|
|
|
|
merged = [(s, e) for s, e in merged if e - s >= MIN_SEG_DURATION] |
|
|
return merged |
|
|
|
|
|
|
|
|
|
|
|
def analyze_audio(file_path): |
|
|
if file_path is None: |
|
|
return "⚠️ 请上传音频文件", None |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
data, sr = librosa.load(file_path, sr=SAMPLE_RATE) |
|
|
sf.write(tmp.name, data, sr) |
|
|
segments = detect_singing(tmp.name) |
|
|
|
|
|
if not segments: |
|
|
return "未检测到唱歌片段", json.dumps([], indent=2) |
|
|
|
|
|
json_output = json.dumps( |
|
|
[{"start": s, "end": e, "duration": round(e - s, 2)} for s, e in segments], |
|
|
indent=2 |
|
|
) |
|
|
return f"检测到 {len(segments)} 段唱歌片段", json_output |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="🎵 Singing Segment Detector (Final)") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎤 唱歌片段自动检测器(AI-Music + HuBERT) |
|
|
- 自动检测视频中的演唱时间段 |
|
|
- 采用 `AI-Music-Detection/ai_music_detection_large_60s` + `HuBERT` 双模型融合 |
|
|
- 输出每段的开始、结束时间与时长 |
|
|
""" |
|
|
) |
|
|
|
|
|
audio_input = gr.Audio(type="filepath", label="上传音频(从视频提取)") |
|
|
run_btn = gr.Button("🚀 开始分析") |
|
|
status_box = gr.Textbox(label="分析状态", interactive=False) |
|
|
json_output = gr.Code(label="唱歌片段时间戳(JSON)", language="json") |
|
|
|
|
|
run_btn.click(fn=analyze_audio, inputs=[audio_input], outputs=[status_box, json_output]) |
|
|
|
|
|
demo.launch() |
|
|
|