2026-04-14 09:54:26 +08:00

494 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
语音识别(Speech-to-Text)类 - 纯 ONNX Runtime 极致性能推理
针对 RK3588 等 ARM 设备进行了深度优化,完全移除 FunASR 依赖。
前处理(fbank + CMVN + LFR)与解码均手写实现。
"""
import platform
import os
import multiprocessing
import numpy as np
from pathlib import Path
from typing import List, Dict, Any, Optional
import onnx
import onnxruntime as ort
from voice_drone.logging_ import get_logger
from voice_drone.tools.wrapper import time_cost
import scipy.special
from voice_drone.core.configuration import SYSTEM_STT_CONFIG, SYSTEM_AUDIO_CONFIG
# voice_drone/core/stt.py -> 工程根(含 voice_drone_assistant 与本仓库根两种布局)
_STT_PROJECT_ROOT = Path(__file__).resolve().parents[2]
def _stt_path_candidates(path: Path) -> List[Path]:
"""相对配置路径的候选绝对路径:优先工程目录,其次嵌套在上一级仓库时的 src/models/。"""
if path.is_absolute():
return [path]
out: List[Path] = [_STT_PROJECT_ROOT / path]
if path.parts and path.parts[0] == "models":
out.append(_STT_PROJECT_ROOT.parent / "src" / path)
return out
class STT:
"""
语音识别(Speech-to-Text)类
使用 ONNX Runtime 进行最优性能推理
针对 RK3588 等 ARM 设备进行了深度优化
"""
def __init__(self):
"""
初始化 STT 模型
"""
stt_conf = SYSTEM_STT_CONFIG
self.logger = get_logger("stt.onnx")
# 从配置读取参数
self.model_dir = stt_conf.get("model_dir")
self.model_path = stt_conf.get("model_path")
self.prefer_int8 = stt_conf.get("prefer_int8", True)
_wf = stt_conf.get("warmup_file")
self.warmup_file: Optional[str] = None
if _wf:
wf_path = Path(_wf)
if wf_path.is_absolute() and wf_path.is_file():
self.warmup_file = str(wf_path)
else:
for c in _stt_path_candidates(wf_path):
if c.is_file():
self.warmup_file = str(c)
break
# 音频预处理参数(确保数值类型正确)
self.sample_rate = int(stt_conf.get("sample_rate", SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000)))
self.n_mels = int(stt_conf.get("n_mels", 80))
self.frame_length_ms = float(stt_conf.get("frame_length_ms", 25))
self.frame_shift_ms = float(stt_conf.get("frame_shift_ms", 10))
self.log_eps = float(stt_conf.get("log_eps", 1e-10))
# ARM 优化配置
arm_conf = stt_conf.get("arm_optimization", {})
self.arm_enabled = arm_conf.get("enabled", True)
self.arm_max_threads = arm_conf.get("max_threads", 4)
# CTC 解码配置
ctc_conf = stt_conf.get("ctc_decode", {})
self.blank_id = ctc_conf.get("blank_id", 0)
# 语言和文本规范化配置(默认值)
lang_conf = stt_conf.get("language", {})
text_norm_conf = stt_conf.get("text_norm", {})
self.lang_zh_default = lang_conf.get("zh_id", 3)
self.with_itn_default = text_norm_conf.get("with_itn_id", 14)
self.without_itn_default = text_norm_conf.get("without_itn_id", 15)
# 后处理配置
postprocess_conf = stt_conf.get("postprocess", {})
self.special_tokens = postprocess_conf.get("special_tokens", [
"<|zh|>", "<|NEUTRAL|>", "<|Speech|>", "<|woitn|>", "<|withitn|>"
])
# 检测是否为 RK3588 或 ARM 设备
ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64')
RK3588 = 'rk3588' in platform.platform().lower() or os.path.exists('/proc/device-tree/compatible')
# ARM 设备性能优化配置
if self.arm_enabled and (ARM or RK3588):
cpu_count = multiprocessing.cpu_count()
optimal_threads = min(self.arm_max_threads, cpu_count)
# 设置 OpenMP 线程数
os.environ['OMP_NUM_THREADS'] = str(optimal_threads)
os.environ['MKL_NUM_THREADS'] = str(optimal_threads)
os.environ['KMP_AFFINITY'] = 'granularity=fine,compact,1,0'
os.environ['OMP_DYNAMIC'] = 'FALSE'
os.environ['MKL_DYNAMIC'] = 'FALSE'
self.logger.info("ARM/RK3588 优化已启用")
self.logger.info(f" CPU 核心数: {cpu_count}")
self.logger.info(f" 优化线程数: {optimal_threads}")
# 确定模型路径
onnx_model_path = self._resolve_model_path()
# 保存模型目录路径(用于加载 tokens.txt)
self.onnx_model_dir = onnx_model_path.parent
self.logger.info(f"加载 ONNX 模型: {onnx_model_path}")
self._load_onnx_model(str(onnx_model_path))
# 模型预热
if self.warmup_file and os.path.exists(self.warmup_file):
try:
self.logger.info(f"正在预热模型(使用: {self.warmup_file})...")
_ = self.invoke(self.warmup_file)
self.logger.info("模型预热完成")
except Exception as e:
self.logger.warning(f"预热失败(可忽略): {e}")
elif self.warmup_file:
self.logger.warning(f"预热文件不存在: {self.warmup_file},跳过预热步骤")
def _resolve_existing_model_file(self, raw: Optional[str]) -> Optional[Path]:
if not raw:
return None
p = Path(raw)
for c in _stt_path_candidates(p):
if c.is_file():
return c
return None
def _resolve_existing_model_dir(self, raw: Optional[str]) -> Optional[Path]:
if not raw:
return None
p = Path(raw)
for c in _stt_path_candidates(p):
if c.is_dir():
return c
return None
def _resolve_model_path(self) -> Path:
"""
解析模型路径
Returns:
模型文件路径
"""
if self.model_path:
hit = self._resolve_existing_model_file(self.model_path)
if hit is not None:
return hit
if not self.model_dir:
raise ValueError("配置中必须指定 model_path 或 model_dir")
model_dir = self._resolve_existing_model_dir(self.model_dir)
if model_dir is None:
tried = ", ".join(str(x) for x in _stt_path_candidates(Path(self.model_dir)))
raise FileNotFoundError(
f"ONNX 模型目录不存在。config model_dir={self.model_dir!r},已尝试: {tried}"
f"请将 SenseVoice 放入 {_STT_PROJECT_ROOT / 'models'},或 ln -s ../src/models "
f"{_STT_PROJECT_ROOT / 'models'}(见 models/README.txt"
)
# 优先使用 INT8 量化模型(如果启用)
if self.prefer_int8:
int8_path = model_dir / "model.int8.onnx"
if int8_path.exists():
return int8_path
# 回退到普通模型
onnx_path = model_dir / "model.onnx"
if onnx_path.exists():
return onnx_path
raise FileNotFoundError(f"ONNX 模型文件不存在: 在 {model_dir} 中未找到 model.int8.onnx 或 model.onnx")
def _load_onnx_model(self, onnx_model_path: str):
"""加载 ONNX 模型"""
# 创建 ONNX Runtime 会话选项
sess_options = ort.SessionOptions()
# ARM 设备优化
ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64')
if self.arm_enabled and ARM:
cpu_count = multiprocessing.cpu_count()
optimal_threads = min(self.arm_max_threads, cpu_count)
sess_options.intra_op_num_threads = optimal_threads
sess_options.inter_op_num_threads = optimal_threads
# 启用所有图优化(最优性能)
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# 创建推理会话
self.onnx_session = ort.InferenceSession(
onnx_model_path,
sess_options=sess_options,
providers=['CPUExecutionProvider']
)
# 获取模型元数据
onnx_model = onnx.load(onnx_model_path)
self.model_metadata = {prop.key: prop.value for prop in onnx_model.metadata_props}
# 解析元数据
self.lfr_window_size = int(self.model_metadata.get('lfr_window_size', 7))
self.lfr_window_shift = int(self.model_metadata.get('lfr_window_shift', 6))
self.vocab_size = int(self.model_metadata.get('vocab_size', 25055))
# 解析 CMVN 参数
neg_mean_str = self.model_metadata.get('neg_mean', '')
inv_stddev_str = self.model_metadata.get('inv_stddev', '')
self.neg_mean = np.array([float(x) for x in neg_mean_str.split(',')]) if neg_mean_str else None
self.inv_stddev = np.array([float(x) for x in inv_stddev_str.split(',')]) if inv_stddev_str else None
# 语言和文本规范化 ID(从元数据获取,如果没有则使用配置默认值)
self.lang_zh = int(self.model_metadata.get('lang_zh', self.lang_zh_default))
self.with_itn = int(self.model_metadata.get('with_itn', self.with_itn_default))
self.without_itn = int(self.model_metadata.get('without_itn', self.without_itn_default))
self.logger.info("ONNX 模型加载完成")
self.logger.info(f" LFR窗口大小: {self.lfr_window_size}")
self.logger.info(f" LFR窗口偏移: {self.lfr_window_shift}")
self.logger.info(f" 词汇表大小: {self.vocab_size}")
def _load_tokens(self):
"""加载 tokens 映射"""
tokens_file = self.onnx_model_dir / "tokens.txt"
if tokens_file.exists():
self.tokens = {}
with open(tokens_file, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
token = parts[0]
token_id = int(parts[-1])
self.tokens[token_id] = token
return True
return False
def _preprocess_audio_array(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> tuple:
"""
预处理音频数组:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现)
支持实时音频流处理(numpy数组输入)
流程:
1. 输入 16k 单声道 numpy 数组(int16 或 float32)
2. 计算 80 维 log-mel fbank
3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev)
4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征
Args:
audio_array: 音频数据(numpy array,int16 或 float32)
sample_rate: 采样率,None时使用配置值
Returns:
(features, lengths): 特征和长度
"""
import librosa
sr = sample_rate if sample_rate is not None else self.sample_rate
# 1. 转换为float32格式(如果输入是int16)
if audio_array.dtype == np.int16:
audio = audio_array.astype(np.float32) / 32768.0
else:
audio = audio_array.astype(np.float32)
# 确保是单声道
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
if audio.size == 0:
raise ValueError("音频数组为空")
# 2. 计算 fbank 特征
n_fft = int(self.frame_length_ms / 1000.0 * sr)
hop_length = int(self.frame_shift_ms / 1000.0 * sr)
mel_spec = librosa.feature.melspectrogram(
y=audio,
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
n_mels=self.n_mels,
window="hann",
center=True,
power=1.0, # 线性能量
)
# log-mel
log_mel = np.log(np.maximum(mel_spec, self.log_eps)).T # (T, n_mels)
# 3. CMVN使用 ONNX 元数据中的 neg_mean / inv_stddev
if self.neg_mean is not None and self.inv_stddev is not None:
if self.neg_mean.shape[0] == log_mel.shape[1]:
log_mel = (log_mel + self.neg_mean) * self.inv_stddev
# 4. LFR按窗口 lfr_window_size 堆叠,步长 lfr_window_shift
T, D = log_mel.shape
m = self.lfr_window_size
n = self.lfr_window_shift
if T < m:
# 帧数不够,补到 m 帧
pad = np.tile(log_mel[-1], (m - T, 1))
log_mel = np.vstack([log_mel, pad])
T = m
# 计算 LFR 后的帧数
T_lfr = 1 + (T - m) // n
lfr_feats = []
for i in range(T_lfr):
start = i * n
end = start + m
chunk = log_mel[start:end, :] # (m, D)
lfr_feats.append(chunk.reshape(-1)) # 展平为 560 维
lfr_feats = np.stack(lfr_feats, axis=0) # (T_lfr, m*D=560)
# 增加 batch 维度: (1, T_lfr, 560)
lfr_feats = lfr_feats[np.newaxis, :, :].astype(np.float32)
lengths = np.array([lfr_feats.shape[1]], dtype=np.int32)
return lfr_feats, lengths
def _preprocess_audio(self, audio_path: str) -> tuple:
"""
预处理音频:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现)
流程:
1. 读入 16k 单声道 wav
2. 计算 80 维 log-mel fbank
3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev)
4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征
"""
import librosa
# 1. 读入音频
audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
if audio.size == 0:
raise ValueError(f"音频为空: {audio_path}")
# 使用_preprocess_audio_array处理
return self._preprocess_audio_array(audio, sr)
def _ctc_decode(self, logits: np.ndarray, length: np.ndarray) -> str:
"""
CTC 解码:将 logits 转换为文本
Args:
logits: CTC logits,形状为 (N, T, vocab_size)
length: 序列长度,形状为 (N,)
Returns:
解码后的文本
"""
# 加载 tokens(如果还没加载)
if not hasattr(self, 'tokens') or len(self.tokens) == 0:
self._load_tokens()
# Greedy CTC 解码
# 应用 softmax 获取概率
probs = scipy.special.softmax(logits[0][:length[0]], axis=-1)
# 获取每个时间步的最大概率 token
token_ids = np.argmax(probs, axis=-1)
# CTC 解码:移除空白和重复
prev_token = -1
decoded_tokens = []
for token_id in token_ids:
if token_id != self.blank_id and token_id != prev_token:
decoded_tokens.append(token_id)
prev_token = token_id
# Token ID 转文本
text_parts = []
for token_id in decoded_tokens:
if token_id in self.tokens:
token = self.tokens[token_id]
# 处理 SentencePiece 标记
if token.startswith(''):
if text_parts: # 如果不是第一个token,添加空格
text_parts.append(' ')
text_parts.append(token[1:])
elif not token.startswith('<|'): # 忽略特殊标记
text_parts.append(token)
text = ''.join(text_parts)
# 后处理:移除残留的特殊标记
for special in self.special_tokens:
text = text.replace(special, '')
return text.strip()
@time_cost("STT-语音识别推理耗时")
def invoke(self, audio_path: str) -> List[Dict[str, Any]]:
"""
执行语音识别推理(从文件)
Args:
audio_path: 音频文件路径
Returns:
识别结果列表,格式: [{"text": "识别文本"}]
"""
# 预处理音频
features, features_length = self._preprocess_audio(audio_path)
# 执行推理
text = self._inference(features, features_length)
return [{"text": text}]
def invoke_numpy(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> str:
"""
执行语音识别推理(从numpy数组,实时处理)
Args:
audio_array: 音频数据(numpy array,int16 或 float32)
sample_rate: 采样率,None时使用配置值
Returns:
识别文本
"""
# 预处理音频数组
features, features_length = self._preprocess_audio_array(audio_array, sample_rate)
# 执行推理
text = self._inference(features, features_length)
return text
def _inference(self, features: np.ndarray, features_length: np.ndarray) -> str:
"""
执行ONNX推理(内部方法)
Args:
features: 特征数组
features_length: 特征长度
Returns:
识别文本
"""
# 准备 ONNX 模型输入
N, T, C = features.shape
# 语言ID
language = np.array([self.lang_zh], dtype=np.int32)
# 文本规范化
text_norm = np.array([self.with_itn], dtype=np.int32)
# ONNX 推理
inputs = {
'x': features.astype(np.float32),
'x_length': features_length.astype(np.int32),
'language': language,
'text_norm': text_norm
}
outputs = self.onnx_session.run(None, inputs)
logits = outputs[0] # 形状: (N, T, vocab_size)
# CTC 解码
text = self._ctc_decode(logits, features_length)
return text
if __name__ == "__main__":
# 使用 ONNX 模型进行推理
import os
stt = STT()
for i in range(10):
result = stt.invoke("/home/lktx/projects/audio_controll_drone_without_llm/test/测试音频.wav")
# result = stt.invoke_numpy(np.random.rand(16000), 16000)
print(f"{i+1}次识别结果: {result}")