2026-04-14 10:08:41 +08:00

362 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Kokoro TTS 服务实现 - 基于 ONNX 的本地文字转语音
完整复用香橙派项目的 KokoroOnnxTTS 实现
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Generator, Optional, List
import numpy as np
import onnxruntime as ort
from loguru import logger
from app.services.tts_service import TTSServiceInterface
from app.config import settings
class KokoroTTSService(TTSServiceInterface):
"""
Kokoro TTS 服务实现 - 完整复用香橙派代码
基本用法:
service = KokoroTTSService()
await service.initialize()
for audio_chunk in service.synthesize("你好"):
yield audio_chunk
"""
def __init__(self):
self._session: Optional[ort.InferenceSession] = None
self._vocab: Optional[dict] = None
self._voices: Optional[np.ndarray] = None
self._g2p = None # misaki G2P
self._initialized = False
self._model_dir: Optional[Path] = None
self._onnx_path: Optional[Path] = None
self._voice_path: Optional[Path] = None
self._tokenizer_path: Optional[Path] = None
self._voice_name: str = "zf_001"
self._speed: float = 1.15
self._sample_rate: int = 24000
async def initialize(self) -> bool:
"""加载 Kokoro 模型及所有组件"""
try:
# 定位模型目录
model_dir = Path(settings.TTS_MODEL_DIR) / "Kokoro-82M-v1.1-zh-ONNX"
if not model_dir.exists():
logger.error(f"Kokoro 模型目录不存在: {model_dir}")
return False
self._model_dir = model_dir
# tokenizer 路径
self._tokenizer_path = model_dir / "tokenizer.json"
if not self._tokenizer_path.exists():
logger.error(f"tokenizer.json 不存在: {self._tokenizer_path}")
return False
# ONNX 模型路径 - 优先选择 model_q4.onnx (速度快)
onnx_dir = model_dir / "onnx"
model_preferences = [
"model_q4.onnx",
"model_fp16.onnx",
"model.onnx",
"model_int8.onnx",
"model_uint8.onnx",
]
self._onnx_path = None
for pref in model_preferences:
candidate = onnx_dir / pref
if candidate.exists():
self._onnx_path = candidate
break
if not self._onnx_path:
logger.error(f"未找到 ONNX 模型文件,搜索目录: {onnx_dir}")
return False
# 语音风格文件
self._voice_name = settings.TTS_VOICE_NAME or "zf_001"
self._voice_path = model_dir / "voices" / f"{self._voice_name}.bin"
if not self._voice_path.exists():
logger.warning(f"语音文件不存在: {self._voice_path},尝试使用第一个可用文件")
voices_dir = model_dir / "voices"
voice_files = list(voices_dir.glob("*.bin"))
if voice_files:
self._voice_path = voice_files[0]
self._voice_name = self._voice_path.stem
logger.info(f"使用语音文件: {self._voice_name}")
else:
logger.error("没有任何语音文件")
return False
# 加载所有组件
self._load_tokenizer_vocab()
self._load_voices()
self._load_onnx_session()
self._init_g2p()
logger.info(
f"Kokoro TTS 初始化成功: "
f"model={self._onnx_path.name}, voice={self._voice_name}, "
f"speed={self._speed}, sr={self._sample_rate}Hz"
)
self._initialized = True
return True
except Exception as e:
logger.error(f"Kokoro TTS 初始化失败: {e}", exc_info=True)
return False
def _load_tokenizer_vocab(self):
"""从 tokenizer.json 载入 vocab 映射"""
with open(self._tokenizer_path, "r", encoding="utf-8") as f:
data = json.load(f)
model = data.get("model") or {}
vocab = model.get("vocab")
if not isinstance(vocab, dict):
raise ValueError("tokenizer.json 格式不正确,未找到 model.vocab 字段")
self._vocab = {k: int(v) for k, v in vocab.items()}
logger.info(f"tokenizer vocab 加载完成,词表大小: {len(self._vocab)}")
def _load_voices(self):
"""载入语音风格向量"""
voices = np.fromfile(str(self._voice_path), dtype=np.float32)
try:
voices = voices.reshape(-1, 1, 256)
except ValueError as e:
raise ValueError(
f"语音风格文件形状不符合预期,无法 reshape 为 (-1,1,256): {self._voice_path}"
) from e
self._voices = voices
logger.info(f"语音风格文件加载完成: {self._voice_name}, 可用 style 数量: {voices.shape[0]}")
def _load_onnx_session(self):
"""创建 ONNX Runtime 推理会话"""
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# 设置线程数
intra_threads = int(os.environ.get("ROCKET_TTS_ORT_INTRA_OP_THREADS", "4"))
inter_threads = int(os.environ.get("ROCKET_TTS_ORT_INTER_OP_THREADS", "2"))
if intra_threads > 0:
sess_options.intra_op_num_threads = intra_threads
if inter_threads > 0:
sess_options.inter_op_num_threads = inter_threads
# CPU 推理
self._session = ort.InferenceSession(
str(self._onnx_path),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
logger.info(f"ONNX 模型加载完成: {self._onnx_path}")
def _init_g2p(self):
"""初始化中文 G2P (基于 misaki[zh])"""
try:
from misaki import zh
try:
self._g2p = zh.ZHG2P(version="1.1")
except TypeError:
self._g2p = zh.ZHG2P()
logger.info("已启用 misaki[zh] G2P将使用音素级别映射")
except Exception as e:
self._g2p = None
logger.warning(
"未安装或无法导入 misaki[zh],将直接基于原始文本字符做 token 映射,"
"合成效果可能较差。建议执行: pip install \"misaki[zh]\" cn2an pypinyin jieba"
)
logger.debug(f"G2P 初始化失败原因: {e!r}")
def _text_to_phonemes(self, text: str) -> str:
"""文本 -> 音素串"""
if self._g2p is None:
return text.strip()
result = self._g2p(text)
if isinstance(result, tuple) or isinstance(result, list):
ps = result[0]
else:
ps = result
if not ps:
logger.warning("G2P 结果为空,回退为原始文本")
return text.strip()
return ps
def _phonemes_to_token_ids(self, phonemes: str) -> List[int]:
"""将音素串映射为 token id 序列"""
assert self._vocab is not None, "vocab 尚未初始化"
vocab = self._vocab
token_ids: List[int] = []
unknown_chars = set()
for ch in phonemes:
if ch == "\n":
continue
tid = vocab.get(ch)
if tid is None:
unknown_chars.add(ch)
continue
token_ids.append(int(tid))
if unknown_chars:
logger.debug(f"存在无法映射到 vocab 的字符: {unknown_chars}")
return token_ids
def synthesize(
self,
text: str,
sample_rate: int = 24000,
) -> Generator[np.ndarray, None, None]:
"""
流式合成语音
Args:
text: 输入文本(推荐为简体中文)
sample_rate: 目标采样率
Yields:
音频块 (numpy array, float32)
"""
if not self._initialized:
raise RuntimeError("TTS 服务未初始化")
if not text or not text.strip():
raise ValueError("TTS 输入文本不能为空")
try:
# 文本 -> 音素 -> token IDs
phonemes = self._text_to_phonemes(text)
token_ids = self._phonemes_to_token_ids(phonemes)
logger.debug(
f"[Kokoro] 文本: {text[:50]}, 音素: {phonemes[:50]}, token数: {len(token_ids)}"
)
if len(token_ids) == 0:
logger.warning(f"文本在当前 vocab 下无法映射到任何 token: {text!r}")
return
# 按 Kokoro 官方示例token 序列长度 <= 510
if len(token_ids) > 510:
logger.warning(f"token 长度 {len(token_ids)} > 510将被截断")
token_ids = token_ids[:510]
# 添加 pad token
tokens = np.array([[0, *token_ids, 0]], dtype=np.int64)
# 选择 style 向量
voices = self._voices
idx = min(len(token_ids), voices.shape[0] - 1)
style = voices[idx]
speed = np.array([self._speed], dtype=np.float32)
# ONNX 推理
logger.debug(
f"[Kokoro] 开始推理: input_ids.shape={tokens.shape}, style.shape={style.shape}"
)
audio = self._session.run(
None,
{
"input_ids": tokens,
"style": style,
"speed": speed,
},
)[0]
logger.debug(
f"[Kokoro] 推理完成: audio.shape={audio.shape}, dtype={audio.dtype}, "
f"max={np.max(np.abs(audio)):.4f}"
)
# 处理输出形状
audio = audio.astype(np.float32)
if audio.ndim == 2 and audio.shape[0] == 1:
audio = audio[0]
elif audio.ndim > 2:
audio = np.squeeze(audio)
# 调试 WAV 默认关闭:写盘会阻滞首包 PCM 下发
if os.environ.get("ROCKET_KOKORO_DEBUG_WAV", "").strip().lower() in (
"1",
"true",
"yes",
):
try:
import soundfile as sf
debug_path = Path("debug_tts_output.wav")
sf.write(str(debug_path), audio, self._sample_rate)
logger.info(f"[Kokoro] 已保存调试音频到: {debug_path}")
except Exception as e:
logger.debug(f"[Kokoro] 保存调试音频失败: {e}")
# 首块略短便于更快发出首字节,其后保持 ~100ms
first_chunk_ms = float(
os.environ.get("ROCKET_KOKORO_FIRST_CHUNK_MS", "0.05")
)
first_chunk_ms = max(0.02, min(first_chunk_ms, 0.2))
chunk_size = int(sample_rate * 0.1)
first_chunk_size = int(sample_rate * first_chunk_ms)
pos = 0
first = True
while pos < len(audio):
n = first_chunk_size if first else chunk_size
first = False
chunk = audio[pos : pos + n]
pos += len(chunk)
yield chunk
except Exception as e:
logger.error(f"Kokoro TTS 合成失败: text='{text[:50]}...', 错误={e}")
raise
def synthesize_complete(self, text: str) -> tuple[np.ndarray, int]:
"""
完整合成语音(一次性返回全部音频)
Returns:
(音频数据, 采样率)
"""
if not self._initialized:
raise RuntimeError("TTS 服务未初始化")
audio_parts = []
for chunk in self.synthesize(text):
audio_parts.append(chunk)
if audio_parts:
full_audio = np.concatenate(audio_parts)
return full_audio, self._sample_rate
else:
return np.array([], dtype=np.float32), self._sample_rate
async def shutdown(self):
"""关闭服务"""
self._session = None
self._vocab = None
self._voices = None
self._g2p = None
self._initialized = False
logger.info("Kokoro TTS 服务已关闭")