320 lines
9.8 KiB
Python
320 lines
9.8 KiB
Python
"""
|
||
香橙派客户端 WebSocket 适配器
|
||
用于替换现有的本地 LLM + TTS 调用
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import time
|
||
from typing import Optional, Callable
|
||
import websockets
|
||
import numpy as np
|
||
import sounddevice as sd
|
||
from loguru import logger
|
||
|
||
|
||
class CloudVoiceClient:
|
||
"""
|
||
云端语音客户端 - 香橙派端适配器
|
||
|
||
用法:
|
||
client = CloudVoiceClient(
|
||
server_url="ws://192.168.1.100:8765/v1/voice/session",
|
||
auth_token="your-token",
|
||
device_id="drone-001"
|
||
)
|
||
|
||
async with client.connect():
|
||
result = await client.send_text("起飞然后在前方十米悬停")
|
||
print(result.routing) # "flight_intent"
|
||
print(result.flight_intent.summary)
|
||
# 播放 TTS: result.audio_data (numpy array)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
server_url: str,
|
||
auth_token: str,
|
||
device_id: str,
|
||
sample_rate: int = 24000,
|
||
):
|
||
self.server_url = server_url
|
||
self.auth_token = auth_token
|
||
self.device_id = device_id
|
||
self.sample_rate = sample_rate
|
||
|
||
self._ws = None
|
||
self._session_id = None
|
||
self._connected = False
|
||
|
||
async def connect(self) -> "CloudVoiceClient":
|
||
"""建立 WebSocket 连接"""
|
||
try:
|
||
self._ws = await websockets.connect(self.server_url)
|
||
self._connected = True
|
||
self._session_id = f"session-{int(time.time())}"
|
||
|
||
logger.info(f"WebSocket 连接成功: {self.server_url}")
|
||
|
||
# 发送 session.start
|
||
await self._send_session_start()
|
||
|
||
# 接收 session.ready
|
||
await self._receive_session_ready()
|
||
|
||
return self
|
||
|
||
except Exception as e:
|
||
logger.error(f"连接失败: {e}")
|
||
raise
|
||
|
||
async def disconnect(self):
|
||
"""断开连接"""
|
||
if self._ws:
|
||
try:
|
||
# 发送 session.end
|
||
await self._ws.send(json.dumps({
|
||
"type": "session.end",
|
||
"proto_version": "1.0",
|
||
"session_id": self._session_id,
|
||
}))
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
await self._ws.close()
|
||
except:
|
||
pass
|
||
|
||
self._connected = False
|
||
logger.info("WebSocket 连接已关闭")
|
||
|
||
async def __aenter__(self):
|
||
return await self.connect()
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
await self.disconnect()
|
||
|
||
async def _send_session_start(self):
|
||
"""发送 session.start"""
|
||
msg = {
|
||
"type": "session.start",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"session_id": self._session_id,
|
||
"auth_token": self.auth_token,
|
||
"client": {
|
||
"device_id": self.device_id,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": self.sample_rate,
|
||
"prefer_tts_codec": "pcm_s16le"
|
||
},
|
||
"protocol": {"dialog_result": "cloud_voice_dialog_v1"},
|
||
}
|
||
}
|
||
|
||
await self._ws.send(json.dumps(msg, ensure_ascii=False))
|
||
logger.debug("→ session.start")
|
||
|
||
async def _receive_session_ready(self):
|
||
"""接收 session.ready"""
|
||
msg = await self._ws.recv()
|
||
data = json.loads(msg)
|
||
|
||
if data.get("type") != "session.ready":
|
||
raise Exception(f"期望 session.ready,收到: {data}")
|
||
|
||
logger.info("← session.ready - 服务端就绪")
|
||
|
||
async def send_text(self, text: str) -> "CloudVoiceResult":
|
||
"""
|
||
发送文本并获取响应
|
||
|
||
Args:
|
||
text: 用户输入的中文文本
|
||
|
||
Returns:
|
||
CloudVoiceResult 对象
|
||
"""
|
||
if not self._connected:
|
||
raise RuntimeError("未连接,请先调用 connect()")
|
||
|
||
turn_id = f"turn-{int(time.time())}"
|
||
|
||
# 发送 turn.text
|
||
await self._ws.send(json.dumps({
|
||
"type": "turn.text",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"turn_id": turn_id,
|
||
"text": text,
|
||
"is_final": True,
|
||
"source": "device_stt"
|
||
}, ensure_ascii=False))
|
||
|
||
logger.debug(f"→ turn.text: {text}")
|
||
|
||
# 接收响应
|
||
result = CloudVoiceResult(turn_id=turn_id)
|
||
audio_chunks = []
|
||
|
||
while True:
|
||
msg = await self._ws.recv()
|
||
|
||
if isinstance(msg, bytes):
|
||
# 音频数据
|
||
audio_chunks.append(msg)
|
||
else:
|
||
data = json.loads(msg)
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "dialog_result":
|
||
result.parse_dialog_result(data)
|
||
|
||
elif msg_type == "turn.complete":
|
||
result.metrics = data.get("metrics", {})
|
||
|
||
# 拼接音频
|
||
if audio_chunks:
|
||
full_pcm = b"".join(audio_chunks)
|
||
result.audio_data = np.frombuffer(full_pcm, dtype=np.int16)
|
||
|
||
logger.info(
|
||
f"回合完成: LLM={result.metrics.get('llm_ms')}ms, "
|
||
f"TTS={result.metrics.get('tts_first_byte_ms')}ms"
|
||
)
|
||
break
|
||
|
||
elif msg_type == "error":
|
||
raise Exception(
|
||
f"服务端错误: {data.get('code')} - {data.get('message')}"
|
||
)
|
||
|
||
return result
|
||
|
||
async def play_audio(self, audio_data: np.ndarray):
|
||
"""
|
||
播放音频数据
|
||
|
||
Args:
|
||
audio_data: numpy int16 array
|
||
"""
|
||
try:
|
||
audio_float = audio_data.astype(np.float32) / 32768.0
|
||
|
||
sd.play(
|
||
audio_float,
|
||
samplerate=self.sample_rate,
|
||
blocking=True,
|
||
)
|
||
|
||
logger.info(f"音频播放完成: {len(audio_data)} samples")
|
||
|
||
except Exception as e:
|
||
logger.error(f"音频播放失败: {e}")
|
||
raise
|
||
|
||
|
||
class CloudVoiceResult:
|
||
"""云端语音响应结果"""
|
||
|
||
def __init__(self, turn_id: str):
|
||
self.turn_id = turn_id
|
||
|
||
# 意图识别结果
|
||
self.routing: Optional[str] = None # "flight_intent" 或 "chitchat"
|
||
self.flight_intent: Optional[dict] = None
|
||
self.chat_reply: Optional[str] = None
|
||
|
||
# TTS 音频
|
||
self.audio_data: Optional[np.ndarray] = None # int16 numpy array
|
||
|
||
# 性能指标
|
||
self.metrics: dict = {}
|
||
|
||
# 原始数据
|
||
self.raw_data: Optional[dict] = None
|
||
|
||
def parse_dialog_result(self, data: dict):
|
||
"""解析 dialog_result 消息"""
|
||
self.raw_data = data
|
||
|
||
self.routing = data.get("routing")
|
||
|
||
if self.routing == "flight_intent":
|
||
self.flight_intent = data.get("flight_intent")
|
||
self.chat_reply = None
|
||
elif self.routing == "chitchat":
|
||
self.chat_reply = data.get("chat_reply")
|
||
self.flight_intent = None
|
||
|
||
def is_flight_intent(self) -> bool:
|
||
"""是否为飞控意图"""
|
||
return self.routing == "flight_intent"
|
||
|
||
def get_tts_text(self) -> str:
|
||
"""获取用于 TTS 播报的文本"""
|
||
if self.is_flight_intent() and self.flight_intent:
|
||
return self.flight_intent.get("summary", "收到")
|
||
elif self.routing == "chitchat" and self.chat_reply:
|
||
return self.chat_reply
|
||
else:
|
||
return "收到"
|
||
|
||
def get_flight_actions(self) -> list:
|
||
"""获取飞控动作列表"""
|
||
if self.is_flight_intent() and self.flight_intent:
|
||
return self.flight_intent.get("actions", [])
|
||
return []
|
||
|
||
def __repr__(self):
|
||
return (
|
||
f"CloudVoiceResult(routing={self.routing}, "
|
||
f"audio={len(self.audio_data) if self.audio_data is not None else 0} samples)"
|
||
)
|
||
|
||
|
||
# ==================== 使用示例 ====================
|
||
async def example_usage():
|
||
"""使用示例"""
|
||
|
||
# 创建客户端
|
||
client = CloudVoiceClient(
|
||
server_url="ws://192.168.1.100:8765/v1/voice/session",
|
||
auth_token="drone-voice-cloud-token-2024",
|
||
device_id="drone-001",
|
||
)
|
||
|
||
try:
|
||
# 连接
|
||
async with client.connect():
|
||
|
||
# 测试闲聊
|
||
print("\n=== 闲聊测试 ===")
|
||
result = await client.send_text("你好,今天天气怎么样?")
|
||
print(f"路由: {result.routing}")
|
||
print(f"回复: {result.chat_reply}")
|
||
await client.play_audio(result.audio_data)
|
||
|
||
# 测试飞控指令
|
||
print("\n=== 飞控指令测试 ===")
|
||
result = await client.send_text("起飞然后在前方十米悬停")
|
||
print(f"路由: {result.routing}")
|
||
print(f"摘要: {result.flight_intent['summary']}")
|
||
print(f"动作: {result.get_flight_actions()}")
|
||
await client.play_audio(result.audio_data)
|
||
|
||
# 在这里可以将飞控动作发送到飞控板
|
||
# for action in result.get_flight_actions():
|
||
# send_to_flight_controller(action)
|
||
|
||
except Exception as e:
|
||
logger.error(f"测试失败: {e}")
|
||
raise
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(example_usage())
|