DroneMind/voicellmcloud/test/test_08_tts_stream.py
2026-04-14 10:08:41 +08:00

448 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 TTS 音频流完整性和播放
覆盖所有通信模块session管理、dialog_result、tts_audio_chunk、turn.complete
"""
import asyncio
import json
import sys
import time
from pathlib import Path
import websockets
import numpy as np
import sounddevice as sd
sys.path.insert(0, str(Path(__file__).parent))
from test_config import SERVER_URL, AUTH_TOKEN, DEVICE_ID
from test_utils import TestResult
async def test_tts_audio_stream():
"""TTS 音频流完整测试"""
print("=" * 60)
print(" 测试 TTS 音频流完整性")
print("=" * 60)
print()
results = []
# ========== 测试 1: 闲聊场景 TTS 音频流 ==========
result = await test_scenario(
"闲聊场景 TTS",
"你好,今天天气怎么样?",
"chitchat",
check_audio_length=True,
min_audio_length_s=1.0, # 至少 1 秒
)
results.append(result)
result.print_report()
# ========== 测试 2: 飞控指令场景 TTS 音频流 ==========
result = await test_scenario(
"飞控指令场景 TTS",
"起飞然后在前方十米悬停",
"flight_intent",
check_audio_length=True,
min_audio_length_s=1.0,
)
results.append(result)
result.print_report()
# ========== 测试 3: 返航指令场景 TTS 音频流 ==========
result = await test_scenario(
"返航指令场景 TTS",
"返航",
"flight_intent",
check_audio_length=True,
min_audio_length_s=0.5,
)
results.append(result)
result.print_report()
# ========== 测试 4: 降落指令场景 TTS 音频流 ==========
result = await test_scenario(
"降落指令场景 TTS",
"降落",
"flight_intent",
check_audio_length=True,
min_audio_length_s=0.5,
)
results.append(result)
result.print_report()
# ========== 测试 5: 短文本 TTS 音频流 ==========
result = await test_scenario(
"短文本 TTS",
"你好",
"chitchat",
check_audio_length=True,
min_audio_length_s=0.3,
)
results.append(result)
result.print_report()
# ========== 测试 6: 音频流播放验证 ==========
result = await test_audio_playback(
"音频流播放验证",
"这是一段测试语音,用于验证音频流播放功能。"
)
results.append(result)
result.print_report()
# ========== 汇总 ==========
print("\n" + "=" * 60)
print(" TTS 音频流测试汇总")
print("=" * 60)
all_passed = True
for r in results:
status = "✅ 通过" if r.success else "❌ 失败"
print(f" {r.test_name:<30} {status}")
if not r.success and r.errors:
for err in r.errors:
print(f" - {err}")
print(f"\n 总计: {sum(1 for r in results if r.success)}/{len(results)} 通过")
print("=" * 60)
return all(r.success for r in results)
async def test_scenario(
name: str,
text: str,
expected_routing: str,
check_audio_length: bool = True,
min_audio_length_s: float = 0.5,
) -> TestResult:
"""测试一个场景的 TTS 音频流"""
result = TestResult(f"TTS 音频流 - {name}")
try:
async with websockets.connect(SERVER_URL) as ws:
# 1. 建立会话
session_start = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": f"test-tts-{int(time.time())}",
"auth_token": AUTH_TOKEN,
"client": {
"device_id": DEVICE_ID,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
}
}
await ws.send(json.dumps(session_start, ensure_ascii=False))
# 接收 session.ready
msg = await asyncio.wait_for(ws.recv(), timeout=10)
data = json.loads(msg)
if data.get("type") != "session.ready":
result.add_error("未收到 session.ready")
return result.finalize()
# 2. 发送 turn.text
turn_id = f"turn-{int(time.time())}"
turn_text = {
"type": "turn.text",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"turn_id": turn_id,
"text": text,
"is_final": True,
"source": "device_stt"
}
await ws.send(json.dumps(turn_text, ensure_ascii=False))
# 3. 接收响应
dialog_result = None
audio_chunks = []
audio_metadata = []
metrics = {}
while True:
msg = await asyncio.wait_for(ws.recv(), timeout=60)
if isinstance(msg, bytes):
# 二进制音频数据
audio_chunks.append(msg)
else:
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "dialog_result":
dialog_result = data
elif msg_type == "tts_audio_chunk":
# 记录 TTS 元数据
audio_metadata.append({
"seq": data.get("seq"),
"is_final": data.get("is_final"),
"codec": data.get("codec"),
"sample_rate_hz": data.get("sample_rate_hz"),
})
elif msg_type == "turn.complete":
metrics = data.get("metrics", {})
break
# 4. 验证 dialog_result
if dialog_result is None:
result.add_error("未收到 dialog_result")
return result.finalize()
actual_routing = dialog_result.get("routing")
if actual_routing != expected_routing:
result.add_error(
f"路由不匹配: 期望 {expected_routing}, 实际 {actual_routing}"
)
return result.finalize()
result.add_detail("routing", actual_routing)
# 5. 验证音频块序列
if len(audio_chunks) == 0:
result.add_error("未收到任何音频数据")
return result.finalize()
result.add_detail("audio_chunks_count", len(audio_chunks))
result.add_detail("audio_metadata_count", len(audio_metadata))
# 验证音频块序号连续性
for i, meta in enumerate(audio_metadata):
if meta["seq"] != i:
result.add_error(f"音频块序号不连续: 期望 {i}, 实际 {meta['seq']}")
break
if meta["codec"] != "pcm_s16le":
result.add_error(f"音频编码格式错误: 期望 pcm_s16le, 实际 {meta['codec']}")
break
if meta["sample_rate_hz"] != 24000:
result.add_error(f"音频采样率错误: 期望 24000, 实际 {meta['sample_rate_hz']}")
break
# 验证最后一个块标记为 final
if audio_metadata:
if not audio_metadata[-1]["is_final"]:
result.add_error("最后一个音频块未标记为 is_final=True")
# 6. 拼接音频数据
full_pcm = b"".join(audio_chunks)
audio_data = np.frombuffer(full_pcm, dtype=np.int16)
audio_length_s = len(audio_data) / 24000.0
result.add_detail("audio_length_s", f"{audio_length_s:.2f}")
result.add_detail("audio_samples", len(audio_data))
# 7. 验证音频长度
if check_audio_length and audio_length_s < min_audio_length_s:
result.add_error(
f"音频长度过短: {audio_length_s:.2f}s < {min_audio_length_s:.2f}s"
)
# 8. 验证音频数据有效性
if np.max(np.abs(audio_data)) == 0:
result.add_error("音频数据全为零(静音)")
# 9. 播放音频验证
print(f"\n 🔊 播放测试音频 ({audio_length_s:.2f}s)...")
try:
audio_float = audio_data.astype(np.float32) / 32768.0
sd.play(audio_float, samplerate=24000, blocking=True)
print(" ✅ 音频播放成功")
except Exception as e:
result.add_error(f"音频播放失败: {e}")
# 10. 验证性能指标
llm_ms = metrics.get("llm_ms", 0) or 0
tts_ms = metrics.get("tts_first_byte_ms", 0) or 0
result.add_detail("llm_latency_ms", llm_ms)
result.add_detail("tts_latency_ms", tts_ms)
if llm_ms == 0 and tts_ms == 0:
result.add_error("性能指标为空")
# 11. 结束会话
session_end = {
"type": "session.end",
"proto_version": "1.0",
"session_id": session_start["session_id"]
}
await ws.send(json.dumps(session_end))
# 所有检查通过
result.mark_success()
result.add_detail("test_text", text)
result.add_detail("protocol_modules", [
"session.start/ready",
"turn.text",
"dialog_result",
"tts_audio_chunk (text+binary)",
"turn.complete",
"session.end"
])
except Exception as e:
result.add_error(f"测试异常: {e}")
import traceback
traceback.print_exc()
return result.finalize()
async def test_audio_playback(name: str, text: str) -> TestResult:
"""专门测试音频播放功能"""
result = TestResult(f"TTS 音频播放 - {name}")
try:
async with websockets.connect(SERVER_URL) as ws:
# 建立会话
session_start = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": f"test-playback-{int(time.time())}",
"auth_token": AUTH_TOKEN,
"client": {
"device_id": DEVICE_ID,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
}
}
await ws.send(json.dumps(session_start, ensure_ascii=False))
await asyncio.wait_for(ws.recv(), timeout=10) # session.ready
# 发送文本
turn_id = f"turn-{int(time.time())}"
turn_text = {
"type": "turn.text",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"turn_id": turn_id,
"text": text,
"is_final": True,
"source": "device_stt"
}
await ws.send(json.dumps(turn_text, ensure_ascii=False))
# 接收并收集音频
audio_chunks = []
dialog_result = None
metrics = {}
while True:
msg = await asyncio.wait_for(ws.recv(), timeout=60)
if isinstance(msg, bytes):
audio_chunks.append(msg)
else:
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "dialog_result":
dialog_result = data
elif msg_type == "turn.complete":
metrics = data.get("metrics", {})
break
# 播放音频
if audio_chunks:
full_pcm = b"".join(audio_chunks)
audio_data = np.frombuffer(full_pcm, dtype=np.int16)
audio_length_s = len(audio_data) / 24000.0
print(f"\n 🎵 完整音频信息:")
print(f" 长度: {audio_length_s:.2f}s")
print(f" 采样数: {len(audio_data)}")
print(f" 峰值: {np.max(np.abs(audio_data))}")
print(f" 均值: {np.mean(np.abs(audio_data)):.1f}")
# 播放
print(f"\n 🔊 播放音频...")
t_start = time.time()
audio_float = audio_data.astype(np.float32) / 32768.0
sd.play(audio_float, samplerate=24000, blocking=True)
t_play = time.time() - t_start
print(f" ✅ 播放完成 (耗时 {t_play:.2f}s)")
# 验证实时率
realtime_factor = audio_length_s / t_play if t_play > 0 else 0
print(f" 📊 实时率: {realtime_factor:.2f}x")
if realtime_factor < 1.0:
result.add_warning(
f"实时率低于 1.0: {realtime_factor:.2f}x "
f"(可能导致播放卡顿)"
)
result.mark_success()
result.add_detail("audio_length_s", f"{audio_length_s:.2f}")
result.add_detail("playback_time_s", f"{t_play:.2f}")
result.add_detail("realtime_factor", f"{realtime_factor:.2f}")
result.add_detail("metrics", metrics)
else:
result.add_error("未收到音频数据")
# 结束会话
await ws.send(json.dumps({
"type": "session.end",
"proto_version": "1.0",
"session_id": session_start["session_id"]
}))
except Exception as e:
result.add_error(f"测试异常: {e}")
import traceback
traceback.print_exc()
return result.finalize()
async def main():
print("\n" + "=" * 60)
print(" TTS 音频流完整性测试")
print(" 覆盖: session管理 + dialog_result + tts_audio_chunk + turn.complete")
print("=" * 60)
print()
success = await test_tts_audio_stream()
print("\n" + "=" * 60)
if success:
print(" ✅ 所有 TTS 音频流测试通过!")
print(" 已验证:")
print(" - session.start/ready 会话管理")
print(" - turn.text 文本发送")
print(" - dialog_result 结构化结果")
print(" - tts_audio_chunk 流式音频text 头 + binary 体)")
print(" - turn.complete 完成通知")
print(" - 音频数据完整性和长度")
print(" - 音频流本地播放")
else:
print(" ❌ 部分 TTS 音频流测试失败!")
print("=" * 60)
sys.exit(0 if success else 1)
if __name__ == "__main__":
asyncio.run(main())