DroneMind/voicellmcloud/test/test_utils.py
2026-04-14 10:08:41 +08:00

118 lines
3.3 KiB
Python

"""
测试辅助函数
"""
import json
import time
import numpy as np
import sounddevice as sd
from typing import Optional
def play_audio(audio_data: np.ndarray, sample_rate: int = 24000):
"""播放音频"""
if len(audio_data) == 0:
print(" ⚠️ 音频数据为空,跳过播放")
return
try:
# 转换为 float32
audio_float = audio_data.astype(np.float32) / 32768.0
# 播放
sd.play(audio_float, samplerate=sample_rate, blocking=True)
print(f" 🔊 音频播放完成 ({len(audio_data)} samples, {len(audio_data)/sample_rate:.2f}s)")
except Exception as e:
print(f" ❌ 音频播放失败: {e}")
def print_json(data: dict, indent: int = 2):
"""打印 JSON"""
print(json.dumps(data, ensure_ascii=False, indent=indent))
def calculate_metrics(
llm_ms: Optional[int],
tts_ms: Optional[int],
audio_length_s: float,
total_time_s: float,
) -> dict:
"""计算性能指标"""
metrics = {
"llm_latency_ms": llm_ms,
"tts_latency_ms": tts_ms,
"audio_length_s": audio_length_s,
"total_time_s": total_time_s,
}
if llm_ms:
metrics["llm_latency_s"] = llm_ms / 1000
if tts_ms:
metrics["tts_latency_s"] = tts_ms / 1000
if audio_length_s > 0 and total_time_s > 0:
metrics["realtime_factor"] = audio_length_s / total_time_s
return metrics
def print_metrics(metrics: dict):
"""打印性能指标"""
print("\n 📊 性能指标:")
if "llm_latency_ms" in metrics:
print(f" LLM 推理: {metrics['llm_latency_ms']}ms ({metrics.get('llm_latency_s', 0):.2f}s)")
if "tts_latency_ms" in metrics:
print(f" TTS 合成: {metrics['tts_latency_ms']}ms ({metrics.get('tts_latency_s', 0):.2f}s)")
print(f" 音频长度: {metrics['audio_length_s']:.2f}s")
print(f" 总耗时: {metrics['total_time_s']:.2f}s")
if "realtime_factor" in metrics:
print(f" 实时率: {metrics['realtime_factor']:.2f}x")
class TestResult:
"""测试结果"""
def __init__(self, test_name: str):
self.test_name = test_name
self.success = False
self.errors = []
self.details = {}
self.start_time = time.time()
def add_error(self, error: str):
self.errors.append(error)
def add_detail(self, key: str, value):
self.details[key] = value
def add_warning(self, warning: str):
"""添加警告信息"""
self.details.setdefault("warnings", []).append(warning)
def mark_success(self):
self.success = True
def finalize(self):
self.elapsed = time.time() - self.start_time
return self
def print_report(self):
print(f"\n{'='*60}")
print(f" 测试: {self.test_name}")
print(f" 耗时: {self.elapsed:.2f}s")
print(f"{'='*60}")
if self.success:
print(f" ✅ 通过")
else:
print(f" ❌ 失败")
for error in self.errors:
print(f" - {error}")
if self.details:
print(f"\n 详细信息:")
for key, value in self.details.items():
print(f" {key}: {value}")
print()
return self.success