DroneMind/voice_drone/core/text_preprocessor.py
2026-04-14 09:54:26 +08:00

717 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
文本预处理模块 - 高性能实时语音转命令文本处理
本模块主要用于对语音识别输出的文本进行清洗、纠错、简繁转换、分词和参数提取,
便于后续命令意图分析和参数解析。
主要功能:
1. 文本清理:去除杂音、特殊字符、多余空格
2. 纠错:同音字纠正、常见错误修正
3. 简繁转换:统一文本格式(繁体转简体)
4. 分词:使用jieba分词,便于关键词匹配
5. 数字提取:提取距离(米)、速度(米/秒)、时间(秒)
6. 关键词识别:识别命令关键词(起飞、降落、前进等)
性能优化:
- LRU缓存常用处理结果(分词、中文数字解析、完整预处理)
- 预编译正则表达式
- 优化字符串操作(使用正则表达式批量替换)
- 延迟加载可选依赖
- 缓存关键词排序结果
"""
import re
from typing import Dict, Optional, List, Tuple, Set
from functools import lru_cache
from dataclasses import dataclass
from voice_drone.logging_ import get_logger
from voice_drone.core.configuration import KEYWORDS_CONFIG, SYSTEM_TEXT_PREPROCESSOR_CONFIG
import warnings
warnings.filterwarnings("ignore")
logger = get_logger("text_preprocessor")
# 延迟加载可选依赖
try:
from opencc import OpenCC
OPENCC_AVAILABLE = True
except ImportError:
OPENCC_AVAILABLE = False
logger.warning("opencc 未安装,将跳过简繁转换功能")
try:
import jieba
JIEBA_AVAILABLE = True
# 初始化jieba,加载词典
jieba.initialize()
except ImportError:
JIEBA_AVAILABLE = False
logger.warning("jieba 未安装,将跳过分词功能")
try:
from pypinyin import lazy_pinyin, Style
PYPINYIN_AVAILABLE = True
except ImportError:
PYPINYIN_AVAILABLE = False
logger.warning("pypinyin 未安装,将跳过拼音相关功能")
@dataclass
class ExtractedParams:
"""提取的参数信息"""
distance: Optional[float] = None # 距离(米)
speed: Optional[float] = None # 速度(米/秒)
duration: Optional[float] = None # 时间(秒)
command_keyword: Optional[str] = None # 识别的命令关键词
@dataclass
class PreprocessedText:
"""预处理后的文本结果"""
cleaned_text: str # 清理后的文本
normalized_text: str # 规范化后的文本(简繁转换后)
words: List[str] # 分词结果
params: ExtractedParams # 提取的参数
original_text: str # 原始文本
class TextPreprocessor:
"""
高性能文本预处理器
针对实时语音转命令场景优化,支持:
- 文本清理和规范化
- 同音字纠错
- 简繁转换
- 分词
- 数字和单位提取
- 命令关键词识别
"""
def __init__(self,
enable_traditional_to_simplified: Optional[bool] = None,
enable_segmentation: Optional[bool] = None,
enable_correction: Optional[bool] = None,
enable_number_extraction: Optional[bool] = None,
enable_keyword_detection: Optional[bool] = None,
lru_cache_size: Optional[int] = None):
"""
初始化文本预处理器
Args:
enable_traditional_to_simplified: 是否启用繁简转换(None时从配置读取)
enable_segmentation: 是否启用分词(None时从配置读取)
enable_correction: 是否启用纠错(None时从配置读取)
enable_number_extraction: 是否启用数字提取(None时从配置读取)
enable_keyword_detection: 是否启用关键词检测(None时从配置读取)
lru_cache_size: LRU缓存大小(None时从配置读取)
"""
# 从配置读取参数(如果未提供)
config = SYSTEM_TEXT_PREPROCESSOR_CONFIG or {}
self.enable_traditional_to_simplified = (
enable_traditional_to_simplified
if enable_traditional_to_simplified is not None
else config.get("enable_traditional_to_simplified", True)
) and OPENCC_AVAILABLE
self.enable_segmentation = (
enable_segmentation
if enable_segmentation is not None
else config.get("enable_segmentation", True)
) and JIEBA_AVAILABLE
self.enable_correction = (
enable_correction
if enable_correction is not None
else config.get("enable_correction", True)
)
self.enable_number_extraction = (
enable_number_extraction
if enable_number_extraction is not None
else config.get("enable_number_extraction", True)
)
self.enable_keyword_detection = (
enable_keyword_detection
if enable_keyword_detection is not None
else config.get("enable_keyword_detection", True)
)
cache_size = (
lru_cache_size
if lru_cache_size is not None
else config.get("lru_cache_size", 512)
)
# 初始化OpenCC(如果可用)
if self.enable_traditional_to_simplified:
self.opencc = OpenCC('t2s') # 繁体转简体
else:
self.opencc = None
# 加载关键词映射(命令关键词 -> 命令类型)
self._load_keyword_mapping()
# 预编译正则表达式(性能优化)
self._compile_regex_patterns()
# 加载纠错字典
self._load_correction_dict()
# 设置LRU缓存大小
self._cache_size = cache_size
# 创建缓存装饰器(用于分词、中文数字解析、完整预处理)
self._segment_text_cached = lru_cache(maxsize=cache_size)(self._segment_text_impl)
self._parse_chinese_number_cached = lru_cache(maxsize=128)(self._parse_chinese_number_impl)
self._preprocess_cached = lru_cache(maxsize=cache_size)(self._preprocess_impl)
self._preprocess_fast_cached = lru_cache(maxsize=cache_size)(self._preprocess_fast_impl)
logger.info(f"文本预处理器初始化完成")
logger.info(f" 繁简转换: {'启用' if self.enable_traditional_to_simplified else '禁用'}")
logger.info(f" 分词: {'启用' if self.enable_segmentation else '禁用'}")
logger.info(f" 纠错: {'启用' if self.enable_correction else '禁用'}")
logger.info(f" 数字提取: {'启用' if self.enable_number_extraction else '禁用'}")
logger.info(f" 关键词检测: {'启用' if self.enable_keyword_detection else '禁用'}")
def _load_keyword_mapping(self):
"""加载关键词映射表(命令关键词 -> 命令类型)"""
self.keyword_to_command: Dict[str, str] = {}
if KEYWORDS_CONFIG:
for command_type, keywords in KEYWORDS_CONFIG.items():
if isinstance(keywords, list):
for keyword in keywords:
self.keyword_to_command[keyword] = command_type
elif isinstance(keywords, str):
self.keyword_to_command[keywords] = command_type
# 预计算排序结果(按长度降序,优先匹配长关键词)
self.sorted_keywords = sorted(
self.keyword_to_command.keys(),
key=len,
reverse=True
)
logger.debug(f"加载了 {len(self.keyword_to_command)} 个关键词映射")
def _compile_regex_patterns(self):
"""预编译正则表达式(性能优化)"""
# 清理文本:去除特殊字符、多余空格
self.pattern_clean_special = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9\s米每秒秒分小时\.]')
self.pattern_clean_spaces = re.compile(r'\s+')
# 数字提取模式
# 距离:数字 + (米|m|M|公尺)(排除速度单位)
self.pattern_distance = re.compile(
r'(\d+\.?\d*)\s*(?:米|m|M|公尺|meter|meters)(?!\s*[/每]?\s*秒)',
re.IGNORECASE
)
# 速度:数字 + (米每秒|m/s|米/秒|米秒|mps|MPS)(优先匹配)
# 支持"三米每秒"、"5米/秒"、"2.5米每秒"等格式
self.pattern_speed = re.compile(
r'(?:速度\s*[:]?\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:米\s*[/每]?\s*秒|m\s*/\s*s|mps|MPS)',
re.IGNORECASE
)
# 时间:数字 + (秒|s|S|分钟|分|min|小时|时|h|H)
# 支持"持续10秒"、"5分钟"等格式
self.pattern_duration = re.compile(
r'(?:持续\s*|持续\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:秒|s|S|分钟|分|min|小时|时|h|H)',
re.IGNORECASE
)
# 中文数字映射(用于识别"十米"、"五秒"等)
self.chinese_numbers = {
'': 0, '': 1, '': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10,
'': 1, '': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10,
'': 100, '': 1000, '': 10000
}
# 中文数字模式(如"十米"、"五秒"、"二十米"、"三米每秒")
# 支持"二十"、"三十"等复合数字
self.pattern_chinese_number = re.compile(
r'([零一二三四五六七八九十壹贰叁肆伍陆柒捌玖拾百千万]+)\s*(?:米|秒|分|小时|米\s*[/每]?\s*秒)'
)
def _load_correction_dict(self):
"""加载纠错字典(同音字、常见错误)并编译正则表达式"""
# 无人机控制相关的常见同音字/错误字映射(只保留实际需要纠错的)
correction_pairs = [
# 动作相关(同音字纠错)
('起非', '起飞'),
('降洛', '降落'),
('悬廷', '悬停'),
('停只', '停止'),
]
# 构建正则表达式模式(按长度降序,优先匹配长模式)
if correction_pairs:
# 按长度降序排序,优先匹配长模式
sorted_pairs = sorted(correction_pairs, key=lambda x: len(x[0]), reverse=True)
# 构建替换映射字典(用于快速查找)
self.correction_replacements = {wrong: correct for wrong, correct in sorted_pairs}
# 编译单一正则表达式
patterns = [re.escape(wrong) for wrong, _ in sorted_pairs]
self.correction_pattern = re.compile('|'.join(patterns))
else:
self.correction_pattern = None
self.correction_replacements = {}
logger.debug(f"加载了 {len(correction_pairs)} 个纠错规则")
def clean_text(self, text: str) -> str:
"""
清理文本:去除特殊字符、多余空格
Args:
text: 原始文本
Returns:
清理后的文本
"""
if not text:
return ""
# 去除特殊字符(保留中文、英文、数字、空格、常用标点)
text = self.pattern_clean_special.sub('', text)
# 统一空格(多个空格合并为一个)
text = self.pattern_clean_spaces.sub(' ', text)
# 去除首尾空格
text = text.strip()
return text
def correct_text(self, text: str) -> str:
"""
纠错:同音字、常见错误修正(使用正则表达式优化)
Args:
text: 待纠错文本
Returns:
纠错后的文本
"""
if not self.enable_correction or not text or not self.correction_pattern:
return text
# 使用正则表达式一次性替换所有模式(性能优化)
def replacer(match):
matched = match.group(0)
# 直接从字典中查找替换(O(1)查找)
return self.correction_replacements.get(matched, matched)
return self.correction_pattern.sub(replacer, text)
def traditional_to_simplified(self, text: str) -> str:
"""
繁体转简体
Args:
text: 待转换文本
Returns:
转换后的文本
"""
if not self.enable_traditional_to_simplified or not self.opencc or not text:
return text
try:
return self.opencc.convert(text)
except Exception as e:
logger.warning(f"繁简转换失败: {e}")
return text
def _segment_text_impl(self, text: str) -> List[str]:
"""
分词实现(内部方法,不带缓存)
Args:
text: 待分词文本
Returns:
分词结果列表
"""
if not self.enable_segmentation or not text:
return [text] if text else []
try:
words = list(jieba.cut(text, cut_all=False))
# 过滤空字符串
words = [w.strip() for w in words if w.strip()]
return words
except Exception as e:
logger.warning(f"分词失败: {e}")
return [text] if text else []
def segment_text(self, text: str) -> List[str]:
"""
分词(带缓存)
Args:
text: 待分词文本
Returns:
分词结果列表
"""
return self._segment_text_cached(text)
def extract_numbers(self, text: str) -> ExtractedParams:
"""
提取数字和单位(距离、速度、时间)
Args:
text: 待提取文本
Returns:
ExtractedParams对象,包含提取的参数
"""
params = ExtractedParams()
if not self.enable_number_extraction or not text:
return params
# 优先提取速度(避免被误识别为距离)
speed_match = self.pattern_speed.search(text)
if speed_match:
try:
speed_str = speed_match.group(1)
# 尝试解析中文数字
if speed_str.isdigit() or '.' in speed_str:
params.speed = float(speed_str)
else:
# 中文数字(使用缓存)
chinese_speed = self._parse_chinese_number(speed_str)
if chinese_speed is not None:
params.speed = float(chinese_speed)
except (ValueError, AttributeError):
pass
# 提取距离(米,排除速度单位)
distance_match = self.pattern_distance.search(text)
if distance_match:
try:
params.distance = float(distance_match.group(1))
except ValueError:
pass
# 提取时间(秒)
duration_matches = self.pattern_duration.finditer(text) # 查找所有匹配
for duration_match in duration_matches:
try:
duration_str = duration_match.group(1)
duration_unit = duration_match.group(2).lower() if len(duration_match.groups()) > 1 else ''
# 解析数字(支持中文数字)
if duration_str.isdigit() or '.' in duration_str:
duration_value = float(duration_str)
else:
# 中文数字
chinese_duration = self._parse_chinese_number(duration_str)
if chinese_duration is None:
continue
duration_value = float(chinese_duration)
# 转换为秒
if '' in duration_unit or 'min' in duration_unit:
params.duration = duration_value * 60
break # 取第一个匹配
elif '小时' in duration_unit or 'h' in duration_unit:
params.duration = duration_value * 3600
break
else: # 秒
params.duration = duration_value
break
except (ValueError, IndexError, AttributeError):
continue
# 尝试提取中文数字(如"十米"、"五秒"、"二十米"、"三米每秒")
chinese_matches = self.pattern_chinese_number.finditer(text)
for chinese_match in chinese_matches:
try:
chinese_num_str = chinese_match.group(1)
full_match = chinese_match.group(0)
# 解析中文数字(使用缓存)
num_value = self._parse_chinese_number(chinese_num_str)
if num_value is not None:
# 判断单位类型
if '米每秒' in full_match or '米/秒' in full_match or '米每' in full_match:
# 速度单位
if params.speed is None:
params.speed = float(num_value)
elif '' in full_match and '' not in full_match:
# 距离单位(不包含"秒")
if params.distance is None:
params.distance = float(num_value)
elif '' in full_match and '' not in full_match:
# 时间单位(不包含"米")
if params.duration is None:
params.duration = float(num_value)
elif '' in full_match and '' not in full_match:
# 时间单位(分钟)
if params.duration is None:
params.duration = float(num_value) * 60
except (ValueError, IndexError, AttributeError):
continue
return params
def _parse_chinese_number_impl(self, chinese_num: str) -> Optional[int]:
"""
解析中文数字实现(内部方法,不带缓存)
Args:
chinese_num: 中文数字字符串
Returns:
对应的阿拉伯数字,解析失败返回None
"""
if not chinese_num:
return None
# 单个数字
if chinese_num in self.chinese_numbers:
return self.chinese_numbers[chinese_num]
# "十" -> 10
if chinese_num == '' or chinese_num == '':
return 10
# "十一" -> 11, "十二" -> 12, ...
if chinese_num.startswith('') or chinese_num.startswith(''):
rest = chinese_num[1:]
if rest in self.chinese_numbers:
return 10 + self.chinese_numbers[rest]
# "二十" -> 20, "三十" -> 30, ...
if chinese_num.endswith('') or chinese_num.endswith(''):
prefix = chinese_num[:-1]
if prefix in self.chinese_numbers:
return self.chinese_numbers[prefix] * 10
# "二十五" -> 25, "三十五" -> 35, ...
if '' in chinese_num or '' in chinese_num:
parts = chinese_num.replace('', '').split('')
if len(parts) == 2:
tens_part = parts[0] if parts[0] else '' # "十五" -> parts[0]为空
ones_part = parts[1] if parts[1] else ''
tens = self.chinese_numbers.get(tens_part, 1) if tens_part else 1
ones = self.chinese_numbers.get(ones_part, 0) if ones_part else 0
return tens * 10 + ones
return None
def _parse_chinese_number(self, chinese_num: str) -> Optional[int]:
"""
解析中文数字(支持"""二十""三十"""等,带缓存)
Args:
chinese_num: 中文数字字符串
Returns:
对应的阿拉伯数字,解析失败返回None
"""
return self._parse_chinese_number_cached(chinese_num)
def detect_keyword(self, text: str, words: Optional[List[str]] = None) -> Optional[str]:
"""
检测命令关键词(使用缓存的排序结果)
Args:
text: 待检测文本
words: 分词结果(如果已分词,可传入以提高性能)
Returns:
检测到的命令类型(如"takeoff""forward"等),未检测到返回None
"""
if not self.enable_keyword_detection or not text:
return None
# 如果已分词,优先使用分词结果匹配
if words:
for word in words:
if word in self.keyword_to_command:
return self.keyword_to_command[word]
# 使用缓存的排序结果(按长度降序,优先匹配长关键词)
for keyword in self.sorted_keywords:
if keyword in text:
return self.keyword_to_command[keyword]
return None
def preprocess(self, text: str) -> PreprocessedText:
"""
完整的文本预处理流程(带缓存)
Args:
text: 原始文本
Returns:
PreprocessedText对象,包含所有预处理结果
"""
return self._preprocess_cached(text)
def _preprocess_impl(self, text: str) -> PreprocessedText:
"""
完整的文本预处理流程实现(内部方法,不带缓存)
Args:
text: 原始文本
Returns:
PreprocessedText对象,包含所有预处理结果
"""
if not text:
return PreprocessedText(
cleaned_text="",
normalized_text="",
words=[],
params=ExtractedParams(),
original_text=text
)
original_text = text
# 1. 清理文本
cleaned_text = self.clean_text(text)
# 2. 纠错
corrected_text = self.correct_text(cleaned_text)
# 3. 繁简转换
normalized_text = self.traditional_to_simplified(corrected_text)
# 4. 分词
words = self.segment_text(normalized_text)
# 5. 提取数字和单位
params = self.extract_numbers(normalized_text)
# 6. 检测关键词
command_keyword = self.detect_keyword(normalized_text, words)
params.command_keyword = command_keyword
return PreprocessedText(
cleaned_text=cleaned_text,
normalized_text=normalized_text,
words=words,
params=params,
original_text=original_text
)
def preprocess_fast(self, text: str) -> Tuple[str, ExtractedParams]:
"""
快速预处理(仅返回规范化文本和参数,不进行分词,带缓存)
适用于实时场景,性能优先
Args:
text: 原始文本
Returns:
(规范化文本, 提取的参数)
"""
return self._preprocess_fast_cached(text)
def _preprocess_fast_impl(self, text: str) -> Tuple[str, ExtractedParams]:
"""
快速预处理实现(内部方法,不带缓存)
Args:
text: 原始文本
Returns:
(规范化文本, 提取的参数)
"""
if not text:
return "", ExtractedParams()
# 1. 清理
cleaned = self.clean_text(text)
# 2. 纠错
corrected = self.correct_text(cleaned)
# 3. 繁简转换
normalized = self.traditional_to_simplified(corrected)
# 4. 提取参数(不进行分词,提高性能)
params = self.extract_numbers(normalized)
# 5. 检测关键词(在完整文本中搜索)
params.command_keyword = self.detect_keyword(normalized, words=None)
return normalized, params
# 全局单例(可选,用于提高性能)
_global_preprocessor: Optional[TextPreprocessor] = None
def get_preprocessor() -> TextPreprocessor:
"""获取全局预处理器实例(单例模式)"""
global _global_preprocessor
if _global_preprocessor is None:
_global_preprocessor = TextPreprocessor()
return _global_preprocessor
if __name__ == "__main__":
from command import Command
# 测试代码
preprocessor = TextPreprocessor()
test_cases = [
"现在起飞往前飞,飞10米,速度为5米每秒",
"向前飞二十米,速度三米每秒",
"立刻降落",
"悬停五秒",
"向右飛十米", # 繁体测试
"往左飛,速度2.5米/秒,持續10秒",
]
print("=" * 60)
print("文本预处理器测试")
print("=" * 60)
for i, test_text in enumerate(test_cases, 1):
print(f"\n测试 {i}: {test_text}")
print("-" * 60)
# 完整预处理
result = preprocessor.preprocess(test_text)
print(f"原始文本: {result.original_text}")
print(f"清理后: {result.cleaned_text}")
print(f"规范化: {result.normalized_text}")
print(f"分词: {result.words}")
print(f"提取参数:")
print(f" 距离: {result.params.distance}")
print(f" 速度: {result.params.speed} 米/秒")
print(f" 时间: {result.params.duration}")
print(f" 命令关键词: {result.params.command_keyword}")
command = Command.create(result.params.command_keyword, 1, result.params.distance, result.params.speed, result.params.duration)
print(f"命令: {command.to_dict()}")
# 快速预处理
fast_text, fast_params = preprocessor.preprocess_fast(test_text)
print(f"\n快速预处理结果:")
print(f" 规范化文本: {fast_text}")
print(f" 命令关键词: {fast_params.command_keyword}")