File size: 4,559 Bytes
465ffc9
 
 
 
621b172
465ffc9
 
fc9c607
465ffc9
621b172
465ffc9
621b172
 
57c5b85
465ffc9
621b172
465ffc9
621b172
 
 
 
465ffc9
5c718c3
465ffc9
57c5b85
621b172
57c5b85
621b172
57c5b85
 
621b172
 
 
 
57c5b85
621b172
57c5b85
621b172
fc9c607
 
621b172
 
 
 
fc9c607
57c5b85
621b172
57c5b85
621b172
57c5b85
 
621b172
 
 
 
465ffc9
 
621b172
71f519c
 
 
 
465ffc9
71f519c
 
 
465ffc9
71f519c
 
 
 
 
fc9c607
57c5b85
465ffc9
71f519c
 
 
465ffc9
71f519c
465ffc9
71f519c
465ffc9
71f519c
465ffc9
 
 
 
71f519c
57c5b85
465ffc9
 
 
621b172
 
 
 
465ffc9
 
621b172
465ffc9
 
 
 
57c5b85
465ffc9
 
 
 
 
 
 
 
621b172
 
57c5b85
621b172
 
 
 
 
 
57c5b85
621b172
 
 
 
 
 
 
465ffc9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import librosa
import numpy as np
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import tempfile
import soundfile as sf
import json

# === 参数设置 ===
SAMPLE_RATE = 16000
CHUNK_SIZE = 60  # 模型输入60秒
STEP = 10        # 滑动步长
MUSIC_THRESHOLD = 0.5
VOICE_THRESHOLD = 0.3
MIN_SEG_DURATION = 8  # 最小唱段长度(秒)

# === 模型加载 ===
print("Loading models...")

# 🎵 音乐检测模型(AST架构)
music_model_id = "AI-Music-Detection/ai_music_detection_large_60s"
music_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
music_model = AutoModelForAudioClassification.from_pretrained(music_model_id)

# 🗣️ 语音活动检测模型(HuBERT)
voice_model_id = "superb/hubert-large-superb-sid"
voice_extractor = AutoFeatureExtractor.from_pretrained(voice_model_id)
voice_model = AutoModelForAudioClassification.from_pretrained(voice_model_id)

print("✅ Models loaded successfully.")


# === 模型推理函数 ===
def predict_music_score(wav):
    """预测音乐片段概率"""
    wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
    inputs = music_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = music_model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1).squeeze()
        score = float(probs[-1]) if probs.numel() > 1 else float(probs[0])
    return score


def predict_voice_score(wav):
    """预测语音片段概率"""
    wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
    inputs = voice_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = voice_model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1).squeeze()
        score = float(probs.mean())  # 平均各类别概率
    return score


def detect_singing(audio_path):
    wav, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
    total_len = len(wav)
    chunk_len = SAMPLE_RATE * CHUNK_SIZE
    step_len = SAMPLE_RATE * STEP
    results = []

    for start_idx in range(0, max(1, total_len - chunk_len), step_len):
        end_idx = start_idx + chunk_len
        snippet = wav[start_idx:end_idx]

        # --- 修正:强制长度一致 ---
        if len(snippet) < chunk_len:
            snippet = np.pad(snippet, (0, chunk_len - len(snippet)))

        # --- 模型推理 ---
        music_score = predict_music_score(snippet)
        voice_score = predict_voice_score(snippet)

        start_t = start_idx / SAMPLE_RATE
        end_t = start_t + CHUNK_SIZE

        if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD:
            results.append((float(start_t), float(end_t)))

    # 合并相邻片段
    merged = []
    for seg in results:
        if not merged or seg[0] > merged[-1][1]:
            merged.append(list(seg))
        else:
            merged[-1][1] = seg[1]

    merged = [(s, e) for s, e in merged if e - s >= MIN_SEG_DURATION]
    return merged


# === 主推理函数 ===
def analyze_audio(file_path):
    if file_path is None:
        return "⚠️ 请上传音频文件", None

    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        data, sr = librosa.load(file_path, sr=SAMPLE_RATE)
        sf.write(tmp.name, data, sr)
        segments = detect_singing(tmp.name)

    if not segments:
        return "未检测到唱歌片段", json.dumps([], indent=2)

    json_output = json.dumps(
        [{"start": s, "end": e, "duration": round(e - s, 2)} for s, e in segments],
        indent=2
    )
    return f"检测到 {len(segments)} 段唱歌片段", json_output


# === Gradio UI ===
with gr.Blocks(title="🎵 Singing Segment Detector (Final)") as demo:
    gr.Markdown(
        """
        # 🎤 唱歌片段自动检测器(AI-Music + HuBERT)
        - 自动检测视频中的演唱时间段  
        - 采用 `AI-Music-Detection/ai_music_detection_large_60s` + `HuBERT` 双模型融合  
        - 输出每段的开始、结束时间与时长  
        """
    )

    audio_input = gr.Audio(type="filepath", label="上传音频(从视频提取)")
    run_btn = gr.Button("🚀 开始分析")
    status_box = gr.Textbox(label="分析状态", interactive=False)
    json_output = gr.Code(label="唱歌片段时间戳(JSON)", language="json")

    run_btn.click(fn=analyze_audio, inputs=[audio_input], outputs=[status_box, json_output])

demo.launch()