696 lines
25 KiB
Python
696 lines
25 KiB
Python
"""
|
||
TTS(Text-to-Speech)模块 - 基于 Kokoro ONNX 的中文实时合成
|
||
|
||
使用 Kokoro-82M-v1.1-zh-ONNX 模型进行文本转语音合成:
|
||
1. 文本 -> (可选)使用 misaki[zh] 做 G2P,得到音素串
|
||
2. 音素字符 -> 根据 tokenizer vocab 映射为 token id 序列
|
||
3. 通过 ONNX Runtime 推理生成 24kHz 单声道语音
|
||
|
||
说明:
|
||
- 主要依赖: onnxruntime + numpy
|
||
- 如果已安装 misaki[zh] (推荐),效果更好:
|
||
pip install "misaki[zh]" cn2an pypinyin jieba
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
from typing import List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
|
||
# 仅保留 ERROR,避免加载 Kokoro 时大量 ConstantFolding/Reciprocal 警告刷屏(不影响推理结果)
|
||
try:
|
||
ort.set_default_logger_severity(3)
|
||
except Exception:
|
||
pass
|
||
|
||
from voice_drone.core.configuration import SYSTEM_TTS_CONFIG
|
||
from voice_drone.logging_ import get_logger
|
||
|
||
# voice_drone/core/tts.py -> voice_drone_assistant 根目录
|
||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||
|
||
logger = get_logger("tts.kokoro_onnx")
|
||
|
||
|
||
def _tts_model_dir_candidates(rel: Path) -> List[Path]:
|
||
if rel.is_absolute():
|
||
return [rel]
|
||
out: List[Path] = [_PROJECT_ROOT / rel]
|
||
if rel.parts and rel.parts[0] == "models":
|
||
out.append(_PROJECT_ROOT.parent / "src" / rel)
|
||
return out
|
||
|
||
|
||
def _resolve_kokoro_model_dir(raw: str | Path) -> Path:
|
||
"""含 tokenizer.json 的目录;支持子工程 models/ 缺失时回退到上级仓库 src/models/。"""
|
||
p = Path(raw)
|
||
for c in _tts_model_dir_candidates(p):
|
||
if (c / "tokenizer.json").is_file():
|
||
return c.resolve()
|
||
for c in _tts_model_dir_candidates(p):
|
||
if c.is_dir():
|
||
logger.warning(
|
||
"Kokoro 目录存在但未找到 tokenizer.json: %s(将仍使用该路径,后续可能报错)",
|
||
c,
|
||
)
|
||
return c.resolve()
|
||
return (_PROJECT_ROOT / p).resolve()
|
||
|
||
|
||
class KokoroOnnxTTS:
|
||
"""
|
||
Kokoro 中文 ONNX 文本转语音封装
|
||
|
||
基本用法:
|
||
tts = KokoroOnnxTTS()
|
||
audio, sr = tts.synthesize("你好,世界")
|
||
|
||
返回:
|
||
audio: np.ndarray[float32] 形状为 (N,), 范围约 [-1, 1]
|
||
sr: int 采样率(默认 24000)
|
||
"""
|
||
|
||
def __init__(self, config: Optional[dict] = None) -> None:
|
||
# 读取系统 TTS 配置
|
||
self.config = config or SYSTEM_TTS_CONFIG or {}
|
||
|
||
# 模型根目录(包含 onnx/、tokenizer.json、voices/)
|
||
_raw_dir = self.config.get(
|
||
"model_dir", "models/Kokoro-82M-v1.1-zh-ONNX"
|
||
)
|
||
model_dir = _resolve_kokoro_model_dir(_raw_dir)
|
||
self.model_dir = model_dir
|
||
|
||
# ONNX 模型文件名(位于 model_dir/onnx 下;若 onnx/ 下没有可改配置为根目录文件名)
|
||
self.model_name = self.config.get("model_name", "model_q4.onnx")
|
||
self.onnx_path = model_dir / "onnx" / self.model_name
|
||
if not self.onnx_path.is_file():
|
||
alt = model_dir / self.model_name
|
||
if alt.is_file():
|
||
self.onnx_path = alt
|
||
|
||
# 语音风格(voices 子目录下的 *.bin, 这里不含扩展名)
|
||
self.voice_name = self.config.get("voice", "zf_001")
|
||
self.voice_path = model_dir / "voices" / f"{self.voice_name}.bin"
|
||
|
||
# 语速与输出采样率
|
||
self.speed = float(self.config.get("speed", 1.0))
|
||
self.sample_rate = int(self.config.get("sample_rate", 24000))
|
||
|
||
# tokenizer.json 路径(本地随 ONNX 模型一起提供)
|
||
self.tokenizer_path = model_dir / "tokenizer.json"
|
||
|
||
# 初始化组件
|
||
self._session: Optional[ort.InferenceSession] = None
|
||
self._vocab: Optional[dict] = None
|
||
self._voices: Optional[np.ndarray] = None
|
||
self._g2p = None # misaki[zh] G2P, 如不可用则退化为直接使用原始文本
|
||
|
||
self._load_all()
|
||
|
||
# ------------------------------------------------------------------ #
|
||
# 对外主接口
|
||
# ------------------------------------------------------------------ #
|
||
def synthesize(self, text: str) -> Tuple[np.ndarray, int]:
|
||
"""
|
||
文本转语音
|
||
|
||
Args:
|
||
text: 输入文本(推荐为简体中文)
|
||
|
||
Returns:
|
||
(audio, sample_rate)
|
||
"""
|
||
if not text or not text.strip():
|
||
raise ValueError("TTS 输入文本不能为空")
|
||
|
||
phonemes = self._text_to_phonemes(text)
|
||
token_ids = self._phonemes_to_token_ids(phonemes)
|
||
|
||
if len(token_ids) == 0:
|
||
raise ValueError(f"TTS: 文本在当前 vocab 下无法映射到任何 token, text={text!r}")
|
||
|
||
# 按 Kokoro-ONNX 官方示例约定:
|
||
# - token 序列长度 <= 510
|
||
# - 前后各添加 pad token 0
|
||
if len(token_ids) > 510:
|
||
logger.warning(f"TTS: token 长度 {len(token_ids)} > 510, 将被截断为 510")
|
||
token_ids = token_ids[:510]
|
||
|
||
tokens = np.array([[0, *token_ids, 0]], dtype=np.int64) # shape: (1, <=512)
|
||
|
||
# 根据 token 数量选择 style 向量
|
||
assert self._voices is not None, "TTS: voices 尚未初始化"
|
||
voices = self._voices # shape: (N, 1, 256)
|
||
idx = min(len(token_ids), voices.shape[0] - 1)
|
||
style = voices[idx] # shape: (1, 256)
|
||
|
||
speed = np.array([self.speed], dtype=np.float32)
|
||
|
||
# ONNX 输入名约定: input_ids, style, speed
|
||
assert self._session is not None, "TTS: ONNX Session 尚未初始化"
|
||
session = self._session
|
||
audio = session.run(
|
||
None,
|
||
{
|
||
"input_ids": tokens,
|
||
"style": style,
|
||
"speed": speed,
|
||
},
|
||
)[0]
|
||
|
||
# 兼容不同导出形状:
|
||
# - 标准 Kokoro ONNX: (1, N)
|
||
# - 也有可能是 (N,)
|
||
audio = audio.astype(np.float32)
|
||
if audio.ndim == 2 and audio.shape[0] == 1:
|
||
audio = audio[0]
|
||
elif audio.ndim > 2:
|
||
# 极端情况: 压缩多余维度
|
||
audio = np.squeeze(audio)
|
||
|
||
return audio, self.sample_rate
|
||
|
||
def synthesize_to_file(self, text: str, wav_path: str) -> str:
|
||
"""
|
||
文本合成并保存为 wav 文件(16-bit PCM)
|
||
需要依赖 scipy, 可选:
|
||
pip install scipy
|
||
"""
|
||
try:
|
||
import scipy.io.wavfile as wavfile # type: ignore
|
||
except Exception as e: # pragma: no cover - 仅在未安装时触发
|
||
raise RuntimeError("保存到 wav 需要安装 scipy, 请先执行: pip install scipy") from e
|
||
|
||
audio, sr = self.synthesize(text)
|
||
# 简单归一化并转为 int16
|
||
max_val = float(np.max(np.abs(audio)) or 1.0)
|
||
audio_int16 = np.clip(audio / max_val, -1.0, 1.0)
|
||
audio_int16 = (audio_int16 * 32767.0).astype(np.int16)
|
||
|
||
# 某些 SciPy 版本对一维/零维数组支持不统一,这里显式加上通道维度
|
||
if audio_int16.ndim == 0:
|
||
audio_to_save = audio_int16.reshape(-1, 1) # 标量 -> (1,1)
|
||
elif audio_int16.ndim == 1:
|
||
audio_to_save = audio_int16.reshape(-1, 1) # (N,) -> (N,1) 单声道
|
||
else:
|
||
audio_to_save = audio_int16
|
||
|
||
wavfile.write(wav_path, sr, audio_to_save)
|
||
return wav_path
|
||
|
||
# ------------------------------------------------------------------ #
|
||
# 内部初始化
|
||
# ------------------------------------------------------------------ #
|
||
def _load_all(self) -> None:
|
||
self._load_tokenizer_vocab()
|
||
self._load_voices()
|
||
self._load_onnx_session()
|
||
self._init_g2p()
|
||
|
||
def _load_tokenizer_vocab(self) -> None:
|
||
"""
|
||
从本地 tokenizer.json 载入 vocab 映射:
|
||
token(str) -> id(int)
|
||
"""
|
||
if not self.tokenizer_path.exists():
|
||
raise FileNotFoundError(f"TTS: 未找到 tokenizer.json: {self.tokenizer_path}")
|
||
|
||
with open(self.tokenizer_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
model = data.get("model") or {}
|
||
vocab = model.get("vocab")
|
||
if not isinstance(vocab, dict):
|
||
raise ValueError("TTS: tokenizer.json 格式不正确, 未找到 model.vocab 字段")
|
||
|
||
# 保存为: 字符 -> id
|
||
self._vocab = {k: int(v) for k, v in vocab.items()}
|
||
logger.info(f"TTS: tokenizer vocab 加载完成, 词表大小: {len(self._vocab)}")
|
||
|
||
def _load_voices(self) -> None:
|
||
"""
|
||
载入语音风格向量(voices/*.bin)
|
||
"""
|
||
if not self.voice_path.exists():
|
||
raise FileNotFoundError(f"TTS: 未找到语音风格文件: {self.voice_path}")
|
||
|
||
voices = np.fromfile(self.voice_path, dtype=np.float32)
|
||
try:
|
||
voices = voices.reshape(-1, 1, 256)
|
||
except ValueError as e:
|
||
raise ValueError(
|
||
f"TTS: 语音风格文件形状不符合预期, 无法 reshape 为 (-1,1,256): {self.voice_path}"
|
||
) from e
|
||
|
||
self._voices = voices
|
||
logger.info(
|
||
f"TTS: 语音风格文件加载完成: {self.voice_name}, 可用 style 数量: {voices.shape[0]}"
|
||
)
|
||
|
||
def _load_onnx_session(self) -> None:
|
||
"""
|
||
创建 ONNX Runtime 推理会话
|
||
"""
|
||
if not self.onnx_path.exists():
|
||
raise FileNotFoundError(f"TTS: 未找到 ONNX 模型文件: {self.onnx_path}")
|
||
|
||
sess_options = ort.SessionOptions()
|
||
# 启用所有图优化
|
||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
# RK3588 等多核 CPU:可用环境变量固定 ORT 线程,避免过小/过大(默认 0 表示交给 ORT 自动)
|
||
_ti = os.environ.get("ROCKET_TTS_ORT_INTRA_OP_THREADS", "").strip()
|
||
if _ti.isdigit() and int(_ti) > 0:
|
||
sess_options.intra_op_num_threads = int(_ti)
|
||
_te = os.environ.get("ROCKET_TTS_ORT_INTER_OP_THREADS", "").strip()
|
||
if _te.isdigit() and int(_te) > 0:
|
||
sess_options.inter_op_num_threads = int(_te)
|
||
|
||
# 简单的 CPU 推理(如需 GPU, 可在此扩展 providers)
|
||
self._session = ort.InferenceSession(
|
||
str(self.onnx_path),
|
||
sess_options=sess_options,
|
||
providers=["CPUExecutionProvider"],
|
||
)
|
||
|
||
logger.info(f"TTS: Kokoro ONNX 模型加载完成: {self.onnx_path}")
|
||
|
||
def _init_g2p(self) -> None:
|
||
"""
|
||
初始化中文 G2P (基于 misaki[zh])。
|
||
如果环境中未安装 misaki, 则退化为直接使用原始文本字符做映射。
|
||
"""
|
||
try:
|
||
from misaki import zh # type: ignore
|
||
|
||
# 兼容不同版本的 misaki:
|
||
# - 新版: ZHG2P(version=...) 可用
|
||
# - 旧版: ZHG2P() 不接受参数
|
||
try:
|
||
self._g2p = zh.ZHG2P(version="1.1") # type: ignore[call-arg]
|
||
except TypeError:
|
||
self._g2p = zh.ZHG2P() # type: ignore[call-arg]
|
||
|
||
logger.info("TTS: 已启用 misaki[zh] G2P, 将使用音素级别映射")
|
||
except Exception as e:
|
||
self._g2p = None
|
||
logger.warning(
|
||
"TTS: 未安装或无法导入 misaki[zh], 将直接基于原始文本字符做 token 映射, "
|
||
"合成效果可能较差。建议执行: pip install \"misaki[zh]\" cn2an pypinyin jieba"
|
||
)
|
||
logger.debug(f"TTS: G2P 初始化失败原因: {e!r}")
|
||
|
||
# ------------------------------------------------------------------ #
|
||
# 文本 -> 音素 / token
|
||
# ------------------------------------------------------------------ #
|
||
def _text_to_phonemes(self, text: str) -> str:
|
||
"""
|
||
文本 -> 音素串
|
||
|
||
- 若 misaki[zh] 可用, 则使用 ZHG2P(version='1.1') 得到音素串
|
||
- 否则, 直接返回原始文本(后续按字符映射)
|
||
"""
|
||
if self._g2p is None:
|
||
return text.strip()
|
||
|
||
# 兼容不同版本的 misaki:
|
||
# - 有的返回 (phonemes, tokens)
|
||
# - 有的只返回 phonemes 字符串
|
||
result = self._g2p(text)
|
||
if isinstance(result, tuple) or isinstance(result, list):
|
||
ps = result[0]
|
||
else:
|
||
ps = result
|
||
|
||
if not ps:
|
||
# 回退: 如果 G2P 返回空, 使用原始文本
|
||
logger.warning("TTS: G2P 结果为空, 回退为原始文本")
|
||
return text.strip()
|
||
return ps
|
||
|
||
def _phonemes_to_token_ids(self, phonemes: str) -> List[int]:
|
||
"""
|
||
将音素串映射为 token id 序列
|
||
|
||
直接按字符级别查表:
|
||
- 每个字符在 vocab 中有唯一 id
|
||
- 空格本身也是一个 token (id=16)
|
||
"""
|
||
assert self._vocab is not None, "TTS: vocab 尚未初始化"
|
||
vocab = self._vocab
|
||
|
||
token_ids: List[int] = []
|
||
unknown_chars = set()
|
||
|
||
for ch in phonemes:
|
||
if ch == "\n":
|
||
continue
|
||
tid = vocab.get(ch)
|
||
if tid is None:
|
||
unknown_chars.add(ch)
|
||
continue
|
||
token_ids.append(int(tid))
|
||
|
||
if unknown_chars:
|
||
logger.debug(f"TTS: 存在无法映射到 vocab 的字符: {unknown_chars}")
|
||
|
||
return token_ids
|
||
|
||
|
||
def _resolve_output_device_id(raw: object) -> Optional[int]:
|
||
"""
|
||
将配置中的 output_device 解析为 sounddevice 设备索引。
|
||
None / 空:返回 None,表示使用 sd 默认输出。
|
||
"""
|
||
import sounddevice as sd # type: ignore
|
||
|
||
if raw is None:
|
||
return None
|
||
if isinstance(raw, bool):
|
||
return None
|
||
if isinstance(raw, int):
|
||
return raw if raw >= 0 else None
|
||
s = str(raw).strip()
|
||
if not s or s.lower() in ("null", "none", "default", ""):
|
||
return None
|
||
if s.isdigit():
|
||
return int(s)
|
||
needle = s.lower()
|
||
devices = sd.query_devices()
|
||
matches: List[int] = []
|
||
for i, d in enumerate(devices):
|
||
if int(d.get("max_output_channels", 0) or 0) <= 0:
|
||
continue
|
||
name = (d.get("name") or "").lower()
|
||
if needle in name:
|
||
matches.append(i)
|
||
if not matches:
|
||
logger.warning(
|
||
f"TTS: 未找到名称包含 {s!r} 的输出设备,将使用系统默认输出。"
|
||
"请检查 system.yaml 中 tts.output_device 或查看启动日志中的设备列表。"
|
||
)
|
||
return None
|
||
if len(matches) > 1:
|
||
logger.info(
|
||
f"TTS: 名称 {s!r} 匹配到多个输出设备索引 {matches},使用第一个 {matches[0]}"
|
||
)
|
||
return matches[0]
|
||
|
||
|
||
_playback_dev_cache: Optional[int] = None
|
||
_playback_dev_cache_key: Optional[str] = None
|
||
_playback_dev_cache_ready: bool = False
|
||
|
||
|
||
def get_playback_output_device_id() -> Optional[int]:
|
||
"""从 SYSTEM_TTS_CONFIG 解析并缓存播放设备索引(None=默认输出)。"""
|
||
global _playback_dev_cache, _playback_dev_cache_key, _playback_dev_cache_ready
|
||
|
||
cfg = SYSTEM_TTS_CONFIG or {}
|
||
raw = cfg.get("output_device")
|
||
key = repr(raw)
|
||
if _playback_dev_cache_ready and _playback_dev_cache_key == key:
|
||
return _playback_dev_cache
|
||
dev_id = _resolve_output_device_id(raw)
|
||
_playback_dev_cache = dev_id
|
||
_playback_dev_cache_key = key
|
||
_playback_dev_cache_ready = True
|
||
if dev_id is not None:
|
||
import sounddevice as sd # type: ignore
|
||
|
||
info = sd.query_devices(dev_id)
|
||
logger.info(
|
||
f"TTS: 播放将使用输出设备 index={dev_id} name={info.get('name', '?')!r}"
|
||
)
|
||
else:
|
||
logger.info("TTS: 播放使用系统默认输出设备(未指定或无法匹配 tts.output_device)")
|
||
return dev_id
|
||
|
||
|
||
def _sounddevice_default_output_index():
|
||
"""sounddevice 0.5+ 的 default.device 可能是 _InputOutputPair,需取 [1] 为输出索引。"""
|
||
import sounddevice as sd # type: ignore
|
||
|
||
default = sd.default.device
|
||
if isinstance(default, (list, tuple)):
|
||
return int(default[1])
|
||
if hasattr(default, "__getitem__"):
|
||
try:
|
||
return int(default[1])
|
||
except (IndexError, TypeError, ValueError):
|
||
pass
|
||
try:
|
||
return int(default)
|
||
except (TypeError, ValueError):
|
||
return None
|
||
|
||
|
||
def log_sounddevice_output_devices() -> None:
|
||
"""列出所有可用输出设备及当前默认输出,便于配置 tts.output_device。"""
|
||
try:
|
||
import sounddevice as sd # type: ignore
|
||
|
||
out_idx = _sounddevice_default_output_index()
|
||
logger.info("sounddevice 输出设备列表(用于配置 tts.output_device 索引或名称子串):")
|
||
for i, d in enumerate(sd.query_devices()):
|
||
if int(d.get("max_output_channels", 0) or 0) <= 0:
|
||
continue
|
||
mark = " <- 当前默认输出" if out_idx is not None and int(out_idx) == i else ""
|
||
logger.info(f" [{i}] {d.get('name', '?')}{mark}")
|
||
except Exception as e:
|
||
logger.warning(f"无法枚举 sounddevice 输出设备: {e}")
|
||
|
||
|
||
def _resample_playback_audio(audio: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
|
||
"""将波形重采样到设备采样率;优先 librosa(kaiser_best),失败则回退 scipy 多相。"""
|
||
from math import gcd
|
||
|
||
from scipy.signal import resample # type: ignore
|
||
from scipy.signal import resample_poly # type: ignore
|
||
|
||
if abs(sr_in - sr_out) < 1:
|
||
return np.asarray(audio, dtype=np.float32)
|
||
try:
|
||
import librosa # type: ignore
|
||
|
||
return librosa.resample(
|
||
np.asarray(audio, dtype=np.float32),
|
||
orig_sr=sr_in,
|
||
target_sr=sr_out,
|
||
res_type="kaiser_best",
|
||
).astype(np.float32)
|
||
except Exception as e:
|
||
logger.debug(f"TTS: librosa 重采样不可用,使用多相重采样: {e!r}")
|
||
try:
|
||
g = gcd(int(sr_in), int(sr_out))
|
||
if g > 0:
|
||
up = int(sr_out) // g
|
||
down = int(sr_in) // g
|
||
return resample_poly(audio, up, down).astype(np.float32)
|
||
except Exception as e2:
|
||
logger.debug(f"TTS: resample_poly 失败,回退 FFT resample: {e2!r}")
|
||
num = max(1, int(len(audio) * float(sr_out) / float(sr_in)))
|
||
return resample(audio, num).astype(np.float32)
|
||
|
||
|
||
def _fade_playback_edges(audio: np.ndarray, sample_rate: int, fade_ms: float) -> np.ndarray:
|
||
"""极短线性淡入淡出,减轻扬声器/驱动在段首段尾的爆音与杂音感。"""
|
||
if fade_ms <= 0 or audio.size < 16:
|
||
return audio
|
||
n = int(float(sample_rate) * fade_ms / 1000.0)
|
||
n = min(n, len(audio) // 4)
|
||
if n <= 0:
|
||
return audio
|
||
out = np.asarray(audio, dtype=np.float32, order="C").copy()
|
||
fade = np.linspace(0.0, 1.0, n, dtype=np.float32)
|
||
out[:n] *= fade
|
||
out[-n:] *= fade[::-1]
|
||
return out
|
||
|
||
|
||
def play_tts_audio(
|
||
audio: np.ndarray,
|
||
sample_rate: int,
|
||
*,
|
||
output_device: Optional[object] = None,
|
||
) -> None:
|
||
"""
|
||
使用 sounddevice 播放单声道 float32 音频(阻塞至播放结束)。
|
||
|
||
在 Windows 上 PortAudio/sounddevice 从非主线程调用时经常出现「无声音、无报错」,
|
||
因此本项目中应答播报应在主线程(采集循环所在线程)调用本函数。
|
||
|
||
另:多数 Realtek/WASAPI 设备对 24000Hz 播放会「完全无声」且不报错,需重采样到设备
|
||
default_samplerate(常见 48000/44100),并用 OutputStream 写出。
|
||
|
||
Args:
|
||
output_device: 若指定,覆盖 system.yaml 的 tts.output_device(设备索引或名称子串)。
|
||
"""
|
||
import sounddevice as sd # type: ignore
|
||
|
||
cfg = SYSTEM_TTS_CONFIG or {}
|
||
force_native = bool(cfg.get("playback_resample_to_device_native", True))
|
||
do_normalize = bool(cfg.get("playback_peak_normalize", True))
|
||
gain = float(cfg.get("playback_gain", 1.0))
|
||
if gain <= 0:
|
||
gain = 1.0
|
||
fade_ms = float(cfg.get("playback_edge_fade_ms", 8.0))
|
||
latency = cfg.get("playback_output_latency", "low")
|
||
if latency not in ("low", "medium", "high"):
|
||
latency = "low"
|
||
|
||
audio = np.asarray(audio, dtype=np.float32).squeeze()
|
||
if audio.ndim > 1:
|
||
audio = np.squeeze(audio)
|
||
if audio.size == 0:
|
||
logger.warning("TTS: 播放跳过,音频长度为 0")
|
||
return
|
||
|
||
if output_device is not None:
|
||
dev = _resolve_output_device_id(output_device)
|
||
else:
|
||
dev = get_playback_output_device_id()
|
||
if dev is None:
|
||
dev = _sounddevice_default_output_index()
|
||
if dev is None:
|
||
logger.warning("TTS: 无法解析输出设备索引,使用 sounddevice 默认输出")
|
||
else:
|
||
dev = int(dev)
|
||
|
||
info = sd.query_devices(dev) if dev is not None else sd.query_devices(kind="output")
|
||
native_sr = int(float(info.get("default_samplerate", 48000)))
|
||
sr_out = int(sample_rate)
|
||
if force_native and native_sr > 0 and abs(native_sr - sr_out) > 1:
|
||
audio = _resample_playback_audio(audio, sr_out, native_sr)
|
||
sr_out = native_sr
|
||
logger.info(
|
||
f"TTS: 播放重采样 {sample_rate}Hz -> {sr_out}Hz(匹配设备 default_samplerate,避免 Windows 无声)"
|
||
)
|
||
|
||
peak_before = float(np.max(np.abs(audio)))
|
||
if do_normalize and peak_before > 1e-8 and peak_before > 0.95:
|
||
audio = (audio / peak_before * 0.92).astype(np.float32, copy=False)
|
||
|
||
if gain != 1.0:
|
||
audio = (audio * np.float32(gain)).astype(np.float32, copy=False)
|
||
|
||
audio = _fade_playback_edges(audio, sr_out, fade_ms)
|
||
|
||
peak = float(np.max(np.abs(audio)))
|
||
rms = float(np.sqrt(np.mean(np.square(audio))))
|
||
dname = info.get("name", "?") if isinstance(info, dict) else "?"
|
||
logger.info(
|
||
f"TTS: 播放 峰值={peak:.5f} RMS={rms:.5f} sr={sr_out}Hz 设备={dev!r} ({dname!r})"
|
||
)
|
||
if peak < 1e-8:
|
||
logger.warning("TTS: 波形接近静音,请检查合成是否异常")
|
||
|
||
audio = np.clip(audio, -1.0, 1.0).astype(np.float32, copy=False)
|
||
block = audio.reshape(-1, 1)
|
||
|
||
try:
|
||
with sd.OutputStream(
|
||
device=dev,
|
||
channels=1,
|
||
samplerate=sr_out,
|
||
dtype="float32",
|
||
latency=latency,
|
||
) as stream:
|
||
stream.write(block)
|
||
except Exception as e:
|
||
logger.warning(f"TTS: OutputStream 失败,回退 sd.play: {e}", exc_info=True)
|
||
sd.play(block, samplerate=sr_out, device=dev)
|
||
sd.wait()
|
||
|
||
|
||
def play_wav_path(
|
||
path: str | Path,
|
||
*,
|
||
output_device: Optional[object] = None,
|
||
) -> None:
|
||
"""
|
||
播放 16-bit PCM WAV(单声道或立体声下混为单声道),走与 synthesize + play_tts_audio
|
||
相同的 sounddevice 输出路径(含 ROCKET_TTS_DEVICE / yaml 设备解析)。
|
||
"""
|
||
import wave
|
||
|
||
p = Path(path)
|
||
with wave.open(str(p), "rb") as wf:
|
||
ch = int(wf.getnchannels())
|
||
sw = int(wf.getsampwidth())
|
||
sr = int(wf.getframerate())
|
||
nframes = int(wf.getnframes())
|
||
if sw != 2:
|
||
raise ValueError(f"仅支持 16-bit PCM: {p}")
|
||
raw = wf.readframes(nframes)
|
||
mono = np.frombuffer(raw, dtype="<i2").astype(np.float32, copy=False)
|
||
if ch == 2:
|
||
mono = mono.reshape(-1, 2).mean(axis=1).astype(np.float32, copy=False)
|
||
elif ch != 1:
|
||
raise ValueError(f"仅支持 1 或 2 通道: {p} (ch={ch})")
|
||
mono = mono * np.float32(1.0 / 32768.0)
|
||
logger.info("TTS: 播放预生成 WAV %s (%sHz, %s 采样)", p.name, sr, mono.size)
|
||
play_tts_audio(mono, sr, output_device=output_device)
|
||
|
||
|
||
def speak_text(
|
||
text: str,
|
||
tts: Optional["KokoroOnnxTTS"] = None,
|
||
*,
|
||
output_device: Optional[object] = None,
|
||
) -> None:
|
||
"""
|
||
合成并播放一段语音;失败时仅打日志,不向外抛异常(适合命令成功后的反馈)。
|
||
"""
|
||
if not text or not str(text).strip():
|
||
return
|
||
try:
|
||
engine = tts or KokoroOnnxTTS()
|
||
t = str(text).strip()
|
||
logger.info(f"TTS: 开始合成并播放: {t!r}")
|
||
audio, sr = engine.synthesize(t)
|
||
play_tts_audio(audio, sr, output_device=output_device)
|
||
logger.info("TTS: 播放完成")
|
||
except Exception as e:
|
||
logger.warning(f"TTS 播放失败: {e}", exc_info=True)
|
||
|
||
|
||
__all__ = [
|
||
"KokoroOnnxTTS",
|
||
"play_tts_audio",
|
||
"play_wav_path",
|
||
"speak_text",
|
||
"get_playback_output_device_id",
|
||
"log_sounddevice_output_devices",
|
||
]
|
||
|
||
if __name__ == "__main__":
|
||
# 与主程序一致:使用 play_tts_audio(含重采样到设备 native 采样率)
|
||
tts = KokoroOnnxTTS()
|
||
|
||
text = "任务执行完成,开始返航降落"
|
||
|
||
print(f"正在合成语音: {text}")
|
||
audio, sr = tts.synthesize(text)
|
||
|
||
print("正在播放(与主程序相同 play_tts_audio 路径)...")
|
||
try:
|
||
play_tts_audio(audio, sr)
|
||
print("播放结束。")
|
||
except Exception as e:
|
||
print(f"播放失败: {e}")
|
||
|
||
# === 保存为 WAV 文件(可选)===
|
||
try:
|
||
output_path = "任务执行完成,开始返航降落.wav"
|
||
tts.synthesize_to_file(text, output_path)
|
||
print(f"音频已保存至: {output_path}")
|
||
except RuntimeError as e:
|
||
print(f"保存失败(可能缺少 scipy): {e}")
|
||
|