2026-04-14 09:54:26 +08:00

429 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
语音活动检测(VAD)模块 - 纯 ONNX Runtime 版 Silero VAD
使用 Silero VAD 的 ONNX 模型检测语音活动,识别语音的开始和结束。
不依赖 PyTorch/silero_vad 包,只依赖 onnxruntime + numpy。
"""
import os
from pathlib import Path
from typing import Optional
import numpy as np
import onnxruntime as ort
import multiprocessing
from voice_drone.core.configuration import (
SYSTEM_AUDIO_CONFIG,
SYSTEM_RECOGNIZER_CONFIG,
SYSTEM_VAD_CONFIG,
)
from voice_drone.logging_ import get_logger
from voice_drone.tools.wrapper import time_cost
logger = get_logger("vad.silero_onnx")
class VAD:
"""
语音活动检测器
使用 ONNX Runtime
"""
def __init__(self):
"""
初始化 VAD 检测器
Args:
config: 可选的配置字典,用于覆盖默认配置
"""
# ---- 从 system.yaml 的 vad 部分读取默认配置 ----
# system.yaml:
# vad:
# threshold: 0.65
# start_frame: 3
# end_frame: 10
# min_silence_duration_s: 0.5
# max_silence_duration_s: 30
# model_path: "src/models/silero_vad.onnx"
vad_conf = SYSTEM_VAD_CONFIG
# 语音概率阈值YAML 可能是字符串)
self.speech_threshold = float(vad_conf.get("threshold", 0.5))
# 连续多少帧检测到语音才认为“开始说话”
self.speech_start_frames = int(vad_conf.get("start_frame", 3))
# 连续多少帧检测到静音才认为“结束说话”
self.silence_end_frames = int(vad_conf.get("end_frame", 10))
# 可选: 最短/最长语音段时长(秒),可以在上层按需使用
self.min_speech_duration = vad_conf.get("min_silence_duration_s")
self.max_speech_duration = vad_conf.get("max_silence_duration_s")
# 采样率来自 audio 配置
self.sample_rate = SYSTEM_AUDIO_CONFIG.get("sample_rate")
# 与 recognizer 一致:能量 VAD 时不加载 Silero避免无模型文件仍强加载
_ev_env = os.environ.get("ROCKET_ENERGY_VAD", "").lower() in (
"1",
"true",
"yes",
)
_yaml_backend = str(
SYSTEM_RECOGNIZER_CONFIG.get("vad_backend", "silero")
).lower()
if _ev_env or _yaml_backend == "energy":
self.onnx_session = None
self.vad_model_path = None
self.window_size = 512 if int(self.sample_rate or 16000) == 16000 else 256
self.input_name = None
self.sr_input_name = None
self.state_input_name = None
self.output_name = None
self.state = None
self.speech_frame_count = 0
self.silence_frame_count = 0
self.is_speaking = False
logger.info(
"VAD能量RMS分段模式跳过 Silero ONNX与 ROCKET_ENERGY_VAD / vad_backend 一致)"
)
return
# ---- 加载 Silero VAD ONNX 模型 ----
raw_mp = SYSTEM_VAD_CONFIG.get("model_path")
if not raw_mp:
raise FileNotFoundError(
"vad.model_path 未配置。若只用能量 VAD请在 system.yaml 中设 "
"recognizer.vad_backend: energy 并设置 ROCKET_ENERGY_VAD=1"
)
mp = Path(raw_mp)
if not mp.is_absolute():
mp = Path(__file__).resolve().parents[2] / mp
self.vad_model_path = str(mp)
if not mp.is_file():
raise FileNotFoundError(
f"Silero VAD 模型不存在: {self.vad_model_path}。请下载 silero_vad.onnx 到该路径,"
"或改用能量 VADrecognizer.vad_backend: energy 且 ROCKET_ENERGY_VAD=1"
)
try:
logger.info(f"正在加载 Silero VAD ONNX 模型: {self.vad_model_path}")
sess_options = ort.SessionOptions()
cpu_count = multiprocessing.cpu_count()
optimal_threads = min(4, cpu_count)
sess_options.intra_op_num_threads = optimal_threads
sess_options.inter_op_num_threads = optimal_threads
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.onnx_session = ort.InferenceSession(
str(mp),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
inputs = self.onnx_session.get_inputs()
outputs = self.onnx_session.get_outputs()
if not inputs:
raise RuntimeError("VAD ONNX 模型没有输入节点")
# ---- 解析输入 / 输出 ----
# 典型的 Silero VAD ONNX 会有:
# - 输入: audio(input), 采样率(sr), 状态(state 或 h/c)
# - 输出: 语音概率 + 可选的新状态
self.input_name = None
self.sr_input_name: Optional[str] = None
self.state_input_name: Optional[str] = None
for inp in inputs:
name = inp.name
if self.input_name is None:
# 优先匹配常见名称,否则退回第一个
if name in ("input", "audio", "waveform"):
self.input_name = name
else:
self.input_name = name
if name == "sr":
self.sr_input_name = name
if name in ("state", "h", "c", "hidden"):
self.state_input_name = name
# 如果依然没有确定 input_name,兜底使用第一个
if self.input_name is None:
self.input_name = inputs[0].name
self.output_name = outputs[0].name if outputs else None
# 预分配状态向量(如果模型需要)
self.state: Optional[np.ndarray] = None
if self.state_input_name is not None:
state_inp = next(i for i in inputs if i.name == self.state_input_name)
# state 的 shape 通常是 [1, N] 或 [N], 这里用 0 初始化
state_shape = [
int(d) if isinstance(d, int) and d > 0 else 1
for d in (state_inp.shape or [1])
]
self.state = np.zeros(state_shape, dtype=np.float32)
# 从输入 shape 推断窗口大小: (batch, samples) 或 (samples,)
input_shape = inputs[0].shape
win_size = None
if isinstance(input_shape, (list, tuple)) and len(input_shape) >= 1:
last_dim = input_shape[-1]
if isinstance(last_dim, int):
win_size = last_dim
if win_size is None:
win_size = 512 if self.sample_rate == 16000 else 256
self.window_size = int(win_size)
logger.info(
f"Silero VAD ONNX 模型加载完成: 输入={self.input_name}, 输出={self.output_name}, "
f"window_size={self.window_size}, sample_rate={self.sample_rate}"
)
except Exception as e:
logger.error(f"Silero VAD ONNX 模型加载失败: {e}")
raise RuntimeError(
f"无法加载 Silero VAD: {e}。若无需 Silero请设 ROCKET_ENERGY_VAD=1 且 "
"recognizer.vad_backend: energy"
) from e
# State tracking
self.speech_frame_count = 0 # Consecutive speech frame count
self.silence_frame_count = 0 # Consecutive silence frame count
self.is_speaking = False # Currently speaking
logger.info(
"VADDetector 初始化完成: "
f"speech_threshold={self.speech_threshold}, "
f"speech_start_frames={self.speech_start_frames}, "
f"silence_end_frames={self.silence_end_frames}, "
f"sample_rate={self.sample_rate}Hz"
)
# @time_cost("VAD-语音检测耗时")
def is_speech(self, audio_chunk: bytes) -> bool:
"""
检测音频块是否包含语音
Args:
audio_chunk: 音频数据(bytes),必须是 16kHz, 16-bit, 单声道 PCM
Returns:
True 表示检测到语音,False 表示静音
"""
try:
if self.onnx_session is None:
return False
# 将 bytes 转换为 numpy array(int16),确保 little-endian 字节序
audio_array = np.frombuffer(audio_chunk, dtype="<i2")
# 转换为 float32 并归一化到 [-1, 1]
audio_float = audio_array.astype(np.float32) / 32768.0
required_samples = getattr(self, "window_size", 512 if self.sample_rate == 16000 else 256)
# 如果音频块小于要求的大小,填充零
if len(audio_float) < required_samples:
audio_float = np.pad(
audio_float, (0, required_samples - len(audio_float)), mode="constant"
)
# 如果音频块大于要求的大小,分割成多个小块并取平均值
if len(audio_float) > required_samples:
num_chunks = len(audio_float) // required_samples
speech_probs = []
for i in range(num_chunks):
start_idx = i * required_samples
end_idx = start_idx + required_samples
chunk = audio_float[start_idx:end_idx]
# 模型通常期望输入形状为 (1, samples)
input_data = chunk[np.newaxis, :].astype(np.float32)
ort_inputs = {self.input_name: input_data}
# 如果模型需要 sr,state 等附加输入,一并提供
if getattr(self, "sr_input_name", None) is not None:
# Silero VAD 一般期望 int64 采样率
ort_inputs[self.sr_input_name] = np.array(
[self.sample_rate], dtype=np.int64
)
if getattr(self, "state_input_name", None) is not None and self.state is not None:
ort_inputs[self.state_input_name] = self.state
outputs = self.onnx_session.run(None, ort_inputs)
# 如果模型返回新的 state,更新内部状态
if (
getattr(self, "state_input_name", None) is not None
and len(outputs) > 1
):
self.state = outputs[1]
prob = float(outputs[0].reshape(-1)[0])
speech_probs.append(prob)
speech_prob = float(np.mean(speech_probs))
else:
input_data = audio_float[:required_samples][np.newaxis, :].astype(np.float32)
ort_inputs = {self.input_name: input_data}
if getattr(self, "sr_input_name", None) is not None:
ort_inputs[self.sr_input_name] = np.array(
[self.sample_rate], dtype=np.int64
)
if getattr(self, "state_input_name", None) is not None and self.state is not None:
ort_inputs[self.state_input_name] = self.state
outputs = self.onnx_session.run(None, ort_inputs)
if (
getattr(self, "state_input_name", None) is not None
and len(outputs) > 1
):
self.state = outputs[1]
speech_prob = float(outputs[0].reshape(-1)[0])
return speech_prob >= self.speech_threshold
except Exception as e:
logger.error(f"VAD detection failed: {e}")
return False
def is_speech_numpy(self, audio_array: np.ndarray) -> bool:
"""
检测音频数组是否包含语音
Args:
audio_array: 音频数据(numpy array,dtype=int16)
Returns:
True 表示检测到语音,False 表示静音
"""
# 转换为 bytes
audio_bytes = audio_array.tobytes()
return self.is_speech(audio_bytes)
def detect_speech_start(self, audio_chunk: bytes) -> bool:
"""
检测语音开始
需要连续检测到多帧语音才认为语音开始
Args:
audio_chunk: 音频数据块
Returns:
True 表示检测到语音开始
"""
if self.is_speaking:
return False
if self.is_speech(audio_chunk):
self.speech_frame_count += 1
self.silence_frame_count = 0
if self.speech_frame_count >= self.speech_start_frames:
self.is_speaking = True
self.speech_frame_count = 0
logger.info("Speech start detected")
return True
else:
self.speech_frame_count = 0
return False
def detect_speech_end(self, audio_chunk: bytes) -> bool:
"""
检测语音结束
需要连续检测到多帧静音才认为语音结束
Args:
audio_chunk: 音频数据块
Returns:
True 表示检测到语音结束
"""
if not self.is_speaking:
return False
if not self.is_speech(audio_chunk):
self.silence_frame_count += 1
self.speech_frame_count = 0
if self.silence_frame_count >= self.silence_end_frames:
self.is_speaking = False
self.silence_frame_count = 0
logger.info("Speech end detected")
return True
else:
self.silence_frame_count = 0
return False
def reset(self) -> None:
"""
重置检测器状态
清除帧计数、是否在说话标记,以及 Silero 的 RNN 状态(长间隔后应清零,避免与后续音频错位)。
"""
self.speech_frame_count = 0
self.silence_frame_count = 0
self.is_speaking = False
if self.state is not None:
self.state.fill(0)
logger.debug("VAD detector state reset")
if __name__ == "__main__":
"""
使用测试音频按帧扫描,统计语音帧比例,更直观地验证 VAD 是否工作正常。
"""
import wave
vad = VAD()
audio_file = "test/测试音频.wav"
# 1. 读取 wav
with wave.open(audio_file, "rb") as wf:
n_channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
framerate = wf.getframerate()
n_frames = wf.getnframes()
raw = wf.readframes(n_frames)
# 2. 转成 int16 数组
audio = np.frombuffer(raw, dtype="<i2")
# 3. 双声道 -> 单声道
if n_channels == 2:
audio = audio.reshape(-1, 2)
audio = audio.mean(axis=1).astype(np.int16)
# 4. 重采样到 VAD 所需采样率(通常 16k)
target_sr = vad.sample_rate
if framerate != target_sr:
x_old = np.linspace(0, 1, num=len(audio), endpoint=False)
x_new = np.linspace(0, 1, num=int(len(audio) * target_sr / framerate), endpoint=False)
audio = np.interp(x_new, x_old, audio).astype(np.int16)
print("wav info:", n_channels, "ch,", framerate, "Hz")
print("audio len (samples):", len(audio), " target_sr:", target_sr)
# 5. 按 VAD 窗口大小逐帧扫描
frame_samples = vad.window_size
frame_bytes = frame_samples * 2 # int16 -> 2 字节
audio_bytes = audio.tobytes()
num_frames = len(audio_bytes) // frame_bytes
speech_frames = 0
for i in range(num_frames):
chunk = audio_bytes[i * frame_bytes : (i + 1) * frame_bytes]
if vad.is_speech(chunk):
speech_frames += 1
speech_ratio = speech_frames / num_frames if num_frames > 0 else 0.0
print("total frames:", num_frames)
print("speech frames:", speech_frames)
print("speech ratio:", speech_ratio)
print("has_any_speech:", speech_frames > 0)