2026-04-14 10:08:41 +08:00

215 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试客户端 - 验证云端语音服务功能
"""
import asyncio
import json
import sys
from pathlib import Path
import websockets
import numpy as np
import sounddevice as sd
# 添加父目录到路径
sys.path.insert(0, str(Path(__file__).parent))
from app.config import settings
class TestClient:
"""测试客户端"""
def __init__(self, server_url: str = None):
self.server_url = server_url or f"ws://localhost:{settings.WS_PORT}{settings.WS_PATH}"
self.session_id = "test-session-001"
self.turn_id = "test-turn-001"
self.audio_buffer = []
async def connect_and_test(self):
"""连接服务器并执行测试"""
print(f"连接到: {self.server_url}")
try:
async with websockets.connect(self.server_url) as ws:
print("✓ WebSocket 连接成功")
# 1. 发送 session.start
await self._send_session_start(ws)
# 2. 接收 session.ready
await self._receive_session_ready(ws)
# 3. 测试用例
test_cases = [
("闲聊测试", "你好,今天天气怎么样?"),
("飞控测试 - 起飞悬停", "起飞然后在前方十米悬停"),
("飞控测试 - 返航", "返航"),
("飞控测试 - 降落", "降落"),
]
for name, text in test_cases:
print(f"\n{'='*60}")
print(f"测试: {name}")
print(f"输入: {text}")
print(f"{'='*60}")
await self._send_turn_text(ws, text)
await self._receive_turn_response(ws)
# 等待一下再继续
await asyncio.sleep(1)
# 4. 结束会话
await self._send_session_end(ws)
print("\n✓ 所有测试完成")
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
async def _send_session_start(self, ws):
"""发送 session.start"""
msg = {
"type": "session.start",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"session_id": self.session_id,
"auth_token": settings.BEARER_TOKEN,
"client": {
"device_id": "test-drone-001",
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le"
},
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
}
}
await ws.send(json.dumps(msg, ensure_ascii=False))
print("→ session.start")
async def _receive_session_ready(self, ws):
"""接收 session.ready"""
msg = await ws.recv()
data = json.loads(msg)
if data.get("type") == "session.ready":
print("← session.ready")
print(f" 服务端能力: {json.dumps(data.get('server_caps', {}), ensure_ascii=False)}")
else:
print(f"❌ 期望 session.ready收到: {data}")
async def _send_turn_text(self, ws, text: str):
"""发送 turn.text"""
msg = {
"type": "turn.text",
"proto_version": "1.0",
"transport_profile": "text_uplink",
"turn_id": self.turn_id,
"text": text,
"is_final": True,
"source": "device_stt"
}
await ws.send(json.dumps(msg, ensure_ascii=False))
print(f"→ turn.text: {text}")
async def _receive_turn_response(self, ws):
"""接收回合响应"""
self.audio_buffer = []
audio_chunks = 0
while True:
msg = await ws.recv()
if isinstance(msg, bytes):
# 二进制音频数据
audio_chunks += 1
self.audio_buffer.append(msg)
if audio_chunks == 1:
print(f"← 音频数据开始... ({len(msg)} bytes)")
else:
print(f"← 音频数据继续... ({len(msg)} bytes)")
else:
# JSON 消息
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "dialog_result":
print(f"← dialog_result")
routing = data.get("routing")
print(f" routing: {routing}")
if routing == "flight_intent":
intent = data.get("flight_intent")
if intent:
print(f" summary: {intent.get('summary')}")
print(f" actions: {json.dumps(intent.get('actions', []), ensure_ascii=False)}")
elif routing == "chitchat":
reply = data.get("chat_reply")
print(f" chat_reply: {reply}")
elif msg_type == "tts_audio_chunk":
print(f"← tts_audio_chunk (seq={data.get('seq')}, final={data.get('is_final')})")
elif msg_type == "turn.complete":
print(f"← turn.complete")
metrics = data.get("metrics", {})
print(f" LLM 耗时: {metrics.get('llm_ms')}ms")
print(f" TTS 耗时: {metrics.get('tts_first_byte_ms')}ms")
# 播放音频
if self.audio_buffer:
await self._play_audio()
break
elif msg_type == "error":
print(f"❌ error: {data.get('code')} - {data.get('message')}")
break
async def _play_audio(self):
"""播放接收到的音频数据"""
if not self.audio_buffer:
return
try:
# 拼接所有音频块
full_pcm = b"".join(self.audio_buffer)
# 转换为 numpy array (int16)
audio = np.frombuffer(full_pcm, dtype=np.int16)
# 播放
print(f"\n🔊 播放音频: {len(audio)} samples, {len(full_pcm)} bytes")
sd.play(
audio.astype(np.float32) / 32768.0,
samplerate=settings.TTS_SAMPLE_RATE,
blocking=True,
)
print("✓ 音频播放完成\n")
except Exception as e:
print(f"❌ 音频播放失败: {e}")
async def main():
"""主函数"""
print("=" * 60)
print(" 云端无人机语音服务 - 测试客户端")
print("=" * 60)
print()
client = TestClient()
await client.connect_and_test()
if __name__ == "__main__":
asyncio.run(main())