118 lines
3.3 KiB
Python
118 lines
3.3 KiB
Python
"""
|
|
测试辅助函数
|
|
"""
|
|
import json
|
|
import time
|
|
import numpy as np
|
|
import sounddevice as sd
|
|
from typing import Optional
|
|
|
|
|
|
def play_audio(audio_data: np.ndarray, sample_rate: int = 24000):
|
|
"""播放音频"""
|
|
if len(audio_data) == 0:
|
|
print(" ⚠️ 音频数据为空,跳过播放")
|
|
return
|
|
|
|
try:
|
|
# 转换为 float32
|
|
audio_float = audio_data.astype(np.float32) / 32768.0
|
|
|
|
# 播放
|
|
sd.play(audio_float, samplerate=sample_rate, blocking=True)
|
|
print(f" 🔊 音频播放完成 ({len(audio_data)} samples, {len(audio_data)/sample_rate:.2f}s)")
|
|
except Exception as e:
|
|
print(f" ❌ 音频播放失败: {e}")
|
|
|
|
|
|
def print_json(data: dict, indent: int = 2):
|
|
"""打印 JSON"""
|
|
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
|
|
|
|
|
def calculate_metrics(
|
|
llm_ms: Optional[int],
|
|
tts_ms: Optional[int],
|
|
audio_length_s: float,
|
|
total_time_s: float,
|
|
) -> dict:
|
|
"""计算性能指标"""
|
|
metrics = {
|
|
"llm_latency_ms": llm_ms,
|
|
"tts_latency_ms": tts_ms,
|
|
"audio_length_s": audio_length_s,
|
|
"total_time_s": total_time_s,
|
|
}
|
|
|
|
if llm_ms:
|
|
metrics["llm_latency_s"] = llm_ms / 1000
|
|
|
|
if tts_ms:
|
|
metrics["tts_latency_s"] = tts_ms / 1000
|
|
|
|
if audio_length_s > 0 and total_time_s > 0:
|
|
metrics["realtime_factor"] = audio_length_s / total_time_s
|
|
|
|
return metrics
|
|
|
|
|
|
def print_metrics(metrics: dict):
|
|
"""打印性能指标"""
|
|
print("\n 📊 性能指标:")
|
|
if "llm_latency_ms" in metrics:
|
|
print(f" LLM 推理: {metrics['llm_latency_ms']}ms ({metrics.get('llm_latency_s', 0):.2f}s)")
|
|
if "tts_latency_ms" in metrics:
|
|
print(f" TTS 合成: {metrics['tts_latency_ms']}ms ({metrics.get('tts_latency_s', 0):.2f}s)")
|
|
print(f" 音频长度: {metrics['audio_length_s']:.2f}s")
|
|
print(f" 总耗时: {metrics['total_time_s']:.2f}s")
|
|
if "realtime_factor" in metrics:
|
|
print(f" 实时率: {metrics['realtime_factor']:.2f}x")
|
|
|
|
|
|
class TestResult:
|
|
"""测试结果"""
|
|
def __init__(self, test_name: str):
|
|
self.test_name = test_name
|
|
self.success = False
|
|
self.errors = []
|
|
self.details = {}
|
|
self.start_time = time.time()
|
|
|
|
def add_error(self, error: str):
|
|
self.errors.append(error)
|
|
|
|
def add_detail(self, key: str, value):
|
|
self.details[key] = value
|
|
|
|
def add_warning(self, warning: str):
|
|
"""添加警告信息"""
|
|
self.details.setdefault("warnings", []).append(warning)
|
|
|
|
def mark_success(self):
|
|
self.success = True
|
|
|
|
def finalize(self):
|
|
self.elapsed = time.time() - self.start_time
|
|
return self
|
|
|
|
def print_report(self):
|
|
print(f"\n{'='*60}")
|
|
print(f" 测试: {self.test_name}")
|
|
print(f" 耗时: {self.elapsed:.2f}s")
|
|
print(f"{'='*60}")
|
|
|
|
if self.success:
|
|
print(f" ✅ 通过")
|
|
else:
|
|
print(f" ❌ 失败")
|
|
for error in self.errors:
|
|
print(f" - {error}")
|
|
|
|
if self.details:
|
|
print(f"\n 详细信息:")
|
|
for key, value in self.details.items():
|
|
print(f" {key}: {value}")
|
|
|
|
print()
|
|
return self.success
|