448 lines
16 KiB
Python
448 lines
16 KiB
Python
"""
|
||
测试 TTS 音频流完整性和播放
|
||
覆盖所有通信模块:session管理、dialog_result、tts_audio_chunk、turn.complete
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import websockets
|
||
import numpy as np
|
||
import sounddevice as sd
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from test_config import SERVER_URL, AUTH_TOKEN, DEVICE_ID
|
||
from test_utils import TestResult
|
||
|
||
|
||
async def test_tts_audio_stream():
|
||
"""TTS 音频流完整测试"""
|
||
print("=" * 60)
|
||
print(" 测试 TTS 音频流完整性")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
results = []
|
||
|
||
# ========== 测试 1: 闲聊场景 TTS 音频流 ==========
|
||
result = await test_scenario(
|
||
"闲聊场景 TTS",
|
||
"你好,今天天气怎么样?",
|
||
"chitchat",
|
||
check_audio_length=True,
|
||
min_audio_length_s=1.0, # 至少 1 秒
|
||
)
|
||
results.append(result)
|
||
result.print_report()
|
||
|
||
# ========== 测试 2: 飞控指令场景 TTS 音频流 ==========
|
||
result = await test_scenario(
|
||
"飞控指令场景 TTS",
|
||
"起飞然后在前方十米悬停",
|
||
"flight_intent",
|
||
check_audio_length=True,
|
||
min_audio_length_s=1.0,
|
||
)
|
||
results.append(result)
|
||
result.print_report()
|
||
|
||
# ========== 测试 3: 返航指令场景 TTS 音频流 ==========
|
||
result = await test_scenario(
|
||
"返航指令场景 TTS",
|
||
"返航",
|
||
"flight_intent",
|
||
check_audio_length=True,
|
||
min_audio_length_s=0.5,
|
||
)
|
||
results.append(result)
|
||
result.print_report()
|
||
|
||
# ========== 测试 4: 降落指令场景 TTS 音频流 ==========
|
||
result = await test_scenario(
|
||
"降落指令场景 TTS",
|
||
"降落",
|
||
"flight_intent",
|
||
check_audio_length=True,
|
||
min_audio_length_s=0.5,
|
||
)
|
||
results.append(result)
|
||
result.print_report()
|
||
|
||
# ========== 测试 5: 短文本 TTS 音频流 ==========
|
||
result = await test_scenario(
|
||
"短文本 TTS",
|
||
"你好",
|
||
"chitchat",
|
||
check_audio_length=True,
|
||
min_audio_length_s=0.3,
|
||
)
|
||
results.append(result)
|
||
result.print_report()
|
||
|
||
# ========== 测试 6: 音频流播放验证 ==========
|
||
result = await test_audio_playback(
|
||
"音频流播放验证",
|
||
"这是一段测试语音,用于验证音频流播放功能。"
|
||
)
|
||
results.append(result)
|
||
result.print_report()
|
||
|
||
# ========== 汇总 ==========
|
||
print("\n" + "=" * 60)
|
||
print(" TTS 音频流测试汇总")
|
||
print("=" * 60)
|
||
|
||
all_passed = True
|
||
for r in results:
|
||
status = "✅ 通过" if r.success else "❌ 失败"
|
||
print(f" {r.test_name:<30} {status}")
|
||
if not r.success and r.errors:
|
||
for err in r.errors:
|
||
print(f" - {err}")
|
||
|
||
print(f"\n 总计: {sum(1 for r in results if r.success)}/{len(results)} 通过")
|
||
print("=" * 60)
|
||
|
||
return all(r.success for r in results)
|
||
|
||
|
||
async def test_scenario(
|
||
name: str,
|
||
text: str,
|
||
expected_routing: str,
|
||
check_audio_length: bool = True,
|
||
min_audio_length_s: float = 0.5,
|
||
) -> TestResult:
|
||
"""测试一个场景的 TTS 音频流"""
|
||
result = TestResult(f"TTS 音频流 - {name}")
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
# 1. 建立会话
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": f"test-tts-{int(time.time())}",
|
||
"auth_token": AUTH_TOKEN,
|
||
"client": {
|
||
"device_id": DEVICE_ID,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(session_start, ensure_ascii=False))
|
||
|
||
# 接收 session.ready
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
data = json.loads(msg)
|
||
if data.get("type") != "session.ready":
|
||
result.add_error("未收到 session.ready")
|
||
return result.finalize()
|
||
|
||
# 2. 发送 turn.text
|
||
turn_id = f"turn-{int(time.time())}"
|
||
turn_text = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": turn_id,
|
||
"text": text,
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}
|
||
|
||
await ws.send(json.dumps(turn_text, ensure_ascii=False))
|
||
|
||
# 3. 接收响应
|
||
dialog_result = None
|
||
audio_chunks = []
|
||
audio_metadata = []
|
||
metrics = {}
|
||
|
||
while True:
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=60)
|
||
|
||
if isinstance(msg, bytes):
|
||
# 二进制音频数据
|
||
audio_chunks.append(msg)
|
||
else:
|
||
data = json.loads(msg)
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "dialog_result":
|
||
dialog_result = data
|
||
|
||
elif msg_type == "tts_audio_chunk":
|
||
# 记录 TTS 元数据
|
||
audio_metadata.append({
|
||
"seq": data.get("seq"),
|
||
"is_final": data.get("is_final"),
|
||
"codec": data.get("codec"),
|
||
"sample_rate_hz": data.get("sample_rate_hz"),
|
||
})
|
||
|
||
elif msg_type == "turn.complete":
|
||
metrics = data.get("metrics", {})
|
||
break
|
||
|
||
# 4. 验证 dialog_result
|
||
if dialog_result is None:
|
||
result.add_error("未收到 dialog_result")
|
||
return result.finalize()
|
||
|
||
actual_routing = dialog_result.get("routing")
|
||
if actual_routing != expected_routing:
|
||
result.add_error(
|
||
f"路由不匹配: 期望 {expected_routing}, 实际 {actual_routing}"
|
||
)
|
||
return result.finalize()
|
||
|
||
result.add_detail("routing", actual_routing)
|
||
|
||
# 5. 验证音频块序列
|
||
if len(audio_chunks) == 0:
|
||
result.add_error("未收到任何音频数据")
|
||
return result.finalize()
|
||
|
||
result.add_detail("audio_chunks_count", len(audio_chunks))
|
||
result.add_detail("audio_metadata_count", len(audio_metadata))
|
||
|
||
# 验证音频块序号连续性
|
||
for i, meta in enumerate(audio_metadata):
|
||
if meta["seq"] != i:
|
||
result.add_error(f"音频块序号不连续: 期望 {i}, 实际 {meta['seq']}")
|
||
break
|
||
|
||
if meta["codec"] != "pcm_s16le":
|
||
result.add_error(f"音频编码格式错误: 期望 pcm_s16le, 实际 {meta['codec']}")
|
||
break
|
||
|
||
if meta["sample_rate_hz"] != 24000:
|
||
result.add_error(f"音频采样率错误: 期望 24000, 实际 {meta['sample_rate_hz']}")
|
||
break
|
||
|
||
# 验证最后一个块标记为 final
|
||
if audio_metadata:
|
||
if not audio_metadata[-1]["is_final"]:
|
||
result.add_error("最后一个音频块未标记为 is_final=True")
|
||
|
||
# 6. 拼接音频数据
|
||
full_pcm = b"".join(audio_chunks)
|
||
audio_data = np.frombuffer(full_pcm, dtype=np.int16)
|
||
audio_length_s = len(audio_data) / 24000.0
|
||
|
||
result.add_detail("audio_length_s", f"{audio_length_s:.2f}")
|
||
result.add_detail("audio_samples", len(audio_data))
|
||
|
||
# 7. 验证音频长度
|
||
if check_audio_length and audio_length_s < min_audio_length_s:
|
||
result.add_error(
|
||
f"音频长度过短: {audio_length_s:.2f}s < {min_audio_length_s:.2f}s"
|
||
)
|
||
|
||
# 8. 验证音频数据有效性
|
||
if np.max(np.abs(audio_data)) == 0:
|
||
result.add_error("音频数据全为零(静音)")
|
||
|
||
# 9. 播放音频验证
|
||
print(f"\n 🔊 播放测试音频 ({audio_length_s:.2f}s)...")
|
||
try:
|
||
audio_float = audio_data.astype(np.float32) / 32768.0
|
||
sd.play(audio_float, samplerate=24000, blocking=True)
|
||
print(" ✅ 音频播放成功")
|
||
except Exception as e:
|
||
result.add_error(f"音频播放失败: {e}")
|
||
|
||
# 10. 验证性能指标
|
||
llm_ms = metrics.get("llm_ms", 0) or 0
|
||
tts_ms = metrics.get("tts_first_byte_ms", 0) or 0
|
||
|
||
result.add_detail("llm_latency_ms", llm_ms)
|
||
result.add_detail("tts_latency_ms", tts_ms)
|
||
|
||
if llm_ms == 0 and tts_ms == 0:
|
||
result.add_error("性能指标为空")
|
||
|
||
# 11. 结束会话
|
||
session_end = {
|
||
"type": "session.end",
|
||
"proto_version": "1.0",
|
||
"session_id": session_start["session_id"]
|
||
}
|
||
await ws.send(json.dumps(session_end))
|
||
|
||
# 所有检查通过
|
||
result.mark_success()
|
||
result.add_detail("test_text", text)
|
||
result.add_detail("protocol_modules", [
|
||
"session.start/ready",
|
||
"turn.text",
|
||
"dialog_result",
|
||
"tts_audio_chunk (text+binary)",
|
||
"turn.complete",
|
||
"session.end"
|
||
])
|
||
|
||
except Exception as e:
|
||
result.add_error(f"测试异常: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
return result.finalize()
|
||
|
||
|
||
async def test_audio_playback(name: str, text: str) -> TestResult:
|
||
"""专门测试音频播放功能"""
|
||
result = TestResult(f"TTS 音频播放 - {name}")
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
# 建立会话
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": f"test-playback-{int(time.time())}",
|
||
"auth_token": AUTH_TOKEN,
|
||
"client": {
|
||
"device_id": DEVICE_ID,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(session_start, ensure_ascii=False))
|
||
await asyncio.wait_for(ws.recv(), timeout=10) # session.ready
|
||
|
||
# 发送文本
|
||
turn_id = f"turn-{int(time.time())}"
|
||
turn_text = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": turn_id,
|
||
"text": text,
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}
|
||
|
||
await ws.send(json.dumps(turn_text, ensure_ascii=False))
|
||
|
||
# 接收并收集音频
|
||
audio_chunks = []
|
||
dialog_result = None
|
||
metrics = {}
|
||
|
||
while True:
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=60)
|
||
|
||
if isinstance(msg, bytes):
|
||
audio_chunks.append(msg)
|
||
else:
|
||
data = json.loads(msg)
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "dialog_result":
|
||
dialog_result = data
|
||
|
||
elif msg_type == "turn.complete":
|
||
metrics = data.get("metrics", {})
|
||
break
|
||
|
||
# 播放音频
|
||
if audio_chunks:
|
||
full_pcm = b"".join(audio_chunks)
|
||
audio_data = np.frombuffer(full_pcm, dtype=np.int16)
|
||
audio_length_s = len(audio_data) / 24000.0
|
||
|
||
print(f"\n 🎵 完整音频信息:")
|
||
print(f" 长度: {audio_length_s:.2f}s")
|
||
print(f" 采样数: {len(audio_data)}")
|
||
print(f" 峰值: {np.max(np.abs(audio_data))}")
|
||
print(f" 均值: {np.mean(np.abs(audio_data)):.1f}")
|
||
|
||
# 播放
|
||
print(f"\n 🔊 播放音频...")
|
||
t_start = time.time()
|
||
audio_float = audio_data.astype(np.float32) / 32768.0
|
||
sd.play(audio_float, samplerate=24000, blocking=True)
|
||
t_play = time.time() - t_start
|
||
|
||
print(f" ✅ 播放完成 (耗时 {t_play:.2f}s)")
|
||
|
||
# 验证实时率
|
||
realtime_factor = audio_length_s / t_play if t_play > 0 else 0
|
||
print(f" 📊 实时率: {realtime_factor:.2f}x")
|
||
|
||
if realtime_factor < 1.0:
|
||
result.add_warning(
|
||
f"实时率低于 1.0: {realtime_factor:.2f}x "
|
||
f"(可能导致播放卡顿)"
|
||
)
|
||
|
||
result.mark_success()
|
||
result.add_detail("audio_length_s", f"{audio_length_s:.2f}")
|
||
result.add_detail("playback_time_s", f"{t_play:.2f}")
|
||
result.add_detail("realtime_factor", f"{realtime_factor:.2f}")
|
||
result.add_detail("metrics", metrics)
|
||
else:
|
||
result.add_error("未收到音频数据")
|
||
|
||
# 结束会话
|
||
await ws.send(json.dumps({
|
||
"type": "session.end",
|
||
"proto_version": "1.0",
|
||
"session_id": session_start["session_id"]
|
||
}))
|
||
|
||
except Exception as e:
|
||
result.add_error(f"测试异常: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
return result.finalize()
|
||
|
||
|
||
async def main():
|
||
print("\n" + "=" * 60)
|
||
print(" TTS 音频流完整性测试")
|
||
print(" 覆盖: session管理 + dialog_result + tts_audio_chunk + turn.complete")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
success = await test_tts_audio_stream()
|
||
|
||
print("\n" + "=" * 60)
|
||
if success:
|
||
print(" ✅ 所有 TTS 音频流测试通过!")
|
||
print(" 已验证:")
|
||
print(" - session.start/ready 会话管理")
|
||
print(" - turn.text 文本发送")
|
||
print(" - dialog_result 结构化结果")
|
||
print(" - tts_audio_chunk 流式音频(text 头 + binary 体)")
|
||
print(" - turn.complete 完成通知")
|
||
print(" - 音频数据完整性和长度")
|
||
print(" - 音频流本地播放")
|
||
else:
|
||
print(" ❌ 部分 TTS 音频流测试失败!")
|
||
print("=" * 60)
|
||
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|