362 lines
13 KiB
Python
362 lines
13 KiB
Python
"""
|
||
Kokoro TTS 服务实现 - 基于 ONNX 的本地文字转语音
|
||
完整复用香橙派项目的 KokoroOnnxTTS 实现
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Generator, Optional, List
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
from loguru import logger
|
||
|
||
from app.services.tts_service import TTSServiceInterface
|
||
from app.config import settings
|
||
|
||
|
||
class KokoroTTSService(TTSServiceInterface):
|
||
"""
|
||
Kokoro TTS 服务实现 - 完整复用香橙派代码
|
||
|
||
基本用法:
|
||
service = KokoroTTSService()
|
||
await service.initialize()
|
||
for audio_chunk in service.synthesize("你好"):
|
||
yield audio_chunk
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._session: Optional[ort.InferenceSession] = None
|
||
self._vocab: Optional[dict] = None
|
||
self._voices: Optional[np.ndarray] = None
|
||
self._g2p = None # misaki G2P
|
||
self._initialized = False
|
||
|
||
self._model_dir: Optional[Path] = None
|
||
self._onnx_path: Optional[Path] = None
|
||
self._voice_path: Optional[Path] = None
|
||
self._tokenizer_path: Optional[Path] = None
|
||
|
||
self._voice_name: str = "zf_001"
|
||
self._speed: float = 1.15
|
||
self._sample_rate: int = 24000
|
||
|
||
async def initialize(self) -> bool:
|
||
"""加载 Kokoro 模型及所有组件"""
|
||
try:
|
||
# 定位模型目录
|
||
model_dir = Path(settings.TTS_MODEL_DIR) / "Kokoro-82M-v1.1-zh-ONNX"
|
||
if not model_dir.exists():
|
||
logger.error(f"Kokoro 模型目录不存在: {model_dir}")
|
||
return False
|
||
|
||
self._model_dir = model_dir
|
||
|
||
# tokenizer 路径
|
||
self._tokenizer_path = model_dir / "tokenizer.json"
|
||
if not self._tokenizer_path.exists():
|
||
logger.error(f"tokenizer.json 不存在: {self._tokenizer_path}")
|
||
return False
|
||
|
||
# ONNX 模型路径 - 优先选择 model_q4.onnx (速度快)
|
||
onnx_dir = model_dir / "onnx"
|
||
model_preferences = [
|
||
"model_q4.onnx",
|
||
"model_fp16.onnx",
|
||
"model.onnx",
|
||
"model_int8.onnx",
|
||
"model_uint8.onnx",
|
||
]
|
||
|
||
self._onnx_path = None
|
||
for pref in model_preferences:
|
||
candidate = onnx_dir / pref
|
||
if candidate.exists():
|
||
self._onnx_path = candidate
|
||
break
|
||
|
||
if not self._onnx_path:
|
||
logger.error(f"未找到 ONNX 模型文件,搜索目录: {onnx_dir}")
|
||
return False
|
||
|
||
# 语音风格文件
|
||
self._voice_name = settings.TTS_VOICE_NAME or "zf_001"
|
||
self._voice_path = model_dir / "voices" / f"{self._voice_name}.bin"
|
||
if not self._voice_path.exists():
|
||
logger.warning(f"语音文件不存在: {self._voice_path},尝试使用第一个可用文件")
|
||
voices_dir = model_dir / "voices"
|
||
voice_files = list(voices_dir.glob("*.bin"))
|
||
if voice_files:
|
||
self._voice_path = voice_files[0]
|
||
self._voice_name = self._voice_path.stem
|
||
logger.info(f"使用语音文件: {self._voice_name}")
|
||
else:
|
||
logger.error("没有任何语音文件")
|
||
return False
|
||
|
||
# 加载所有组件
|
||
self._load_tokenizer_vocab()
|
||
self._load_voices()
|
||
self._load_onnx_session()
|
||
self._init_g2p()
|
||
|
||
logger.info(
|
||
f"Kokoro TTS 初始化成功: "
|
||
f"model={self._onnx_path.name}, voice={self._voice_name}, "
|
||
f"speed={self._speed}, sr={self._sample_rate}Hz"
|
||
)
|
||
|
||
self._initialized = True
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"Kokoro TTS 初始化失败: {e}", exc_info=True)
|
||
return False
|
||
|
||
def _load_tokenizer_vocab(self):
|
||
"""从 tokenizer.json 载入 vocab 映射"""
|
||
with open(self._tokenizer_path, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
model = data.get("model") or {}
|
||
vocab = model.get("vocab")
|
||
if not isinstance(vocab, dict):
|
||
raise ValueError("tokenizer.json 格式不正确,未找到 model.vocab 字段")
|
||
|
||
self._vocab = {k: int(v) for k, v in vocab.items()}
|
||
logger.info(f"tokenizer vocab 加载完成,词表大小: {len(self._vocab)}")
|
||
|
||
def _load_voices(self):
|
||
"""载入语音风格向量"""
|
||
voices = np.fromfile(str(self._voice_path), dtype=np.float32)
|
||
try:
|
||
voices = voices.reshape(-1, 1, 256)
|
||
except ValueError as e:
|
||
raise ValueError(
|
||
f"语音风格文件形状不符合预期,无法 reshape 为 (-1,1,256): {self._voice_path}"
|
||
) from e
|
||
|
||
self._voices = voices
|
||
logger.info(f"语音风格文件加载完成: {self._voice_name}, 可用 style 数量: {voices.shape[0]}")
|
||
|
||
def _load_onnx_session(self):
|
||
"""创建 ONNX Runtime 推理会话"""
|
||
sess_options = ort.SessionOptions()
|
||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
|
||
# 设置线程数
|
||
intra_threads = int(os.environ.get("ROCKET_TTS_ORT_INTRA_OP_THREADS", "4"))
|
||
inter_threads = int(os.environ.get("ROCKET_TTS_ORT_INTER_OP_THREADS", "2"))
|
||
if intra_threads > 0:
|
||
sess_options.intra_op_num_threads = intra_threads
|
||
if inter_threads > 0:
|
||
sess_options.inter_op_num_threads = inter_threads
|
||
|
||
# CPU 推理
|
||
self._session = ort.InferenceSession(
|
||
str(self._onnx_path),
|
||
sess_options=sess_options,
|
||
providers=["CPUExecutionProvider"],
|
||
)
|
||
|
||
logger.info(f"ONNX 模型加载完成: {self._onnx_path}")
|
||
|
||
def _init_g2p(self):
|
||
"""初始化中文 G2P (基于 misaki[zh])"""
|
||
try:
|
||
from misaki import zh
|
||
|
||
try:
|
||
self._g2p = zh.ZHG2P(version="1.1")
|
||
except TypeError:
|
||
self._g2p = zh.ZHG2P()
|
||
|
||
logger.info("已启用 misaki[zh] G2P,将使用音素级别映射")
|
||
|
||
except Exception as e:
|
||
self._g2p = None
|
||
logger.warning(
|
||
"未安装或无法导入 misaki[zh],将直接基于原始文本字符做 token 映射,"
|
||
"合成效果可能较差。建议执行: pip install \"misaki[zh]\" cn2an pypinyin jieba"
|
||
)
|
||
logger.debug(f"G2P 初始化失败原因: {e!r}")
|
||
|
||
def _text_to_phonemes(self, text: str) -> str:
|
||
"""文本 -> 音素串"""
|
||
if self._g2p is None:
|
||
return text.strip()
|
||
|
||
result = self._g2p(text)
|
||
if isinstance(result, tuple) or isinstance(result, list):
|
||
ps = result[0]
|
||
else:
|
||
ps = result
|
||
|
||
if not ps:
|
||
logger.warning("G2P 结果为空,回退为原始文本")
|
||
return text.strip()
|
||
|
||
return ps
|
||
|
||
def _phonemes_to_token_ids(self, phonemes: str) -> List[int]:
|
||
"""将音素串映射为 token id 序列"""
|
||
assert self._vocab is not None, "vocab 尚未初始化"
|
||
vocab = self._vocab
|
||
|
||
token_ids: List[int] = []
|
||
unknown_chars = set()
|
||
|
||
for ch in phonemes:
|
||
if ch == "\n":
|
||
continue
|
||
tid = vocab.get(ch)
|
||
if tid is None:
|
||
unknown_chars.add(ch)
|
||
continue
|
||
token_ids.append(int(tid))
|
||
|
||
if unknown_chars:
|
||
logger.debug(f"存在无法映射到 vocab 的字符: {unknown_chars}")
|
||
|
||
return token_ids
|
||
|
||
def synthesize(
|
||
self,
|
||
text: str,
|
||
sample_rate: int = 24000,
|
||
) -> Generator[np.ndarray, None, None]:
|
||
"""
|
||
流式合成语音
|
||
|
||
Args:
|
||
text: 输入文本(推荐为简体中文)
|
||
sample_rate: 目标采样率
|
||
|
||
Yields:
|
||
音频块 (numpy array, float32)
|
||
"""
|
||
if not self._initialized:
|
||
raise RuntimeError("TTS 服务未初始化")
|
||
|
||
if not text or not text.strip():
|
||
raise ValueError("TTS 输入文本不能为空")
|
||
|
||
try:
|
||
# 文本 -> 音素 -> token IDs
|
||
phonemes = self._text_to_phonemes(text)
|
||
token_ids = self._phonemes_to_token_ids(phonemes)
|
||
|
||
logger.debug(
|
||
f"[Kokoro] 文本: {text[:50]}, 音素: {phonemes[:50]}, token数: {len(token_ids)}"
|
||
)
|
||
|
||
if len(token_ids) == 0:
|
||
logger.warning(f"文本在当前 vocab 下无法映射到任何 token: {text!r}")
|
||
return
|
||
|
||
# 按 Kokoro 官方示例:token 序列长度 <= 510
|
||
if len(token_ids) > 510:
|
||
logger.warning(f"token 长度 {len(token_ids)} > 510,将被截断")
|
||
token_ids = token_ids[:510]
|
||
|
||
# 添加 pad token
|
||
tokens = np.array([[0, *token_ids, 0]], dtype=np.int64)
|
||
|
||
# 选择 style 向量
|
||
voices = self._voices
|
||
idx = min(len(token_ids), voices.shape[0] - 1)
|
||
style = voices[idx]
|
||
|
||
speed = np.array([self._speed], dtype=np.float32)
|
||
|
||
# ONNX 推理
|
||
logger.debug(
|
||
f"[Kokoro] 开始推理: input_ids.shape={tokens.shape}, style.shape={style.shape}"
|
||
)
|
||
audio = self._session.run(
|
||
None,
|
||
{
|
||
"input_ids": tokens,
|
||
"style": style,
|
||
"speed": speed,
|
||
},
|
||
)[0]
|
||
|
||
logger.debug(
|
||
f"[Kokoro] 推理完成: audio.shape={audio.shape}, dtype={audio.dtype}, "
|
||
f"max={np.max(np.abs(audio)):.4f}"
|
||
)
|
||
|
||
# 处理输出形状
|
||
audio = audio.astype(np.float32)
|
||
if audio.ndim == 2 and audio.shape[0] == 1:
|
||
audio = audio[0]
|
||
elif audio.ndim > 2:
|
||
audio = np.squeeze(audio)
|
||
|
||
# 调试 WAV 默认关闭:写盘会阻滞首包 PCM 下发
|
||
if os.environ.get("ROCKET_KOKORO_DEBUG_WAV", "").strip().lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
):
|
||
try:
|
||
import soundfile as sf
|
||
|
||
debug_path = Path("debug_tts_output.wav")
|
||
sf.write(str(debug_path), audio, self._sample_rate)
|
||
logger.info(f"[Kokoro] 已保存调试音频到: {debug_path}")
|
||
except Exception as e:
|
||
logger.debug(f"[Kokoro] 保存调试音频失败: {e}")
|
||
|
||
# 首块略短便于更快发出首字节,其后保持 ~100ms
|
||
first_chunk_ms = float(
|
||
os.environ.get("ROCKET_KOKORO_FIRST_CHUNK_MS", "0.05")
|
||
)
|
||
first_chunk_ms = max(0.02, min(first_chunk_ms, 0.2))
|
||
chunk_size = int(sample_rate * 0.1)
|
||
first_chunk_size = int(sample_rate * first_chunk_ms)
|
||
pos = 0
|
||
first = True
|
||
while pos < len(audio):
|
||
n = first_chunk_size if first else chunk_size
|
||
first = False
|
||
chunk = audio[pos : pos + n]
|
||
pos += len(chunk)
|
||
yield chunk
|
||
|
||
except Exception as e:
|
||
logger.error(f"Kokoro TTS 合成失败: text='{text[:50]}...', 错误={e}")
|
||
raise
|
||
|
||
def synthesize_complete(self, text: str) -> tuple[np.ndarray, int]:
|
||
"""
|
||
完整合成语音(一次性返回全部音频)
|
||
|
||
Returns:
|
||
(音频数据, 采样率)
|
||
"""
|
||
if not self._initialized:
|
||
raise RuntimeError("TTS 服务未初始化")
|
||
|
||
audio_parts = []
|
||
for chunk in self.synthesize(text):
|
||
audio_parts.append(chunk)
|
||
|
||
if audio_parts:
|
||
full_audio = np.concatenate(audio_parts)
|
||
return full_audio, self._sample_rate
|
||
else:
|
||
return np.array([], dtype=np.float32), self._sample_rate
|
||
|
||
async def shutdown(self):
|
||
"""关闭服务"""
|
||
self._session = None
|
||
self._vocab = None
|
||
self._voices = None
|
||
self._g2p = None
|
||
self._initialized = False
|
||
logger.info("Kokoro TTS 服务已关闭")
|