273 lines
8.7 KiB
Python
273 lines
8.7 KiB
Python
"""
|
||
测试 07: 香橙派客户端模拟
|
||
完整模拟香橙派客户端的行为
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import websockets
|
||
import numpy as np
|
||
import sounddevice as sd
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from test_config import SERVER_URL, AUTH_TOKEN
|
||
|
||
|
||
class OrangePiClient:
|
||
"""香橙派客户端模拟器"""
|
||
|
||
def __init__(self, server_url: str, auth_token: str, device_id: str):
|
||
self.server_url = server_url
|
||
self.auth_token = auth_token
|
||
self.device_id = device_id
|
||
self.ws = None
|
||
self.session_id = None
|
||
|
||
async def connect(self):
|
||
"""连接到云端服务"""
|
||
self.ws = await websockets.connect(self.server_url)
|
||
self.session_id = f"session-{int(time.time())}"
|
||
|
||
# 发送 session.start
|
||
session_start = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": self.session_id,
|
||
"auth_token": self.auth_token,
|
||
"client": {
|
||
"device_id": self.device_id,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await self.ws.send(json.dumps(session_start, ensure_ascii=False))
|
||
|
||
# 接收 session.ready
|
||
msg = await self.ws.recv()
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") != "session.ready":
|
||
raise Exception(f"期望 session.ready,收到 {data.get('type')}")
|
||
|
||
print(f"✅ 已连接云端 (session={self.session_id})")
|
||
|
||
async def disconnect(self):
|
||
"""断开连接"""
|
||
if self.ws:
|
||
session_end = {
|
||
"type": "session.end",
|
||
"proto_version": "1.0",
|
||
"session_id": self.session_id
|
||
}
|
||
await self.ws.send(json.dumps(session_end))
|
||
await self.ws.close()
|
||
print("✅ 已断开连接")
|
||
|
||
async def send_command(self, text: str) -> dict:
|
||
"""
|
||
发送语音指令(模拟 STT 后的文本)
|
||
|
||
Returns:
|
||
{
|
||
"routing": str,
|
||
"flight_intent": dict or None,
|
||
"chat_reply": str or None,
|
||
"audio_data": np.ndarray,
|
||
"metrics": dict
|
||
}
|
||
"""
|
||
turn_id = f"turn-{int(time.time())}"
|
||
|
||
# 发送 turn.text
|
||
turn_text = {
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": turn_id,
|
||
"text": text,
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}
|
||
|
||
await self.ws.send(json.dumps(turn_text, ensure_ascii=False))
|
||
print(f"→ {text}")
|
||
|
||
# 接收响应
|
||
result = {
|
||
"turn_id": turn_id,
|
||
"routing": None,
|
||
"flight_intent": None,
|
||
"chat_reply": None,
|
||
"audio_data": None,
|
||
"metrics": {}
|
||
}
|
||
|
||
audio_chunks = []
|
||
|
||
while True:
|
||
msg = await self.ws.recv()
|
||
|
||
if isinstance(msg, bytes):
|
||
# 音频数据
|
||
audio_chunks.append(msg)
|
||
else:
|
||
data = json.loads(msg)
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "dialog_result":
|
||
result["routing"] = data.get("routing")
|
||
result["flight_intent"] = data.get("flight_intent")
|
||
result["chat_reply"] = data.get("chat_reply")
|
||
|
||
print(f"← routing={data.get('routing')}")
|
||
|
||
if data.get("routing") == "flight_intent":
|
||
intent = data.get("flight_intent")
|
||
if intent:
|
||
print(f" Summary: {intent.get('summary')}")
|
||
print(f" Actions: {json.dumps(intent.get('actions', []), ensure_ascii=False)}")
|
||
else:
|
||
reply = data.get("chat_reply")
|
||
print(f" Reply: {reply}")
|
||
|
||
elif msg_type == "turn.complete":
|
||
result["metrics"] = data.get("metrics", {})
|
||
print(f" LLM: {result['metrics'].get('llm_ms')}ms, TTS: {result['metrics'].get('tts_first_byte_ms')}ms")
|
||
break
|
||
|
||
elif msg_type == "error":
|
||
raise Exception(f"服务端错误: {data.get('code')} - {data.get('message')}")
|
||
|
||
# 拼接音频
|
||
if audio_chunks:
|
||
full_pcm = b"".join(audio_chunks)
|
||
result["audio_data"] = np.frombuffer(full_pcm, dtype=np.int16)
|
||
|
||
return result
|
||
|
||
async def play_audio(self, audio_data: np.ndarray):
|
||
"""播放 TTS 音频"""
|
||
if audio_data is None or len(audio_data) == 0:
|
||
return
|
||
|
||
try:
|
||
audio_float = audio_data.astype(np.float32) / 32768.0
|
||
sd.play(audio_float, samplerate=24000, blocking=True)
|
||
print("🔊 播放完成")
|
||
except Exception as e:
|
||
print(f"❌ 播放失败: {e}")
|
||
|
||
def execute_flight_actions(self, actions: list):
|
||
"""执行飞控动作(模拟)"""
|
||
print("\n🎮 飞控指令执行:")
|
||
for action in actions:
|
||
action_type = action.get("type")
|
||
args = action.get("args", {})
|
||
|
||
if action_type == "takeoff":
|
||
print(f" ✈️ 起飞")
|
||
elif action_type == "land":
|
||
print(f" 🛬 降落")
|
||
elif action_type == "return_home":
|
||
print(f" 🏠 返航")
|
||
elif action_type == "hover":
|
||
print(f" ⏸️ 悬停")
|
||
elif action_type == "goto":
|
||
x = args.get("x", 0)
|
||
y = args.get("y", 0)
|
||
z = args.get("z", 0)
|
||
print(f" 🎯 前往坐标 (x={x}, y={y}, z={z})")
|
||
else:
|
||
print(f" ⚙️ 未知动作: {action_type}")
|
||
|
||
|
||
async def test_orangepi_workflow():
|
||
"""测试香橙派完整工作流程"""
|
||
print("=" * 60)
|
||
print(" 测试 07: 香橙派客户端模拟")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
client = OrangePiClient(
|
||
server_url=SERVER_URL,
|
||
auth_token=AUTH_TOKEN,
|
||
device_id="orange-pi-drone-001"
|
||
)
|
||
|
||
try:
|
||
# 1. 连接
|
||
print("步骤 1: 连接云端")
|
||
print("-" * 60)
|
||
await client.connect()
|
||
print()
|
||
|
||
# 2. 测试场景
|
||
test_scenarios = [
|
||
("场景 1: 闲聊", "你好,今天天气怎么样?", "chitchat"),
|
||
("场景 2: 起飞指令", "起飞", "flight_intent"),
|
||
("场景 3: 复杂飞控", "起飞然后在前方十米悬停", "flight_intent"),
|
||
("场景 4: 返航", "返航", "flight_intent"),
|
||
("场景 5: 降落", "降落", "flight_intent"),
|
||
]
|
||
|
||
for scenario_name, text, expected_routing in test_scenarios:
|
||
print("=" * 60)
|
||
print(f" {scenario_name}")
|
||
print("=" * 60)
|
||
|
||
# 发送指令
|
||
result = await client.send_command(text)
|
||
|
||
# 验证路由类型
|
||
if result["routing"] != expected_routing:
|
||
print(f"❌ 路由不匹配: 期望 {expected_routing}, 实际 {result['routing']}")
|
||
continue
|
||
|
||
# 播放 TTS
|
||
print("\n🔊 播放 TTS:")
|
||
await client.play_audio(result["audio_data"])
|
||
|
||
# 如果是飞控指令,执行动作
|
||
if result["routing"] == "flight_intent" and result["flight_intent"]:
|
||
actions = result["flight_intent"].get("actions", [])
|
||
client.execute_flight_actions(actions)
|
||
|
||
print()
|
||
|
||
# 3. 断开
|
||
print("=" * 60)
|
||
print("步骤 3: 断开连接")
|
||
print("=" * 60)
|
||
await client.disconnect()
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" ✅ 香橙派客户端模拟测试完成!")
|
||
print("=" * 60)
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
async def main():
|
||
success = await test_orangepi_workflow()
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|