970 lines
41 KiB
Python
970 lines
41 KiB
Python
"""
|
||
高性能实时语音识别与命令生成系统
|
||
|
||
整合所有模块,实现从语音检测到命令发送的完整流程:
|
||
1. 音频采集(高性能模式)
|
||
2. 音频预处理(降噪+AGC)
|
||
3. VAD语音活动检测
|
||
4. STT语音识别
|
||
5. 文本预处理(纠错+参数提取)
|
||
6. 命令生成
|
||
7. Socket发送
|
||
|
||
性能优化:
|
||
- 多线程异步处理
|
||
- 非阻塞音频采集
|
||
- LRU缓存优化
|
||
- 低延迟设计
|
||
"""
|
||
|
||
import math
|
||
import numpy as np
|
||
import os
|
||
import random
|
||
import threading
|
||
import queue
|
||
import time
|
||
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||
from voice_drone.core.audio import AudioCapture, AudioPreprocessor
|
||
from voice_drone.core.vad import VAD
|
||
from voice_drone.core.stt import STT
|
||
from voice_drone.core.text_preprocessor import TextPreprocessor, get_preprocessor
|
||
from voice_drone.core.command import Command
|
||
from voice_drone.core.scoket_client import SocketClient
|
||
from voice_drone.core.configuration import (
|
||
SYSTEM_AUDIO_CONFIG,
|
||
SYSTEM_RECOGNIZER_CONFIG,
|
||
SYSTEM_SOCKET_SERVER_CONFIG,
|
||
)
|
||
from voice_drone.core.tts_ack_cache import (
|
||
compute_ack_pcm_fingerprint,
|
||
load_cached_phrases,
|
||
persist_phrases,
|
||
)
|
||
from voice_drone.core.wake_word import WakeWordDetector, get_wake_word_detector
|
||
from voice_drone.logging_ import get_logger
|
||
|
||
if TYPE_CHECKING:
|
||
from voice_drone.core.tts import KokoroOnnxTTS
|
||
|
||
logger = get_logger("recognizer")
|
||
|
||
|
||
class VoiceCommandRecognizer:
|
||
"""
|
||
高性能实时语音命令识别器
|
||
|
||
完整的语音转命令系统,包括:
|
||
- 音频采集和预处理
|
||
- 语音活动检测
|
||
- 语音识别
|
||
- 文本预处理和参数提取
|
||
- 命令生成
|
||
- Socket发送
|
||
"""
|
||
|
||
def __init__(self, auto_connect_socket: bool = True):
|
||
"""
|
||
初始化语音命令识别器
|
||
|
||
Args:
|
||
auto_connect_socket: 是否自动连接Socket服务器
|
||
"""
|
||
logger.info("初始化语音命令识别系统...")
|
||
|
||
# 初始化各模块
|
||
self.audio_capture = AudioCapture()
|
||
self.audio_preprocessor = AudioPreprocessor()
|
||
self.vad = VAD()
|
||
self.stt = STT()
|
||
self.text_preprocessor = get_preprocessor() # 使用全局单例
|
||
self.wake_word_detector = get_wake_word_detector() # 使用全局单例
|
||
|
||
# Socket客户端
|
||
self.socket_client = SocketClient(SYSTEM_SOCKET_SERVER_CONFIG)
|
||
self.auto_connect_socket = auto_connect_socket
|
||
if self.auto_connect_socket:
|
||
if not self.socket_client.connect():
|
||
logger.warning("Socket连接失败,将在发送命令时重试")
|
||
|
||
# 语音段缓冲区
|
||
self.speech_buffer: list = [] # 存储语音音频块
|
||
self.speech_buffer_lock = threading.Lock()
|
||
|
||
# 预缓冲区:保存语音检测前一小段音频,避免丢失开头
|
||
# 例如:pre_speech_max_seconds = 0.8 表示保留最近约 0.8 秒音频
|
||
self.pre_speech_buffer: list = [] # 存储最近的静音/背景音块
|
||
# 从系统配置读取(确保类型正确:YAML 可能把数值当字符串)
|
||
self.pre_speech_max_seconds: float = float(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("pre_speech_max_seconds", 0.8)
|
||
)
|
||
self.pre_speech_max_chunks: Optional[int] = None # 根据采样率和chunk大小动态计算
|
||
|
||
# 命令发送成功后的 TTS 反馈(懒加载 Kokoro,避免拖慢启动)
|
||
self.ack_tts_enabled = bool(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_enabled", True))
|
||
self.ack_tts_text = str(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_text", "好的收到")).strip()
|
||
self.ack_tts_phrases: Dict[str, List[str]] = self._normalize_ack_tts_phrases(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_phrases")
|
||
)
|
||
# True:仅 ack_tts_phrases 中出现的命令会播报,且每次随机一句;False:全局 ack_tts_text(所有成功命令同一应答)
|
||
self._ack_mode_phrases: bool = bool(self.ack_tts_phrases)
|
||
self.ack_tts_prewarm = bool(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_prewarm", True))
|
||
self.ack_tts_prewarm_blocking = bool(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_prewarm_blocking", True)
|
||
)
|
||
self.ack_pause_mic_for_playback = bool(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("ack_pause_mic_for_playback", True)
|
||
)
|
||
self.ack_tts_disk_cache = bool(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_disk_cache", True)
|
||
)
|
||
self._tts_engine: Optional["KokoroOnnxTTS"] = None
|
||
# 阻塞预加载时缓存波形:全局单句 _tts_ack_pcm,或按命令随机模式下的 _tts_phrase_pcm_cache(每句一条)
|
||
self._tts_ack_pcm: Optional[Tuple[np.ndarray, int]] = None
|
||
self._tts_phrase_pcm_cache: Dict[str, Tuple[np.ndarray, int]] = {}
|
||
self._tts_lock = threading.Lock()
|
||
# 命令线程只入队,主线程 process_audio_stream 中统一播放(避免 Windows 下后台线程 sd.play 无声)
|
||
self._ack_playback_queue: queue.Queue = queue.Queue(maxsize=8)
|
||
|
||
# STT识别线程和队列
|
||
self.stt_queue = queue.Queue(maxsize=5) # STT识别队列
|
||
self.stt_thread: Optional[threading.Thread] = None
|
||
|
||
# 命令处理线程和队列
|
||
self.command_queue = queue.Queue(maxsize=10) # 命令处理队列
|
||
self.command_thread: Optional[threading.Thread] = None
|
||
|
||
# 运行状态
|
||
self.running = False
|
||
|
||
# 命令序列号(用于去重和顺序保证)
|
||
self.sequence_id = 0
|
||
self.sequence_lock = threading.Lock()
|
||
|
||
logger.info(
|
||
f"应答TTS配置: enabled={self.ack_tts_enabled}, "
|
||
f"mode={'按命令随机短语' if self._ack_mode_phrases else '全局固定文案'}, "
|
||
f"prewarm_blocking={self.ack_tts_prewarm_blocking}, "
|
||
f"pause_mic={self.ack_pause_mic_for_playback}, "
|
||
f"disk_cache={self.ack_tts_disk_cache}"
|
||
)
|
||
if self._ack_mode_phrases:
|
||
logger.info(f" 仅播报命令: {list(self.ack_tts_phrases.keys())}")
|
||
|
||
# VAD 后端:silero(默认)或 energy(按块 RMS,Silero 在部分板载麦上长期无段时使用)
|
||
_ev_env = os.environ.get("ROCKET_ENERGY_VAD", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
)
|
||
_yaml_backend = str(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("vad_backend", "silero")
|
||
).lower()
|
||
self._use_energy_vad: bool = _ev_env or _yaml_backend == "energy"
|
||
self._energy_rms_high: float = float(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_rms_high", 280)
|
||
)
|
||
self._energy_rms_low: float = float(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_rms_low", 150)
|
||
)
|
||
self._energy_start_chunks: int = int(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_start_chunks", 4)
|
||
)
|
||
self._energy_end_chunks: int = int(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_end_chunks", 15)
|
||
)
|
||
# 高噪底/AGC 下 RMS 几乎不低于 energy_vad_rms_low 时,用「相对本段峰值」辅助判停
|
||
self._energy_end_peak_ratio: float = float(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_end_peak_ratio", 0.88)
|
||
)
|
||
# 说话过程中对 utt 峰值每块乘衰减再与当前 rms 取 max,避免前几个字特响导致后半句一直被判「相对衰减」而误切段
|
||
self._energy_utt_peak_decay: float = float(
|
||
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_utt_peak_decay", 0.988)
|
||
)
|
||
self._energy_utt_peak_decay = max(0.95, min(0.9999, self._energy_utt_peak_decay))
|
||
self._ev_speaking: bool = False
|
||
self._ev_high_run: int = 0
|
||
self._ev_low_run: int = 0
|
||
self._ev_rms_peak: float = 0.0
|
||
self._ev_last_diag_time: float = 0.0
|
||
self._ev_utt_peak: float = 0.0
|
||
# 可选:能量 VAD 刚进入「正在说话」时回调(用于机端 PROMPT_LISTEN 计时清零等)
|
||
self._vad_speech_start_hook: Optional[Callable[[], None]] = None
|
||
|
||
_trail_raw = SYSTEM_RECOGNIZER_CONFIG.get("trailing_silence_seconds")
|
||
if _trail_raw is not None:
|
||
_trail = float(_trail_raw)
|
||
if _trail > 0:
|
||
fs = int(SYSTEM_AUDIO_CONFIG.get("frame_size", 1024))
|
||
sr = int(SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000))
|
||
if fs > 0 and sr > 0:
|
||
n_end = max(1, int(math.ceil(_trail * sr / fs)))
|
||
self._energy_end_chunks = n_end
|
||
self.vad.silence_end_frames = n_end
|
||
logger.info(
|
||
"VAD 句尾切段:trailing_silence_seconds=%.2f → 连续静音块数=%d "
|
||
"(每块≈%.0fms,Silero 与 energy 共用)",
|
||
_trail,
|
||
n_end,
|
||
1000.0 * fs / sr,
|
||
)
|
||
|
||
if self._use_energy_vad:
|
||
logger.info(
|
||
"VAD 后端: energy(RMS)"
|
||
f" high={self._energy_rms_high} low={self._energy_rms_low} "
|
||
f"start_chunks={self._energy_start_chunks} end_chunks={self._energy_end_chunks}"
|
||
f" end_peak_ratio={self._energy_end_peak_ratio}"
|
||
f" utt_peak_decay={self._energy_utt_peak_decay}"
|
||
)
|
||
|
||
logger.info("语音命令识别系统初始化完成")
|
||
|
||
@staticmethod
|
||
def _normalize_ack_tts_phrases(raw) -> Dict[str, List[str]]:
|
||
"""YAML: ack_tts_phrases: { takeoff: [\"...\", ...], ... }"""
|
||
result: Dict[str, List[str]] = {}
|
||
if not isinstance(raw, dict):
|
||
return result
|
||
for k, v in raw.items():
|
||
key = str(k).strip()
|
||
if not key:
|
||
continue
|
||
if isinstance(v, list):
|
||
phrases = [str(x).strip() for x in v if str(x).strip()]
|
||
elif isinstance(v, str) and v.strip():
|
||
phrases = [v.strip()]
|
||
else:
|
||
phrases = []
|
||
if phrases:
|
||
result[key] = phrases
|
||
return result
|
||
|
||
def _has_ack_tts_content(self) -> bool:
|
||
if self._ack_mode_phrases:
|
||
return any(bool(v) for v in self.ack_tts_phrases.values())
|
||
return bool(self.ack_tts_text)
|
||
|
||
def _pick_ack_phrase(self, command_name: str) -> Optional[str]:
|
||
if self._ack_mode_phrases:
|
||
phrases = self.ack_tts_phrases.get(command_name)
|
||
if not phrases:
|
||
return None
|
||
return random.choice(phrases)
|
||
return self.ack_tts_text or None
|
||
|
||
def _get_cached_pcm_for_phrase(self, phrase: str) -> Optional[Tuple[np.ndarray, int]]:
|
||
"""若启动阶段已预合成该句,则返回缓存,播报时不再跑 ONNX(低延迟)。"""
|
||
if self._ack_mode_phrases:
|
||
return self._tts_phrase_pcm_cache.get(phrase)
|
||
if self._tts_ack_pcm is not None:
|
||
return self._tts_ack_pcm
|
||
return None
|
||
|
||
def _ensure_tts_engine(self) -> "KokoroOnnxTTS":
|
||
"""懒加载 Kokoro(双检锁,避免多线程重复加载)。"""
|
||
from voice_drone.core.tts import KokoroOnnxTTS
|
||
|
||
if self._tts_engine is not None:
|
||
return self._tts_engine
|
||
with self._tts_lock:
|
||
if self._tts_engine is None:
|
||
logger.info("TTS: 正在加载 Kokoro 模型(首次约需十余秒)…")
|
||
self._tts_engine = KokoroOnnxTTS()
|
||
logger.info("TTS: Kokoro 模型加载完成")
|
||
assert self._tts_engine is not None
|
||
return self._tts_engine
|
||
|
||
def _enqueue_ack_playback(self, command_name: str) -> None:
|
||
"""
|
||
命令已成功发出后,将待播音频交给主线程队列。
|
||
|
||
不在此线程直接调用 sounddevice:Windows 上后台线程常出现播放完全无声。
|
||
"""
|
||
if not self.ack_tts_enabled:
|
||
return
|
||
phrase = self._pick_ack_phrase(command_name)
|
||
if not phrase:
|
||
return
|
||
try:
|
||
cached = self._get_cached_pcm_for_phrase(phrase)
|
||
if cached is not None:
|
||
audio, sr = cached
|
||
self._ack_playback_queue.put(("pcm", audio.copy(), sr), block=False)
|
||
logger.info(
|
||
f"命令已发送,已排队语音应答(主线程播放,预缓存): {phrase!r}"
|
||
)
|
||
print(f"[TTS] 已排队语音应答(主线程播放,预缓存): {phrase!r}", flush=True)
|
||
else:
|
||
self._ack_playback_queue.put(("synth", phrase), block=False)
|
||
logger.info(
|
||
f"命令已发送,已排队语音应答(主线程合成+播放,无缓存,可能有数秒延迟): {phrase!r}"
|
||
)
|
||
print(
|
||
f"[TTS] 已排队语音应答(主线程合成+播放,无缓存): {phrase!r}",
|
||
flush=True,
|
||
)
|
||
except queue.Full:
|
||
logger.warning("应答语音播放队列已满,跳过本次")
|
||
|
||
def _before_audio_iteration(self) -> None:
|
||
"""主循环每轮开头(主线程):子类可扩展以播放其它排队 TTS。"""
|
||
self._drain_ack_playback_queue()
|
||
|
||
def _drain_ack_playback_queue(self, recover_mic: bool = True) -> None:
|
||
"""在主线程中播放队列中的应答(与麦克风采集同进程、同主循环线程)。
|
||
|
||
Args:
|
||
recover_mic: 播完后是否恢复麦克风;退出 shutdown 时应为 False,避免与 stop() 中关流冲突。
|
||
"""
|
||
from voice_drone.core.tts import play_tts_audio, speak_text
|
||
|
||
items: list = []
|
||
while True:
|
||
try:
|
||
items.append(self._ack_playback_queue.get_nowait())
|
||
except queue.Empty:
|
||
break
|
||
if not items:
|
||
return
|
||
|
||
mic_stopped = False
|
||
if self.ack_pause_mic_for_playback:
|
||
try:
|
||
logger.info(
|
||
"TTS: 已暂停麦克风采集以便扬声器播放(避免 Windows 下输入/输出同时开无声)"
|
||
)
|
||
self.audio_capture.stop_stream()
|
||
mic_stopped = True
|
||
except Exception as e:
|
||
logger.warning(f"暂停麦克风失败,将尝试直接播放: {e}")
|
||
|
||
try:
|
||
for item in items:
|
||
try:
|
||
kind = item[0]
|
||
if kind == "pcm":
|
||
_, audio, sr = item
|
||
logger.info("TTS: 主线程播放应答(预缓存波形)")
|
||
play_tts_audio(audio, sr)
|
||
logger.info("TTS: 播放完成")
|
||
elif kind == "synth":
|
||
logger.info("TTS: 主线程合成并播放应答(无预缓存)")
|
||
tts = self._ensure_tts_engine()
|
||
text = item[1] if len(item) >= 2 else (self.ack_tts_text or "")
|
||
speak_text(text, tts=tts)
|
||
except Exception as e:
|
||
logger.warning(f"应答语音播放失败: {e}", exc_info=True)
|
||
finally:
|
||
if mic_stopped and recover_mic:
|
||
try:
|
||
self.audio_capture.start_stream()
|
||
try:
|
||
self.audio_preprocessor.reset()
|
||
except Exception as e: # noqa: BLE001
|
||
logger.debug("audio_preprocessor.reset: %s", e)
|
||
# TTS 暂停期间若未凑齐「尾静音」帧,VAD 会一直保持 is_speaking=True;
|
||
# 恢复后 detect_speech_start 会直接放弃,表现为「能恢复采集但再也不识别」。
|
||
self.vad.reset()
|
||
with self.speech_buffer_lock:
|
||
self.speech_buffer.clear()
|
||
self.pre_speech_buffer.clear()
|
||
logger.info("TTS: 麦克风采集已恢复(已重置 VAD 与语音缓冲)")
|
||
except Exception as e:
|
||
logger.error(f"麦克风采集恢复失败,请重启程序: {e}", exc_info=True)
|
||
|
||
def _prewarm_tts_async(self) -> None:
|
||
"""后台预加载 TTS(仅当未使用阻塞预加载时)。"""
|
||
if not self.ack_tts_enabled or not self._has_ack_tts_content() or not self.ack_tts_prewarm:
|
||
return
|
||
|
||
def _run() -> None:
|
||
try:
|
||
self._ensure_tts_engine()
|
||
if self._ack_mode_phrases:
|
||
logger.warning(
|
||
"TTS: 当前为「按命令随机短语」且未使用阻塞预加载,"
|
||
"各句首次播报可能仍有数秒延迟;若需低延迟请将 ack_tts_prewarm_blocking 设为 true。"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"TTS 预加载失败(将在首次播报时重试): {e}", exc_info=True)
|
||
|
||
threading.Thread(target=_run, daemon=True, name="tts-prewarm").start()
|
||
|
||
def _prewarm_tts_blocking(self) -> None:
|
||
"""启动时准备应答 PCM:优先读磁盘缓存(文案与 TTS 配置未变则跳过合成);必要时加载 Kokoro 并合成。"""
|
||
if not self.ack_tts_enabled or not self._has_ack_tts_content() or not self.ack_tts_prewarm:
|
||
return
|
||
use_disk = self.ack_tts_disk_cache
|
||
logger.info("TTS: 正在准备语音反馈(磁盘缓存 / 合成)…")
|
||
print("正在加载语音反馈…")
|
||
try:
|
||
if self._ack_mode_phrases:
|
||
self._tts_phrase_pcm_cache.clear()
|
||
seen: set = set()
|
||
unique: List[str] = []
|
||
for lst in self.ack_tts_phrases.values():
|
||
for t in lst:
|
||
p = str(t).strip()
|
||
if p and p not in seen:
|
||
seen.add(p)
|
||
unique.append(p)
|
||
if not unique:
|
||
return
|
||
|
||
fingerprint = compute_ack_pcm_fingerprint(unique, mode_phrases=True)
|
||
missing = list(unique)
|
||
if use_disk:
|
||
loaded, missing = load_cached_phrases(unique, fingerprint)
|
||
for ph, pcm in loaded.items():
|
||
self._tts_phrase_pcm_cache[ph] = pcm
|
||
|
||
if not missing:
|
||
self._tts_ack_pcm = None
|
||
logger.info(
|
||
"TTS: 已从磁盘加载全部应答波形(%d 句),跳过 Kokoro 加载与合成",
|
||
len(unique),
|
||
)
|
||
print("语音反馈已就绪(本地缓存),可以开始说话下指令。")
|
||
return
|
||
|
||
self._ensure_tts_engine()
|
||
assert self._tts_engine is not None
|
||
need = [p for p in unique if p not in self._tts_phrase_pcm_cache]
|
||
for j, phrase in enumerate(need, start=1):
|
||
logger.info(
|
||
f"TTS: 合成应答句 {j}/{len(need)}: {phrase!r}"
|
||
)
|
||
audio, sr = self._tts_engine.synthesize(phrase)
|
||
self._tts_phrase_pcm_cache[phrase] = (audio, sr)
|
||
self._tts_ack_pcm = None
|
||
if use_disk:
|
||
persist_phrases(fingerprint, dict(self._tts_phrase_pcm_cache))
|
||
logger.info(
|
||
"TTS: 语音反馈已就绪(随机应答已缓存,播报低延迟)"
|
||
)
|
||
print("语音反馈引擎已就绪,可以开始说话下指令。")
|
||
else:
|
||
text = (self.ack_tts_text or "").strip()
|
||
if not text:
|
||
return
|
||
fingerprint = compute_ack_pcm_fingerprint(
|
||
[], global_text=text, mode_phrases=False
|
||
)
|
||
missing = [text]
|
||
if use_disk:
|
||
loaded, missing = load_cached_phrases([text], fingerprint)
|
||
if text in loaded:
|
||
self._tts_ack_pcm = loaded[text]
|
||
|
||
if not missing:
|
||
logger.info(
|
||
"TTS: 已从磁盘加载全局应答波形,跳过 Kokoro 加载与合成"
|
||
)
|
||
print("语音反馈已就绪(本地缓存),可以开始说话下指令。")
|
||
return
|
||
|
||
self._ensure_tts_engine()
|
||
assert self._tts_engine is not None
|
||
audio, sr = self._tts_engine.synthesize(text)
|
||
self._tts_ack_pcm = (audio, sr)
|
||
if use_disk:
|
||
persist_phrases(fingerprint, {text: self._tts_ack_pcm})
|
||
logger.info(
|
||
"TTS: 语音反馈引擎已就绪;已缓存应答语音,命令成功后将快速播报"
|
||
)
|
||
print("语音反馈引擎已就绪,可以开始说话下指令。")
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"TTS: 启动阶段预加载失败,命令成功后可能延迟或无语音反馈: {e}",
|
||
exc_info=True,
|
||
)
|
||
|
||
@staticmethod
|
||
def _init_sounddevice_output_probe() -> None:
|
||
"""在主线程探测默认输出设备;应答播报必须在主线程调用 sd.play。"""
|
||
try:
|
||
from voice_drone.core.tts import log_sounddevice_output_devices
|
||
|
||
log_sounddevice_output_devices()
|
||
import sounddevice as sd # type: ignore
|
||
|
||
from voice_drone.core.tts import _sounddevice_default_output_index
|
||
|
||
out_idx = _sounddevice_default_output_index()
|
||
if out_idx is not None and int(out_idx) >= 0:
|
||
info = sd.query_devices(int(out_idx))
|
||
logger.info(
|
||
f"sounddevice 默认输出设备: {info.get('name', '?')} (index={out_idx})"
|
||
)
|
||
sd.check_output_settings(samplerate=24000, channels=1, dtype="float32")
|
||
# 预解析 tts.output_device,启动日志中可见实际用于播放的设备
|
||
from voice_drone.core.tts import get_playback_output_device_id
|
||
|
||
get_playback_output_device_id()
|
||
except Exception as e:
|
||
logger.warning(f"sounddevice 输出设备探测失败,可能导致无法播音: {e}")
|
||
|
||
def _get_next_sequence_id(self) -> int:
|
||
"""获取下一个命令序列号"""
|
||
with self.sequence_lock:
|
||
self.sequence_id += 1
|
||
return self.sequence_id
|
||
|
||
@staticmethod
|
||
def _int16_chunk_rms(chunk: np.ndarray) -> float:
|
||
if chunk.size == 0:
|
||
return 0.0
|
||
return float(np.sqrt(np.mean(chunk.astype(np.float64) ** 2)))
|
||
|
||
def _submit_concatenated_speech_to_stt(self) -> None:
|
||
"""在持有 speech_buffer_lock 时调用:合并 speech_buffer 并送 STT,然后清空。"""
|
||
if len(self.speech_buffer) == 0:
|
||
return
|
||
speech_audio = np.concatenate(self.speech_buffer)
|
||
self.speech_buffer.clear()
|
||
min_samples = int(self.audio_capture.sample_rate * 0.5)
|
||
if len(speech_audio) >= min_samples:
|
||
try:
|
||
self.stt_queue.put(speech_audio.copy(), block=False)
|
||
logger.debug(
|
||
f"提交语音段到STT队列,长度: {len(speech_audio)} 采样点"
|
||
)
|
||
if os.environ.get("ROCKET_PRINT_VAD", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
):
|
||
print(
|
||
f"[VAD] 已送 STT,{len(speech_audio)} 采样点(≈{len(speech_audio) / float(self.audio_capture.sample_rate):.2f}s)",
|
||
flush=True,
|
||
)
|
||
except queue.Full:
|
||
logger.warning("STT队列已满,跳过本次识别")
|
||
elif os.environ.get("ROCKET_PRINT_VAD", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
):
|
||
print(
|
||
f"[VAD] 语音段太短已丢弃({len(speech_audio)} < {min_samples} 采样)",
|
||
flush=True,
|
||
)
|
||
|
||
def _energy_vad_on_chunk(self, processed_chunk: np.ndarray) -> None:
|
||
rms = self._int16_chunk_rms(processed_chunk)
|
||
_vad_diag = os.environ.get("ROCKET_PRINT_VAD", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
)
|
||
if _vad_diag:
|
||
self._ev_rms_peak = max(self._ev_rms_peak, rms)
|
||
now = time.monotonic()
|
||
if now - self._ev_last_diag_time >= 3.0:
|
||
print(
|
||
f"[VAD] energy 诊断:近 3s 块 RMS 峰值≈{self._ev_rms_peak:.0f} "
|
||
f"(high={self._energy_rms_high} low={self._energy_rms_low})",
|
||
flush=True,
|
||
)
|
||
self._ev_rms_peak = 0.0
|
||
self._ev_last_diag_time = now
|
||
|
||
if not self._ev_speaking:
|
||
if rms >= self._energy_rms_high:
|
||
self._ev_high_run += 1
|
||
else:
|
||
self._ev_high_run = 0
|
||
if self._ev_high_run >= self._energy_start_chunks:
|
||
self._ev_speaking = True
|
||
self._ev_high_run = 0
|
||
self._ev_low_run = 0
|
||
self._ev_utt_peak = rms
|
||
hook = self._vad_speech_start_hook
|
||
if hook is not None:
|
||
try:
|
||
hook()
|
||
except Exception as e: # noqa: BLE001
|
||
logger.debug("vad_speech_start_hook: %s", e, exc_info=True)
|
||
with self.speech_buffer_lock:
|
||
if self.pre_speech_buffer:
|
||
self.speech_buffer = list(self.pre_speech_buffer)
|
||
else:
|
||
self.speech_buffer.clear()
|
||
self.speech_buffer.append(processed_chunk)
|
||
logger.debug(
|
||
"energy VAD: 开始收集语音段(含预缓冲约 %.2f s)",
|
||
self.pre_speech_max_seconds,
|
||
)
|
||
return
|
||
|
||
with self.speech_buffer_lock:
|
||
self.speech_buffer.append(processed_chunk)
|
||
|
||
self._ev_utt_peak = max(rms, self._ev_utt_peak * self._energy_utt_peak_decay)
|
||
below_abs = rms <= self._energy_rms_low
|
||
below_rel = (
|
||
self._energy_end_peak_ratio > 0
|
||
and self._ev_utt_peak >= self._energy_rms_high
|
||
and rms <= self._ev_utt_peak * self._energy_end_peak_ratio
|
||
)
|
||
if below_abs or below_rel:
|
||
self._ev_low_run += 1
|
||
else:
|
||
self._ev_low_run = 0
|
||
|
||
if self._ev_low_run >= self._energy_end_chunks:
|
||
self._ev_speaking = False
|
||
self._ev_low_run = 0
|
||
self._ev_utt_peak = 0.0
|
||
with self.speech_buffer_lock:
|
||
self._submit_concatenated_speech_to_stt()
|
||
self._reset_agc_after_utterance_end()
|
||
logger.debug("energy VAD: 语音段结束,已提交")
|
||
|
||
def _reset_agc_after_utterance_end(self) -> None:
|
||
"""VAD 句尾:清 AGC 滑窗,避免巨响后 RMS 卡死。"""
|
||
try:
|
||
self.audio_preprocessor.reset_agc_state()
|
||
except AttributeError:
|
||
pass
|
||
|
||
def discard_pending_stt_segments(self) -> int:
|
||
"""丢弃尚未被 STT 线程取走的整句,避免唤醒/播 TTS 关麦后仍识别旧段。"""
|
||
n = 0
|
||
while True:
|
||
try:
|
||
self.stt_queue.get_nowait()
|
||
self.stt_queue.task_done()
|
||
n += 1
|
||
except queue.Empty:
|
||
break
|
||
if n:
|
||
logger.info(
|
||
"已丢弃 %s 条待 STT 的语音段(流程切换,避免与播 TTS 重叠)",
|
||
n,
|
||
)
|
||
return n
|
||
|
||
def _stt_worker_thread(self):
|
||
"""STT识别工作线程(异步处理,不阻塞主流程)"""
|
||
logger.info("STT识别线程已启动")
|
||
while self.running:
|
||
try:
|
||
audio_data = self.stt_queue.get(timeout=0.1)
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"STT工作线程错误: {e}", exc_info=True)
|
||
continue
|
||
|
||
try:
|
||
if audio_data is None:
|
||
break
|
||
|
||
try:
|
||
text = self.stt.invoke_numpy(audio_data)
|
||
|
||
if os.environ.get("ROCKET_PRINT_STT", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
):
|
||
print(
|
||
f"[STT] {text!r}"
|
||
if (text and text.strip())
|
||
else "[STT] <空或不识别>",
|
||
flush=True,
|
||
)
|
||
|
||
if text and text.strip():
|
||
logger.info(f"🎤 STT识别结果: {text}")
|
||
|
||
try:
|
||
self.command_queue.put(text, block=False)
|
||
logger.debug(f"文本已提交到命令处理队列: {text}")
|
||
except queue.Full:
|
||
logger.warning("命令处理队列已满,跳过本次识别结果")
|
||
|
||
except Exception as e:
|
||
logger.error(f"STT识别失败: {e}", exc_info=True)
|
||
finally:
|
||
self.stt_queue.task_done()
|
||
|
||
logger.info("STT识别线程已停止")
|
||
|
||
def _command_worker_thread(self):
|
||
"""命令处理工作线程(文本预处理+命令生成+Socket发送)"""
|
||
logger.info("命令处理线程已启动")
|
||
while self.running:
|
||
try:
|
||
text = self.command_queue.get(timeout=0.1)
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"命令处理线程错误: {e}", exc_info=True)
|
||
continue
|
||
|
||
try:
|
||
if text is None:
|
||
break
|
||
|
||
try:
|
||
# 1. 检测唤醒词
|
||
is_wake, matched_wake_word = self.wake_word_detector.detect(text)
|
||
|
||
if not is_wake:
|
||
logger.debug(f"未检测到唤醒词,忽略文本: {text}")
|
||
continue
|
||
|
||
logger.info(f"🔔 检测到唤醒词: {matched_wake_word}")
|
||
|
||
# 2. 提取命令文本(移除唤醒词)
|
||
command_text = self.wake_word_detector.extract_command_text(text)
|
||
if not command_text or not command_text.strip():
|
||
logger.warning(f"唤醒词后无命令内容: {text}")
|
||
continue
|
||
|
||
logger.debug(f"提取的命令文本: {command_text}")
|
||
|
||
# 3. 文本预处理(快速模式,不进行分词)
|
||
normalized_text, params = self.text_preprocessor.preprocess_fast(command_text)
|
||
|
||
logger.debug(f"文本预处理结果:")
|
||
logger.debug(f" 规范化文本: {normalized_text}")
|
||
logger.debug(f" 命令关键词: {params.command_keyword}")
|
||
logger.debug(f" 距离: {params.distance} 米")
|
||
logger.debug(f" 速度: {params.speed} 米/秒")
|
||
logger.debug(f" 时间: {params.duration} 秒")
|
||
|
||
# 4. 检查是否识别到命令关键词
|
||
if not params.command_keyword:
|
||
logger.warning(f"未识别到有效命令关键词: {normalized_text}")
|
||
continue
|
||
|
||
# 5. 生成命令
|
||
sequence_id = self._get_next_sequence_id()
|
||
command = Command.create(
|
||
command=params.command_keyword,
|
||
sequence_id=sequence_id,
|
||
distance=params.distance,
|
||
speed=params.speed,
|
||
duration=params.duration
|
||
)
|
||
|
||
logger.info(f"📝 生成命令: {command.command}")
|
||
logger.debug(f"命令详情: {command.to_dict()}")
|
||
|
||
# 6. 发送命令到Socket服务器
|
||
if self.socket_client.send_command_with_retry(command):
|
||
logger.info(f"✅ 命令已发送: {command.command} (序列号: {sequence_id})")
|
||
self._enqueue_ack_playback(command.command)
|
||
else:
|
||
logger.warning(
|
||
"命令未送达(已达 max_retries): %s (序列号: %s)",
|
||
command.command,
|
||
sequence_id,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"命令处理失败: {e}", exc_info=True)
|
||
|
||
finally:
|
||
self.command_queue.task_done()
|
||
|
||
logger.info("命令处理线程已停止")
|
||
|
||
def start(self):
|
||
"""启动语音命令识别系统"""
|
||
if self.running:
|
||
logger.warning("语音命令识别系统已在运行")
|
||
return
|
||
|
||
# 先完成阻塞式 TTS 预加载,再开麦与识别线程,避免用户在预加载期间下指令导致无波形缓存、播报延迟
|
||
print("[TTS] 探测扬声器并预加载应答语音(可能需十余秒,请勿说话)…", flush=True)
|
||
self._init_sounddevice_output_probe()
|
||
if self.ack_tts_enabled and self._has_ack_tts_content() and self.ack_tts_prewarm:
|
||
if self.ack_tts_prewarm_blocking:
|
||
self._prewarm_tts_blocking()
|
||
else:
|
||
print(
|
||
"[TTS] 已跳过启动预加载(ack_tts_enabled/应答文案/ack_tts_prewarm)",
|
||
flush=True,
|
||
)
|
||
|
||
self.running = True
|
||
|
||
# 启动STT识别线程
|
||
self.stt_thread = threading.Thread(target=self._stt_worker_thread, daemon=True)
|
||
self.stt_thread.start()
|
||
|
||
# 启动命令处理线程
|
||
self.command_thread = threading.Thread(target=self._command_worker_thread, daemon=True)
|
||
self.command_thread.start()
|
||
|
||
# 启动音频采集
|
||
self.audio_capture.start_stream()
|
||
|
||
if self.ack_tts_enabled and self._has_ack_tts_content() and self.ack_tts_prewarm:
|
||
if not self.ack_tts_prewarm_blocking:
|
||
self._prewarm_tts_async()
|
||
|
||
logger.info("语音命令识别系统已启动")
|
||
print("\n" + "=" * 70)
|
||
print("🎙️ 高性能实时语音命令识别系统已启动")
|
||
print("=" * 70)
|
||
print("💡 功能说明:")
|
||
print(" - 系统会自动检测语音并识别")
|
||
print(f" - 🔔 唤醒词: {self.wake_word_detector.primary}")
|
||
print(" - 只有包含唤醒词的语音才会被处理")
|
||
print(" - 识别结果会自动转换为无人机控制命令")
|
||
print(" - 命令会自动发送到Socket服务器")
|
||
print(" - 按 Ctrl+C 退出")
|
||
print("=" * 70 + "\n")
|
||
|
||
def stop(self):
|
||
"""停止语音命令识别系统"""
|
||
if not self.running:
|
||
return
|
||
|
||
self.running = False
|
||
|
||
# 先通知工作线程结束,再播放尚未 drain 的应答(避免 Ctrl+C 时主循环未跑下一轮导致无声)
|
||
if self.stt_thread is not None:
|
||
self.stt_queue.put(None)
|
||
if self.command_thread is not None:
|
||
self.command_queue.put(None)
|
||
if self.stt_thread is not None:
|
||
self.stt_thread.join(timeout=2.0)
|
||
if self.command_thread is not None:
|
||
self.command_thread.join(timeout=2.0)
|
||
|
||
if self.ack_tts_enabled:
|
||
try:
|
||
self._drain_ack_playback_queue(recover_mic=False)
|
||
except Exception as e:
|
||
logger.warning(f"退出前播放应答失败: {e}", exc_info=True)
|
||
|
||
self.audio_capture.stop_stream()
|
||
|
||
# 断开Socket连接
|
||
if self.socket_client.connected:
|
||
self.socket_client.disconnect()
|
||
|
||
logger.info("语音命令识别系统已停止")
|
||
print("\n语音命令识别系统已停止")
|
||
|
||
def process_audio_stream(self):
|
||
"""
|
||
处理音频流(主循环)
|
||
|
||
高性能实时处理流程:
|
||
1. 采集音频块(非阻塞)
|
||
2. 预处理(降噪+AGC)
|
||
3. VAD检测语音开始/结束
|
||
4. 收集语音段
|
||
5. 异步STT识别(不阻塞主流程)
|
||
"""
|
||
try:
|
||
while self.running:
|
||
# 0. 主线程播放命令应答(必须在采集循环线程中执行 sd.play,见 tts.play_tts_audio 说明)
|
||
self._before_audio_iteration()
|
||
|
||
# 1. 采集音频块(非阻塞,高性能模式)
|
||
chunk = self.audio_capture.read_chunk_numpy(timeout=0.1)
|
||
if chunk is None:
|
||
continue
|
||
|
||
# 2. 音频预处理(降噪+AGC)
|
||
processed_chunk = self.audio_preprocessor.process(chunk)
|
||
|
||
# 初始化预缓冲区的最大块数(只需计算一次)
|
||
if self.pre_speech_max_chunks is None:
|
||
# 每个chunk包含的采样点数
|
||
samples_per_chunk = processed_chunk.shape[0]
|
||
if samples_per_chunk > 0:
|
||
# 0.8 秒需要的chunk数量 = 预缓冲秒数 * 采样率 / 每块采样数
|
||
chunks = int(
|
||
self.pre_speech_max_seconds * self.audio_capture.sample_rate
|
||
/ samples_per_chunk
|
||
)
|
||
# 至少保留 1 块,避免被算成 0
|
||
self.pre_speech_max_chunks = max(chunks, 1)
|
||
else:
|
||
self.pre_speech_max_chunks = 1
|
||
|
||
# 将当前块加入预缓冲区(环形缓冲)
|
||
# 注意:预缓冲区保存的是“最近的一段音频”,无论当下是否在说话
|
||
self.pre_speech_buffer.append(processed_chunk)
|
||
if (
|
||
self.pre_speech_max_chunks is not None
|
||
and len(self.pre_speech_buffer) > self.pre_speech_max_chunks
|
||
):
|
||
# 超出最大长度时,丢弃最早的块
|
||
self.pre_speech_buffer.pop(0)
|
||
|
||
# 3. VAD:Silero 或能量(RMS)分段
|
||
if self._use_energy_vad:
|
||
self._energy_vad_on_chunk(processed_chunk)
|
||
else:
|
||
chunk_bytes = processed_chunk.tobytes()
|
||
|
||
if self.vad.detect_speech_start(chunk_bytes):
|
||
hook = self._vad_speech_start_hook
|
||
if hook is not None:
|
||
try:
|
||
hook()
|
||
except Exception as e: # noqa: BLE001
|
||
logger.debug(
|
||
"vad_speech_start_hook: %s", e, exc_info=True
|
||
)
|
||
with self.speech_buffer_lock:
|
||
if self.pre_speech_buffer:
|
||
self.speech_buffer = list(self.pre_speech_buffer)
|
||
else:
|
||
self.speech_buffer.clear()
|
||
self.speech_buffer.append(processed_chunk)
|
||
logger.debug(
|
||
"检测到语音开始,使用预缓冲音频(约 %.2f 秒)作为前缀,开始收集语音段",
|
||
self.pre_speech_max_seconds,
|
||
)
|
||
|
||
elif self.vad.is_speaking:
|
||
with self.speech_buffer_lock:
|
||
self.speech_buffer.append(processed_chunk)
|
||
|
||
if self.vad.detect_speech_end(chunk_bytes):
|
||
with self.speech_buffer_lock:
|
||
self._submit_concatenated_speech_to_stt()
|
||
self._reset_agc_after_utterance_end()
|
||
logger.debug("检测到语音结束,提交识别")
|
||
|
||
hook = getattr(self, "_after_processed_audio_chunk", None)
|
||
if hook is not None:
|
||
try:
|
||
hook(processed_chunk)
|
||
except Exception as e: # noqa: BLE001
|
||
logger.debug(
|
||
"after_processed_audio_chunk: %s", e, exc_info=True
|
||
)
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("用户中断")
|
||
except Exception as e:
|
||
logger.error(f"处理音频流时发生错误: {e}", exc_info=True)
|
||
raise
|
||
|
||
def run(self):
|
||
"""运行语音命令识别系统(完整流程)"""
|
||
try:
|
||
self.start()
|
||
self.process_audio_stream()
|
||
finally:
|
||
self.stop()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试代码
|
||
recognizer = VoiceCommandRecognizer()
|
||
recognizer.run()
|