215 lines
7.2 KiB
Python
215 lines
7.2 KiB
Python
"""
|
||
测试客户端 - 验证云端语音服务功能
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import websockets
|
||
import numpy as np
|
||
import sounddevice as sd
|
||
|
||
# 添加父目录到路径
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from app.config import settings
|
||
|
||
|
||
class TestClient:
|
||
"""测试客户端"""
|
||
|
||
def __init__(self, server_url: str = None):
|
||
self.server_url = server_url or f"ws://localhost:{settings.WS_PORT}{settings.WS_PATH}"
|
||
self.session_id = "test-session-001"
|
||
self.turn_id = "test-turn-001"
|
||
self.audio_buffer = []
|
||
|
||
async def connect_and_test(self):
|
||
"""连接服务器并执行测试"""
|
||
print(f"连接到: {self.server_url}")
|
||
|
||
try:
|
||
async with websockets.connect(self.server_url) as ws:
|
||
print("✓ WebSocket 连接成功")
|
||
|
||
# 1. 发送 session.start
|
||
await self._send_session_start(ws)
|
||
|
||
# 2. 接收 session.ready
|
||
await self._receive_session_ready(ws)
|
||
|
||
# 3. 测试用例
|
||
test_cases = [
|
||
("闲聊测试", "你好,今天天气怎么样?"),
|
||
("飞控测试 - 起飞悬停", "起飞然后在前方十米悬停"),
|
||
("飞控测试 - 返航", "返航"),
|
||
("飞控测试 - 降落", "降落"),
|
||
]
|
||
|
||
for name, text in test_cases:
|
||
print(f"\n{'='*60}")
|
||
print(f"测试: {name}")
|
||
print(f"输入: {text}")
|
||
print(f"{'='*60}")
|
||
|
||
await self._send_turn_text(ws, text)
|
||
await self._receive_turn_response(ws)
|
||
|
||
# 等待一下再继续
|
||
await asyncio.sleep(1)
|
||
|
||
# 4. 结束会话
|
||
await self._send_session_end(ws)
|
||
|
||
print("\n✓ 所有测试完成")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
async def _send_session_start(self, ws):
|
||
"""发送 session.start"""
|
||
msg = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": self.session_id,
|
||
"auth_token": settings.BEARER_TOKEN,
|
||
"client": {
|
||
"device_id": "test-drone-001",
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await ws.send(json.dumps(msg, ensure_ascii=False))
|
||
print("→ session.start")
|
||
|
||
async def _receive_session_ready(self, ws):
|
||
"""接收 session.ready"""
|
||
msg = await ws.recv()
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") == "session.ready":
|
||
print("← session.ready")
|
||
print(f" 服务端能力: {json.dumps(data.get('server_caps', {}), ensure_ascii=False)}")
|
||
else:
|
||
print(f"❌ 期望 session.ready,收到: {data}")
|
||
|
||
async def _send_turn_text(self, ws, text: str):
|
||
"""发送 turn.text"""
|
||
msg = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": self.turn_id,
|
||
"text": text,
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}
|
||
|
||
await ws.send(json.dumps(msg, ensure_ascii=False))
|
||
print(f"→ turn.text: {text}")
|
||
|
||
async def _receive_turn_response(self, ws):
|
||
"""接收回合响应"""
|
||
self.audio_buffer = []
|
||
audio_chunks = 0
|
||
|
||
while True:
|
||
msg = await ws.recv()
|
||
|
||
if isinstance(msg, bytes):
|
||
# 二进制音频数据
|
||
audio_chunks += 1
|
||
self.audio_buffer.append(msg)
|
||
|
||
if audio_chunks == 1:
|
||
print(f"← 音频数据开始... ({len(msg)} bytes)")
|
||
else:
|
||
print(f"← 音频数据继续... ({len(msg)} bytes)")
|
||
else:
|
||
# JSON 消息
|
||
data = json.loads(msg)
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "dialog_result":
|
||
print(f"← dialog_result")
|
||
routing = data.get("routing")
|
||
print(f" routing: {routing}")
|
||
|
||
if routing == "flight_intent":
|
||
intent = data.get("flight_intent")
|
||
if intent:
|
||
print(f" summary: {intent.get('summary')}")
|
||
print(f" actions: {json.dumps(intent.get('actions', []), ensure_ascii=False)}")
|
||
elif routing == "chitchat":
|
||
reply = data.get("chat_reply")
|
||
print(f" chat_reply: {reply}")
|
||
|
||
elif msg_type == "tts_audio_chunk":
|
||
print(f"← tts_audio_chunk (seq={data.get('seq')}, final={data.get('is_final')})")
|
||
|
||
elif msg_type == "turn.complete":
|
||
print(f"← turn.complete")
|
||
metrics = data.get("metrics", {})
|
||
print(f" LLM 耗时: {metrics.get('llm_ms')}ms")
|
||
print(f" TTS 耗时: {metrics.get('tts_first_byte_ms')}ms")
|
||
|
||
# 播放音频
|
||
if self.audio_buffer:
|
||
await self._play_audio()
|
||
|
||
break
|
||
|
||
elif msg_type == "error":
|
||
print(f"❌ error: {data.get('code')} - {data.get('message')}")
|
||
break
|
||
|
||
async def _play_audio(self):
|
||
"""播放接收到的音频数据"""
|
||
if not self.audio_buffer:
|
||
return
|
||
|
||
try:
|
||
# 拼接所有音频块
|
||
full_pcm = b"".join(self.audio_buffer)
|
||
|
||
# 转换为 numpy array (int16)
|
||
audio = np.frombuffer(full_pcm, dtype=np.int16)
|
||
|
||
# 播放
|
||
print(f"\n🔊 播放音频: {len(audio)} samples, {len(full_pcm)} bytes")
|
||
|
||
sd.play(
|
||
audio.astype(np.float32) / 32768.0,
|
||
samplerate=settings.TTS_SAMPLE_RATE,
|
||
blocking=True,
|
||
)
|
||
|
||
print("✓ 音频播放完成\n")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 音频播放失败: {e}")
|
||
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
print("=" * 60)
|
||
print(" 云端无人机语音服务 - 测试客户端")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
client = TestClient()
|
||
await client.connect_and_test()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|