DroneMind/voicellmcloud/test/test_04_concurrent.py
2026-04-14 10:08:41 +08:00

195 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 04: 多会话并发测试
模拟多架无人机同时连接
"""
import asyncio
import json
import sys
import time
from pathlib import Path
import websockets
sys.path.insert(0, str(Path(__file__).parent))
from test_config import SERVER_URL, AUTH_TOKEN
async def test_concurrent():
"""并发会话测试"""
print("=" * 60)
print(" 测试 04: 多会话并发测试")
print("=" * 60)
print()
num_sessions = 4 # 模拟 4 架无人机
tasks = []
print(f"🚀 启动 {num_sessions} 个并发会话...")
print()
# 创建多个会话
for i in range(num_sessions):
task = asyncio.create_task(
run_session(i + 1, f"drone-{i+1:03d}")
)
tasks.append(task)
# 等待所有会话完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 汇总结果
print("\n" + "=" * 60)
print(" 并发测试汇总")
print("=" * 60)
success_count = 0
fail_count = 0
for i, result in enumerate(results):
drone_id = f"drone-{i+1:03d}"
if isinstance(result, Exception):
print(f"{drone_id}: 异常 - {result}")
fail_count += 1
elif result.get("success"):
print(f"{drone_id}: 通过 (LLM={result.get('llm_ms')}ms, TTS={result.get('tts_ms')}ms)")
success_count += 1
else:
print(f"{drone_id}: 失败 - {result.get('error')}")
fail_count += 1
print()
print(f" 总计: {success_count} 通过, {fail_count} 失败")
print("=" * 60)
return fail_count == 0
async def run_session(session_num: int, device_id: str) -> dict:
"""运行单个会话"""
session_id = f"test-session-{session_num:03d}"
turn_id = f"test-turn-{session_num:03d}"
try:
async with websockets.connect(SERVER_URL) as ws:
# 1. 建立会话
session_start = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": session_id,
"auth_token": AUTH_TOKEN,
"client": {
"device_id": device_id,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
}
}
await ws.send(json.dumps(session_start, ensure_ascii=False))
# 接收 session.ready
msg = await asyncio.wait_for(ws.recv(), timeout=10)
data = json.loads(msg)
if data.get("type") != "session.ready":
return {"success": False, "error": f"期望 session.ready收到 {data.get('type')}"}
print(f" [{device_id}] ✅ 会话建立")
# 2. 发送测试文本
test_texts = [
"你好",
"起飞",
"返航",
"降落",
]
test_text = test_texts[(session_num - 1) % len(test_texts)]
turn_text = {
"type": "turn.text",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"turn_id": turn_id,
"text": test_text,
"is_final": True,
"source": "device_stt"
}
t_start = time.time()
await ws.send(json.dumps(turn_text, ensure_ascii=False))
print(f" [{device_id}] → {test_text}")
# 3. 接收响应
dialog_result = None
audio_chunks = []
metrics = {}
while True:
msg = await asyncio.wait_for(ws.recv(), timeout=60)
if isinstance(msg, bytes):
audio_chunks.append(msg)
else:
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "dialog_result":
dialog_result = data
routing = data.get("routing")
print(f" [{device_id}] ← routing={routing}")
elif msg_type == "turn.complete":
metrics = data.get("metrics", {})
break
elif msg_type == "error":
return {
"success": False,
"error": f"{data.get('code')}: {data.get('message')}"
}
t_end = time.time()
total_time = t_end - t_start
print(f" [{device_id}] ✅ 完成 (耗时={total_time:.2f}s)")
# 结束会话
session_end = {
"type": "session.end",
"proto_version": "1.0",
"session_id": session_id
}
await ws.send(json.dumps(session_end))
return {
"success": True,
"device_id": device_id,
"routing": dialog_result.get("routing") if dialog_result else None,
"llm_ms": metrics.get("llm_ms"),
"tts_ms": metrics.get("tts_first_byte_ms"),
"total_time": total_time,
"audio_chunks": len(audio_chunks)
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
async def main():
success = await test_concurrent()
sys.exit(0 if success else 1)
if __name__ == "__main__":
asyncio.run(main())