DroneMind/voicellmcloud/test/test_03_errors.py
2026-04-14 10:08:41 +08:00

261 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 03: 错误处理测试
测试各种错误场景
"""
import asyncio
import json
import sys
from pathlib import Path
import websockets
sys.path.insert(0, str(Path(__file__).parent))
from test_config import SERVER_URL, AUTH_TOKEN, DEVICE_ID
from test_utils import TestResult
async def test_errors():
"""错误处理测试"""
results = []
# ========== 测试 1: 鉴权失败 ==========
result = await test_unauthorized()
results.append(result)
# ========== 测试 2: 非法消息格式 ==========
result = await test_invalid_message()
results.append(result)
# ========== 测试 3: 不支持的音频消息 ==========
result = await test_unsupported_audio()
results.append(result)
# ========== 测试 4: 缺少 session.start ==========
result = await test_missing_session_start()
results.append(result)
# 汇总结果
print("\n" + "=" * 60)
print(" 错误处理测试汇总")
print("=" * 60)
all_passed = True
for r in results:
r.print_report()
if not r.success:
all_passed = False
return all_passed
async def test_unauthorized() -> TestResult:
"""测试鉴权失败"""
result = TestResult("鉴权失败测试")
try:
async with websockets.connect(SERVER_URL) as ws:
# 使用错误的 token
session_start = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": "test-unauth-001",
"auth_token": "wrong-token",
"client": {
"device_id": DEVICE_ID,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
}
}
await ws.send(json.dumps(session_start, ensure_ascii=False))
print("→ session.start (错误 token)")
# 应该收到错误
msg = await asyncio.wait_for(ws.recv(), timeout=10)
data = json.loads(msg)
if data.get("type") == "error" and data.get("code") == "UNAUTHORIZED":
print(f"← error (code={data.get('code')})")
print("✅ 正确拒绝未授权请求")
result.mark_success()
else:
result.add_error(f"期望 UNAUTHORIZED 错误,收到: {data}")
except Exception as e:
result.add_error(f"测试异常: {e}")
return result.finalize()
async def test_invalid_message() -> TestResult:
"""测试非法消息格式"""
result = TestResult("非法消息格式测试")
try:
async with websockets.connect(SERVER_URL) as ws:
# 先建立正常会话
session_start = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": "test-invalid-001",
"auth_token": AUTH_TOKEN,
"client": {
"device_id": DEVICE_ID,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
}
}
await ws.send(json.dumps(session_start))
await ws.recv() # session.ready
# 发送缺少必需字段的消息
invalid_msg = {
"type": "turn.text",
"proto_version": "1.0",
"transport_profile": "text_uplink",
# 缺少 turn_id 和 text
}
await ws.send(json.dumps(invalid_msg))
print("→ turn.text (缺少必需字段)")
msg = await asyncio.wait_for(ws.recv(), timeout=10)
data = json.loads(msg)
if data.get("type") == "error" and data.get("code") == "INVALID_MESSAGE":
print(f"← error (code={data.get('code')})")
print("✅ 正确拒绝非法消息")
result.mark_success()
else:
result.add_error(f"期望 INVALID_MESSAGE 错误,收到: {data}")
except Exception as e:
result.add_error(f"测试异常: {e}")
return result.finalize()
async def test_unsupported_audio() -> TestResult:
"""测试不支持的音频消息"""
result = TestResult("不支持的音频消息测试")
try:
async with websockets.connect(SERVER_URL) as ws:
# 建立会话
session_start = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": "test-audio-001",
"auth_token": AUTH_TOKEN,
"client": {
"device_id": DEVICE_ID,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
}
}
await ws.send(json.dumps(session_start))
await ws.recv() # session.ready
# 发送音频块text_uplink 模式不允许)
audio_chunk_msg = {
"type": "turn.audio_chunk",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"turn_id": "test-turn-001",
"seq": 0,
}
await ws.send(json.dumps(audio_chunk_msg))
print("→ turn.audio_chunk (不允许的消息类型)")
msg = await asyncio.wait_for(ws.recv(), timeout=10)
data = json.loads(msg)
if data.get("type") == "error" and data.get("code") == "INVALID_MESSAGE":
print(f"← error (code={data.get('code')})")
print("✅ 正确拒绝音频消息")
result.mark_success()
else:
result.add_error(f"期望 INVALID_MESSAGE 错误,收到: {data}")
except Exception as e:
result.add_error(f"测试异常: {e}")
return result.finalize()
async def test_missing_session_start() -> TestResult:
"""测试缺少 session.start"""
result = TestResult("缺少 session.start 测试")
try:
async with websockets.connect(SERVER_URL) as ws:
# 不发送 session.start直接发送 turn.text
turn_text = {
"type": "turn.text",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"turn_id": "test-turn-001",
"text": "你好",
"is_final": True,
"source": "device_stt"
}
await ws.send(json.dumps(turn_text))
print("→ turn.text (未发送 session.start)")
msg = await asyncio.wait_for(ws.recv(), timeout=10)
data = json.loads(msg)
if data.get("type") == "error":
print(f"← error (code={data.get('code')})")
print("✅ 正确处理缺少 session.start 的情况")
result.mark_success()
else:
result.add_error(f"期望 error收到: {data}")
except Exception as e:
result.add_error(f"测试异常: {e}")
return result.finalize()
async def main():
print("=" * 60)
print(" 测试 03: 错误处理测试")
print("=" * 60)
print()
success = await test_errors()
print("\n" + "=" * 60)
if success:
print(" ✅ 所有错误处理测试通过!")
else:
print(" ❌ 部分测试失败!")
print("=" * 60)
sys.exit(0 if success else 1)
if __name__ == "__main__":
asyncio.run(main())