""" 文本预处理模块 - 高性能实时语音转命令文本处理 本模块主要用于对语音识别输出的文本进行清洗、纠错、简繁转换、分词和参数提取, 便于后续命令意图分析和参数解析。 主要功能: 1. 文本清理:去除杂音、特殊字符、多余空格 2. 纠错:同音字纠正、常见错误修正 3. 简繁转换:统一文本格式(繁体转简体) 4. 分词:使用jieba分词,便于关键词匹配 5. 数字提取:提取距离(米)、速度(米/秒)、时间(秒) 6. 关键词识别:识别命令关键词(起飞、降落、前进等) 性能优化: - LRU缓存常用处理结果(分词、中文数字解析、完整预处理) - 预编译正则表达式 - 优化字符串操作(使用正则表达式批量替换) - 延迟加载可选依赖 - 缓存关键词排序结果 """ import re from typing import Dict, Optional, List, Tuple, Set from functools import lru_cache from dataclasses import dataclass from voice_drone.logging_ import get_logger from voice_drone.core.configuration import KEYWORDS_CONFIG, SYSTEM_TEXT_PREPROCESSOR_CONFIG import warnings warnings.filterwarnings("ignore") logger = get_logger("text_preprocessor") # 延迟加载可选依赖 try: from opencc import OpenCC OPENCC_AVAILABLE = True except ImportError: OPENCC_AVAILABLE = False logger.warning("opencc 未安装,将跳过简繁转换功能") try: import jieba JIEBA_AVAILABLE = True # 初始化jieba,加载词典 jieba.initialize() except ImportError: JIEBA_AVAILABLE = False logger.warning("jieba 未安装,将跳过分词功能") try: from pypinyin import lazy_pinyin, Style PYPINYIN_AVAILABLE = True except ImportError: PYPINYIN_AVAILABLE = False logger.warning("pypinyin 未安装,将跳过拼音相关功能") @dataclass class ExtractedParams: """提取的参数信息""" distance: Optional[float] = None # 距离(米) speed: Optional[float] = None # 速度(米/秒) duration: Optional[float] = None # 时间(秒) command_keyword: Optional[str] = None # 识别的命令关键词 @dataclass class PreprocessedText: """预处理后的文本结果""" cleaned_text: str # 清理后的文本 normalized_text: str # 规范化后的文本(简繁转换后) words: List[str] # 分词结果 params: ExtractedParams # 提取的参数 original_text: str # 原始文本 class TextPreprocessor: """ 高性能文本预处理器 针对实时语音转命令场景优化,支持: - 文本清理和规范化 - 同音字纠错 - 简繁转换 - 分词 - 数字和单位提取 - 命令关键词识别 """ def __init__(self, enable_traditional_to_simplified: Optional[bool] = None, enable_segmentation: Optional[bool] = None, enable_correction: Optional[bool] = None, enable_number_extraction: Optional[bool] = None, enable_keyword_detection: Optional[bool] = None, lru_cache_size: Optional[int] = None): """ 初始化文本预处理器 Args: enable_traditional_to_simplified: 是否启用繁简转换(None时从配置读取) enable_segmentation: 是否启用分词(None时从配置读取) enable_correction: 是否启用纠错(None时从配置读取) enable_number_extraction: 是否启用数字提取(None时从配置读取) enable_keyword_detection: 是否启用关键词检测(None时从配置读取) lru_cache_size: LRU缓存大小(None时从配置读取) """ # 从配置读取参数(如果未提供) config = SYSTEM_TEXT_PREPROCESSOR_CONFIG or {} self.enable_traditional_to_simplified = ( enable_traditional_to_simplified if enable_traditional_to_simplified is not None else config.get("enable_traditional_to_simplified", True) ) and OPENCC_AVAILABLE self.enable_segmentation = ( enable_segmentation if enable_segmentation is not None else config.get("enable_segmentation", True) ) and JIEBA_AVAILABLE self.enable_correction = ( enable_correction if enable_correction is not None else config.get("enable_correction", True) ) self.enable_number_extraction = ( enable_number_extraction if enable_number_extraction is not None else config.get("enable_number_extraction", True) ) self.enable_keyword_detection = ( enable_keyword_detection if enable_keyword_detection is not None else config.get("enable_keyword_detection", True) ) cache_size = ( lru_cache_size if lru_cache_size is not None else config.get("lru_cache_size", 512) ) # 初始化OpenCC(如果可用) if self.enable_traditional_to_simplified: self.opencc = OpenCC('t2s') # 繁体转简体 else: self.opencc = None # 加载关键词映射(命令关键词 -> 命令类型) self._load_keyword_mapping() # 预编译正则表达式(性能优化) self._compile_regex_patterns() # 加载纠错字典 self._load_correction_dict() # 设置LRU缓存大小 self._cache_size = cache_size # 创建缓存装饰器(用于分词、中文数字解析、完整预处理) self._segment_text_cached = lru_cache(maxsize=cache_size)(self._segment_text_impl) self._parse_chinese_number_cached = lru_cache(maxsize=128)(self._parse_chinese_number_impl) self._preprocess_cached = lru_cache(maxsize=cache_size)(self._preprocess_impl) self._preprocess_fast_cached = lru_cache(maxsize=cache_size)(self._preprocess_fast_impl) logger.info(f"文本预处理器初始化完成") logger.info(f" 繁简转换: {'启用' if self.enable_traditional_to_simplified else '禁用'}") logger.info(f" 分词: {'启用' if self.enable_segmentation else '禁用'}") logger.info(f" 纠错: {'启用' if self.enable_correction else '禁用'}") logger.info(f" 数字提取: {'启用' if self.enable_number_extraction else '禁用'}") logger.info(f" 关键词检测: {'启用' if self.enable_keyword_detection else '禁用'}") def _load_keyword_mapping(self): """加载关键词映射表(命令关键词 -> 命令类型)""" self.keyword_to_command: Dict[str, str] = {} if KEYWORDS_CONFIG: for command_type, keywords in KEYWORDS_CONFIG.items(): if isinstance(keywords, list): for keyword in keywords: self.keyword_to_command[keyword] = command_type elif isinstance(keywords, str): self.keyword_to_command[keywords] = command_type # 预计算排序结果(按长度降序,优先匹配长关键词) self.sorted_keywords = sorted( self.keyword_to_command.keys(), key=len, reverse=True ) logger.debug(f"加载了 {len(self.keyword_to_command)} 个关键词映射") def _compile_regex_patterns(self): """预编译正则表达式(性能优化)""" # 清理文本:去除特殊字符、多余空格 self.pattern_clean_special = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9\s米每秒秒分小时\.]') self.pattern_clean_spaces = re.compile(r'\s+') # 数字提取模式 # 距离:数字 + (米|m|M|公尺)(排除速度单位) self.pattern_distance = re.compile( r'(\d+\.?\d*)\s*(?:米|m|M|公尺|meter|meters)(?!\s*[/每]?\s*秒)', re.IGNORECASE ) # 速度:数字 + (米每秒|m/s|米/秒|米秒|mps|MPS)(优先匹配) # 支持"三米每秒"、"5米/秒"、"2.5米每秒"等格式 self.pattern_speed = re.compile( r'(?:速度\s*[::]?\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:米\s*[/每]?\s*秒|m\s*/\s*s|mps|MPS)', re.IGNORECASE ) # 时间:数字 + (秒|s|S|分钟|分|min|小时|时|h|H) # 支持"持续10秒"、"5分钟"等格式 self.pattern_duration = re.compile( r'(?:持续\s*|持续\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:秒|s|S|分钟|分|min|小时|时|h|H)', re.IGNORECASE ) # 中文数字映射(用于识别"十米"、"五秒"等) self.chinese_numbers = { '零': 0, '一': 1, '二': 2, '三': 3, '四': 4, '五': 5, '六': 6, '七': 7, '八': 8, '九': 9, '十': 10, '壹': 1, '贰': 2, '叁': 3, '肆': 4, '伍': 5, '陆': 6, '柒': 7, '捌': 8, '玖': 9, '拾': 10, '百': 100, '千': 1000, '万': 10000 } # 中文数字模式(如"十米"、"五秒"、"二十米"、"三米每秒") # 支持"二十"、"三十"等复合数字 self.pattern_chinese_number = re.compile( r'([零一二三四五六七八九十壹贰叁肆伍陆柒捌玖拾百千万]+)\s*(?:米|秒|分|小时|米\s*[/每]?\s*秒)' ) def _load_correction_dict(self): """加载纠错字典(同音字、常见错误)并编译正则表达式""" # 无人机控制相关的常见同音字/错误字映射(只保留实际需要纠错的) correction_pairs = [ # 动作相关(同音字纠错) ('起非', '起飞'), ('降洛', '降落'), ('悬廷', '悬停'), ('停只', '停止'), ] # 构建正则表达式模式(按长度降序,优先匹配长模式) if correction_pairs: # 按长度降序排序,优先匹配长模式 sorted_pairs = sorted(correction_pairs, key=lambda x: len(x[0]), reverse=True) # 构建替换映射字典(用于快速查找) self.correction_replacements = {wrong: correct for wrong, correct in sorted_pairs} # 编译单一正则表达式 patterns = [re.escape(wrong) for wrong, _ in sorted_pairs] self.correction_pattern = re.compile('|'.join(patterns)) else: self.correction_pattern = None self.correction_replacements = {} logger.debug(f"加载了 {len(correction_pairs)} 个纠错规则") def clean_text(self, text: str) -> str: """ 清理文本:去除特殊字符、多余空格 Args: text: 原始文本 Returns: 清理后的文本 """ if not text: return "" # 去除特殊字符(保留中文、英文、数字、空格、常用标点) text = self.pattern_clean_special.sub('', text) # 统一空格(多个空格合并为一个) text = self.pattern_clean_spaces.sub(' ', text) # 去除首尾空格 text = text.strip() return text def correct_text(self, text: str) -> str: """ 纠错:同音字、常见错误修正(使用正则表达式优化) Args: text: 待纠错文本 Returns: 纠错后的文本 """ if not self.enable_correction or not text or not self.correction_pattern: return text # 使用正则表达式一次性替换所有模式(性能优化) def replacer(match): matched = match.group(0) # 直接从字典中查找替换(O(1)查找) return self.correction_replacements.get(matched, matched) return self.correction_pattern.sub(replacer, text) def traditional_to_simplified(self, text: str) -> str: """ 繁体转简体 Args: text: 待转换文本 Returns: 转换后的文本 """ if not self.enable_traditional_to_simplified or not self.opencc or not text: return text try: return self.opencc.convert(text) except Exception as e: logger.warning(f"繁简转换失败: {e}") return text def _segment_text_impl(self, text: str) -> List[str]: """ 分词实现(内部方法,不带缓存) Args: text: 待分词文本 Returns: 分词结果列表 """ if not self.enable_segmentation or not text: return [text] if text else [] try: words = list(jieba.cut(text, cut_all=False)) # 过滤空字符串 words = [w.strip() for w in words if w.strip()] return words except Exception as e: logger.warning(f"分词失败: {e}") return [text] if text else [] def segment_text(self, text: str) -> List[str]: """ 分词(带缓存) Args: text: 待分词文本 Returns: 分词结果列表 """ return self._segment_text_cached(text) def extract_numbers(self, text: str) -> ExtractedParams: """ 提取数字和单位(距离、速度、时间) Args: text: 待提取文本 Returns: ExtractedParams对象,包含提取的参数 """ params = ExtractedParams() if not self.enable_number_extraction or not text: return params # 优先提取速度(避免被误识别为距离) speed_match = self.pattern_speed.search(text) if speed_match: try: speed_str = speed_match.group(1) # 尝试解析中文数字 if speed_str.isdigit() or '.' in speed_str: params.speed = float(speed_str) else: # 中文数字(使用缓存) chinese_speed = self._parse_chinese_number(speed_str) if chinese_speed is not None: params.speed = float(chinese_speed) except (ValueError, AttributeError): pass # 提取距离(米,排除速度单位) distance_match = self.pattern_distance.search(text) if distance_match: try: params.distance = float(distance_match.group(1)) except ValueError: pass # 提取时间(秒) duration_matches = self.pattern_duration.finditer(text) # 查找所有匹配 for duration_match in duration_matches: try: duration_str = duration_match.group(1) duration_unit = duration_match.group(2).lower() if len(duration_match.groups()) > 1 else '秒' # 解析数字(支持中文数字) if duration_str.isdigit() or '.' in duration_str: duration_value = float(duration_str) else: # 中文数字 chinese_duration = self._parse_chinese_number(duration_str) if chinese_duration is None: continue duration_value = float(chinese_duration) # 转换为秒 if '分' in duration_unit or 'min' in duration_unit: params.duration = duration_value * 60 break # 取第一个匹配 elif '小时' in duration_unit or 'h' in duration_unit: params.duration = duration_value * 3600 break else: # 秒 params.duration = duration_value break except (ValueError, IndexError, AttributeError): continue # 尝试提取中文数字(如"十米"、"五秒"、"二十米"、"三米每秒") chinese_matches = self.pattern_chinese_number.finditer(text) for chinese_match in chinese_matches: try: chinese_num_str = chinese_match.group(1) full_match = chinese_match.group(0) # 解析中文数字(使用缓存) num_value = self._parse_chinese_number(chinese_num_str) if num_value is not None: # 判断单位类型 if '米每秒' in full_match or '米/秒' in full_match or '米每' in full_match: # 速度单位 if params.speed is None: params.speed = float(num_value) elif '米' in full_match and '秒' not in full_match: # 距离单位(不包含"秒") if params.distance is None: params.distance = float(num_value) elif '秒' in full_match and '米' not in full_match: # 时间单位(不包含"米") if params.duration is None: params.duration = float(num_value) elif '分' in full_match and '米' not in full_match: # 时间单位(分钟) if params.duration is None: params.duration = float(num_value) * 60 except (ValueError, IndexError, AttributeError): continue return params def _parse_chinese_number_impl(self, chinese_num: str) -> Optional[int]: """ 解析中文数字实现(内部方法,不带缓存) Args: chinese_num: 中文数字字符串 Returns: 对应的阿拉伯数字,解析失败返回None """ if not chinese_num: return None # 单个数字 if chinese_num in self.chinese_numbers: return self.chinese_numbers[chinese_num] # "十" -> 10 if chinese_num == '十' or chinese_num == '拾': return 10 # "十一" -> 11, "十二" -> 12, ... if chinese_num.startswith('十') or chinese_num.startswith('拾'): rest = chinese_num[1:] if rest in self.chinese_numbers: return 10 + self.chinese_numbers[rest] # "二十" -> 20, "三十" -> 30, ... if chinese_num.endswith('十') or chinese_num.endswith('拾'): prefix = chinese_num[:-1] if prefix in self.chinese_numbers: return self.chinese_numbers[prefix] * 10 # "二十五" -> 25, "三十五" -> 35, ... if '十' in chinese_num or '拾' in chinese_num: parts = chinese_num.replace('拾', '十').split('十') if len(parts) == 2: tens_part = parts[0] if parts[0] else '一' # "十五" -> parts[0]为空 ones_part = parts[1] if parts[1] else '' tens = self.chinese_numbers.get(tens_part, 1) if tens_part else 1 ones = self.chinese_numbers.get(ones_part, 0) if ones_part else 0 return tens * 10 + ones return None def _parse_chinese_number(self, chinese_num: str) -> Optional[int]: """ 解析中文数字(支持"十"、"二十"、"三十"、"五"等,带缓存) Args: chinese_num: 中文数字字符串 Returns: 对应的阿拉伯数字,解析失败返回None """ return self._parse_chinese_number_cached(chinese_num) def detect_keyword(self, text: str, words: Optional[List[str]] = None) -> Optional[str]: """ 检测命令关键词(使用缓存的排序结果) Args: text: 待检测文本 words: 分词结果(如果已分词,可传入以提高性能) Returns: 检测到的命令类型(如"takeoff"、"forward"等),未检测到返回None """ if not self.enable_keyword_detection or not text: return None # 如果已分词,优先使用分词结果匹配 if words: for word in words: if word in self.keyword_to_command: return self.keyword_to_command[word] # 使用缓存的排序结果(按长度降序,优先匹配长关键词) for keyword in self.sorted_keywords: if keyword in text: return self.keyword_to_command[keyword] return None def preprocess(self, text: str) -> PreprocessedText: """ 完整的文本预处理流程(带缓存) Args: text: 原始文本 Returns: PreprocessedText对象,包含所有预处理结果 """ return self._preprocess_cached(text) def _preprocess_impl(self, text: str) -> PreprocessedText: """ 完整的文本预处理流程实现(内部方法,不带缓存) Args: text: 原始文本 Returns: PreprocessedText对象,包含所有预处理结果 """ if not text: return PreprocessedText( cleaned_text="", normalized_text="", words=[], params=ExtractedParams(), original_text=text ) original_text = text # 1. 清理文本 cleaned_text = self.clean_text(text) # 2. 纠错 corrected_text = self.correct_text(cleaned_text) # 3. 繁简转换 normalized_text = self.traditional_to_simplified(corrected_text) # 4. 分词 words = self.segment_text(normalized_text) # 5. 提取数字和单位 params = self.extract_numbers(normalized_text) # 6. 检测关键词 command_keyword = self.detect_keyword(normalized_text, words) params.command_keyword = command_keyword return PreprocessedText( cleaned_text=cleaned_text, normalized_text=normalized_text, words=words, params=params, original_text=original_text ) def preprocess_fast(self, text: str) -> Tuple[str, ExtractedParams]: """ 快速预处理(仅返回规范化文本和参数,不进行分词,带缓存) 适用于实时场景,性能优先 Args: text: 原始文本 Returns: (规范化文本, 提取的参数) """ return self._preprocess_fast_cached(text) def _preprocess_fast_impl(self, text: str) -> Tuple[str, ExtractedParams]: """ 快速预处理实现(内部方法,不带缓存) Args: text: 原始文本 Returns: (规范化文本, 提取的参数) """ if not text: return "", ExtractedParams() # 1. 清理 cleaned = self.clean_text(text) # 2. 纠错 corrected = self.correct_text(cleaned) # 3. 繁简转换 normalized = self.traditional_to_simplified(corrected) # 4. 提取参数(不进行分词,提高性能) params = self.extract_numbers(normalized) # 5. 检测关键词(在完整文本中搜索) params.command_keyword = self.detect_keyword(normalized, words=None) return normalized, params # 全局单例(可选,用于提高性能) _global_preprocessor: Optional[TextPreprocessor] = None def get_preprocessor() -> TextPreprocessor: """获取全局预处理器实例(单例模式)""" global _global_preprocessor if _global_preprocessor is None: _global_preprocessor = TextPreprocessor() return _global_preprocessor if __name__ == "__main__": from command import Command # 测试代码 preprocessor = TextPreprocessor() test_cases = [ "现在起飞往前飞,飞10米,速度为5米每秒", "向前飞二十米,速度三米每秒", "立刻降落", "悬停五秒", "向右飛十米", # 繁体测试 "往左飛,速度2.5米/秒,持續10秒", ] print("=" * 60) print("文本预处理器测试") print("=" * 60) for i, test_text in enumerate(test_cases, 1): print(f"\n测试 {i}: {test_text}") print("-" * 60) # 完整预处理 result = preprocessor.preprocess(test_text) print(f"原始文本: {result.original_text}") print(f"清理后: {result.cleaned_text}") print(f"规范化: {result.normalized_text}") print(f"分词: {result.words}") print(f"提取参数:") print(f" 距离: {result.params.distance} 米") print(f" 速度: {result.params.speed} 米/秒") print(f" 时间: {result.params.duration} 秒") print(f" 命令关键词: {result.params.command_keyword}") command = Command.create(result.params.command_keyword, 1, result.params.distance, result.params.speed, result.params.duration) print(f"命令: {command.to_dict()}") # 快速预处理 fast_text, fast_params = preprocessor.preprocess_fast(test_text) print(f"\n快速预处理结果:") print(f" 规范化文本: {fast_text}") print(f" 命令关键词: {fast_params.command_keyword}")