File size: 4,559 Bytes
465ffc9 621b172 465ffc9 fc9c607 465ffc9 621b172 465ffc9 621b172 57c5b85 465ffc9 621b172 465ffc9 621b172 465ffc9 5c718c3 465ffc9 57c5b85 621b172 57c5b85 621b172 57c5b85 621b172 57c5b85 621b172 57c5b85 621b172 fc9c607 621b172 fc9c607 57c5b85 621b172 57c5b85 621b172 57c5b85 621b172 465ffc9 621b172 71f519c 465ffc9 71f519c 465ffc9 71f519c fc9c607 57c5b85 465ffc9 71f519c 465ffc9 71f519c 465ffc9 71f519c 465ffc9 71f519c 465ffc9 71f519c 57c5b85 465ffc9 621b172 465ffc9 621b172 465ffc9 57c5b85 465ffc9 621b172 57c5b85 621b172 57c5b85 621b172 465ffc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import librosa
import numpy as np
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import tempfile
import soundfile as sf
import json
# === 参数设置 ===
SAMPLE_RATE = 16000
CHUNK_SIZE = 60 # 模型输入60秒
STEP = 10 # 滑动步长
MUSIC_THRESHOLD = 0.5
VOICE_THRESHOLD = 0.3
MIN_SEG_DURATION = 8 # 最小唱段长度(秒)
# === 模型加载 ===
print("Loading models...")
# 🎵 音乐检测模型(AST架构)
music_model_id = "AI-Music-Detection/ai_music_detection_large_60s"
music_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
music_model = AutoModelForAudioClassification.from_pretrained(music_model_id)
# 🗣️ 语音活动检测模型(HuBERT)
voice_model_id = "superb/hubert-large-superb-sid"
voice_extractor = AutoFeatureExtractor.from_pretrained(voice_model_id)
voice_model = AutoModelForAudioClassification.from_pretrained(voice_model_id)
print("✅ Models loaded successfully.")
# === 模型推理函数 ===
def predict_music_score(wav):
"""预测音乐片段概率"""
wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
inputs = music_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = music_model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).squeeze()
score = float(probs[-1]) if probs.numel() > 1 else float(probs[0])
return score
def predict_voice_score(wav):
"""预测语音片段概率"""
wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
inputs = voice_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = voice_model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).squeeze()
score = float(probs.mean()) # 平均各类别概率
return score
def detect_singing(audio_path):
wav, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
total_len = len(wav)
chunk_len = SAMPLE_RATE * CHUNK_SIZE
step_len = SAMPLE_RATE * STEP
results = []
for start_idx in range(0, max(1, total_len - chunk_len), step_len):
end_idx = start_idx + chunk_len
snippet = wav[start_idx:end_idx]
# --- 修正:强制长度一致 ---
if len(snippet) < chunk_len:
snippet = np.pad(snippet, (0, chunk_len - len(snippet)))
# --- 模型推理 ---
music_score = predict_music_score(snippet)
voice_score = predict_voice_score(snippet)
start_t = start_idx / SAMPLE_RATE
end_t = start_t + CHUNK_SIZE
if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD:
results.append((float(start_t), float(end_t)))
# 合并相邻片段
merged = []
for seg in results:
if not merged or seg[0] > merged[-1][1]:
merged.append(list(seg))
else:
merged[-1][1] = seg[1]
merged = [(s, e) for s, e in merged if e - s >= MIN_SEG_DURATION]
return merged
# === 主推理函数 ===
def analyze_audio(file_path):
if file_path is None:
return "⚠️ 请上传音频文件", None
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
data, sr = librosa.load(file_path, sr=SAMPLE_RATE)
sf.write(tmp.name, data, sr)
segments = detect_singing(tmp.name)
if not segments:
return "未检测到唱歌片段", json.dumps([], indent=2)
json_output = json.dumps(
[{"start": s, "end": e, "duration": round(e - s, 2)} for s, e in segments],
indent=2
)
return f"检测到 {len(segments)} 段唱歌片段", json_output
# === Gradio UI ===
with gr.Blocks(title="🎵 Singing Segment Detector (Final)") as demo:
gr.Markdown(
"""
# 🎤 唱歌片段自动检测器(AI-Music + HuBERT)
- 自动检测视频中的演唱时间段
- 采用 `AI-Music-Detection/ai_music_detection_large_60s` + `HuBERT` 双模型融合
- 输出每段的开始、结束时间与时长
"""
)
audio_input = gr.Audio(type="filepath", label="上传音频(从视频提取)")
run_btn = gr.Button("🚀 开始分析")
status_box = gr.Textbox(label="分析状态", interactive=False)
json_output = gr.Code(label="唱歌片段时间戳(JSON)", language="json")
run_btn.click(fn=analyze_audio, inputs=[audio_input], outputs=[status_box, json_output])
demo.launch()
|