261 lines
8.3 KiB
Python
261 lines
8.3 KiB
Python
"""
|
||
测试 03: 错误处理测试
|
||
测试各种错误场景
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import websockets
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from test_config import SERVER_URL, AUTH_TOKEN, DEVICE_ID
|
||
from test_utils import TestResult
|
||
|
||
|
||
async def test_errors():
|
||
"""错误处理测试"""
|
||
results = []
|
||
|
||
# ========== 测试 1: 鉴权失败 ==========
|
||
result = await test_unauthorized()
|
||
results.append(result)
|
||
|
||
# ========== 测试 2: 非法消息格式 ==========
|
||
result = await test_invalid_message()
|
||
results.append(result)
|
||
|
||
# ========== 测试 3: 不支持的音频消息 ==========
|
||
result = await test_unsupported_audio()
|
||
results.append(result)
|
||
|
||
# ========== 测试 4: 缺少 session.start ==========
|
||
result = await test_missing_session_start()
|
||
results.append(result)
|
||
|
||
# 汇总结果
|
||
print("\n" + "=" * 60)
|
||
print(" 错误处理测试汇总")
|
||
print("=" * 60)
|
||
|
||
all_passed = True
|
||
for r in results:
|
||
r.print_report()
|
||
if not r.success:
|
||
all_passed = False
|
||
|
||
return all_passed
|
||
|
||
|
||
async def test_unauthorized() -> TestResult:
|
||
"""测试鉴权失败"""
|
||
result = TestResult("鉴权失败测试")
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
# 使用错误的 token
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": "test-unauth-001",
|
||
"auth_token": "wrong-token",
|
||
"client": {
|
||
"device_id": DEVICE_ID,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(session_start, ensure_ascii=False))
|
||
print("→ session.start (错误 token)")
|
||
|
||
# 应该收到错误
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") == "error" and data.get("code") == "UNAUTHORIZED":
|
||
print(f"← error (code={data.get('code')})")
|
||
print("✅ 正确拒绝未授权请求")
|
||
result.mark_success()
|
||
else:
|
||
result.add_error(f"期望 UNAUTHORIZED 错误,收到: {data}")
|
||
|
||
except Exception as e:
|
||
result.add_error(f"测试异常: {e}")
|
||
|
||
return result.finalize()
|
||
|
||
|
||
async def test_invalid_message() -> TestResult:
|
||
"""测试非法消息格式"""
|
||
result = TestResult("非法消息格式测试")
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
# 先建立正常会话
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": "test-invalid-001",
|
||
"auth_token": AUTH_TOKEN,
|
||
"client": {
|
||
"device_id": DEVICE_ID,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(session_start))
|
||
await ws.recv() # session.ready
|
||
|
||
# 发送缺少必需字段的消息
|
||
invalid_msg = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
# 缺少 turn_id 和 text
|
||
}
|
||
|
||
await ws.send(json.dumps(invalid_msg))
|
||
print("→ turn.text (缺少必需字段)")
|
||
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") == "error" and data.get("code") == "INVALID_MESSAGE":
|
||
print(f"← error (code={data.get('code')})")
|
||
print("✅ 正确拒绝非法消息")
|
||
result.mark_success()
|
||
else:
|
||
result.add_error(f"期望 INVALID_MESSAGE 错误,收到: {data}")
|
||
|
||
except Exception as e:
|
||
result.add_error(f"测试异常: {e}")
|
||
|
||
return result.finalize()
|
||
|
||
|
||
async def test_unsupported_audio() -> TestResult:
|
||
"""测试不支持的音频消息"""
|
||
result = TestResult("不支持的音频消息测试")
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
# 建立会话
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": "test-audio-001",
|
||
"auth_token": AUTH_TOKEN,
|
||
"client": {
|
||
"device_id": DEVICE_ID,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(session_start))
|
||
await ws.recv() # session.ready
|
||
|
||
# 发送音频块(text_uplink 模式不允许)
|
||
audio_chunk_msg = {
|
||
"type": "turn.audio_chunk",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": "test-turn-001",
|
||
"seq": 0,
|
||
}
|
||
|
||
await ws.send(json.dumps(audio_chunk_msg))
|
||
print("→ turn.audio_chunk (不允许的消息类型)")
|
||
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") == "error" and data.get("code") == "INVALID_MESSAGE":
|
||
print(f"← error (code={data.get('code')})")
|
||
print("✅ 正确拒绝音频消息")
|
||
result.mark_success()
|
||
else:
|
||
result.add_error(f"期望 INVALID_MESSAGE 错误,收到: {data}")
|
||
|
||
except Exception as e:
|
||
result.add_error(f"测试异常: {e}")
|
||
|
||
return result.finalize()
|
||
|
||
|
||
async def test_missing_session_start() -> TestResult:
|
||
"""测试缺少 session.start"""
|
||
result = TestResult("缺少 session.start 测试")
|
||
|
||
try:
|
||
async with websockets.connect(SERVER_URL) as ws:
|
||
# 不发送 session.start,直接发送 turn.text
|
||
turn_text = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": "test-turn-001",
|
||
"text": "你好",
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}
|
||
|
||
await ws.send(json.dumps(turn_text))
|
||
print("→ turn.text (未发送 session.start)")
|
||
|
||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") == "error":
|
||
print(f"← error (code={data.get('code')})")
|
||
print("✅ 正确处理缺少 session.start 的情况")
|
||
result.mark_success()
|
||
else:
|
||
result.add_error(f"期望 error,收到: {data}")
|
||
|
||
except Exception as e:
|
||
result.add_error(f"测试异常: {e}")
|
||
|
||
return result.finalize()
|
||
|
||
|
||
async def main():
|
||
print("=" * 60)
|
||
print(" 测试 03: 错误处理测试")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
success = await test_errors()
|
||
|
||
print("\n" + "=" * 60)
|
||
if success:
|
||
print(" ✅ 所有错误处理测试通过!")
|
||
else:
|
||
print(" ❌ 部分测试失败!")
|
||
print("=" * 60)
|
||
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|