195 lines
5.9 KiB
Python
195 lines
5.9 KiB
Python
"""
|
||
测试 04: 多会话并发测试
|
||
模拟多架无人机同时连接
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import websockets
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from test_config import SERVER_URL, AUTH_TOKEN
|
||
|
||
|
||
async def test_concurrent():
|
||
"""并发会话测试"""
|
||
print("=" * 60)
|
||
print(" 测试 04: 多会话并发测试")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
num_sessions = 4 # 模拟 4 架无人机
|
||
tasks = []
|
||
|
||
print(f"🚀 启动 {num_sessions} 个并发会话...")
|
||
print()
|
||
|
||
# 创建多个会话
|
||
for i in range(num_sessions):
|
||
task = asyncio.create_task(
|
||
run_session(i + 1, f"drone-{i+1:03d}")
|
||
)
|
||
tasks.append(task)
|
||
|
||
# 等待所有会话完成
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 汇总结果
|
||
print("\n" + "=" * 60)
|
||
print(" 并发测试汇总")
|
||
print("=" * 60)
|
||
|
||
success_count = 0
|
||
fail_count = 0
|
||
|
||
for i, result in enumerate(results):
|
||
drone_id = f"drone-{i+1:03d}"
|
||
if isinstance(result, Exception):
|
||
print(f" ❌ {drone_id}: 异常 - {result}")
|
||
fail_count += 1
|
||
elif result.get("success"):
|
||
print(f" ✅ {drone_id}: 通过 (LLM={result.get('llm_ms')}ms, TTS={result.get('tts_ms')}ms)")
|
||
success_count += 1
|
||
else:
|
||
print(f" ❌ {drone_id}: 失败 - {result.get('error')}")
|
||
fail_count += 1
|
||
|
||
print()
|
||
print(f" 总计: {success_count} 通过, {fail_count} 失败")
|
||
print("=" * 60)
|
||
|
||
return fail_count == 0
|
||
|
||
|
||
async def run_session(session_num: int, device_id: str) -> dict:
|
||
"""运行单个会话"""
|
||
session_id = f"test-session-{session_num:03d}"
|
||
turn_id = f"test-turn-{session_num:03d}"
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
# 1. 建立会话
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": session_id,
|
||
"auth_token": AUTH_TOKEN,
|
||
"client": {
|
||
"device_id": device_id,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(session_start, ensure_ascii=False))
|
||
|
||
# 接收 session.ready
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") != "session.ready":
|
||
return {"success": False, "error": f"期望 session.ready,收到 {data.get('type')}"}
|
||
|
||
print(f" [{device_id}] ✅ 会话建立")
|
||
|
||
# 2. 发送测试文本
|
||
test_texts = [
|
||
"你好",
|
||
"起飞",
|
||
"返航",
|
||
"降落",
|
||
]
|
||
|
||
test_text = test_texts[(session_num - 1) % len(test_texts)]
|
||
|
||
turn_text = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": turn_id,
|
||
"text": test_text,
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}
|
||
|
||
t_start = time.time()
|
||
await ws.send(json.dumps(turn_text, ensure_ascii=False))
|
||
print(f" [{device_id}] → {test_text}")
|
||
|
||
# 3. 接收响应
|
||
dialog_result = None
|
||
audio_chunks = []
|
||
metrics = {}
|
||
|
||
while True:
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=60)
|
||
|
||
if isinstance(msg, bytes):
|
||
audio_chunks.append(msg)
|
||
else:
|
||
data = json.loads(msg)
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "dialog_result":
|
||
dialog_result = data
|
||
routing = data.get("routing")
|
||
print(f" [{device_id}] ← routing={routing}")
|
||
|
||
elif msg_type == "turn.complete":
|
||
metrics = data.get("metrics", {})
|
||
break
|
||
|
||
elif msg_type == "error":
|
||
return {
|
||
"success": False,
|
||
"error": f"{data.get('code')}: {data.get('message')}"
|
||
}
|
||
|
||
t_end = time.time()
|
||
total_time = t_end - t_start
|
||
|
||
print(f" [{device_id}] ✅ 完成 (耗时={total_time:.2f}s)")
|
||
|
||
# 结束会话
|
||
session_end = {
|
||
"type": "session.end",
|
||
"proto_version": "1.0",
|
||
"session_id": session_id
|
||
}
|
||
await ws.send(json.dumps(session_end))
|
||
|
||
return {
|
||
"success": True,
|
||
"device_id": device_id,
|
||
"routing": dialog_result.get("routing") if dialog_result else None,
|
||
"llm_ms": metrics.get("llm_ms"),
|
||
"tts_ms": metrics.get("tts_first_byte_ms"),
|
||
"total_time": total_time,
|
||
"audio_chunks": len(audio_chunks)
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
async def main():
|
||
success = await test_concurrent()
|
||
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|