717 lines
26 KiB
Python
717 lines
26 KiB
Python
"""
|
||
文本预处理模块 - 高性能实时语音转命令文本处理
|
||
|
||
本模块主要用于对语音识别输出的文本进行清洗、纠错、简繁转换、分词和参数提取,
|
||
便于后续命令意图分析和参数解析。
|
||
|
||
主要功能:
|
||
1. 文本清理:去除杂音、特殊字符、多余空格
|
||
2. 纠错:同音字纠正、常见错误修正
|
||
3. 简繁转换:统一文本格式(繁体转简体)
|
||
4. 分词:使用jieba分词,便于关键词匹配
|
||
5. 数字提取:提取距离(米)、速度(米/秒)、时间(秒)
|
||
6. 关键词识别:识别命令关键词(起飞、降落、前进等)
|
||
|
||
性能优化:
|
||
- LRU缓存常用处理结果(分词、中文数字解析、完整预处理)
|
||
- 预编译正则表达式
|
||
- 优化字符串操作(使用正则表达式批量替换)
|
||
- 延迟加载可选依赖
|
||
- 缓存关键词排序结果
|
||
"""
|
||
|
||
import re
|
||
from typing import Dict, Optional, List, Tuple, Set
|
||
from functools import lru_cache
|
||
from dataclasses import dataclass
|
||
from voice_drone.logging_ import get_logger
|
||
from voice_drone.core.configuration import KEYWORDS_CONFIG, SYSTEM_TEXT_PREPROCESSOR_CONFIG
|
||
import warnings
|
||
warnings.filterwarnings("ignore")
|
||
|
||
logger = get_logger("text_preprocessor")
|
||
|
||
# 延迟加载可选依赖
|
||
try:
|
||
from opencc import OpenCC
|
||
OPENCC_AVAILABLE = True
|
||
except ImportError:
|
||
OPENCC_AVAILABLE = False
|
||
logger.warning("opencc 未安装,将跳过简繁转换功能")
|
||
|
||
try:
|
||
import jieba
|
||
JIEBA_AVAILABLE = True
|
||
# 初始化jieba,加载词典
|
||
jieba.initialize()
|
||
except ImportError:
|
||
JIEBA_AVAILABLE = False
|
||
logger.warning("jieba 未安装,将跳过分词功能")
|
||
|
||
try:
|
||
from pypinyin import lazy_pinyin, Style
|
||
PYPINYIN_AVAILABLE = True
|
||
except ImportError:
|
||
PYPINYIN_AVAILABLE = False
|
||
logger.warning("pypinyin 未安装,将跳过拼音相关功能")
|
||
|
||
|
||
@dataclass
|
||
class ExtractedParams:
|
||
"""提取的参数信息"""
|
||
distance: Optional[float] = None # 距离(米)
|
||
speed: Optional[float] = None # 速度(米/秒)
|
||
duration: Optional[float] = None # 时间(秒)
|
||
command_keyword: Optional[str] = None # 识别的命令关键词
|
||
|
||
|
||
@dataclass
|
||
class PreprocessedText:
|
||
"""预处理后的文本结果"""
|
||
cleaned_text: str # 清理后的文本
|
||
normalized_text: str # 规范化后的文本(简繁转换后)
|
||
words: List[str] # 分词结果
|
||
params: ExtractedParams # 提取的参数
|
||
original_text: str # 原始文本
|
||
|
||
|
||
class TextPreprocessor:
|
||
"""
|
||
高性能文本预处理器
|
||
|
||
针对实时语音转命令场景优化,支持:
|
||
- 文本清理和规范化
|
||
- 同音字纠错
|
||
- 简繁转换
|
||
- 分词
|
||
- 数字和单位提取
|
||
- 命令关键词识别
|
||
"""
|
||
|
||
def __init__(self,
|
||
enable_traditional_to_simplified: Optional[bool] = None,
|
||
enable_segmentation: Optional[bool] = None,
|
||
enable_correction: Optional[bool] = None,
|
||
enable_number_extraction: Optional[bool] = None,
|
||
enable_keyword_detection: Optional[bool] = None,
|
||
lru_cache_size: Optional[int] = None):
|
||
"""
|
||
初始化文本预处理器
|
||
|
||
Args:
|
||
enable_traditional_to_simplified: 是否启用繁简转换(None时从配置读取)
|
||
enable_segmentation: 是否启用分词(None时从配置读取)
|
||
enable_correction: 是否启用纠错(None时从配置读取)
|
||
enable_number_extraction: 是否启用数字提取(None时从配置读取)
|
||
enable_keyword_detection: 是否启用关键词检测(None时从配置读取)
|
||
lru_cache_size: LRU缓存大小(None时从配置读取)
|
||
"""
|
||
# 从配置读取参数(如果未提供)
|
||
config = SYSTEM_TEXT_PREPROCESSOR_CONFIG or {}
|
||
|
||
self.enable_traditional_to_simplified = (
|
||
enable_traditional_to_simplified
|
||
if enable_traditional_to_simplified is not None
|
||
else config.get("enable_traditional_to_simplified", True)
|
||
) and OPENCC_AVAILABLE
|
||
|
||
self.enable_segmentation = (
|
||
enable_segmentation
|
||
if enable_segmentation is not None
|
||
else config.get("enable_segmentation", True)
|
||
) and JIEBA_AVAILABLE
|
||
|
||
self.enable_correction = (
|
||
enable_correction
|
||
if enable_correction is not None
|
||
else config.get("enable_correction", True)
|
||
)
|
||
|
||
self.enable_number_extraction = (
|
||
enable_number_extraction
|
||
if enable_number_extraction is not None
|
||
else config.get("enable_number_extraction", True)
|
||
)
|
||
|
||
self.enable_keyword_detection = (
|
||
enable_keyword_detection
|
||
if enable_keyword_detection is not None
|
||
else config.get("enable_keyword_detection", True)
|
||
)
|
||
|
||
cache_size = (
|
||
lru_cache_size
|
||
if lru_cache_size is not None
|
||
else config.get("lru_cache_size", 512)
|
||
)
|
||
|
||
# 初始化OpenCC(如果可用)
|
||
if self.enable_traditional_to_simplified:
|
||
self.opencc = OpenCC('t2s') # 繁体转简体
|
||
else:
|
||
self.opencc = None
|
||
|
||
# 加载关键词映射(命令关键词 -> 命令类型)
|
||
self._load_keyword_mapping()
|
||
|
||
# 预编译正则表达式(性能优化)
|
||
self._compile_regex_patterns()
|
||
|
||
# 加载纠错字典
|
||
self._load_correction_dict()
|
||
|
||
# 设置LRU缓存大小
|
||
self._cache_size = cache_size
|
||
|
||
# 创建缓存装饰器(用于分词、中文数字解析、完整预处理)
|
||
self._segment_text_cached = lru_cache(maxsize=cache_size)(self._segment_text_impl)
|
||
self._parse_chinese_number_cached = lru_cache(maxsize=128)(self._parse_chinese_number_impl)
|
||
self._preprocess_cached = lru_cache(maxsize=cache_size)(self._preprocess_impl)
|
||
self._preprocess_fast_cached = lru_cache(maxsize=cache_size)(self._preprocess_fast_impl)
|
||
|
||
logger.info(f"文本预处理器初始化完成")
|
||
logger.info(f" 繁简转换: {'启用' if self.enable_traditional_to_simplified else '禁用'}")
|
||
logger.info(f" 分词: {'启用' if self.enable_segmentation else '禁用'}")
|
||
logger.info(f" 纠错: {'启用' if self.enable_correction else '禁用'}")
|
||
logger.info(f" 数字提取: {'启用' if self.enable_number_extraction else '禁用'}")
|
||
logger.info(f" 关键词检测: {'启用' if self.enable_keyword_detection else '禁用'}")
|
||
|
||
def _load_keyword_mapping(self):
|
||
"""加载关键词映射表(命令关键词 -> 命令类型)"""
|
||
self.keyword_to_command: Dict[str, str] = {}
|
||
|
||
if KEYWORDS_CONFIG:
|
||
for command_type, keywords in KEYWORDS_CONFIG.items():
|
||
if isinstance(keywords, list):
|
||
for keyword in keywords:
|
||
self.keyword_to_command[keyword] = command_type
|
||
elif isinstance(keywords, str):
|
||
self.keyword_to_command[keywords] = command_type
|
||
|
||
# 预计算排序结果(按长度降序,优先匹配长关键词)
|
||
self.sorted_keywords = sorted(
|
||
self.keyword_to_command.keys(),
|
||
key=len,
|
||
reverse=True
|
||
)
|
||
|
||
logger.debug(f"加载了 {len(self.keyword_to_command)} 个关键词映射")
|
||
|
||
def _compile_regex_patterns(self):
|
||
"""预编译正则表达式(性能优化)"""
|
||
# 清理文本:去除特殊字符、多余空格
|
||
self.pattern_clean_special = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9\s米每秒秒分小时\.]')
|
||
self.pattern_clean_spaces = re.compile(r'\s+')
|
||
|
||
# 数字提取模式
|
||
# 距离:数字 + (米|m|M|公尺)(排除速度单位)
|
||
self.pattern_distance = re.compile(
|
||
r'(\d+\.?\d*)\s*(?:米|m|M|公尺|meter|meters)(?!\s*[/每]?\s*秒)',
|
||
re.IGNORECASE
|
||
)
|
||
|
||
# 速度:数字 + (米每秒|m/s|米/秒|米秒|mps|MPS)(优先匹配)
|
||
# 支持"三米每秒"、"5米/秒"、"2.5米每秒"等格式
|
||
self.pattern_speed = re.compile(
|
||
r'(?:速度\s*[::]?\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:米\s*[/每]?\s*秒|m\s*/\s*s|mps|MPS)',
|
||
re.IGNORECASE
|
||
)
|
||
|
||
# 时间:数字 + (秒|s|S|分钟|分|min|小时|时|h|H)
|
||
# 支持"持续10秒"、"5分钟"等格式
|
||
self.pattern_duration = re.compile(
|
||
r'(?:持续\s*|持续\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:秒|s|S|分钟|分|min|小时|时|h|H)',
|
||
re.IGNORECASE
|
||
)
|
||
|
||
# 中文数字映射(用于识别"十米"、"五秒"等)
|
||
self.chinese_numbers = {
|
||
'零': 0, '一': 1, '二': 2, '三': 3, '四': 4, '五': 5,
|
||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10,
|
||
'壹': 1, '贰': 2, '叁': 3, '肆': 4, '伍': 5,
|
||
'陆': 6, '柒': 7, '捌': 8, '玖': 9, '拾': 10,
|
||
'百': 100, '千': 1000, '万': 10000
|
||
}
|
||
|
||
# 中文数字模式(如"十米"、"五秒"、"二十米"、"三米每秒")
|
||
# 支持"二十"、"三十"等复合数字
|
||
self.pattern_chinese_number = re.compile(
|
||
r'([零一二三四五六七八九十壹贰叁肆伍陆柒捌玖拾百千万]+)\s*(?:米|秒|分|小时|米\s*[/每]?\s*秒)'
|
||
)
|
||
|
||
def _load_correction_dict(self):
|
||
"""加载纠错字典(同音字、常见错误)并编译正则表达式"""
|
||
# 无人机控制相关的常见同音字/错误字映射(只保留实际需要纠错的)
|
||
correction_pairs = [
|
||
# 动作相关(同音字纠错)
|
||
('起非', '起飞'),
|
||
('降洛', '降落'),
|
||
('悬廷', '悬停'),
|
||
('停只', '停止'),
|
||
]
|
||
|
||
# 构建正则表达式模式(按长度降序,优先匹配长模式)
|
||
if correction_pairs:
|
||
# 按长度降序排序,优先匹配长模式
|
||
sorted_pairs = sorted(correction_pairs, key=lambda x: len(x[0]), reverse=True)
|
||
|
||
# 构建替换映射字典(用于快速查找)
|
||
self.correction_replacements = {wrong: correct for wrong, correct in sorted_pairs}
|
||
|
||
# 编译单一正则表达式
|
||
patterns = [re.escape(wrong) for wrong, _ in sorted_pairs]
|
||
self.correction_pattern = re.compile('|'.join(patterns))
|
||
else:
|
||
self.correction_pattern = None
|
||
self.correction_replacements = {}
|
||
|
||
logger.debug(f"加载了 {len(correction_pairs)} 个纠错规则")
|
||
|
||
def clean_text(self, text: str) -> str:
|
||
"""
|
||
清理文本:去除特殊字符、多余空格
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
清理后的文本
|
||
"""
|
||
if not text:
|
||
return ""
|
||
|
||
# 去除特殊字符(保留中文、英文、数字、空格、常用标点)
|
||
text = self.pattern_clean_special.sub('', text)
|
||
|
||
# 统一空格(多个空格合并为一个)
|
||
text = self.pattern_clean_spaces.sub(' ', text)
|
||
|
||
# 去除首尾空格
|
||
text = text.strip()
|
||
|
||
return text
|
||
|
||
def correct_text(self, text: str) -> str:
|
||
"""
|
||
纠错:同音字、常见错误修正(使用正则表达式优化)
|
||
|
||
Args:
|
||
text: 待纠错文本
|
||
|
||
Returns:
|
||
纠错后的文本
|
||
"""
|
||
if not self.enable_correction or not text or not self.correction_pattern:
|
||
return text
|
||
|
||
# 使用正则表达式一次性替换所有模式(性能优化)
|
||
def replacer(match):
|
||
matched = match.group(0)
|
||
# 直接从字典中查找替换(O(1)查找)
|
||
return self.correction_replacements.get(matched, matched)
|
||
|
||
return self.correction_pattern.sub(replacer, text)
|
||
|
||
def traditional_to_simplified(self, text: str) -> str:
|
||
"""
|
||
繁体转简体
|
||
|
||
Args:
|
||
text: 待转换文本
|
||
|
||
Returns:
|
||
转换后的文本
|
||
"""
|
||
if not self.enable_traditional_to_simplified or not self.opencc or not text:
|
||
return text
|
||
|
||
try:
|
||
return self.opencc.convert(text)
|
||
except Exception as e:
|
||
logger.warning(f"繁简转换失败: {e}")
|
||
return text
|
||
|
||
def _segment_text_impl(self, text: str) -> List[str]:
|
||
"""
|
||
分词实现(内部方法,不带缓存)
|
||
|
||
Args:
|
||
text: 待分词文本
|
||
|
||
Returns:
|
||
分词结果列表
|
||
"""
|
||
if not self.enable_segmentation or not text:
|
||
return [text] if text else []
|
||
|
||
try:
|
||
words = list(jieba.cut(text, cut_all=False))
|
||
# 过滤空字符串
|
||
words = [w.strip() for w in words if w.strip()]
|
||
return words
|
||
except Exception as e:
|
||
logger.warning(f"分词失败: {e}")
|
||
return [text] if text else []
|
||
|
||
def segment_text(self, text: str) -> List[str]:
|
||
"""
|
||
分词(带缓存)
|
||
|
||
Args:
|
||
text: 待分词文本
|
||
|
||
Returns:
|
||
分词结果列表
|
||
"""
|
||
return self._segment_text_cached(text)
|
||
|
||
def extract_numbers(self, text: str) -> ExtractedParams:
|
||
"""
|
||
提取数字和单位(距离、速度、时间)
|
||
|
||
Args:
|
||
text: 待提取文本
|
||
|
||
Returns:
|
||
ExtractedParams对象,包含提取的参数
|
||
"""
|
||
params = ExtractedParams()
|
||
|
||
if not self.enable_number_extraction or not text:
|
||
return params
|
||
|
||
# 优先提取速度(避免被误识别为距离)
|
||
speed_match = self.pattern_speed.search(text)
|
||
if speed_match:
|
||
try:
|
||
speed_str = speed_match.group(1)
|
||
# 尝试解析中文数字
|
||
if speed_str.isdigit() or '.' in speed_str:
|
||
params.speed = float(speed_str)
|
||
else:
|
||
# 中文数字(使用缓存)
|
||
chinese_speed = self._parse_chinese_number(speed_str)
|
||
if chinese_speed is not None:
|
||
params.speed = float(chinese_speed)
|
||
except (ValueError, AttributeError):
|
||
pass
|
||
|
||
# 提取距离(米,排除速度单位)
|
||
distance_match = self.pattern_distance.search(text)
|
||
if distance_match:
|
||
try:
|
||
params.distance = float(distance_match.group(1))
|
||
except ValueError:
|
||
pass
|
||
|
||
# 提取时间(秒)
|
||
duration_matches = self.pattern_duration.finditer(text) # 查找所有匹配
|
||
for duration_match in duration_matches:
|
||
try:
|
||
duration_str = duration_match.group(1)
|
||
duration_unit = duration_match.group(2).lower() if len(duration_match.groups()) > 1 else '秒'
|
||
|
||
# 解析数字(支持中文数字)
|
||
if duration_str.isdigit() or '.' in duration_str:
|
||
duration_value = float(duration_str)
|
||
else:
|
||
# 中文数字
|
||
chinese_duration = self._parse_chinese_number(duration_str)
|
||
if chinese_duration is None:
|
||
continue
|
||
duration_value = float(chinese_duration)
|
||
|
||
# 转换为秒
|
||
if '分' in duration_unit or 'min' in duration_unit:
|
||
params.duration = duration_value * 60
|
||
break # 取第一个匹配
|
||
elif '小时' in duration_unit or 'h' in duration_unit:
|
||
params.duration = duration_value * 3600
|
||
break
|
||
else: # 秒
|
||
params.duration = duration_value
|
||
break
|
||
except (ValueError, IndexError, AttributeError):
|
||
continue
|
||
|
||
# 尝试提取中文数字(如"十米"、"五秒"、"二十米"、"三米每秒")
|
||
chinese_matches = self.pattern_chinese_number.finditer(text)
|
||
for chinese_match in chinese_matches:
|
||
try:
|
||
chinese_num_str = chinese_match.group(1)
|
||
full_match = chinese_match.group(0)
|
||
|
||
# 解析中文数字(使用缓存)
|
||
num_value = self._parse_chinese_number(chinese_num_str)
|
||
|
||
if num_value is not None:
|
||
# 判断单位类型
|
||
if '米每秒' in full_match or '米/秒' in full_match or '米每' in full_match:
|
||
# 速度单位
|
||
if params.speed is None:
|
||
params.speed = float(num_value)
|
||
elif '米' in full_match and '秒' not in full_match:
|
||
# 距离单位(不包含"秒")
|
||
if params.distance is None:
|
||
params.distance = float(num_value)
|
||
elif '秒' in full_match and '米' not in full_match:
|
||
# 时间单位(不包含"米")
|
||
if params.duration is None:
|
||
params.duration = float(num_value)
|
||
elif '分' in full_match and '米' not in full_match:
|
||
# 时间单位(分钟)
|
||
if params.duration is None:
|
||
params.duration = float(num_value) * 60
|
||
except (ValueError, IndexError, AttributeError):
|
||
continue
|
||
|
||
return params
|
||
|
||
def _parse_chinese_number_impl(self, chinese_num: str) -> Optional[int]:
|
||
"""
|
||
解析中文数字实现(内部方法,不带缓存)
|
||
|
||
Args:
|
||
chinese_num: 中文数字字符串
|
||
|
||
Returns:
|
||
对应的阿拉伯数字,解析失败返回None
|
||
"""
|
||
if not chinese_num:
|
||
return None
|
||
|
||
# 单个数字
|
||
if chinese_num in self.chinese_numbers:
|
||
return self.chinese_numbers[chinese_num]
|
||
|
||
# "十" -> 10
|
||
if chinese_num == '十' or chinese_num == '拾':
|
||
return 10
|
||
|
||
# "十一" -> 11, "十二" -> 12, ...
|
||
if chinese_num.startswith('十') or chinese_num.startswith('拾'):
|
||
rest = chinese_num[1:]
|
||
if rest in self.chinese_numbers:
|
||
return 10 + self.chinese_numbers[rest]
|
||
|
||
# "二十" -> 20, "三十" -> 30, ...
|
||
if chinese_num.endswith('十') or chinese_num.endswith('拾'):
|
||
prefix = chinese_num[:-1]
|
||
if prefix in self.chinese_numbers:
|
||
return self.chinese_numbers[prefix] * 10
|
||
|
||
# "二十五" -> 25, "三十五" -> 35, ...
|
||
if '十' in chinese_num or '拾' in chinese_num:
|
||
parts = chinese_num.replace('拾', '十').split('十')
|
||
if len(parts) == 2:
|
||
tens_part = parts[0] if parts[0] else '一' # "十五" -> parts[0]为空
|
||
ones_part = parts[1] if parts[1] else ''
|
||
|
||
tens = self.chinese_numbers.get(tens_part, 1) if tens_part else 1
|
||
ones = self.chinese_numbers.get(ones_part, 0) if ones_part else 0
|
||
|
||
return tens * 10 + ones
|
||
|
||
return None
|
||
|
||
def _parse_chinese_number(self, chinese_num: str) -> Optional[int]:
|
||
"""
|
||
解析中文数字(支持"十"、"二十"、"三十"、"五"等,带缓存)
|
||
|
||
Args:
|
||
chinese_num: 中文数字字符串
|
||
|
||
Returns:
|
||
对应的阿拉伯数字,解析失败返回None
|
||
"""
|
||
return self._parse_chinese_number_cached(chinese_num)
|
||
|
||
def detect_keyword(self, text: str, words: Optional[List[str]] = None) -> Optional[str]:
|
||
"""
|
||
检测命令关键词(使用缓存的排序结果)
|
||
|
||
Args:
|
||
text: 待检测文本
|
||
words: 分词结果(如果已分词,可传入以提高性能)
|
||
|
||
Returns:
|
||
检测到的命令类型(如"takeoff"、"forward"等),未检测到返回None
|
||
"""
|
||
if not self.enable_keyword_detection or not text:
|
||
return None
|
||
|
||
# 如果已分词,优先使用分词结果匹配
|
||
if words:
|
||
for word in words:
|
||
if word in self.keyword_to_command:
|
||
return self.keyword_to_command[word]
|
||
|
||
# 使用缓存的排序结果(按长度降序,优先匹配长关键词)
|
||
for keyword in self.sorted_keywords:
|
||
if keyword in text:
|
||
return self.keyword_to_command[keyword]
|
||
|
||
return None
|
||
|
||
def preprocess(self, text: str) -> PreprocessedText:
|
||
"""
|
||
完整的文本预处理流程(带缓存)
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
PreprocessedText对象,包含所有预处理结果
|
||
"""
|
||
return self._preprocess_cached(text)
|
||
|
||
def _preprocess_impl(self, text: str) -> PreprocessedText:
|
||
"""
|
||
完整的文本预处理流程实现(内部方法,不带缓存)
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
PreprocessedText对象,包含所有预处理结果
|
||
"""
|
||
if not text:
|
||
return PreprocessedText(
|
||
cleaned_text="",
|
||
normalized_text="",
|
||
words=[],
|
||
params=ExtractedParams(),
|
||
original_text=text
|
||
)
|
||
|
||
original_text = text
|
||
|
||
# 1. 清理文本
|
||
cleaned_text = self.clean_text(text)
|
||
|
||
# 2. 纠错
|
||
corrected_text = self.correct_text(cleaned_text)
|
||
|
||
# 3. 繁简转换
|
||
normalized_text = self.traditional_to_simplified(corrected_text)
|
||
|
||
# 4. 分词
|
||
words = self.segment_text(normalized_text)
|
||
|
||
# 5. 提取数字和单位
|
||
params = self.extract_numbers(normalized_text)
|
||
|
||
# 6. 检测关键词
|
||
command_keyword = self.detect_keyword(normalized_text, words)
|
||
params.command_keyword = command_keyword
|
||
|
||
return PreprocessedText(
|
||
cleaned_text=cleaned_text,
|
||
normalized_text=normalized_text,
|
||
words=words,
|
||
params=params,
|
||
original_text=original_text
|
||
)
|
||
|
||
def preprocess_fast(self, text: str) -> Tuple[str, ExtractedParams]:
|
||
"""
|
||
快速预处理(仅返回规范化文本和参数,不进行分词,带缓存)
|
||
|
||
适用于实时场景,性能优先
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
(规范化文本, 提取的参数)
|
||
"""
|
||
return self._preprocess_fast_cached(text)
|
||
|
||
def _preprocess_fast_impl(self, text: str) -> Tuple[str, ExtractedParams]:
|
||
"""
|
||
快速预处理实现(内部方法,不带缓存)
|
||
|
||
Args:
|
||
text: 原始文本
|
||
|
||
Returns:
|
||
(规范化文本, 提取的参数)
|
||
"""
|
||
if not text:
|
||
return "", ExtractedParams()
|
||
|
||
# 1. 清理
|
||
cleaned = self.clean_text(text)
|
||
|
||
# 2. 纠错
|
||
corrected = self.correct_text(cleaned)
|
||
|
||
# 3. 繁简转换
|
||
normalized = self.traditional_to_simplified(corrected)
|
||
|
||
# 4. 提取参数(不进行分词,提高性能)
|
||
params = self.extract_numbers(normalized)
|
||
|
||
# 5. 检测关键词(在完整文本中搜索)
|
||
params.command_keyword = self.detect_keyword(normalized, words=None)
|
||
|
||
return normalized, params
|
||
|
||
|
||
# 全局单例(可选,用于提高性能)
|
||
_global_preprocessor: Optional[TextPreprocessor] = None
|
||
|
||
|
||
def get_preprocessor() -> TextPreprocessor:
|
||
"""获取全局预处理器实例(单例模式)"""
|
||
global _global_preprocessor
|
||
if _global_preprocessor is None:
|
||
_global_preprocessor = TextPreprocessor()
|
||
return _global_preprocessor
|
||
|
||
|
||
if __name__ == "__main__":
|
||
from command import Command
|
||
|
||
# 测试代码
|
||
preprocessor = TextPreprocessor()
|
||
|
||
test_cases = [
|
||
"现在起飞往前飞,飞10米,速度为5米每秒",
|
||
"向前飞二十米,速度三米每秒",
|
||
"立刻降落",
|
||
"悬停五秒",
|
||
"向右飛十米", # 繁体测试
|
||
"往左飛,速度2.5米/秒,持續10秒",
|
||
]
|
||
|
||
print("=" * 60)
|
||
print("文本预处理器测试")
|
||
print("=" * 60)
|
||
|
||
for i, test_text in enumerate(test_cases, 1):
|
||
print(f"\n测试 {i}: {test_text}")
|
||
print("-" * 60)
|
||
|
||
# 完整预处理
|
||
result = preprocessor.preprocess(test_text)
|
||
print(f"原始文本: {result.original_text}")
|
||
print(f"清理后: {result.cleaned_text}")
|
||
print(f"规范化: {result.normalized_text}")
|
||
print(f"分词: {result.words}")
|
||
print(f"提取参数:")
|
||
print(f" 距离: {result.params.distance} 米")
|
||
print(f" 速度: {result.params.speed} 米/秒")
|
||
print(f" 时间: {result.params.duration} 秒")
|
||
print(f" 命令关键词: {result.params.command_keyword}")
|
||
|
||
|
||
command = Command.create(result.params.command_keyword, 1, result.params.distance, result.params.speed, result.params.duration)
|
||
print(f"命令: {command.to_dict()}")
|
||
|
||
# 快速预处理
|
||
fast_text, fast_params = preprocessor.preprocess_fast(test_text)
|
||
print(f"\n快速预处理结果:")
|
||
print(f" 规范化文本: {fast_text}")
|
||
print(f" 命令关键词: {fast_params.command_keyword}")
|