2026-04-14 09:54:26 +08:00

696 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
TTS(Text-to-Speech)模块 - 基于 Kokoro ONNX 的中文实时合成
使用 Kokoro-82M-v1.1-zh-ONNX 模型进行文本转语音合成:
1. 文本 -> (可选)使用 misaki[zh] 做 G2P,得到音素串
2. 音素字符 -> 根据 tokenizer vocab 映射为 token id 序列
3. 通过 ONNX Runtime 推理生成 24kHz 单声道语音
说明:
- 主要依赖: onnxruntime + numpy
- 如果已安装 misaki[zh] (推荐),效果更好:
pip install "misaki[zh]" cn2an pypinyin jieba
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import onnxruntime as ort
# 仅保留 ERROR避免加载 Kokoro 时大量 ConstantFolding/Reciprocal 警告刷屏(不影响推理结果)
try:
ort.set_default_logger_severity(3)
except Exception:
pass
from voice_drone.core.configuration import SYSTEM_TTS_CONFIG
from voice_drone.logging_ import get_logger
# voice_drone/core/tts.py -> voice_drone_assistant 根目录
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
logger = get_logger("tts.kokoro_onnx")
def _tts_model_dir_candidates(rel: Path) -> List[Path]:
if rel.is_absolute():
return [rel]
out: List[Path] = [_PROJECT_ROOT / rel]
if rel.parts and rel.parts[0] == "models":
out.append(_PROJECT_ROOT.parent / "src" / rel)
return out
def _resolve_kokoro_model_dir(raw: str | Path) -> Path:
"""含 tokenizer.json 的目录;支持子工程 models/ 缺失时回退到上级仓库 src/models/。"""
p = Path(raw)
for c in _tts_model_dir_candidates(p):
if (c / "tokenizer.json").is_file():
return c.resolve()
for c in _tts_model_dir_candidates(p):
if c.is_dir():
logger.warning(
"Kokoro 目录存在但未找到 tokenizer.json: %s(将仍使用该路径,后续可能报错)",
c,
)
return c.resolve()
return (_PROJECT_ROOT / p).resolve()
class KokoroOnnxTTS:
"""
Kokoro 中文 ONNX 文本转语音封装
基本用法:
tts = KokoroOnnxTTS()
audio, sr = tts.synthesize("你好,世界")
返回:
audio: np.ndarray[float32] 形状为 (N,), 范围约 [-1, 1]
sr: int 采样率(默认 24000)
"""
def __init__(self, config: Optional[dict] = None) -> None:
# 读取系统 TTS 配置
self.config = config or SYSTEM_TTS_CONFIG or {}
# 模型根目录(包含 onnx/、tokenizer.json、voices/)
_raw_dir = self.config.get(
"model_dir", "models/Kokoro-82M-v1.1-zh-ONNX"
)
model_dir = _resolve_kokoro_model_dir(_raw_dir)
self.model_dir = model_dir
# ONNX 模型文件名(位于 model_dir/onnx 下;若 onnx/ 下没有可改配置为根目录文件名)
self.model_name = self.config.get("model_name", "model_q4.onnx")
self.onnx_path = model_dir / "onnx" / self.model_name
if not self.onnx_path.is_file():
alt = model_dir / self.model_name
if alt.is_file():
self.onnx_path = alt
# 语音风格(voices 子目录下的 *.bin, 这里不含扩展名)
self.voice_name = self.config.get("voice", "zf_001")
self.voice_path = model_dir / "voices" / f"{self.voice_name}.bin"
# 语速与输出采样率
self.speed = float(self.config.get("speed", 1.0))
self.sample_rate = int(self.config.get("sample_rate", 24000))
# tokenizer.json 路径(本地随 ONNX 模型一起提供)
self.tokenizer_path = model_dir / "tokenizer.json"
# 初始化组件
self._session: Optional[ort.InferenceSession] = None
self._vocab: Optional[dict] = None
self._voices: Optional[np.ndarray] = None
self._g2p = None # misaki[zh] G2P, 如不可用则退化为直接使用原始文本
self._load_all()
# ------------------------------------------------------------------ #
# 对外主接口
# ------------------------------------------------------------------ #
def synthesize(self, text: str) -> Tuple[np.ndarray, int]:
"""
文本转语音
Args:
text: 输入文本(推荐为简体中文)
Returns:
(audio, sample_rate)
"""
if not text or not text.strip():
raise ValueError("TTS 输入文本不能为空")
phonemes = self._text_to_phonemes(text)
token_ids = self._phonemes_to_token_ids(phonemes)
if len(token_ids) == 0:
raise ValueError(f"TTS: 文本在当前 vocab 下无法映射到任何 token, text={text!r}")
# 按 Kokoro-ONNX 官方示例约定:
# - token 序列长度 <= 510
# - 前后各添加 pad token 0
if len(token_ids) > 510:
logger.warning(f"TTS: token 长度 {len(token_ids)} > 510, 将被截断为 510")
token_ids = token_ids[:510]
tokens = np.array([[0, *token_ids, 0]], dtype=np.int64) # shape: (1, <=512)
# 根据 token 数量选择 style 向量
assert self._voices is not None, "TTS: voices 尚未初始化"
voices = self._voices # shape: (N, 1, 256)
idx = min(len(token_ids), voices.shape[0] - 1)
style = voices[idx] # shape: (1, 256)
speed = np.array([self.speed], dtype=np.float32)
# ONNX 输入名约定: input_ids, style, speed
assert self._session is not None, "TTS: ONNX Session 尚未初始化"
session = self._session
audio = session.run(
None,
{
"input_ids": tokens,
"style": style,
"speed": speed,
},
)[0]
# 兼容不同导出形状:
# - 标准 Kokoro ONNX: (1, N)
# - 也有可能是 (N,)
audio = audio.astype(np.float32)
if audio.ndim == 2 and audio.shape[0] == 1:
audio = audio[0]
elif audio.ndim > 2:
# 极端情况: 压缩多余维度
audio = np.squeeze(audio)
return audio, self.sample_rate
def synthesize_to_file(self, text: str, wav_path: str) -> str:
"""
文本合成并保存为 wav 文件(16-bit PCM)
需要依赖 scipy, 可选:
pip install scipy
"""
try:
import scipy.io.wavfile as wavfile # type: ignore
except Exception as e: # pragma: no cover - 仅在未安装时触发
raise RuntimeError("保存到 wav 需要安装 scipy, 请先执行: pip install scipy") from e
audio, sr = self.synthesize(text)
# 简单归一化并转为 int16
max_val = float(np.max(np.abs(audio)) or 1.0)
audio_int16 = np.clip(audio / max_val, -1.0, 1.0)
audio_int16 = (audio_int16 * 32767.0).astype(np.int16)
# 某些 SciPy 版本对一维/零维数组支持不统一,这里显式加上通道维度
if audio_int16.ndim == 0:
audio_to_save = audio_int16.reshape(-1, 1) # 标量 -> (1,1)
elif audio_int16.ndim == 1:
audio_to_save = audio_int16.reshape(-1, 1) # (N,) -> (N,1) 单声道
else:
audio_to_save = audio_int16
wavfile.write(wav_path, sr, audio_to_save)
return wav_path
# ------------------------------------------------------------------ #
# 内部初始化
# ------------------------------------------------------------------ #
def _load_all(self) -> None:
self._load_tokenizer_vocab()
self._load_voices()
self._load_onnx_session()
self._init_g2p()
def _load_tokenizer_vocab(self) -> None:
"""
从本地 tokenizer.json 载入 vocab 映射:
token(str) -> id(int)
"""
if not self.tokenizer_path.exists():
raise FileNotFoundError(f"TTS: 未找到 tokenizer.json: {self.tokenizer_path}")
with open(self.tokenizer_path, "r", encoding="utf-8") as f:
data = json.load(f)
model = data.get("model") or {}
vocab = model.get("vocab")
if not isinstance(vocab, dict):
raise ValueError("TTS: tokenizer.json 格式不正确, 未找到 model.vocab 字段")
# 保存为: 字符 -> id
self._vocab = {k: int(v) for k, v in vocab.items()}
logger.info(f"TTS: tokenizer vocab 加载完成, 词表大小: {len(self._vocab)}")
def _load_voices(self) -> None:
"""
载入语音风格向量(voices/*.bin)
"""
if not self.voice_path.exists():
raise FileNotFoundError(f"TTS: 未找到语音风格文件: {self.voice_path}")
voices = np.fromfile(self.voice_path, dtype=np.float32)
try:
voices = voices.reshape(-1, 1, 256)
except ValueError as e:
raise ValueError(
f"TTS: 语音风格文件形状不符合预期, 无法 reshape 为 (-1,1,256): {self.voice_path}"
) from e
self._voices = voices
logger.info(
f"TTS: 语音风格文件加载完成: {self.voice_name}, 可用 style 数量: {voices.shape[0]}"
)
def _load_onnx_session(self) -> None:
"""
创建 ONNX Runtime 推理会话
"""
if not self.onnx_path.exists():
raise FileNotFoundError(f"TTS: 未找到 ONNX 模型文件: {self.onnx_path}")
sess_options = ort.SessionOptions()
# 启用所有图优化
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# RK3588 等多核 CPU可用环境变量固定 ORT 线程,避免过小/过大(默认 0 表示交给 ORT 自动)
_ti = os.environ.get("ROCKET_TTS_ORT_INTRA_OP_THREADS", "").strip()
if _ti.isdigit() and int(_ti) > 0:
sess_options.intra_op_num_threads = int(_ti)
_te = os.environ.get("ROCKET_TTS_ORT_INTER_OP_THREADS", "").strip()
if _te.isdigit() and int(_te) > 0:
sess_options.inter_op_num_threads = int(_te)
# 简单的 CPU 推理(如需 GPU, 可在此扩展 providers)
self._session = ort.InferenceSession(
str(self.onnx_path),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
logger.info(f"TTS: Kokoro ONNX 模型加载完成: {self.onnx_path}")
def _init_g2p(self) -> None:
"""
初始化中文 G2P (基于 misaki[zh])。
如果环境中未安装 misaki, 则退化为直接使用原始文本字符做映射。
"""
try:
from misaki import zh # type: ignore
# 兼容不同版本的 misaki:
# - 新版: ZHG2P(version=...) 可用
# - 旧版: ZHG2P() 不接受参数
try:
self._g2p = zh.ZHG2P(version="1.1") # type: ignore[call-arg]
except TypeError:
self._g2p = zh.ZHG2P() # type: ignore[call-arg]
logger.info("TTS: 已启用 misaki[zh] G2P, 将使用音素级别映射")
except Exception as e:
self._g2p = None
logger.warning(
"TTS: 未安装或无法导入 misaki[zh], 将直接基于原始文本字符做 token 映射, "
"合成效果可能较差。建议执行: pip install \"misaki[zh]\" cn2an pypinyin jieba"
)
logger.debug(f"TTS: G2P 初始化失败原因: {e!r}")
# ------------------------------------------------------------------ #
# 文本 -> 音素 / token
# ------------------------------------------------------------------ #
def _text_to_phonemes(self, text: str) -> str:
"""
文本 -> 音素串
- 若 misaki[zh] 可用, 则使用 ZHG2P(version='1.1') 得到音素串
- 否则, 直接返回原始文本(后续按字符映射)
"""
if self._g2p is None:
return text.strip()
# 兼容不同版本的 misaki:
# - 有的返回 (phonemes, tokens)
# - 有的只返回 phonemes 字符串
result = self._g2p(text)
if isinstance(result, tuple) or isinstance(result, list):
ps = result[0]
else:
ps = result
if not ps:
# 回退: 如果 G2P 返回空, 使用原始文本
logger.warning("TTS: G2P 结果为空, 回退为原始文本")
return text.strip()
return ps
def _phonemes_to_token_ids(self, phonemes: str) -> List[int]:
"""
将音素串映射为 token id 序列
直接按字符级别查表:
- 每个字符在 vocab 中有唯一 id
- 空格本身也是一个 token (id=16)
"""
assert self._vocab is not None, "TTS: vocab 尚未初始化"
vocab = self._vocab
token_ids: List[int] = []
unknown_chars = set()
for ch in phonemes:
if ch == "\n":
continue
tid = vocab.get(ch)
if tid is None:
unknown_chars.add(ch)
continue
token_ids.append(int(tid))
if unknown_chars:
logger.debug(f"TTS: 存在无法映射到 vocab 的字符: {unknown_chars}")
return token_ids
def _resolve_output_device_id(raw: object) -> Optional[int]:
"""
将配置中的 output_device 解析为 sounddevice 设备索引。
None / 空:返回 None表示使用 sd 默认输出。
"""
import sounddevice as sd # type: ignore
if raw is None:
return None
if isinstance(raw, bool):
return None
if isinstance(raw, int):
return raw if raw >= 0 else None
s = str(raw).strip()
if not s or s.lower() in ("null", "none", "default", ""):
return None
if s.isdigit():
return int(s)
needle = s.lower()
devices = sd.query_devices()
matches: List[int] = []
for i, d in enumerate(devices):
if int(d.get("max_output_channels", 0) or 0) <= 0:
continue
name = (d.get("name") or "").lower()
if needle in name:
matches.append(i)
if not matches:
logger.warning(
f"TTS: 未找到名称包含 {s!r} 的输出设备,将使用系统默认输出。"
"请检查 system.yaml 中 tts.output_device 或查看启动日志中的设备列表。"
)
return None
if len(matches) > 1:
logger.info(
f"TTS: 名称 {s!r} 匹配到多个输出设备索引 {matches},使用第一个 {matches[0]}"
)
return matches[0]
_playback_dev_cache: Optional[int] = None
_playback_dev_cache_key: Optional[str] = None
_playback_dev_cache_ready: bool = False
def get_playback_output_device_id() -> Optional[int]:
"""从 SYSTEM_TTS_CONFIG 解析并缓存播放设备索引None=默认输出)。"""
global _playback_dev_cache, _playback_dev_cache_key, _playback_dev_cache_ready
cfg = SYSTEM_TTS_CONFIG or {}
raw = cfg.get("output_device")
key = repr(raw)
if _playback_dev_cache_ready and _playback_dev_cache_key == key:
return _playback_dev_cache
dev_id = _resolve_output_device_id(raw)
_playback_dev_cache = dev_id
_playback_dev_cache_key = key
_playback_dev_cache_ready = True
if dev_id is not None:
import sounddevice as sd # type: ignore
info = sd.query_devices(dev_id)
logger.info(
f"TTS: 播放将使用输出设备 index={dev_id} name={info.get('name', '?')!r}"
)
else:
logger.info("TTS: 播放使用系统默认输出设备(未指定或无法匹配 tts.output_device")
return dev_id
def _sounddevice_default_output_index():
"""sounddevice 0.5+ 的 default.device 可能是 _InputOutputPair需取 [1] 为输出索引。"""
import sounddevice as sd # type: ignore
default = sd.default.device
if isinstance(default, (list, tuple)):
return int(default[1])
if hasattr(default, "__getitem__"):
try:
return int(default[1])
except (IndexError, TypeError, ValueError):
pass
try:
return int(default)
except (TypeError, ValueError):
return None
def log_sounddevice_output_devices() -> None:
"""列出所有可用输出设备及当前默认输出,便于配置 tts.output_device。"""
try:
import sounddevice as sd # type: ignore
out_idx = _sounddevice_default_output_index()
logger.info("sounddevice 输出设备列表(用于配置 tts.output_device 索引或名称子串):")
for i, d in enumerate(sd.query_devices()):
if int(d.get("max_output_channels", 0) or 0) <= 0:
continue
mark = " <- 当前默认输出" if out_idx is not None and int(out_idx) == i else ""
logger.info(f" [{i}] {d.get('name', '?')}{mark}")
except Exception as e:
logger.warning(f"无法枚举 sounddevice 输出设备: {e}")
def _resample_playback_audio(audio: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
"""将波形重采样到设备采样率;优先 librosakaiser_best失败则回退 scipy 多相。"""
from math import gcd
from scipy.signal import resample # type: ignore
from scipy.signal import resample_poly # type: ignore
if abs(sr_in - sr_out) < 1:
return np.asarray(audio, dtype=np.float32)
try:
import librosa # type: ignore
return librosa.resample(
np.asarray(audio, dtype=np.float32),
orig_sr=sr_in,
target_sr=sr_out,
res_type="kaiser_best",
).astype(np.float32)
except Exception as e:
logger.debug(f"TTS: librosa 重采样不可用,使用多相重采样: {e!r}")
try:
g = gcd(int(sr_in), int(sr_out))
if g > 0:
up = int(sr_out) // g
down = int(sr_in) // g
return resample_poly(audio, up, down).astype(np.float32)
except Exception as e2:
logger.debug(f"TTS: resample_poly 失败,回退 FFT resample: {e2!r}")
num = max(1, int(len(audio) * float(sr_out) / float(sr_in)))
return resample(audio, num).astype(np.float32)
def _fade_playback_edges(audio: np.ndarray, sample_rate: int, fade_ms: float) -> np.ndarray:
"""极短线性淡入淡出,减轻扬声器/驱动在段首段尾的爆音与杂音感。"""
if fade_ms <= 0 or audio.size < 16:
return audio
n = int(float(sample_rate) * fade_ms / 1000.0)
n = min(n, len(audio) // 4)
if n <= 0:
return audio
out = np.asarray(audio, dtype=np.float32, order="C").copy()
fade = np.linspace(0.0, 1.0, n, dtype=np.float32)
out[:n] *= fade
out[-n:] *= fade[::-1]
return out
def play_tts_audio(
audio: np.ndarray,
sample_rate: int,
*,
output_device: Optional[object] = None,
) -> None:
"""
使用 sounddevice 播放单声道 float32 音频(阻塞至播放结束)。
在 Windows 上 PortAudio/sounddevice 从非主线程调用时经常出现「无声音、无报错」,
因此本项目中应答播报应在主线程(采集循环所在线程)调用本函数。
另:多数 Realtek/WASAPI 设备对 24000Hz 播放会「完全无声」且不报错,需重采样到设备
default_samplerate常见 48000/44100并用 OutputStream 写出。
Args:
output_device: 若指定,覆盖 system.yaml 的 tts.output_device设备索引或名称子串
"""
import sounddevice as sd # type: ignore
cfg = SYSTEM_TTS_CONFIG or {}
force_native = bool(cfg.get("playback_resample_to_device_native", True))
do_normalize = bool(cfg.get("playback_peak_normalize", True))
gain = float(cfg.get("playback_gain", 1.0))
if gain <= 0:
gain = 1.0
fade_ms = float(cfg.get("playback_edge_fade_ms", 8.0))
latency = cfg.get("playback_output_latency", "low")
if latency not in ("low", "medium", "high"):
latency = "low"
audio = np.asarray(audio, dtype=np.float32).squeeze()
if audio.ndim > 1:
audio = np.squeeze(audio)
if audio.size == 0:
logger.warning("TTS: 播放跳过,音频长度为 0")
return
if output_device is not None:
dev = _resolve_output_device_id(output_device)
else:
dev = get_playback_output_device_id()
if dev is None:
dev = _sounddevice_default_output_index()
if dev is None:
logger.warning("TTS: 无法解析输出设备索引,使用 sounddevice 默认输出")
else:
dev = int(dev)
info = sd.query_devices(dev) if dev is not None else sd.query_devices(kind="output")
native_sr = int(float(info.get("default_samplerate", 48000)))
sr_out = int(sample_rate)
if force_native and native_sr > 0 and abs(native_sr - sr_out) > 1:
audio = _resample_playback_audio(audio, sr_out, native_sr)
sr_out = native_sr
logger.info(
f"TTS: 播放重采样 {sample_rate}Hz -> {sr_out}Hz匹配设备 default_samplerate避免 Windows 无声)"
)
peak_before = float(np.max(np.abs(audio)))
if do_normalize and peak_before > 1e-8 and peak_before > 0.95:
audio = (audio / peak_before * 0.92).astype(np.float32, copy=False)
if gain != 1.0:
audio = (audio * np.float32(gain)).astype(np.float32, copy=False)
audio = _fade_playback_edges(audio, sr_out, fade_ms)
peak = float(np.max(np.abs(audio)))
rms = float(np.sqrt(np.mean(np.square(audio))))
dname = info.get("name", "?") if isinstance(info, dict) else "?"
logger.info(
f"TTS: 播放 峰值={peak:.5f} RMS={rms:.5f} sr={sr_out}Hz 设备={dev!r} ({dname!r})"
)
if peak < 1e-8:
logger.warning("TTS: 波形接近静音,请检查合成是否异常")
audio = np.clip(audio, -1.0, 1.0).astype(np.float32, copy=False)
block = audio.reshape(-1, 1)
try:
with sd.OutputStream(
device=dev,
channels=1,
samplerate=sr_out,
dtype="float32",
latency=latency,
) as stream:
stream.write(block)
except Exception as e:
logger.warning(f"TTS: OutputStream 失败,回退 sd.play: {e}", exc_info=True)
sd.play(block, samplerate=sr_out, device=dev)
sd.wait()
def play_wav_path(
path: str | Path,
*,
output_device: Optional[object] = None,
) -> None:
"""
播放 16-bit PCM WAV单声道或立体声下混为单声道走与 synthesize + play_tts_audio
相同的 sounddevice 输出路径(含 ROCKET_TTS_DEVICE / yaml 设备解析)。
"""
import wave
p = Path(path)
with wave.open(str(p), "rb") as wf:
ch = int(wf.getnchannels())
sw = int(wf.getsampwidth())
sr = int(wf.getframerate())
nframes = int(wf.getnframes())
if sw != 2:
raise ValueError(f"仅支持 16-bit PCM: {p}")
raw = wf.readframes(nframes)
mono = np.frombuffer(raw, dtype="<i2").astype(np.float32, copy=False)
if ch == 2:
mono = mono.reshape(-1, 2).mean(axis=1).astype(np.float32, copy=False)
elif ch != 1:
raise ValueError(f"仅支持 1 或 2 通道: {p} (ch={ch})")
mono = mono * np.float32(1.0 / 32768.0)
logger.info("TTS: 播放预生成 WAV %s (%sHz, %s 采样)", p.name, sr, mono.size)
play_tts_audio(mono, sr, output_device=output_device)
def speak_text(
text: str,
tts: Optional["KokoroOnnxTTS"] = None,
*,
output_device: Optional[object] = None,
) -> None:
"""
合成并播放一段语音;失败时仅打日志,不向外抛异常(适合命令成功后的反馈)。
"""
if not text or not str(text).strip():
return
try:
engine = tts or KokoroOnnxTTS()
t = str(text).strip()
logger.info(f"TTS: 开始合成并播放: {t!r}")
audio, sr = engine.synthesize(t)
play_tts_audio(audio, sr, output_device=output_device)
logger.info("TTS: 播放完成")
except Exception as e:
logger.warning(f"TTS 播放失败: {e}", exc_info=True)
__all__ = [
"KokoroOnnxTTS",
"play_tts_audio",
"play_wav_path",
"speak_text",
"get_playback_output_device_id",
"log_sounddevice_output_devices",
]
if __name__ == "__main__":
# 与主程序一致:使用 play_tts_audio含重采样到设备 native 采样率)
tts = KokoroOnnxTTS()
text = "任务执行完成,开始返航降落"
print(f"正在合成语音: {text}")
audio, sr = tts.synthesize(text)
print("正在播放(与主程序相同 play_tts_audio 路径)...")
try:
play_tts_audio(audio, sr)
print("播放结束。")
except Exception as e:
print(f"播放失败: {e}")
# === 保存为 WAV 文件(可选)===
try:
output_path = "任务执行完成,开始返航降落.wav"
tts.synthesize_to_file(text, output_path)
print(f"音频已保存至: {output_path}")
except RuntimeError as e:
print(f"保存失败(可能缺少 scipy: {e}")