""" TTS(Text-to-Speech)模块 - 基于 Kokoro ONNX 的中文实时合成 使用 Kokoro-82M-v1.1-zh-ONNX 模型进行文本转语音合成: 1. 文本 -> (可选)使用 misaki[zh] 做 G2P,得到音素串 2. 音素字符 -> 根据 tokenizer vocab 映射为 token id 序列 3. 通过 ONNX Runtime 推理生成 24kHz 单声道语音 说明: - 主要依赖: onnxruntime + numpy - 如果已安装 misaki[zh] (推荐),效果更好: pip install "misaki[zh]" cn2an pypinyin jieba """ from __future__ import annotations import json import os from pathlib import Path from typing import List, Optional, Tuple import numpy as np import onnxruntime as ort # 仅保留 ERROR,避免加载 Kokoro 时大量 ConstantFolding/Reciprocal 警告刷屏(不影响推理结果) try: ort.set_default_logger_severity(3) except Exception: pass from voice_drone.core.configuration import SYSTEM_TTS_CONFIG from voice_drone.logging_ import get_logger # voice_drone/core/tts.py -> voice_drone_assistant 根目录 _PROJECT_ROOT = Path(__file__).resolve().parents[2] logger = get_logger("tts.kokoro_onnx") def _tts_model_dir_candidates(rel: Path) -> List[Path]: if rel.is_absolute(): return [rel] out: List[Path] = [_PROJECT_ROOT / rel] if rel.parts and rel.parts[0] == "models": out.append(_PROJECT_ROOT.parent / "src" / rel) return out def _resolve_kokoro_model_dir(raw: str | Path) -> Path: """含 tokenizer.json 的目录;支持子工程 models/ 缺失时回退到上级仓库 src/models/。""" p = Path(raw) for c in _tts_model_dir_candidates(p): if (c / "tokenizer.json").is_file(): return c.resolve() for c in _tts_model_dir_candidates(p): if c.is_dir(): logger.warning( "Kokoro 目录存在但未找到 tokenizer.json: %s(将仍使用该路径,后续可能报错)", c, ) return c.resolve() return (_PROJECT_ROOT / p).resolve() class KokoroOnnxTTS: """ Kokoro 中文 ONNX 文本转语音封装 基本用法: tts = KokoroOnnxTTS() audio, sr = tts.synthesize("你好,世界") 返回: audio: np.ndarray[float32] 形状为 (N,), 范围约 [-1, 1] sr: int 采样率(默认 24000) """ def __init__(self, config: Optional[dict] = None) -> None: # 读取系统 TTS 配置 self.config = config or SYSTEM_TTS_CONFIG or {} # 模型根目录(包含 onnx/、tokenizer.json、voices/) _raw_dir = self.config.get( "model_dir", "models/Kokoro-82M-v1.1-zh-ONNX" ) model_dir = _resolve_kokoro_model_dir(_raw_dir) self.model_dir = model_dir # ONNX 模型文件名(位于 model_dir/onnx 下;若 onnx/ 下没有可改配置为根目录文件名) self.model_name = self.config.get("model_name", "model_q4.onnx") self.onnx_path = model_dir / "onnx" / self.model_name if not self.onnx_path.is_file(): alt = model_dir / self.model_name if alt.is_file(): self.onnx_path = alt # 语音风格(voices 子目录下的 *.bin, 这里不含扩展名) self.voice_name = self.config.get("voice", "zf_001") self.voice_path = model_dir / "voices" / f"{self.voice_name}.bin" # 语速与输出采样率 self.speed = float(self.config.get("speed", 1.0)) self.sample_rate = int(self.config.get("sample_rate", 24000)) # tokenizer.json 路径(本地随 ONNX 模型一起提供) self.tokenizer_path = model_dir / "tokenizer.json" # 初始化组件 self._session: Optional[ort.InferenceSession] = None self._vocab: Optional[dict] = None self._voices: Optional[np.ndarray] = None self._g2p = None # misaki[zh] G2P, 如不可用则退化为直接使用原始文本 self._load_all() # ------------------------------------------------------------------ # # 对外主接口 # ------------------------------------------------------------------ # def synthesize(self, text: str) -> Tuple[np.ndarray, int]: """ 文本转语音 Args: text: 输入文本(推荐为简体中文) Returns: (audio, sample_rate) """ if not text or not text.strip(): raise ValueError("TTS 输入文本不能为空") phonemes = self._text_to_phonemes(text) token_ids = self._phonemes_to_token_ids(phonemes) if len(token_ids) == 0: raise ValueError(f"TTS: 文本在当前 vocab 下无法映射到任何 token, text={text!r}") # 按 Kokoro-ONNX 官方示例约定: # - token 序列长度 <= 510 # - 前后各添加 pad token 0 if len(token_ids) > 510: logger.warning(f"TTS: token 长度 {len(token_ids)} > 510, 将被截断为 510") token_ids = token_ids[:510] tokens = np.array([[0, *token_ids, 0]], dtype=np.int64) # shape: (1, <=512) # 根据 token 数量选择 style 向量 assert self._voices is not None, "TTS: voices 尚未初始化" voices = self._voices # shape: (N, 1, 256) idx = min(len(token_ids), voices.shape[0] - 1) style = voices[idx] # shape: (1, 256) speed = np.array([self.speed], dtype=np.float32) # ONNX 输入名约定: input_ids, style, speed assert self._session is not None, "TTS: ONNX Session 尚未初始化" session = self._session audio = session.run( None, { "input_ids": tokens, "style": style, "speed": speed, }, )[0] # 兼容不同导出形状: # - 标准 Kokoro ONNX: (1, N) # - 也有可能是 (N,) audio = audio.astype(np.float32) if audio.ndim == 2 and audio.shape[0] == 1: audio = audio[0] elif audio.ndim > 2: # 极端情况: 压缩多余维度 audio = np.squeeze(audio) return audio, self.sample_rate def synthesize_to_file(self, text: str, wav_path: str) -> str: """ 文本合成并保存为 wav 文件(16-bit PCM) 需要依赖 scipy, 可选: pip install scipy """ try: import scipy.io.wavfile as wavfile # type: ignore except Exception as e: # pragma: no cover - 仅在未安装时触发 raise RuntimeError("保存到 wav 需要安装 scipy, 请先执行: pip install scipy") from e audio, sr = self.synthesize(text) # 简单归一化并转为 int16 max_val = float(np.max(np.abs(audio)) or 1.0) audio_int16 = np.clip(audio / max_val, -1.0, 1.0) audio_int16 = (audio_int16 * 32767.0).astype(np.int16) # 某些 SciPy 版本对一维/零维数组支持不统一,这里显式加上通道维度 if audio_int16.ndim == 0: audio_to_save = audio_int16.reshape(-1, 1) # 标量 -> (1,1) elif audio_int16.ndim == 1: audio_to_save = audio_int16.reshape(-1, 1) # (N,) -> (N,1) 单声道 else: audio_to_save = audio_int16 wavfile.write(wav_path, sr, audio_to_save) return wav_path # ------------------------------------------------------------------ # # 内部初始化 # ------------------------------------------------------------------ # def _load_all(self) -> None: self._load_tokenizer_vocab() self._load_voices() self._load_onnx_session() self._init_g2p() def _load_tokenizer_vocab(self) -> None: """ 从本地 tokenizer.json 载入 vocab 映射: token(str) -> id(int) """ if not self.tokenizer_path.exists(): raise FileNotFoundError(f"TTS: 未找到 tokenizer.json: {self.tokenizer_path}") with open(self.tokenizer_path, "r", encoding="utf-8") as f: data = json.load(f) model = data.get("model") or {} vocab = model.get("vocab") if not isinstance(vocab, dict): raise ValueError("TTS: tokenizer.json 格式不正确, 未找到 model.vocab 字段") # 保存为: 字符 -> id self._vocab = {k: int(v) for k, v in vocab.items()} logger.info(f"TTS: tokenizer vocab 加载完成, 词表大小: {len(self._vocab)}") def _load_voices(self) -> None: """ 载入语音风格向量(voices/*.bin) """ if not self.voice_path.exists(): raise FileNotFoundError(f"TTS: 未找到语音风格文件: {self.voice_path}") voices = np.fromfile(self.voice_path, dtype=np.float32) try: voices = voices.reshape(-1, 1, 256) except ValueError as e: raise ValueError( f"TTS: 语音风格文件形状不符合预期, 无法 reshape 为 (-1,1,256): {self.voice_path}" ) from e self._voices = voices logger.info( f"TTS: 语音风格文件加载完成: {self.voice_name}, 可用 style 数量: {voices.shape[0]}" ) def _load_onnx_session(self) -> None: """ 创建 ONNX Runtime 推理会话 """ if not self.onnx_path.exists(): raise FileNotFoundError(f"TTS: 未找到 ONNX 模型文件: {self.onnx_path}") sess_options = ort.SessionOptions() # 启用所有图优化 sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # RK3588 等多核 CPU:可用环境变量固定 ORT 线程,避免过小/过大(默认 0 表示交给 ORT 自动) _ti = os.environ.get("ROCKET_TTS_ORT_INTRA_OP_THREADS", "").strip() if _ti.isdigit() and int(_ti) > 0: sess_options.intra_op_num_threads = int(_ti) _te = os.environ.get("ROCKET_TTS_ORT_INTER_OP_THREADS", "").strip() if _te.isdigit() and int(_te) > 0: sess_options.inter_op_num_threads = int(_te) # 简单的 CPU 推理(如需 GPU, 可在此扩展 providers) self._session = ort.InferenceSession( str(self.onnx_path), sess_options=sess_options, providers=["CPUExecutionProvider"], ) logger.info(f"TTS: Kokoro ONNX 模型加载完成: {self.onnx_path}") def _init_g2p(self) -> None: """ 初始化中文 G2P (基于 misaki[zh])。 如果环境中未安装 misaki, 则退化为直接使用原始文本字符做映射。 """ try: from misaki import zh # type: ignore # 兼容不同版本的 misaki: # - 新版: ZHG2P(version=...) 可用 # - 旧版: ZHG2P() 不接受参数 try: self._g2p = zh.ZHG2P(version="1.1") # type: ignore[call-arg] except TypeError: self._g2p = zh.ZHG2P() # type: ignore[call-arg] logger.info("TTS: 已启用 misaki[zh] G2P, 将使用音素级别映射") except Exception as e: self._g2p = None logger.warning( "TTS: 未安装或无法导入 misaki[zh], 将直接基于原始文本字符做 token 映射, " "合成效果可能较差。建议执行: pip install \"misaki[zh]\" cn2an pypinyin jieba" ) logger.debug(f"TTS: G2P 初始化失败原因: {e!r}") # ------------------------------------------------------------------ # # 文本 -> 音素 / token # ------------------------------------------------------------------ # def _text_to_phonemes(self, text: str) -> str: """ 文本 -> 音素串 - 若 misaki[zh] 可用, 则使用 ZHG2P(version='1.1') 得到音素串 - 否则, 直接返回原始文本(后续按字符映射) """ if self._g2p is None: return text.strip() # 兼容不同版本的 misaki: # - 有的返回 (phonemes, tokens) # - 有的只返回 phonemes 字符串 result = self._g2p(text) if isinstance(result, tuple) or isinstance(result, list): ps = result[0] else: ps = result if not ps: # 回退: 如果 G2P 返回空, 使用原始文本 logger.warning("TTS: G2P 结果为空, 回退为原始文本") return text.strip() return ps def _phonemes_to_token_ids(self, phonemes: str) -> List[int]: """ 将音素串映射为 token id 序列 直接按字符级别查表: - 每个字符在 vocab 中有唯一 id - 空格本身也是一个 token (id=16) """ assert self._vocab is not None, "TTS: vocab 尚未初始化" vocab = self._vocab token_ids: List[int] = [] unknown_chars = set() for ch in phonemes: if ch == "\n": continue tid = vocab.get(ch) if tid is None: unknown_chars.add(ch) continue token_ids.append(int(tid)) if unknown_chars: logger.debug(f"TTS: 存在无法映射到 vocab 的字符: {unknown_chars}") return token_ids def _resolve_output_device_id(raw: object) -> Optional[int]: """ 将配置中的 output_device 解析为 sounddevice 设备索引。 None / 空:返回 None,表示使用 sd 默认输出。 """ import sounddevice as sd # type: ignore if raw is None: return None if isinstance(raw, bool): return None if isinstance(raw, int): return raw if raw >= 0 else None s = str(raw).strip() if not s or s.lower() in ("null", "none", "default", ""): return None if s.isdigit(): return int(s) needle = s.lower() devices = sd.query_devices() matches: List[int] = [] for i, d in enumerate(devices): if int(d.get("max_output_channels", 0) or 0) <= 0: continue name = (d.get("name") or "").lower() if needle in name: matches.append(i) if not matches: logger.warning( f"TTS: 未找到名称包含 {s!r} 的输出设备,将使用系统默认输出。" "请检查 system.yaml 中 tts.output_device 或查看启动日志中的设备列表。" ) return None if len(matches) > 1: logger.info( f"TTS: 名称 {s!r} 匹配到多个输出设备索引 {matches},使用第一个 {matches[0]}" ) return matches[0] _playback_dev_cache: Optional[int] = None _playback_dev_cache_key: Optional[str] = None _playback_dev_cache_ready: bool = False def get_playback_output_device_id() -> Optional[int]: """从 SYSTEM_TTS_CONFIG 解析并缓存播放设备索引(None=默认输出)。""" global _playback_dev_cache, _playback_dev_cache_key, _playback_dev_cache_ready cfg = SYSTEM_TTS_CONFIG or {} raw = cfg.get("output_device") key = repr(raw) if _playback_dev_cache_ready and _playback_dev_cache_key == key: return _playback_dev_cache dev_id = _resolve_output_device_id(raw) _playback_dev_cache = dev_id _playback_dev_cache_key = key _playback_dev_cache_ready = True if dev_id is not None: import sounddevice as sd # type: ignore info = sd.query_devices(dev_id) logger.info( f"TTS: 播放将使用输出设备 index={dev_id} name={info.get('name', '?')!r}" ) else: logger.info("TTS: 播放使用系统默认输出设备(未指定或无法匹配 tts.output_device)") return dev_id def _sounddevice_default_output_index(): """sounddevice 0.5+ 的 default.device 可能是 _InputOutputPair,需取 [1] 为输出索引。""" import sounddevice as sd # type: ignore default = sd.default.device if isinstance(default, (list, tuple)): return int(default[1]) if hasattr(default, "__getitem__"): try: return int(default[1]) except (IndexError, TypeError, ValueError): pass try: return int(default) except (TypeError, ValueError): return None def log_sounddevice_output_devices() -> None: """列出所有可用输出设备及当前默认输出,便于配置 tts.output_device。""" try: import sounddevice as sd # type: ignore out_idx = _sounddevice_default_output_index() logger.info("sounddevice 输出设备列表(用于配置 tts.output_device 索引或名称子串):") for i, d in enumerate(sd.query_devices()): if int(d.get("max_output_channels", 0) or 0) <= 0: continue mark = " <- 当前默认输出" if out_idx is not None and int(out_idx) == i else "" logger.info(f" [{i}] {d.get('name', '?')}{mark}") except Exception as e: logger.warning(f"无法枚举 sounddevice 输出设备: {e}") def _resample_playback_audio(audio: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray: """将波形重采样到设备采样率;优先 librosa(kaiser_best),失败则回退 scipy 多相。""" from math import gcd from scipy.signal import resample # type: ignore from scipy.signal import resample_poly # type: ignore if abs(sr_in - sr_out) < 1: return np.asarray(audio, dtype=np.float32) try: import librosa # type: ignore return librosa.resample( np.asarray(audio, dtype=np.float32), orig_sr=sr_in, target_sr=sr_out, res_type="kaiser_best", ).astype(np.float32) except Exception as e: logger.debug(f"TTS: librosa 重采样不可用,使用多相重采样: {e!r}") try: g = gcd(int(sr_in), int(sr_out)) if g > 0: up = int(sr_out) // g down = int(sr_in) // g return resample_poly(audio, up, down).astype(np.float32) except Exception as e2: logger.debug(f"TTS: resample_poly 失败,回退 FFT resample: {e2!r}") num = max(1, int(len(audio) * float(sr_out) / float(sr_in))) return resample(audio, num).astype(np.float32) def _fade_playback_edges(audio: np.ndarray, sample_rate: int, fade_ms: float) -> np.ndarray: """极短线性淡入淡出,减轻扬声器/驱动在段首段尾的爆音与杂音感。""" if fade_ms <= 0 or audio.size < 16: return audio n = int(float(sample_rate) * fade_ms / 1000.0) n = min(n, len(audio) // 4) if n <= 0: return audio out = np.asarray(audio, dtype=np.float32, order="C").copy() fade = np.linspace(0.0, 1.0, n, dtype=np.float32) out[:n] *= fade out[-n:] *= fade[::-1] return out def play_tts_audio( audio: np.ndarray, sample_rate: int, *, output_device: Optional[object] = None, ) -> None: """ 使用 sounddevice 播放单声道 float32 音频(阻塞至播放结束)。 在 Windows 上 PortAudio/sounddevice 从非主线程调用时经常出现「无声音、无报错」, 因此本项目中应答播报应在主线程(采集循环所在线程)调用本函数。 另:多数 Realtek/WASAPI 设备对 24000Hz 播放会「完全无声」且不报错,需重采样到设备 default_samplerate(常见 48000/44100),并用 OutputStream 写出。 Args: output_device: 若指定,覆盖 system.yaml 的 tts.output_device(设备索引或名称子串)。 """ import sounddevice as sd # type: ignore cfg = SYSTEM_TTS_CONFIG or {} force_native = bool(cfg.get("playback_resample_to_device_native", True)) do_normalize = bool(cfg.get("playback_peak_normalize", True)) gain = float(cfg.get("playback_gain", 1.0)) if gain <= 0: gain = 1.0 fade_ms = float(cfg.get("playback_edge_fade_ms", 8.0)) latency = cfg.get("playback_output_latency", "low") if latency not in ("low", "medium", "high"): latency = "low" audio = np.asarray(audio, dtype=np.float32).squeeze() if audio.ndim > 1: audio = np.squeeze(audio) if audio.size == 0: logger.warning("TTS: 播放跳过,音频长度为 0") return if output_device is not None: dev = _resolve_output_device_id(output_device) else: dev = get_playback_output_device_id() if dev is None: dev = _sounddevice_default_output_index() if dev is None: logger.warning("TTS: 无法解析输出设备索引,使用 sounddevice 默认输出") else: dev = int(dev) info = sd.query_devices(dev) if dev is not None else sd.query_devices(kind="output") native_sr = int(float(info.get("default_samplerate", 48000))) sr_out = int(sample_rate) if force_native and native_sr > 0 and abs(native_sr - sr_out) > 1: audio = _resample_playback_audio(audio, sr_out, native_sr) sr_out = native_sr logger.info( f"TTS: 播放重采样 {sample_rate}Hz -> {sr_out}Hz(匹配设备 default_samplerate,避免 Windows 无声)" ) peak_before = float(np.max(np.abs(audio))) if do_normalize and peak_before > 1e-8 and peak_before > 0.95: audio = (audio / peak_before * 0.92).astype(np.float32, copy=False) if gain != 1.0: audio = (audio * np.float32(gain)).astype(np.float32, copy=False) audio = _fade_playback_edges(audio, sr_out, fade_ms) peak = float(np.max(np.abs(audio))) rms = float(np.sqrt(np.mean(np.square(audio)))) dname = info.get("name", "?") if isinstance(info, dict) else "?" logger.info( f"TTS: 播放 峰值={peak:.5f} RMS={rms:.5f} sr={sr_out}Hz 设备={dev!r} ({dname!r})" ) if peak < 1e-8: logger.warning("TTS: 波形接近静音,请检查合成是否异常") audio = np.clip(audio, -1.0, 1.0).astype(np.float32, copy=False) block = audio.reshape(-1, 1) try: with sd.OutputStream( device=dev, channels=1, samplerate=sr_out, dtype="float32", latency=latency, ) as stream: stream.write(block) except Exception as e: logger.warning(f"TTS: OutputStream 失败,回退 sd.play: {e}", exc_info=True) sd.play(block, samplerate=sr_out, device=dev) sd.wait() def play_wav_path( path: str | Path, *, output_device: Optional[object] = None, ) -> None: """ 播放 16-bit PCM WAV(单声道或立体声下混为单声道),走与 synthesize + play_tts_audio 相同的 sounddevice 输出路径(含 ROCKET_TTS_DEVICE / yaml 设备解析)。 """ import wave p = Path(path) with wave.open(str(p), "rb") as wf: ch = int(wf.getnchannels()) sw = int(wf.getsampwidth()) sr = int(wf.getframerate()) nframes = int(wf.getnframes()) if sw != 2: raise ValueError(f"仅支持 16-bit PCM: {p}") raw = wf.readframes(nframes) mono = np.frombuffer(raw, dtype=" None: """ 合成并播放一段语音;失败时仅打日志,不向外抛异常(适合命令成功后的反馈)。 """ if not text or not str(text).strip(): return try: engine = tts or KokoroOnnxTTS() t = str(text).strip() logger.info(f"TTS: 开始合成并播放: {t!r}") audio, sr = engine.synthesize(t) play_tts_audio(audio, sr, output_device=output_device) logger.info("TTS: 播放完成") except Exception as e: logger.warning(f"TTS 播放失败: {e}", exc_info=True) __all__ = [ "KokoroOnnxTTS", "play_tts_audio", "play_wav_path", "speak_text", "get_playback_output_device_id", "log_sounddevice_output_devices", ] if __name__ == "__main__": # 与主程序一致:使用 play_tts_audio(含重采样到设备 native 采样率) tts = KokoroOnnxTTS() text = "任务执行完成,开始返航降落" print(f"正在合成语音: {text}") audio, sr = tts.synthesize(text) print("正在播放(与主程序相同 play_tts_audio 路径)...") try: play_tts_audio(audio, sr) print("播放结束。") except Exception as e: print(f"播放失败: {e}") # === 保存为 WAV 文件(可选)=== try: output_path = "任务执行完成,开始返航降落.wav" tts.synthesize_to_file(text, output_path) print(f"音频已保存至: {output_path}") except RuntimeError as e: print(f"保存失败(可能缺少 scipy): {e}")