429 lines
16 KiB
Python
429 lines
16 KiB
Python
"""
|
||
语音活动检测(VAD)模块 - 纯 ONNX Runtime 版 Silero VAD
|
||
|
||
使用 Silero VAD 的 ONNX 模型检测语音活动,识别语音的开始和结束。
|
||
不依赖 PyTorch/silero_vad 包,只依赖 onnxruntime + numpy。
|
||
"""
|
||
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
import multiprocessing
|
||
from voice_drone.core.configuration import (
|
||
SYSTEM_AUDIO_CONFIG,
|
||
SYSTEM_RECOGNIZER_CONFIG,
|
||
SYSTEM_VAD_CONFIG,
|
||
)
|
||
from voice_drone.logging_ import get_logger
|
||
from voice_drone.tools.wrapper import time_cost
|
||
logger = get_logger("vad.silero_onnx")
|
||
|
||
|
||
class VAD:
|
||
"""
|
||
语音活动检测器
|
||
|
||
使用 ONNX Runtime
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化 VAD 检测器
|
||
|
||
Args:
|
||
config: 可选的配置字典,用于覆盖默认配置
|
||
"""
|
||
# ---- 从 system.yaml 的 vad 部分读取默认配置 ----
|
||
# system.yaml:
|
||
# vad:
|
||
# threshold: 0.65
|
||
# start_frame: 3
|
||
# end_frame: 10
|
||
# min_silence_duration_s: 0.5
|
||
# max_silence_duration_s: 30
|
||
# model_path: "src/models/silero_vad.onnx"
|
||
vad_conf = SYSTEM_VAD_CONFIG
|
||
|
||
# 语音概率阈值(YAML 可能是字符串)
|
||
self.speech_threshold = float(vad_conf.get("threshold", 0.5))
|
||
|
||
# 连续多少帧检测到语音才认为“开始说话”
|
||
self.speech_start_frames = int(vad_conf.get("start_frame", 3))
|
||
|
||
# 连续多少帧检测到静音才认为“结束说话”
|
||
self.silence_end_frames = int(vad_conf.get("end_frame", 10))
|
||
|
||
# 可选: 最短/最长语音段时长(秒),可以在上层按需使用
|
||
self.min_speech_duration = vad_conf.get("min_silence_duration_s")
|
||
self.max_speech_duration = vad_conf.get("max_silence_duration_s")
|
||
|
||
# 采样率来自 audio 配置
|
||
self.sample_rate = SYSTEM_AUDIO_CONFIG.get("sample_rate")
|
||
|
||
# 与 recognizer 一致:能量 VAD 时不加载 Silero(避免无模型文件仍强加载)
|
||
_ev_env = os.environ.get("ROCKET_ENERGY_VAD", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
)
|
||
_yaml_backend = str(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("vad_backend", "silero")
|
||
).lower()
|
||
if _ev_env or _yaml_backend == "energy":
|
||
self.onnx_session = None
|
||
self.vad_model_path = None
|
||
self.window_size = 512 if int(self.sample_rate or 16000) == 16000 else 256
|
||
self.input_name = None
|
||
self.sr_input_name = None
|
||
self.state_input_name = None
|
||
self.output_name = None
|
||
self.state = None
|
||
self.speech_frame_count = 0
|
||
self.silence_frame_count = 0
|
||
self.is_speaking = False
|
||
logger.info(
|
||
"VAD:能量(RMS)分段模式,跳过 Silero ONNX(与 ROCKET_ENERGY_VAD / vad_backend 一致)"
|
||
)
|
||
return
|
||
|
||
# ---- 加载 Silero VAD ONNX 模型 ----
|
||
raw_mp = SYSTEM_VAD_CONFIG.get("model_path")
|
||
if not raw_mp:
|
||
raise FileNotFoundError(
|
||
"vad.model_path 未配置。若只用能量 VAD,请在 system.yaml 中设 "
|
||
"recognizer.vad_backend: energy 并设置 ROCKET_ENERGY_VAD=1"
|
||
)
|
||
mp = Path(raw_mp)
|
||
if not mp.is_absolute():
|
||
mp = Path(__file__).resolve().parents[2] / mp
|
||
self.vad_model_path = str(mp)
|
||
if not mp.is_file():
|
||
raise FileNotFoundError(
|
||
f"Silero VAD 模型不存在: {self.vad_model_path}。请下载 silero_vad.onnx 到该路径,"
|
||
"或改用能量 VAD:recognizer.vad_backend: energy 且 ROCKET_ENERGY_VAD=1"
|
||
)
|
||
|
||
try:
|
||
logger.info(f"正在加载 Silero VAD ONNX 模型: {self.vad_model_path}")
|
||
sess_options = ort.SessionOptions()
|
||
cpu_count = multiprocessing.cpu_count()
|
||
optimal_threads = min(4, cpu_count)
|
||
sess_options.intra_op_num_threads = optimal_threads
|
||
sess_options.inter_op_num_threads = optimal_threads
|
||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
|
||
self.onnx_session = ort.InferenceSession(
|
||
str(mp),
|
||
sess_options=sess_options,
|
||
providers=["CPUExecutionProvider"],
|
||
)
|
||
|
||
inputs = self.onnx_session.get_inputs()
|
||
outputs = self.onnx_session.get_outputs()
|
||
if not inputs:
|
||
raise RuntimeError("VAD ONNX 模型没有输入节点")
|
||
|
||
# ---- 解析输入 / 输出 ----
|
||
# 典型的 Silero VAD ONNX 会有:
|
||
# - 输入: audio(input), 采样率(sr), 状态(state 或 h/c)
|
||
# - 输出: 语音概率 + 可选的新状态
|
||
self.input_name = None
|
||
self.sr_input_name: Optional[str] = None
|
||
self.state_input_name: Optional[str] = None
|
||
|
||
for inp in inputs:
|
||
name = inp.name
|
||
if self.input_name is None:
|
||
# 优先匹配常见名称,否则退回第一个
|
||
if name in ("input", "audio", "waveform"):
|
||
self.input_name = name
|
||
else:
|
||
self.input_name = name
|
||
if name == "sr":
|
||
self.sr_input_name = name
|
||
if name in ("state", "h", "c", "hidden"):
|
||
self.state_input_name = name
|
||
|
||
# 如果依然没有确定 input_name,兜底使用第一个
|
||
if self.input_name is None:
|
||
self.input_name = inputs[0].name
|
||
|
||
self.output_name = outputs[0].name if outputs else None
|
||
|
||
# 预分配状态向量(如果模型需要)
|
||
self.state: Optional[np.ndarray] = None
|
||
if self.state_input_name is not None:
|
||
state_inp = next(i for i in inputs if i.name == self.state_input_name)
|
||
# state 的 shape 通常是 [1, N] 或 [N], 这里用 0 初始化
|
||
state_shape = [
|
||
int(d) if isinstance(d, int) and d > 0 else 1
|
||
for d in (state_inp.shape or [1])
|
||
]
|
||
self.state = np.zeros(state_shape, dtype=np.float32)
|
||
|
||
# 从输入 shape 推断窗口大小: (batch, samples) 或 (samples,)
|
||
input_shape = inputs[0].shape
|
||
win_size = None
|
||
if isinstance(input_shape, (list, tuple)) and len(input_shape) >= 1:
|
||
last_dim = input_shape[-1]
|
||
if isinstance(last_dim, int):
|
||
win_size = last_dim
|
||
if win_size is None:
|
||
win_size = 512 if self.sample_rate == 16000 else 256
|
||
self.window_size = int(win_size)
|
||
|
||
logger.info(
|
||
f"Silero VAD ONNX 模型加载完成: 输入={self.input_name}, 输出={self.output_name}, "
|
||
f"window_size={self.window_size}, sample_rate={self.sample_rate}"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Silero VAD ONNX 模型加载失败: {e}")
|
||
raise RuntimeError(
|
||
f"无法加载 Silero VAD: {e}。若无需 Silero,请设 ROCKET_ENERGY_VAD=1 且 "
|
||
"recognizer.vad_backend: energy"
|
||
) from e
|
||
|
||
# State tracking
|
||
self.speech_frame_count = 0 # Consecutive speech frame count
|
||
self.silence_frame_count = 0 # Consecutive silence frame count
|
||
self.is_speaking = False # Currently speaking
|
||
|
||
logger.info(
|
||
"VADDetector 初始化完成: "
|
||
f"speech_threshold={self.speech_threshold}, "
|
||
f"speech_start_frames={self.speech_start_frames}, "
|
||
f"silence_end_frames={self.silence_end_frames}, "
|
||
f"sample_rate={self.sample_rate}Hz"
|
||
)
|
||
|
||
# @time_cost("VAD-语音检测耗时")
|
||
def is_speech(self, audio_chunk: bytes) -> bool:
|
||
"""
|
||
检测音频块是否包含语音
|
||
|
||
Args:
|
||
audio_chunk: 音频数据(bytes),必须是 16kHz, 16-bit, 单声道 PCM
|
||
|
||
Returns:
|
||
True 表示检测到语音,False 表示静音
|
||
"""
|
||
try:
|
||
if self.onnx_session is None:
|
||
return False
|
||
# 将 bytes 转换为 numpy array(int16),确保 little-endian 字节序
|
||
audio_array = np.frombuffer(audio_chunk, dtype="<i2")
|
||
|
||
# 转换为 float32 并归一化到 [-1, 1]
|
||
audio_float = audio_array.astype(np.float32) / 32768.0
|
||
|
||
required_samples = getattr(self, "window_size", 512 if self.sample_rate == 16000 else 256)
|
||
|
||
# 如果音频块小于要求的大小,填充零
|
||
if len(audio_float) < required_samples:
|
||
audio_float = np.pad(
|
||
audio_float, (0, required_samples - len(audio_float)), mode="constant"
|
||
)
|
||
|
||
# 如果音频块大于要求的大小,分割成多个小块并取平均值
|
||
if len(audio_float) > required_samples:
|
||
num_chunks = len(audio_float) // required_samples
|
||
speech_probs = []
|
||
|
||
for i in range(num_chunks):
|
||
start_idx = i * required_samples
|
||
end_idx = start_idx + required_samples
|
||
chunk = audio_float[start_idx:end_idx]
|
||
|
||
# 模型通常期望输入形状为 (1, samples)
|
||
input_data = chunk[np.newaxis, :].astype(np.float32)
|
||
ort_inputs = {self.input_name: input_data}
|
||
|
||
# 如果模型需要 sr,state 等附加输入,一并提供
|
||
if getattr(self, "sr_input_name", None) is not None:
|
||
# Silero VAD 一般期望 int64 采样率
|
||
ort_inputs[self.sr_input_name] = np.array(
|
||
[self.sample_rate], dtype=np.int64
|
||
)
|
||
if getattr(self, "state_input_name", None) is not None and self.state is not None:
|
||
ort_inputs[self.state_input_name] = self.state
|
||
|
||
outputs = self.onnx_session.run(None, ort_inputs)
|
||
|
||
# 如果模型返回新的 state,更新内部状态
|
||
if (
|
||
getattr(self, "state_input_name", None) is not None
|
||
and len(outputs) > 1
|
||
):
|
||
self.state = outputs[1]
|
||
|
||
prob = float(outputs[0].reshape(-1)[0])
|
||
speech_probs.append(prob)
|
||
|
||
speech_prob = float(np.mean(speech_probs))
|
||
else:
|
||
input_data = audio_float[:required_samples][np.newaxis, :].astype(np.float32)
|
||
ort_inputs = {self.input_name: input_data}
|
||
|
||
if getattr(self, "sr_input_name", None) is not None:
|
||
ort_inputs[self.sr_input_name] = np.array(
|
||
[self.sample_rate], dtype=np.int64
|
||
)
|
||
if getattr(self, "state_input_name", None) is not None and self.state is not None:
|
||
ort_inputs[self.state_input_name] = self.state
|
||
|
||
outputs = self.onnx_session.run(None, ort_inputs)
|
||
|
||
if (
|
||
getattr(self, "state_input_name", None) is not None
|
||
and len(outputs) > 1
|
||
):
|
||
self.state = outputs[1]
|
||
|
||
speech_prob = float(outputs[0].reshape(-1)[0])
|
||
|
||
return speech_prob >= self.speech_threshold
|
||
except Exception as e:
|
||
logger.error(f"VAD detection failed: {e}")
|
||
return False
|
||
|
||
def is_speech_numpy(self, audio_array: np.ndarray) -> bool:
|
||
"""
|
||
检测音频数组是否包含语音
|
||
|
||
Args:
|
||
audio_array: 音频数据(numpy array,dtype=int16)
|
||
|
||
Returns:
|
||
True 表示检测到语音,False 表示静音
|
||
"""
|
||
# 转换为 bytes
|
||
audio_bytes = audio_array.tobytes()
|
||
return self.is_speech(audio_bytes)
|
||
|
||
def detect_speech_start(self, audio_chunk: bytes) -> bool:
|
||
"""
|
||
检测语音开始
|
||
|
||
需要连续检测到多帧语音才认为语音开始
|
||
|
||
Args:
|
||
audio_chunk: 音频数据块
|
||
|
||
Returns:
|
||
True 表示检测到语音开始
|
||
"""
|
||
if self.is_speaking:
|
||
return False
|
||
|
||
if self.is_speech(audio_chunk):
|
||
self.speech_frame_count += 1
|
||
self.silence_frame_count = 0
|
||
|
||
if self.speech_frame_count >= self.speech_start_frames:
|
||
self.is_speaking = True
|
||
self.speech_frame_count = 0
|
||
logger.info("Speech start detected")
|
||
return True
|
||
else:
|
||
self.speech_frame_count = 0
|
||
|
||
return False
|
||
|
||
def detect_speech_end(self, audio_chunk: bytes) -> bool:
|
||
"""
|
||
检测语音结束
|
||
|
||
需要连续检测到多帧静音才认为语音结束
|
||
|
||
Args:
|
||
audio_chunk: 音频数据块
|
||
|
||
Returns:
|
||
True 表示检测到语音结束
|
||
"""
|
||
if not self.is_speaking:
|
||
return False
|
||
|
||
if not self.is_speech(audio_chunk):
|
||
self.silence_frame_count += 1
|
||
self.speech_frame_count = 0
|
||
|
||
if self.silence_frame_count >= self.silence_end_frames:
|
||
self.is_speaking = False
|
||
self.silence_frame_count = 0
|
||
logger.info("Speech end detected")
|
||
return True
|
||
else:
|
||
self.silence_frame_count = 0
|
||
|
||
return False
|
||
|
||
def reset(self) -> None:
|
||
"""
|
||
重置检测器状态
|
||
|
||
清除帧计数、是否在说话标记,以及 Silero 的 RNN 状态(长间隔后应清零,避免与后续音频错位)。
|
||
"""
|
||
self.speech_frame_count = 0
|
||
self.silence_frame_count = 0
|
||
self.is_speaking = False
|
||
if self.state is not None:
|
||
self.state.fill(0)
|
||
logger.debug("VAD detector state reset")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
"""
|
||
使用测试音频按帧扫描,统计语音帧比例,更直观地验证 VAD 是否工作正常。
|
||
"""
|
||
import wave
|
||
|
||
vad = VAD()
|
||
audio_file = "test/测试音频.wav"
|
||
|
||
# 1. 读取 wav
|
||
with wave.open(audio_file, "rb") as wf:
|
||
n_channels = wf.getnchannels()
|
||
sampwidth = wf.getsampwidth()
|
||
framerate = wf.getframerate()
|
||
n_frames = wf.getnframes()
|
||
raw = wf.readframes(n_frames)
|
||
|
||
# 2. 转成 int16 数组
|
||
audio = np.frombuffer(raw, dtype="<i2")
|
||
|
||
# 3. 双声道 -> 单声道
|
||
if n_channels == 2:
|
||
audio = audio.reshape(-1, 2)
|
||
audio = audio.mean(axis=1).astype(np.int16)
|
||
|
||
# 4. 重采样到 VAD 所需采样率(通常 16k)
|
||
target_sr = vad.sample_rate
|
||
if framerate != target_sr:
|
||
x_old = np.linspace(0, 1, num=len(audio), endpoint=False)
|
||
x_new = np.linspace(0, 1, num=int(len(audio) * target_sr / framerate), endpoint=False)
|
||
audio = np.interp(x_new, x_old, audio).astype(np.int16)
|
||
|
||
print("wav info:", n_channels, "ch,", framerate, "Hz")
|
||
print("audio len (samples):", len(audio), " target_sr:", target_sr)
|
||
|
||
# 5. 按 VAD 窗口大小逐帧扫描
|
||
frame_samples = vad.window_size
|
||
frame_bytes = frame_samples * 2 # int16 -> 2 字节
|
||
audio_bytes = audio.tobytes()
|
||
num_frames = len(audio_bytes) // frame_bytes
|
||
|
||
speech_frames = 0
|
||
for i in range(num_frames):
|
||
chunk = audio_bytes[i * frame_bytes : (i + 1) * frame_bytes]
|
||
if vad.is_speech(chunk):
|
||
speech_frames += 1
|
||
|
||
speech_ratio = speech_frames / num_frames if num_frames > 0 else 0.0
|
||
print("total frames:", num_frames)
|
||
print("speech frames:", speech_frames)
|
||
print("speech ratio:", speech_ratio)
|
||
print("has_any_speech:", speech_frames > 0) |