310 lines
11 KiB
Python
310 lines
11 KiB
Python
"""
|
||
测试 01: 基础功能测试
|
||
测试 WebSocket 连接、会话建立、闲聊和飞控指令识别
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import websockets
|
||
import numpy as np
|
||
import sounddevice as sd
|
||
|
||
# 添加测试目录到路径
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from test_config import SERVER_URL, AUTH_TOKEN, DEVICE_ID, TEST_CASES
|
||
from test_utils import play_audio, print_json, calculate_metrics, print_metrics, TestResult
|
||
|
||
|
||
async def test_basic():
|
||
"""基础功能测试"""
|
||
result = TestResult("基础功能测试")
|
||
|
||
print(f"\n🔗 连接到: {SERVER_URL}")
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
print("✅ WebSocket 连接成功\n")
|
||
|
||
# ========== 1. 建立会话 ==========
|
||
print("=" * 60)
|
||
print(" 测试 1: 建立会话")
|
||
print("=" * 60)
|
||
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": "test-session-001",
|
||
"auth_token": AUTH_TOKEN,
|
||
"client": {
|
||
"device_id": DEVICE_ID,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {
|
||
"dialog_result": "cloud_voice_dialog_v1"
|
||
},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(session_start, ensure_ascii=False))
|
||
print("→ session.start")
|
||
|
||
# 接收 session.ready
|
||
ready_msg = await ws.recv()
|
||
ready_data = json.loads(ready_msg)
|
||
|
||
if ready_data.get("type") != "session.ready":
|
||
result.add_error(f"期望 session.ready,收到: {ready_data.get('type')}")
|
||
result.print_report()
|
||
return result
|
||
|
||
print("← session.ready")
|
||
print(f" 服务端能力: {json.dumps(ready_data.get('server_caps', {}), ensure_ascii=False)}")
|
||
print("✅ 会话建立成功\n")
|
||
|
||
# ========== 2. 闲聊测试 ==========
|
||
print("=" * 60)
|
||
print(" 测试 2: 闲聊对话")
|
||
print("=" * 60)
|
||
|
||
test_text = TEST_CASES["chitchat"][0]
|
||
turn_result = await test_turn(ws, test_text, expect_routing="chitchat")
|
||
|
||
if not turn_result["success"]:
|
||
result.add_error(f"闲聊测试失败: {turn_result['error']}")
|
||
else:
|
||
print("✅ 闲聊测试通过\n")
|
||
|
||
# ========== 3. 飞控指令测试 ==========
|
||
print("=" * 60)
|
||
print(" 测试 3: 飞控指令识别")
|
||
print("=" * 60)
|
||
|
||
test_text = "起飞然后在前方十米悬停"
|
||
turn_result = await test_turn(ws, test_text, expect_routing="flight_intent")
|
||
|
||
if not turn_result["success"]:
|
||
result.add_error(f"飞控指令测试失败: {turn_result['error']}")
|
||
else:
|
||
print("✅ 飞控指令测试通过\n")
|
||
|
||
# 验证飞控意图结构
|
||
flight_intent = turn_result.get("flight_intent")
|
||
if flight_intent:
|
||
print(" 📋 飞控意图详情:")
|
||
print(f" Summary: {flight_intent.get('summary')}")
|
||
print(f" Actions: {json.dumps(flight_intent.get('actions', []), ensure_ascii=False, indent=6)}")
|
||
|
||
# ========== 4. 返航测试 ==========
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 4: 返航指令")
|
||
print("=" * 60)
|
||
|
||
turn_result = await test_turn(ws, "返航", expect_routing="flight_intent")
|
||
|
||
if turn_result["success"]:
|
||
actions = turn_result.get("flight_intent", {}).get("actions", [])
|
||
has_return_home = any(a.get("type") == "return_home" for a in actions)
|
||
if has_return_home:
|
||
print("✅ 返航指令识别正确\n")
|
||
else:
|
||
result.add_error("返航指令未识别到 return_home 动作")
|
||
|
||
# ========== 5. 结束会话 ==========
|
||
print("=" * 60)
|
||
print(" 测试 5: 结束会话")
|
||
print("=" * 60)
|
||
|
||
session_end = {
|
||
"type": "session.end",
|
||
"proto_version": "1.0",
|
||
"session_id": "test-session-001"
|
||
}
|
||
|
||
await ws.send(json.dumps(session_end))
|
||
print("→ session.end")
|
||
print("✅ 会话已结束\n")
|
||
|
||
result.mark_success()
|
||
result.add_detail("test_cases", "闲聊 + 飞控 + 返航")
|
||
|
||
except Exception as e:
|
||
result.add_error(f"测试异常: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
return result.finalize().print_report()
|
||
|
||
|
||
async def test_turn(ws, text: str, expect_routing: str) -> dict:
|
||
"""
|
||
测试单轮对话
|
||
|
||
Returns:
|
||
dict: {
|
||
"success": bool,
|
||
"routing": str,
|
||
"flight_intent": dict or None,
|
||
"chat_reply": str or None,
|
||
"audio_data": np.ndarray or None,
|
||
"metrics": dict,
|
||
"error": str or None
|
||
}
|
||
"""
|
||
print(f"\n 输入: {text}")
|
||
print(f" 期望路由: {expect_routing}")
|
||
|
||
turn_id = f"turn-{int(time.time())}"
|
||
audio_chunks = []
|
||
dialog_result = None
|
||
metrics = {}
|
||
|
||
t_start = time.time()
|
||
|
||
try:
|
||
# 发送 turn.text
|
||
turn_text = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": turn_id,
|
||
"text": text,
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}
|
||
|
||
await ws.send(json.dumps(turn_text, ensure_ascii=False))
|
||
print(f" → turn.text (turn_id={turn_id[:8]}...)")
|
||
|
||
# 接收响应
|
||
while True:
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=60)
|
||
|
||
if isinstance(msg, bytes):
|
||
# 音频数据
|
||
audio_chunks.append(msg)
|
||
else:
|
||
data = json.loads(msg)
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "dialog_result":
|
||
dialog_result = data
|
||
routing = data.get("routing")
|
||
print(f" ← dialog_result (routing={routing})")
|
||
|
||
if routing == "flight_intent":
|
||
intent = data.get("flight_intent")
|
||
if intent:
|
||
print(f" Summary: {intent.get('summary')}")
|
||
elif routing == "chitchat":
|
||
reply = data.get("chat_reply")
|
||
print(f" Reply: {reply}")
|
||
|
||
elif msg_type == "llm.text_delta":
|
||
# 流式 LLM 增量;旧测试仅忽略,仍以 dialog_result 为准
|
||
if data.get("done"):
|
||
print(f" ← llm.text_delta (done)")
|
||
|
||
elif msg_type == "tts_audio_chunk":
|
||
seq = data.get("seq")
|
||
is_final = data.get("is_final")
|
||
if seq == 0:
|
||
print(f" ← TTS 音频流开始")
|
||
if is_final:
|
||
print(f" ← TTS 音频流结束 (seq={seq})")
|
||
|
||
elif msg_type == "turn.complete":
|
||
metrics = data.get("metrics", {})
|
||
print(f" ← turn.complete")
|
||
print(f" LLM: {metrics.get('llm_ms')}ms")
|
||
print(f" TTS: {metrics.get('tts_first_byte_ms')}ms")
|
||
break
|
||
|
||
elif msg_type == "error":
|
||
print(f" ❌ 错误: {data.get('code')} - {data.get('message')}")
|
||
return {
|
||
"success": False,
|
||
"error": f"{data.get('code')}: {data.get('message')}"
|
||
}
|
||
|
||
t_end = time.time()
|
||
total_time = t_end - t_start
|
||
|
||
# 拼接音频
|
||
audio_data = None
|
||
if audio_chunks:
|
||
full_pcm = b"".join(audio_chunks)
|
||
audio_data = np.frombuffer(full_pcm, dtype=np.int16)
|
||
audio_length = len(audio_data) / 24000
|
||
print(f"\n 📊 生成音频: {audio_length:.2f}s ({len(audio_data)} samples)")
|
||
|
||
# 验证路由类型
|
||
actual_routing = dialog_result.get("routing") if dialog_result else None
|
||
if actual_routing != expect_routing:
|
||
return {
|
||
"success": False,
|
||
"error": f"路由不匹配: 期望 {expect_routing}, 实际 {actual_routing}"
|
||
}
|
||
|
||
# 播放音频
|
||
if audio_data is not None and len(audio_data) > 0:
|
||
print(f"\n 🔊 播放 TTS 音频...")
|
||
play_audio(audio_data, sample_rate=24000)
|
||
|
||
# 计算指标
|
||
perf_metrics = calculate_metrics(
|
||
llm_ms=metrics.get("llm_ms"),
|
||
tts_ms=metrics.get("tts_first_byte_ms"),
|
||
audio_length_s=len(audio_data) / 24000 if audio_data is not None else 0,
|
||
total_time_s=total_time
|
||
)
|
||
print_metrics(perf_metrics)
|
||
|
||
return {
|
||
"success": True,
|
||
"routing": actual_routing,
|
||
"flight_intent": dialog_result.get("flight_intent") if dialog_result else None,
|
||
"chat_reply": dialog_result.get("chat_reply") if dialog_result else None,
|
||
"audio_data": audio_data,
|
||
"metrics": perf_metrics
|
||
}
|
||
|
||
except asyncio.TimeoutError:
|
||
return {
|
||
"success": False,
|
||
"error": "响应超时 (60s)"
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
async def main():
|
||
print("=" * 60)
|
||
print(" 测试 01: 基础功能测试")
|
||
print("=" * 60)
|
||
|
||
success = await test_basic()
|
||
|
||
print("\n" + "=" * 60)
|
||
if success:
|
||
print(" ✅ 测试通过!")
|
||
else:
|
||
print(" ❌ 测试失败!")
|
||
print("=" * 60)
|
||
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|