DroneMind/voice_drone/core/wake_word.py
2026-04-14 09:54:26 +08:00

376 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
唤醒词检测模块 - 高性能实时唤醒词识别
支持:
- 精确匹配
- 模糊匹配(同音字、拼音)
- 部分匹配
- 配置化变体映射
性能优化:
- 预编译正则表达式
- LRU缓存匹配结果
- 优化字符串操作
"""
import re
from typing import Optional, List, Tuple
from functools import lru_cache
from voice_drone.logging_ import get_logger
from voice_drone.core.configuration import (
WAKE_WORD_PRIMARY,
WAKE_WORD_VARIANTS,
WAKE_WORD_MATCHING_CONFIG
)
logger = get_logger("wake_word")
# 延迟加载可选依赖
try:
from pypinyin import lazy_pinyin, Style
PYPINYIN_AVAILABLE = True
except ImportError:
PYPINYIN_AVAILABLE = False
logger.warning("pypinyin 未安装,拼音匹配功能将受限")
class WakeWordDetector:
"""
唤醒词检测器
支持多种匹配模式:
- 精确匹配:完全匹配唤醒词
- 模糊匹配:同音字、拼音变体
- 部分匹配:只匹配部分唤醒词
"""
def __init__(self):
"""初始化唤醒词检测器"""
logger.info("初始化唤醒词检测器...")
# 从配置加载
self.primary = WAKE_WORD_PRIMARY
self.variants = WAKE_WORD_VARIANTS or []
self.matching_config = WAKE_WORD_MATCHING_CONFIG or {}
# 匹配配置
self.enable_fuzzy = self.matching_config.get("enable_fuzzy", True)
self.enable_partial = self.matching_config.get("enable_partial", True)
self.ignore_case = self.matching_config.get("ignore_case", True)
self.ignore_spaces = self.matching_config.get("ignore_spaces", True)
self.min_match_length = self.matching_config.get("min_match_length", 2)
self.similarity_threshold = self.matching_config.get("similarity_threshold", 0.7)
# 构建匹配模式
self._build_patterns()
logger.info(f"唤醒词检测器初始化完成")
logger.info(f" 主唤醒词: {self.primary}")
logger.info(f" 变体数量: {len(self.variants)}")
logger.info(f" 模糊匹配: {'启用' if self.enable_fuzzy else '禁用'}")
logger.info(f" 部分匹配: {'启用' if self.enable_partial else '禁用'}")
def _build_patterns(self):
"""构建匹配模式(预编译正则表达式)"""
# 标准化所有变体(去除空格、转小写等)
self.normalized_variants = []
for variant in self.variants:
normalized = self._normalize_text(variant)
if normalized:
self.normalized_variants.append(normalized)
# 去重
self.normalized_variants = list(set(self.normalized_variants))
# 按长度降序排序(优先匹配长模式)
self.normalized_variants.sort(key=len, reverse=True)
# 构建正则表达式模式
patterns = []
for variant in self.normalized_variants:
# 转义特殊字符
escaped = re.escape(variant)
patterns.append(escaped)
# 编译单一正则表达式
if patterns:
self.pattern = re.compile('|'.join(patterns), re.IGNORECASE if self.ignore_case else 0)
else:
self.pattern = None
logger.debug(f"构建了 {len(self.normalized_variants)} 个匹配模式")
def _normalize_text(self, text: str) -> str:
"""
标准化文本(用于匹配)
Args:
text: 原始文本
Returns:
标准化后的文本
"""
if not text:
return ""
normalized = text
# 忽略大小写
if self.ignore_case:
normalized = normalized.lower()
# 忽略空格
if self.ignore_spaces:
normalized = normalized.replace(' ', '').replace('\t', '').replace('\n', '')
return normalized.strip()
def _get_pinyin(self, text: str) -> str:
"""
获取文本的拼音(用于拼音匹配)
Args:
text: 中文文本
Returns:
拼音字符串(小写,无空格)
"""
if not PYPINYIN_AVAILABLE:
return ""
try:
pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
return ''.join(pinyin_list).lower()
except Exception as e:
logger.debug(f"拼音转换失败: {e}")
return ""
def _fuzzy_match(self, text: str, variant: str) -> bool:
"""
模糊匹配(同音字、拼音)
Args:
text: 待匹配文本
variant: 变体文本
Returns:
是否匹配
"""
# 1. 精确匹配(已标准化)
normalized_text = self._normalize_text(text)
normalized_variant = self._normalize_text(variant)
if normalized_text == normalized_variant:
return True
# 2. 拼音匹配
if PYPINYIN_AVAILABLE:
text_pinyin = self._get_pinyin(text)
variant_pinyin = self._get_pinyin(variant)
if text_pinyin and variant_pinyin:
# 完全匹配拼音
if text_pinyin == variant_pinyin:
return True
# 部分匹配拼音(至少匹配一半)
if len(variant_pinyin) >= 2:
# 检查是否包含变体的拼音
if variant_pinyin in text_pinyin or text_pinyin in variant_pinyin:
# 计算相似度
similarity = min(len(variant_pinyin), len(text_pinyin)) / max(len(variant_pinyin), len(text_pinyin))
if similarity >= self.similarity_threshold:
return True
# 3. 字符级相似度匹配(简单实现)
if len(normalized_text) >= self.min_match_length and len(normalized_variant) >= self.min_match_length:
# 检查是否包含变体
if normalized_variant in normalized_text or normalized_text in normalized_variant:
return True
return False
def _partial_match(self, text: str) -> bool:
"""
部分匹配(只匹配部分唤醒词,如主词较长时取前半段;短词请在配置 variants 中列出)
Args:
text: 待匹配文本
Returns:
是否匹配
"""
if not self.enable_partial:
return False
normalized_text = self._normalize_text(text)
# 检查是否包含主唤醒词的一部分
if self.primary:
# 提取主唤醒词的前半部分(如四字词可拆成前两字)
primary_normalized = self._normalize_text(self.primary)
if len(primary_normalized) >= self.min_match_length * 2:
half_length = len(primary_normalized) // 2
half_wake_word = primary_normalized[:half_length]
if len(half_wake_word) >= self.min_match_length:
if half_wake_word in normalized_text:
return True
return False
@lru_cache(maxsize=256)
def detect(self, text: str) -> Tuple[bool, Optional[str]]:
"""
检测文本中是否包含唤醒词
Args:
text: 待检测文本
Returns:
(是否匹配, 匹配到的唤醒词)
"""
if not text or not self.pattern:
return False, None
normalized_text = self._normalize_text(text)
# 1. 精确匹配(使用正则表达式)
if self.pattern:
match = self.pattern.search(normalized_text)
if match:
matched_text = match.group(0)
logger.debug(f"精确匹配到唤醒词: {matched_text}")
return True, matched_text
# 2. 模糊匹配(同音字、拼音)
if self.enable_fuzzy:
for variant in self.normalized_variants:
if self._fuzzy_match(normalized_text, variant):
logger.debug(f"模糊匹配到唤醒词变体: {variant}")
return True, variant
# 3. 部分匹配
if self._partial_match(normalized_text):
logger.debug(f"部分匹配到唤醒词")
return True, self.primary[:len(self.primary)//2] if self.primary else None
return False, None
def extract_command_text(self, text: str) -> Optional[str]:
"""
从文本中提取命令部分(移除唤醒词)
Args:
text: 包含唤醒词的完整文本
Returns:
提取的命令文本,如果未检测到唤醒词返回None
"""
is_wake, matched_wake_word = self.detect(text)
if not is_wake:
return None
# 标准化文本用于查找
normalized_text = self._normalize_text(text)
normalized_wake = self._normalize_text(matched_wake_word) if matched_wake_word else ""
if not normalized_wake or normalized_wake not in normalized_text:
return None
# 找到唤醒词在标准化文本中的位置
idx = normalized_text.find(normalized_wake)
if idx < 0:
return None
# 方法1尝试在原始文本中精确查找匹配的变体
original_text = text
text_lower = original_text.lower()
# 查找所有可能的变体在原始文本中的位置
best_match_idx = -1
best_match_length = 0
# 检查配置中的所有变体
for variant in self.variants:
variant_normalized = self._normalize_text(variant)
if variant_normalized == normalized_wake:
# 这个变体匹配到了,尝试在原始文本中找到它
variant_lower = variant.lower()
if variant_lower in text_lower:
variant_idx = text_lower.find(variant_lower)
if variant_idx >= 0:
# 选择最长的匹配(更准确)
if len(variant) > best_match_length:
best_match_idx = variant_idx
best_match_length = len(variant)
# 如果找到了匹配的变体
if best_match_idx >= 0:
command_start = best_match_idx + best_match_length
command_text = original_text[command_start:].strip()
# 移除开头的标点符号
command_text = command_text.lstrip(',。、,.').strip()
return command_text if command_text else None
# 方法2回退方案 - 使用字符计数近似定位
# 计算标准化文本中唤醒词结束位置对应的原始文本位置
wake_end_in_normalized = idx + len(normalized_wake)
# 计算原始文本中对应的字符位置
char_count = 0
for i, char in enumerate(original_text):
normalized_char = self._normalize_text(char)
if normalized_char:
if char_count >= wake_end_in_normalized:
command_text = original_text[i:].strip()
command_text = command_text.lstrip(',。、,.').strip()
return command_text if command_text else None
char_count += 1
return None
# 全局单例
_global_detector: Optional[WakeWordDetector] = None
def get_wake_word_detector() -> WakeWordDetector:
"""获取全局唤醒词检测器实例(单例模式)"""
global _global_detector
if _global_detector is None:
_global_detector = WakeWordDetector()
return _global_detector
if __name__ == "__main__":
# 测试代码
detector = WakeWordDetector()
test_cases = [
("无人机,现在起飞", True),
("wu ren ji 现在起飞", True),
("Wu Ren Ji 现在起飞", True),
("五人机,现在起飞", True),
("现在起飞", False),
("无人,现在起飞", True), # 变体列表中的短说
("人机 前进", False), # 已移除单独「人机」变体,避免子串误唤醒
("无人机 前进", True),
]
print("=" * 60)
print("唤醒词检测测试")
print("=" * 60)
for text, expected in test_cases:
is_wake, matched = detector.detect(text)
command_text = detector.extract_command_text(text)
status = "OK" if is_wake == expected else "FAIL"
print(f"{status} 文本: {text}")
print(f" 匹配: {is_wake} (期望: {expected})")
print(f" 匹配词: {matched}")
print(f" 提取命令: {command_text}")
print()