DroneMind/voice_drone/core/tts_ack_cache.py
2026-04-14 09:54:26 +08:00

153 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
应答 TTS 波形磁盘缓存:文案与 TTS 配置未变时跳过逐条合成,加快启动。
缓存目录:项目根下 cache/ack_tts_pcm/
"""
from __future__ import annotations
import hashlib
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from voice_drone.core.configuration import SYSTEM_TTS_CONFIG
# 与 src/core/configuration.py 一致src/core/tts_ack_cache.py -> parents[2]
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
ACK_PCM_CACHE_DIR = _PROJECT_ROOT / "cache" / "ack_tts_pcm"
MANIFEST_NAME = "manifest.json"
CACHE_FORMAT = 1
def _tts_signature() -> dict:
tts = SYSTEM_TTS_CONFIG or {}
return {
"model_dir": str(tts.get("model_dir", "")),
"model_name": str(tts.get("model_name", "")),
"voice": str(tts.get("voice", "")),
"speed": round(float(tts.get("speed", 1.0)), 6),
"sample_rate": int(tts.get("sample_rate", 24000)),
}
def compute_ack_pcm_fingerprint(
unique_phrases: List[str],
*,
global_text: Optional[str] = None,
mode_phrases: bool = True,
) -> str:
"""文案 + TTS 签名变化则指纹变,磁盘缓存失效。"""
payload = {
"cache_format": CACHE_FORMAT,
"tts": _tts_signature(),
"mode_phrases": mode_phrases,
}
if mode_phrases:
payload["phrases"] = sorted(unique_phrases)
else:
payload["global_text"] = (global_text or "").strip()
raw = json.dumps(payload, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _phrase_file_stem(fingerprint: str, phrase: str) -> str:
h = hashlib.sha256(fingerprint.encode("utf-8"))
h.update(b"\0")
h.update(phrase.encode("utf-8"))
return h.hexdigest()[:40]
def _load_one_npz(path: Path) -> Optional[Tuple[np.ndarray, int]]:
try:
z = np.load(path, allow_pickle=False)
audio = np.asarray(z["audio"], dtype=np.float32).squeeze()
sr = int(np.asarray(z["sr"]).reshape(-1)[0])
if audio.size == 0 or sr <= 0:
return None
return (audio, sr)
except Exception:
return None
def load_cached_phrases(
unique_phrases: List[str],
fingerprint: str,
) -> Tuple[Dict[str, Tuple[np.ndarray, int]], List[str]]:
"""
从磁盘加载与 fingerprint 匹配的缓存。
Returns:
(已加载的 phrase -> (audio, sr), 仍需合成的 phrase 列表)
"""
out: Dict[str, Tuple[np.ndarray, int]] = {}
if not unique_phrases:
return {}, []
cache_dir = ACK_PCM_CACHE_DIR
manifest_path = cache_dir / MANIFEST_NAME
if not manifest_path.is_file():
return {}, list(unique_phrases)
try:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json.load(f)
except Exception:
return {}, list(unique_phrases)
if int(manifest.get("format", 0)) != CACHE_FORMAT:
return {}, list(unique_phrases)
if manifest.get("fingerprint") != fingerprint:
return {}, list(unique_phrases)
files: Dict[str, str] = manifest.get("files") or {}
missing: List[str] = []
for phrase in unique_phrases:
fname = files.get(phrase)
if not fname:
missing.append(phrase)
continue
path = cache_dir / fname
if not path.is_file():
missing.append(phrase)
continue
loaded = _load_one_npz(path)
if loaded is None:
missing.append(phrase)
continue
out[phrase] = loaded
return out, missing
def persist_phrases(fingerprint: str, phrase_pcm: Dict[str, Tuple[np.ndarray, int]]) -> None:
"""写入/更新整包 manifest 与各句 npz覆盖同名 manifest"""
if not phrase_pcm:
return
cache_dir = ACK_PCM_CACHE_DIR
cache_dir.mkdir(parents=True, exist_ok=True)
files: Dict[str, str] = {}
for phrase, (audio, sr) in phrase_pcm.items():
stem = _phrase_file_stem(fingerprint, phrase)
fname = f"{stem}.npz"
path = cache_dir / fname
audio = np.asarray(audio, dtype=np.float32).squeeze()
np.savez_compressed(path, audio=audio, sr=np.array([int(sr)], dtype=np.int32))
files[phrase] = fname
manifest = {
"format": CACHE_FORMAT,
"fingerprint": fingerprint,
"files": files,
}
tmp = cache_dir / (MANIFEST_NAME + ".tmp")
with open(tmp, "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=0)
tmp.replace(cache_dir / MANIFEST_NAME)