376 lines
13 KiB
Python
376 lines
13 KiB
Python
"""
|
||
唤醒词检测模块 - 高性能实时唤醒词识别
|
||
|
||
支持:
|
||
- 精确匹配
|
||
- 模糊匹配(同音字、拼音)
|
||
- 部分匹配
|
||
- 配置化变体映射
|
||
|
||
性能优化:
|
||
- 预编译正则表达式
|
||
- LRU缓存匹配结果
|
||
- 优化字符串操作
|
||
"""
|
||
|
||
import re
|
||
from typing import Optional, List, Tuple
|
||
from functools import lru_cache
|
||
from voice_drone.logging_ import get_logger
|
||
from voice_drone.core.configuration import (
|
||
WAKE_WORD_PRIMARY,
|
||
WAKE_WORD_VARIANTS,
|
||
WAKE_WORD_MATCHING_CONFIG
|
||
)
|
||
|
||
logger = get_logger("wake_word")
|
||
|
||
# 延迟加载可选依赖
|
||
try:
|
||
from pypinyin import lazy_pinyin, Style
|
||
PYPINYIN_AVAILABLE = True
|
||
except ImportError:
|
||
PYPINYIN_AVAILABLE = False
|
||
logger.warning("pypinyin 未安装,拼音匹配功能将受限")
|
||
|
||
|
||
class WakeWordDetector:
|
||
"""
|
||
唤醒词检测器
|
||
|
||
支持多种匹配模式:
|
||
- 精确匹配:完全匹配唤醒词
|
||
- 模糊匹配:同音字、拼音变体
|
||
- 部分匹配:只匹配部分唤醒词
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化唤醒词检测器"""
|
||
logger.info("初始化唤醒词检测器...")
|
||
|
||
# 从配置加载
|
||
self.primary = WAKE_WORD_PRIMARY
|
||
self.variants = WAKE_WORD_VARIANTS or []
|
||
self.matching_config = WAKE_WORD_MATCHING_CONFIG or {}
|
||
|
||
# 匹配配置
|
||
self.enable_fuzzy = self.matching_config.get("enable_fuzzy", True)
|
||
self.enable_partial = self.matching_config.get("enable_partial", True)
|
||
self.ignore_case = self.matching_config.get("ignore_case", True)
|
||
self.ignore_spaces = self.matching_config.get("ignore_spaces", True)
|
||
self.min_match_length = self.matching_config.get("min_match_length", 2)
|
||
self.similarity_threshold = self.matching_config.get("similarity_threshold", 0.7)
|
||
|
||
# 构建匹配模式
|
||
self._build_patterns()
|
||
|
||
logger.info(f"唤醒词检测器初始化完成")
|
||
logger.info(f" 主唤醒词: {self.primary}")
|
||
logger.info(f" 变体数量: {len(self.variants)}")
|
||
logger.info(f" 模糊匹配: {'启用' if self.enable_fuzzy else '禁用'}")
|
||
logger.info(f" 部分匹配: {'启用' if self.enable_partial else '禁用'}")
|
||
|
||
def _build_patterns(self):
|
||
"""构建匹配模式(预编译正则表达式)"""
|
||
# 标准化所有变体(去除空格、转小写等)
|
||
self.normalized_variants = []
|
||
|
||
for variant in self.variants:
|
||
normalized = self._normalize_text(variant)
|
||
if normalized:
|
||
self.normalized_variants.append(normalized)
|
||
|
||
# 去重
|
||
self.normalized_variants = list(set(self.normalized_variants))
|
||
|
||
# 按长度降序排序(优先匹配长模式)
|
||
self.normalized_variants.sort(key=len, reverse=True)
|
||
|
||
# 构建正则表达式模式
|
||
patterns = []
|
||
for variant in self.normalized_variants:
|
||
# 转义特殊字符
|
||
escaped = re.escape(variant)
|
||
patterns.append(escaped)
|
||
|
||
# 编译单一正则表达式
|
||
if patterns:
|
||
self.pattern = re.compile('|'.join(patterns), re.IGNORECASE if self.ignore_case else 0)
|
||
else:
|
||
self.pattern = None
|
||
|
||
logger.debug(f"构建了 {len(self.normalized_variants)} 个匹配模式")
|
||
|
||
def _normalize_text(self, text: str) -> str:
|
||
"""
|
||
标准化文本(用于匹配)
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
标准化后的文本
|
||
"""
|
||
if not text:
|
||
return ""
|
||
|
||
normalized = text
|
||
|
||
# 忽略大小写
|
||
if self.ignore_case:
|
||
normalized = normalized.lower()
|
||
|
||
# 忽略空格
|
||
if self.ignore_spaces:
|
||
normalized = normalized.replace(' ', '').replace('\t', '').replace('\n', '')
|
||
|
||
return normalized.strip()
|
||
|
||
def _get_pinyin(self, text: str) -> str:
|
||
"""
|
||
获取文本的拼音(用于拼音匹配)
|
||
|
||
Args:
|
||
text: 中文文本
|
||
|
||
Returns:
|
||
拼音字符串(小写,无空格)
|
||
"""
|
||
if not PYPINYIN_AVAILABLE:
|
||
return ""
|
||
|
||
try:
|
||
pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
|
||
return ''.join(pinyin_list).lower()
|
||
except Exception as e:
|
||
logger.debug(f"拼音转换失败: {e}")
|
||
return ""
|
||
|
||
def _fuzzy_match(self, text: str, variant: str) -> bool:
|
||
"""
|
||
模糊匹配(同音字、拼音)
|
||
|
||
Args:
|
||
text: 待匹配文本
|
||
variant: 变体文本
|
||
|
||
Returns:
|
||
是否匹配
|
||
"""
|
||
# 1. 精确匹配(已标准化)
|
||
normalized_text = self._normalize_text(text)
|
||
normalized_variant = self._normalize_text(variant)
|
||
|
||
if normalized_text == normalized_variant:
|
||
return True
|
||
|
||
# 2. 拼音匹配
|
||
if PYPINYIN_AVAILABLE:
|
||
text_pinyin = self._get_pinyin(text)
|
||
variant_pinyin = self._get_pinyin(variant)
|
||
|
||
if text_pinyin and variant_pinyin:
|
||
# 完全匹配拼音
|
||
if text_pinyin == variant_pinyin:
|
||
return True
|
||
|
||
# 部分匹配拼音(至少匹配一半)
|
||
if len(variant_pinyin) >= 2:
|
||
# 检查是否包含变体的拼音
|
||
if variant_pinyin in text_pinyin or text_pinyin in variant_pinyin:
|
||
# 计算相似度
|
||
similarity = min(len(variant_pinyin), len(text_pinyin)) / max(len(variant_pinyin), len(text_pinyin))
|
||
if similarity >= self.similarity_threshold:
|
||
return True
|
||
|
||
# 3. 字符级相似度匹配(简单实现)
|
||
if len(normalized_text) >= self.min_match_length and len(normalized_variant) >= self.min_match_length:
|
||
# 检查是否包含变体
|
||
if normalized_variant in normalized_text or normalized_text in normalized_variant:
|
||
return True
|
||
|
||
return False
|
||
|
||
def _partial_match(self, text: str) -> bool:
|
||
"""
|
||
部分匹配(只匹配部分唤醒词,如主词较长时取前半段;短词请在配置 variants 中列出)
|
||
|
||
Args:
|
||
text: 待匹配文本
|
||
|
||
Returns:
|
||
是否匹配
|
||
"""
|
||
if not self.enable_partial:
|
||
return False
|
||
|
||
normalized_text = self._normalize_text(text)
|
||
|
||
# 检查是否包含主唤醒词的一部分
|
||
if self.primary:
|
||
# 提取主唤醒词的前半部分(如四字词可拆成前两字)
|
||
primary_normalized = self._normalize_text(self.primary)
|
||
if len(primary_normalized) >= self.min_match_length * 2:
|
||
half_length = len(primary_normalized) // 2
|
||
half_wake_word = primary_normalized[:half_length]
|
||
|
||
if len(half_wake_word) >= self.min_match_length:
|
||
if half_wake_word in normalized_text:
|
||
return True
|
||
|
||
return False
|
||
|
||
@lru_cache(maxsize=256)
|
||
def detect(self, text: str) -> Tuple[bool, Optional[str]]:
|
||
"""
|
||
检测文本中是否包含唤醒词
|
||
|
||
Args:
|
||
text: 待检测文本
|
||
|
||
Returns:
|
||
(是否匹配, 匹配到的唤醒词)
|
||
"""
|
||
if not text or not self.pattern:
|
||
return False, None
|
||
|
||
normalized_text = self._normalize_text(text)
|
||
|
||
# 1. 精确匹配(使用正则表达式)
|
||
if self.pattern:
|
||
match = self.pattern.search(normalized_text)
|
||
if match:
|
||
matched_text = match.group(0)
|
||
logger.debug(f"精确匹配到唤醒词: {matched_text}")
|
||
return True, matched_text
|
||
|
||
# 2. 模糊匹配(同音字、拼音)
|
||
if self.enable_fuzzy:
|
||
for variant in self.normalized_variants:
|
||
if self._fuzzy_match(normalized_text, variant):
|
||
logger.debug(f"模糊匹配到唤醒词变体: {variant}")
|
||
return True, variant
|
||
|
||
# 3. 部分匹配
|
||
if self._partial_match(normalized_text):
|
||
logger.debug(f"部分匹配到唤醒词")
|
||
return True, self.primary[:len(self.primary)//2] if self.primary else None
|
||
|
||
return False, None
|
||
|
||
def extract_command_text(self, text: str) -> Optional[str]:
|
||
"""
|
||
从文本中提取命令部分(移除唤醒词)
|
||
|
||
Args:
|
||
text: 包含唤醒词的完整文本
|
||
|
||
Returns:
|
||
提取的命令文本,如果未检测到唤醒词返回None
|
||
"""
|
||
is_wake, matched_wake_word = self.detect(text)
|
||
|
||
if not is_wake:
|
||
return None
|
||
|
||
# 标准化文本用于查找
|
||
normalized_text = self._normalize_text(text)
|
||
normalized_wake = self._normalize_text(matched_wake_word) if matched_wake_word else ""
|
||
|
||
if not normalized_wake or normalized_wake not in normalized_text:
|
||
return None
|
||
|
||
# 找到唤醒词在标准化文本中的位置
|
||
idx = normalized_text.find(normalized_wake)
|
||
if idx < 0:
|
||
return None
|
||
|
||
# 方法1:尝试在原始文本中精确查找匹配的变体
|
||
original_text = text
|
||
text_lower = original_text.lower()
|
||
|
||
# 查找所有可能的变体在原始文本中的位置
|
||
best_match_idx = -1
|
||
best_match_length = 0
|
||
|
||
# 检查配置中的所有变体
|
||
for variant in self.variants:
|
||
variant_normalized = self._normalize_text(variant)
|
||
if variant_normalized == normalized_wake:
|
||
# 这个变体匹配到了,尝试在原始文本中找到它
|
||
variant_lower = variant.lower()
|
||
if variant_lower in text_lower:
|
||
variant_idx = text_lower.find(variant_lower)
|
||
if variant_idx >= 0:
|
||
# 选择最长的匹配(更准确)
|
||
if len(variant) > best_match_length:
|
||
best_match_idx = variant_idx
|
||
best_match_length = len(variant)
|
||
|
||
# 如果找到了匹配的变体
|
||
if best_match_idx >= 0:
|
||
command_start = best_match_idx + best_match_length
|
||
command_text = original_text[command_start:].strip()
|
||
# 移除开头的标点符号
|
||
command_text = command_text.lstrip(',。、,.').strip()
|
||
return command_text if command_text else None
|
||
|
||
# 方法2:回退方案 - 使用字符计数近似定位
|
||
# 计算标准化文本中唤醒词结束位置对应的原始文本位置
|
||
wake_end_in_normalized = idx + len(normalized_wake)
|
||
|
||
# 计算原始文本中对应的字符位置
|
||
char_count = 0
|
||
for i, char in enumerate(original_text):
|
||
normalized_char = self._normalize_text(char)
|
||
if normalized_char:
|
||
if char_count >= wake_end_in_normalized:
|
||
command_text = original_text[i:].strip()
|
||
command_text = command_text.lstrip(',。、,.').strip()
|
||
return command_text if command_text else None
|
||
char_count += 1
|
||
|
||
return None
|
||
|
||
|
||
# 全局单例
|
||
_global_detector: Optional[WakeWordDetector] = None
|
||
|
||
|
||
def get_wake_word_detector() -> WakeWordDetector:
|
||
"""获取全局唤醒词检测器实例(单例模式)"""
|
||
global _global_detector
|
||
if _global_detector is None:
|
||
_global_detector = WakeWordDetector()
|
||
return _global_detector
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试代码
|
||
detector = WakeWordDetector()
|
||
|
||
test_cases = [
|
||
("无人机,现在起飞", True),
|
||
("wu ren ji 现在起飞", True),
|
||
("Wu Ren Ji 现在起飞", True),
|
||
("五人机,现在起飞", True),
|
||
("现在起飞", False),
|
||
("无人,现在起飞", True), # 变体列表中的短说
|
||
("人机 前进", False), # 已移除单独「人机」变体,避免子串误唤醒
|
||
("无人机 前进", True),
|
||
]
|
||
|
||
print("=" * 60)
|
||
print("唤醒词检测测试")
|
||
print("=" * 60)
|
||
|
||
for text, expected in test_cases:
|
||
is_wake, matched = detector.detect(text)
|
||
command_text = detector.extract_command_text(text)
|
||
status = "OK" if is_wake == expected else "FAIL"
|
||
print(f"{status} 文本: {text}")
|
||
print(f" 匹配: {is_wake} (期望: {expected})")
|
||
print(f" 匹配词: {matched}")
|
||
print(f" 提取命令: {command_text}")
|
||
print()
|