DroneMind/voicellmcloud/test/test_01_basic.py
2026-04-14 10:08:41 +08:00

310 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 01: 基础功能测试
测试 WebSocket 连接、会话建立、闲聊和飞控指令识别
"""
import asyncio
import json
import sys
import time
from pathlib import Path
import websockets
import numpy as np
import sounddevice as sd
# 添加测试目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from test_config import SERVER_URL, AUTH_TOKEN, DEVICE_ID, TEST_CASES
from test_utils import play_audio, print_json, calculate_metrics, print_metrics, TestResult
async def test_basic():
"""基础功能测试"""
result = TestResult("基础功能测试")
print(f"\n🔗 连接到: {SERVER_URL}")
try:
async with websockets.connect(SERVER_URL) as ws:
print("✅ WebSocket 连接成功\n")
# ========== 1. 建立会话 ==========
print("=" * 60)
print(" 测试 1: 建立会话")
print("=" * 60)
session_start = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": "test-session-001",
"auth_token": AUTH_TOKEN,
"client": {
"device_id": DEVICE_ID,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {
"dialog_result": "cloud_voice_dialog_v1"
},
}
}
await ws.send(json.dumps(session_start, ensure_ascii=False))
print("→ session.start")
# 接收 session.ready
ready_msg = await ws.recv()
ready_data = json.loads(ready_msg)
if ready_data.get("type") != "session.ready":
result.add_error(f"期望 session.ready收到: {ready_data.get('type')}")
result.print_report()
return result
print("← session.ready")
print(f" 服务端能力: {json.dumps(ready_data.get('server_caps', {}), ensure_ascii=False)}")
print("✅ 会话建立成功\n")
# ========== 2. 闲聊测试 ==========
print("=" * 60)
print(" 测试 2: 闲聊对话")
print("=" * 60)
test_text = TEST_CASES["chitchat"][0]
turn_result = await test_turn(ws, test_text, expect_routing="chitchat")
if not turn_result["success"]:
result.add_error(f"闲聊测试失败: {turn_result['error']}")
else:
print("✅ 闲聊测试通过\n")
# ========== 3. 飞控指令测试 ==========
print("=" * 60)
print(" 测试 3: 飞控指令识别")
print("=" * 60)
test_text = "起飞然后在前方十米悬停"
turn_result = await test_turn(ws, test_text, expect_routing="flight_intent")
if not turn_result["success"]:
result.add_error(f"飞控指令测试失败: {turn_result['error']}")
else:
print("✅ 飞控指令测试通过\n")
# 验证飞控意图结构
flight_intent = turn_result.get("flight_intent")
if flight_intent:
print(" 📋 飞控意图详情:")
print(f" Summary: {flight_intent.get('summary')}")
print(f" Actions: {json.dumps(flight_intent.get('actions', []), ensure_ascii=False, indent=6)}")
# ========== 4. 返航测试 ==========
print("\n" + "=" * 60)
print(" 测试 4: 返航指令")
print("=" * 60)
turn_result = await test_turn(ws, "返航", expect_routing="flight_intent")
if turn_result["success"]:
actions = turn_result.get("flight_intent", {}).get("actions", [])
has_return_home = any(a.get("type") == "return_home" for a in actions)
if has_return_home:
print("✅ 返航指令识别正确\n")
else:
result.add_error("返航指令未识别到 return_home 动作")
# ========== 5. 结束会话 ==========
print("=" * 60)
print(" 测试 5: 结束会话")
print("=" * 60)
session_end = {
"type": "session.end",
"proto_version": "1.0",
"session_id": "test-session-001"
}
await ws.send(json.dumps(session_end))
print("→ session.end")
print("✅ 会话已结束\n")
result.mark_success()
result.add_detail("test_cases", "闲聊 + 飞控 + 返航")
except Exception as e:
result.add_error(f"测试异常: {e}")
import traceback
traceback.print_exc()
return result.finalize().print_report()
async def test_turn(ws, text: str, expect_routing: str) -> dict:
"""
测试单轮对话
Returns:
dict: {
"success": bool,
"routing": str,
"flight_intent": dict or None,
"chat_reply": str or None,
"audio_data": np.ndarray or None,
"metrics": dict,
"error": str or None
}
"""
print(f"\n 输入: {text}")
print(f" 期望路由: {expect_routing}")
turn_id = f"turn-{int(time.time())}"
audio_chunks = []
dialog_result = None
metrics = {}
t_start = time.time()
try:
# 发送 turn.text
turn_text = {
"type": "turn.text",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"turn_id": turn_id,
"text": text,
"is_final": True,
"source": "device_stt"
}
await ws.send(json.dumps(turn_text, ensure_ascii=False))
print(f" → turn.text (turn_id={turn_id[:8]}...)")
# 接收响应
while True:
msg = await asyncio.wait_for(ws.recv(), timeout=60)
if isinstance(msg, bytes):
# 音频数据
audio_chunks.append(msg)
else:
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "dialog_result":
dialog_result = data
routing = data.get("routing")
print(f" ← dialog_result (routing={routing})")
if routing == "flight_intent":
intent = data.get("flight_intent")
if intent:
print(f" Summary: {intent.get('summary')}")
elif routing == "chitchat":
reply = data.get("chat_reply")
print(f" Reply: {reply}")
elif msg_type == "llm.text_delta":
# 流式 LLM 增量;旧测试仅忽略,仍以 dialog_result 为准
if data.get("done"):
print(f" ← llm.text_delta (done)")
elif msg_type == "tts_audio_chunk":
seq = data.get("seq")
is_final = data.get("is_final")
if seq == 0:
print(f" ← TTS 音频流开始")
if is_final:
print(f" ← TTS 音频流结束 (seq={seq})")
elif msg_type == "turn.complete":
metrics = data.get("metrics", {})
print(f" ← turn.complete")
print(f" LLM: {metrics.get('llm_ms')}ms")
print(f" TTS: {metrics.get('tts_first_byte_ms')}ms")
break
elif msg_type == "error":
print(f" ❌ 错误: {data.get('code')} - {data.get('message')}")
return {
"success": False,
"error": f"{data.get('code')}: {data.get('message')}"
}
t_end = time.time()
total_time = t_end - t_start
# 拼接音频
audio_data = None
if audio_chunks:
full_pcm = b"".join(audio_chunks)
audio_data = np.frombuffer(full_pcm, dtype=np.int16)
audio_length = len(audio_data) / 24000
print(f"\n 📊 生成音频: {audio_length:.2f}s ({len(audio_data)} samples)")
# 验证路由类型
actual_routing = dialog_result.get("routing") if dialog_result else None
if actual_routing != expect_routing:
return {
"success": False,
"error": f"路由不匹配: 期望 {expect_routing}, 实际 {actual_routing}"
}
# 播放音频
if audio_data is not None and len(audio_data) > 0:
print(f"\n 🔊 播放 TTS 音频...")
play_audio(audio_data, sample_rate=24000)
# 计算指标
perf_metrics = calculate_metrics(
llm_ms=metrics.get("llm_ms"),
tts_ms=metrics.get("tts_first_byte_ms"),
audio_length_s=len(audio_data) / 24000 if audio_data is not None else 0,
total_time_s=total_time
)
print_metrics(perf_metrics)
return {
"success": True,
"routing": actual_routing,
"flight_intent": dialog_result.get("flight_intent") if dialog_result else None,
"chat_reply": dialog_result.get("chat_reply") if dialog_result else None,
"audio_data": audio_data,
"metrics": perf_metrics
}
except asyncio.TimeoutError:
return {
"success": False,
"error": "响应超时 (60s)"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
async def main():
print("=" * 60)
print(" 测试 01: 基础功能测试")
print("=" * 60)
success = await test_basic()
print("\n" + "=" * 60)
if success:
print(" ✅ 测试通过!")
else:
print(" ❌ 测试失败!")
print("=" * 60)
sys.exit(0 if success else 1)
if __name__ == "__main__":
asyncio.run(main())