494 lines
18 KiB
Python
494 lines
18 KiB
Python
"""
|
||
语音识别(Speech-to-Text)类 - 纯 ONNX Runtime 极致性能推理
|
||
针对 RK3588 等 ARM 设备进行了深度优化,完全移除 FunASR 依赖。
|
||
前处理(fbank + CMVN + LFR)与解码均手写实现。
|
||
"""
|
||
import platform
|
||
import os
|
||
import multiprocessing
|
||
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
import onnx
|
||
import onnxruntime as ort
|
||
from voice_drone.logging_ import get_logger
|
||
from voice_drone.tools.wrapper import time_cost
|
||
import scipy.special
|
||
from voice_drone.core.configuration import SYSTEM_STT_CONFIG, SYSTEM_AUDIO_CONFIG
|
||
|
||
# voice_drone/core/stt.py -> 工程根(含 voice_drone_assistant 与本仓库根两种布局)
|
||
_STT_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||
|
||
|
||
def _stt_path_candidates(path: Path) -> List[Path]:
|
||
"""相对配置路径的候选绝对路径:优先工程目录,其次嵌套在上一级仓库时的 src/models/。"""
|
||
if path.is_absolute():
|
||
return [path]
|
||
out: List[Path] = [_STT_PROJECT_ROOT / path]
|
||
if path.parts and path.parts[0] == "models":
|
||
out.append(_STT_PROJECT_ROOT.parent / "src" / path)
|
||
return out
|
||
|
||
|
||
class STT:
|
||
"""
|
||
语音识别(Speech-to-Text)类
|
||
使用 ONNX Runtime 进行最优性能推理
|
||
针对 RK3588 等 ARM 设备进行了深度优化
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化 STT 模型
|
||
"""
|
||
stt_conf = SYSTEM_STT_CONFIG
|
||
self.logger = get_logger("stt.onnx")
|
||
|
||
# 从配置读取参数
|
||
self.model_dir = stt_conf.get("model_dir")
|
||
self.model_path = stt_conf.get("model_path")
|
||
self.prefer_int8 = stt_conf.get("prefer_int8", True)
|
||
_wf = stt_conf.get("warmup_file")
|
||
self.warmup_file: Optional[str] = None
|
||
if _wf:
|
||
wf_path = Path(_wf)
|
||
if wf_path.is_absolute() and wf_path.is_file():
|
||
self.warmup_file = str(wf_path)
|
||
else:
|
||
for c in _stt_path_candidates(wf_path):
|
||
if c.is_file():
|
||
self.warmup_file = str(c)
|
||
break
|
||
|
||
# 音频预处理参数(确保数值类型正确)
|
||
self.sample_rate = int(stt_conf.get("sample_rate", SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000)))
|
||
self.n_mels = int(stt_conf.get("n_mels", 80))
|
||
self.frame_length_ms = float(stt_conf.get("frame_length_ms", 25))
|
||
self.frame_shift_ms = float(stt_conf.get("frame_shift_ms", 10))
|
||
self.log_eps = float(stt_conf.get("log_eps", 1e-10))
|
||
|
||
# ARM 优化配置
|
||
arm_conf = stt_conf.get("arm_optimization", {})
|
||
self.arm_enabled = arm_conf.get("enabled", True)
|
||
self.arm_max_threads = arm_conf.get("max_threads", 4)
|
||
|
||
# CTC 解码配置
|
||
ctc_conf = stt_conf.get("ctc_decode", {})
|
||
self.blank_id = ctc_conf.get("blank_id", 0)
|
||
|
||
# 语言和文本规范化配置(默认值)
|
||
lang_conf = stt_conf.get("language", {})
|
||
text_norm_conf = stt_conf.get("text_norm", {})
|
||
self.lang_zh_default = lang_conf.get("zh_id", 3)
|
||
self.with_itn_default = text_norm_conf.get("with_itn_id", 14)
|
||
self.without_itn_default = text_norm_conf.get("without_itn_id", 15)
|
||
|
||
# 后处理配置
|
||
postprocess_conf = stt_conf.get("postprocess", {})
|
||
self.special_tokens = postprocess_conf.get("special_tokens", [
|
||
"<|zh|>", "<|NEUTRAL|>", "<|Speech|>", "<|woitn|>", "<|withitn|>"
|
||
])
|
||
|
||
# 检测是否为 RK3588 或 ARM 设备
|
||
ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64')
|
||
RK3588 = 'rk3588' in platform.platform().lower() or os.path.exists('/proc/device-tree/compatible')
|
||
|
||
# ARM 设备性能优化配置
|
||
if self.arm_enabled and (ARM or RK3588):
|
||
cpu_count = multiprocessing.cpu_count()
|
||
optimal_threads = min(self.arm_max_threads, cpu_count)
|
||
|
||
# 设置 OpenMP 线程数
|
||
os.environ['OMP_NUM_THREADS'] = str(optimal_threads)
|
||
os.environ['MKL_NUM_THREADS'] = str(optimal_threads)
|
||
os.environ['KMP_AFFINITY'] = 'granularity=fine,compact,1,0'
|
||
os.environ['OMP_DYNAMIC'] = 'FALSE'
|
||
os.environ['MKL_DYNAMIC'] = 'FALSE'
|
||
|
||
self.logger.info("ARM/RK3588 优化已启用")
|
||
self.logger.info(f" CPU 核心数: {cpu_count}")
|
||
self.logger.info(f" 优化线程数: {optimal_threads}")
|
||
|
||
# 确定模型路径
|
||
onnx_model_path = self._resolve_model_path()
|
||
|
||
# 保存模型目录路径(用于加载 tokens.txt)
|
||
self.onnx_model_dir = onnx_model_path.parent
|
||
|
||
self.logger.info(f"加载 ONNX 模型: {onnx_model_path}")
|
||
self._load_onnx_model(str(onnx_model_path))
|
||
|
||
# 模型预热
|
||
if self.warmup_file and os.path.exists(self.warmup_file):
|
||
try:
|
||
self.logger.info(f"正在预热模型(使用: {self.warmup_file})...")
|
||
_ = self.invoke(self.warmup_file)
|
||
self.logger.info("模型预热完成")
|
||
except Exception as e:
|
||
self.logger.warning(f"预热失败(可忽略): {e}")
|
||
elif self.warmup_file:
|
||
self.logger.warning(f"预热文件不存在: {self.warmup_file},跳过预热步骤")
|
||
|
||
def _resolve_existing_model_file(self, raw: Optional[str]) -> Optional[Path]:
|
||
if not raw:
|
||
return None
|
||
p = Path(raw)
|
||
for c in _stt_path_candidates(p):
|
||
if c.is_file():
|
||
return c
|
||
return None
|
||
|
||
def _resolve_existing_model_dir(self, raw: Optional[str]) -> Optional[Path]:
|
||
if not raw:
|
||
return None
|
||
p = Path(raw)
|
||
for c in _stt_path_candidates(p):
|
||
if c.is_dir():
|
||
return c
|
||
return None
|
||
|
||
def _resolve_model_path(self) -> Path:
|
||
"""
|
||
解析模型路径
|
||
|
||
Returns:
|
||
模型文件路径
|
||
"""
|
||
if self.model_path:
|
||
hit = self._resolve_existing_model_file(self.model_path)
|
||
if hit is not None:
|
||
return hit
|
||
|
||
if not self.model_dir:
|
||
raise ValueError("配置中必须指定 model_path 或 model_dir")
|
||
|
||
model_dir = self._resolve_existing_model_dir(self.model_dir)
|
||
if model_dir is None:
|
||
tried = ", ".join(str(x) for x in _stt_path_candidates(Path(self.model_dir)))
|
||
raise FileNotFoundError(
|
||
f"ONNX 模型目录不存在。config model_dir={self.model_dir!r},已尝试: {tried}。"
|
||
f"请将 SenseVoice 放入 {_STT_PROJECT_ROOT / 'models'},或 ln -s ../src/models "
|
||
f"{_STT_PROJECT_ROOT / 'models'}(见 models/README.txt)。"
|
||
)
|
||
|
||
# 优先使用 INT8 量化模型(如果启用)
|
||
if self.prefer_int8:
|
||
int8_path = model_dir / "model.int8.onnx"
|
||
if int8_path.exists():
|
||
return int8_path
|
||
|
||
# 回退到普通模型
|
||
onnx_path = model_dir / "model.onnx"
|
||
if onnx_path.exists():
|
||
return onnx_path
|
||
|
||
raise FileNotFoundError(f"ONNX 模型文件不存在: 在 {model_dir} 中未找到 model.int8.onnx 或 model.onnx")
|
||
|
||
def _load_onnx_model(self, onnx_model_path: str):
|
||
"""加载 ONNX 模型"""
|
||
# 创建 ONNX Runtime 会话选项
|
||
sess_options = ort.SessionOptions()
|
||
|
||
# ARM 设备优化
|
||
ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64')
|
||
if self.arm_enabled and ARM:
|
||
cpu_count = multiprocessing.cpu_count()
|
||
optimal_threads = min(self.arm_max_threads, cpu_count)
|
||
sess_options.intra_op_num_threads = optimal_threads
|
||
sess_options.inter_op_num_threads = optimal_threads
|
||
|
||
# 启用所有图优化(最优性能)
|
||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
|
||
# 创建推理会话
|
||
self.onnx_session = ort.InferenceSession(
|
||
onnx_model_path,
|
||
sess_options=sess_options,
|
||
providers=['CPUExecutionProvider']
|
||
)
|
||
|
||
# 获取模型元数据
|
||
onnx_model = onnx.load(onnx_model_path)
|
||
self.model_metadata = {prop.key: prop.value for prop in onnx_model.metadata_props}
|
||
|
||
# 解析元数据
|
||
self.lfr_window_size = int(self.model_metadata.get('lfr_window_size', 7))
|
||
self.lfr_window_shift = int(self.model_metadata.get('lfr_window_shift', 6))
|
||
self.vocab_size = int(self.model_metadata.get('vocab_size', 25055))
|
||
|
||
# 解析 CMVN 参数
|
||
neg_mean_str = self.model_metadata.get('neg_mean', '')
|
||
inv_stddev_str = self.model_metadata.get('inv_stddev', '')
|
||
self.neg_mean = np.array([float(x) for x in neg_mean_str.split(',')]) if neg_mean_str else None
|
||
self.inv_stddev = np.array([float(x) for x in inv_stddev_str.split(',')]) if inv_stddev_str else None
|
||
|
||
# 语言和文本规范化 ID(从元数据获取,如果没有则使用配置默认值)
|
||
self.lang_zh = int(self.model_metadata.get('lang_zh', self.lang_zh_default))
|
||
self.with_itn = int(self.model_metadata.get('with_itn', self.with_itn_default))
|
||
self.without_itn = int(self.model_metadata.get('without_itn', self.without_itn_default))
|
||
|
||
self.logger.info("ONNX 模型加载完成")
|
||
self.logger.info(f" LFR窗口大小: {self.lfr_window_size}")
|
||
self.logger.info(f" LFR窗口偏移: {self.lfr_window_shift}")
|
||
self.logger.info(f" 词汇表大小: {self.vocab_size}")
|
||
|
||
def _load_tokens(self):
|
||
"""加载 tokens 映射"""
|
||
tokens_file = self.onnx_model_dir / "tokens.txt"
|
||
if tokens_file.exists():
|
||
self.tokens = {}
|
||
with open(tokens_file, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
parts = line.strip().split()
|
||
if len(parts) >= 2:
|
||
token = parts[0]
|
||
token_id = int(parts[-1])
|
||
self.tokens[token_id] = token
|
||
return True
|
||
return False
|
||
|
||
def _preprocess_audio_array(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> tuple:
|
||
"""
|
||
预处理音频数组:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现)
|
||
|
||
支持实时音频流处理(numpy数组输入)
|
||
|
||
流程:
|
||
1. 输入 16k 单声道 numpy 数组(int16 或 float32)
|
||
2. 计算 80 维 log-mel fbank
|
||
3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev)
|
||
4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征
|
||
|
||
Args:
|
||
audio_array: 音频数据(numpy array,int16 或 float32)
|
||
sample_rate: 采样率,None时使用配置值
|
||
|
||
Returns:
|
||
(features, lengths): 特征和长度
|
||
"""
|
||
import librosa
|
||
|
||
sr = sample_rate if sample_rate is not None else self.sample_rate
|
||
|
||
# 1. 转换为float32格式(如果输入是int16)
|
||
if audio_array.dtype == np.int16:
|
||
audio = audio_array.astype(np.float32) / 32768.0
|
||
else:
|
||
audio = audio_array.astype(np.float32)
|
||
|
||
# 确保是单声道
|
||
if len(audio.shape) > 1:
|
||
audio = np.mean(audio, axis=1)
|
||
|
||
if audio.size == 0:
|
||
raise ValueError("音频数组为空")
|
||
|
||
# 2. 计算 fbank 特征
|
||
n_fft = int(self.frame_length_ms / 1000.0 * sr)
|
||
hop_length = int(self.frame_shift_ms / 1000.0 * sr)
|
||
|
||
mel_spec = librosa.feature.melspectrogram(
|
||
y=audio,
|
||
sr=sr,
|
||
n_fft=n_fft,
|
||
hop_length=hop_length,
|
||
n_mels=self.n_mels,
|
||
window="hann",
|
||
center=True,
|
||
power=1.0, # 线性能量
|
||
)
|
||
|
||
# log-mel
|
||
log_mel = np.log(np.maximum(mel_spec, self.log_eps)).T # (T, n_mels)
|
||
|
||
# 3. CMVN:使用 ONNX 元数据中的 neg_mean / inv_stddev
|
||
if self.neg_mean is not None and self.inv_stddev is not None:
|
||
if self.neg_mean.shape[0] == log_mel.shape[1]:
|
||
log_mel = (log_mel + self.neg_mean) * self.inv_stddev
|
||
|
||
# 4. LFR:按窗口 lfr_window_size 堆叠,步长 lfr_window_shift
|
||
T, D = log_mel.shape
|
||
m = self.lfr_window_size
|
||
n = self.lfr_window_shift
|
||
|
||
if T < m:
|
||
# 帧数不够,补到 m 帧
|
||
pad = np.tile(log_mel[-1], (m - T, 1))
|
||
log_mel = np.vstack([log_mel, pad])
|
||
T = m
|
||
|
||
# 计算 LFR 后的帧数
|
||
T_lfr = 1 + (T - m) // n
|
||
lfr_feats = []
|
||
for i in range(T_lfr):
|
||
start = i * n
|
||
end = start + m
|
||
chunk = log_mel[start:end, :] # (m, D)
|
||
lfr_feats.append(chunk.reshape(-1)) # 展平为 560 维
|
||
|
||
lfr_feats = np.stack(lfr_feats, axis=0) # (T_lfr, m*D=560)
|
||
|
||
# 增加 batch 维度: (1, T_lfr, 560)
|
||
lfr_feats = lfr_feats[np.newaxis, :, :].astype(np.float32)
|
||
lengths = np.array([lfr_feats.shape[1]], dtype=np.int32)
|
||
|
||
return lfr_feats, lengths
|
||
|
||
def _preprocess_audio(self, audio_path: str) -> tuple:
|
||
"""
|
||
预处理音频:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现)
|
||
|
||
流程:
|
||
1. 读入 16k 单声道 wav
|
||
2. 计算 80 维 log-mel fbank
|
||
3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev)
|
||
4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征
|
||
"""
|
||
import librosa
|
||
|
||
# 1. 读入音频
|
||
audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
|
||
if audio.size == 0:
|
||
raise ValueError(f"音频为空: {audio_path}")
|
||
|
||
# 使用_preprocess_audio_array处理
|
||
return self._preprocess_audio_array(audio, sr)
|
||
|
||
def _ctc_decode(self, logits: np.ndarray, length: np.ndarray) -> str:
|
||
"""
|
||
CTC 解码:将 logits 转换为文本
|
||
|
||
Args:
|
||
logits: CTC logits,形状为 (N, T, vocab_size)
|
||
length: 序列长度,形状为 (N,)
|
||
|
||
Returns:
|
||
解码后的文本
|
||
"""
|
||
# 加载 tokens(如果还没加载)
|
||
if not hasattr(self, 'tokens') or len(self.tokens) == 0:
|
||
self._load_tokens()
|
||
|
||
# Greedy CTC 解码
|
||
# 应用 softmax 获取概率
|
||
probs = scipy.special.softmax(logits[0][:length[0]], axis=-1)
|
||
|
||
# 获取每个时间步的最大概率 token
|
||
token_ids = np.argmax(probs, axis=-1)
|
||
|
||
# CTC 解码:移除空白和重复
|
||
prev_token = -1
|
||
decoded_tokens = []
|
||
|
||
for token_id in token_ids:
|
||
if token_id != self.blank_id and token_id != prev_token:
|
||
decoded_tokens.append(token_id)
|
||
prev_token = token_id
|
||
|
||
# Token ID 转文本
|
||
text_parts = []
|
||
for token_id in decoded_tokens:
|
||
if token_id in self.tokens:
|
||
token = self.tokens[token_id]
|
||
# 处理 SentencePiece 标记
|
||
if token.startswith('▁'):
|
||
if text_parts: # 如果不是第一个token,添加空格
|
||
text_parts.append(' ')
|
||
text_parts.append(token[1:])
|
||
elif not token.startswith('<|'): # 忽略特殊标记
|
||
text_parts.append(token)
|
||
|
||
text = ''.join(text_parts)
|
||
|
||
# 后处理:移除残留的特殊标记
|
||
for special in self.special_tokens:
|
||
text = text.replace(special, '')
|
||
|
||
return text.strip()
|
||
|
||
@time_cost("STT-语音识别推理耗时")
|
||
def invoke(self, audio_path: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
执行语音识别推理(从文件)
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
|
||
Returns:
|
||
识别结果列表,格式: [{"text": "识别文本"}]
|
||
"""
|
||
# 预处理音频
|
||
features, features_length = self._preprocess_audio(audio_path)
|
||
|
||
# 执行推理
|
||
text = self._inference(features, features_length)
|
||
|
||
return [{"text": text}]
|
||
|
||
def invoke_numpy(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> str:
|
||
"""
|
||
执行语音识别推理(从numpy数组,实时处理)
|
||
|
||
Args:
|
||
audio_array: 音频数据(numpy array,int16 或 float32)
|
||
sample_rate: 采样率,None时使用配置值
|
||
|
||
Returns:
|
||
识别文本
|
||
"""
|
||
# 预处理音频数组
|
||
features, features_length = self._preprocess_audio_array(audio_array, sample_rate)
|
||
|
||
# 执行推理
|
||
text = self._inference(features, features_length)
|
||
|
||
return text
|
||
|
||
def _inference(self, features: np.ndarray, features_length: np.ndarray) -> str:
|
||
"""
|
||
执行ONNX推理(内部方法)
|
||
|
||
Args:
|
||
features: 特征数组
|
||
features_length: 特征长度
|
||
|
||
Returns:
|
||
识别文本
|
||
"""
|
||
# 准备 ONNX 模型输入
|
||
N, T, C = features.shape
|
||
|
||
# 语言ID
|
||
language = np.array([self.lang_zh], dtype=np.int32)
|
||
|
||
# 文本规范化
|
||
text_norm = np.array([self.with_itn], dtype=np.int32)
|
||
|
||
# ONNX 推理
|
||
inputs = {
|
||
'x': features.astype(np.float32),
|
||
'x_length': features_length.astype(np.int32),
|
||
'language': language,
|
||
'text_norm': text_norm
|
||
}
|
||
|
||
outputs = self.onnx_session.run(None, inputs)
|
||
logits = outputs[0] # 形状: (N, T, vocab_size)
|
||
|
||
# CTC 解码
|
||
text = self._ctc_decode(logits, features_length)
|
||
|
||
return text
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 使用 ONNX 模型进行推理
|
||
import os
|
||
stt = STT()
|
||
|
||
for i in range(10):
|
||
result = stt.invoke("/home/lktx/projects/audio_controll_drone_without_llm/test/测试音频.wav")
|
||
# result = stt.invoke_numpy(np.random.rand(16000), 16000)
|
||
print(f"第{i+1}次识别结果: {result}") |