1000 lines
37 KiB
Python
1000 lines
37 KiB
Python
"""
|
||
云端语音 WebSocket 客户端:会话 `session.start.transport_profile` 固定为 pcm_asr_uplink。
|
||
|
||
- 主路径:`turn.audio.start` → 若干 `turn.audio.chunk`(每条仅文本 JSON,含 `pcm_base64`)→ `turn.audio.end`;**禁止**用 WebSocket binary 上发 PCM(与 Starlette receive 语义一致)。
|
||
- 辅助:`run_turn` 发 `turn.text`(如同句快路径仅有文本);`run_tts_synthesize` 仅 TTS。
|
||
- `asr.partial` 仅调试展示,不参与机端状态机。
|
||
|
||
文档:`docs/CLOUD_VOICE_SESSION_SCHEME_v1.md`,`docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md`。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import json
|
||
import os
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
|
||
from voice_drone.core.cloud_dialog_v1 import CLOUD_VOICE_DIALOG_V1
|
||
from voice_drone.logging_ import get_logger
|
||
|
||
logger = get_logger("voice_drone.cloud_voice")
|
||
|
||
_CLOUD_PROTO = "1.0"
|
||
TRANSPORT_PCM_ASR_UPLINK = "pcm_asr_uplink"
|
||
|
||
|
||
def _merge_session_client(
|
||
device_id: str,
|
||
*,
|
||
session_client_extensions: dict[str, Any] | None,
|
||
) -> dict[str, Any]:
|
||
"""session.start 的 client:capabilities 与设备信息 + 可选 PX4 等扩展(不覆盖 device_id/locale)。"""
|
||
client: dict[str, Any] = {
|
||
"device_id": device_id,
|
||
"locale": "zh-CN",
|
||
"capabilities": {
|
||
"playback_sample_rate_hz": 24000,
|
||
"prefer_tts_codec": "pcm_s16le",
|
||
},
|
||
}
|
||
ext = session_client_extensions or {}
|
||
for k, v in ext.items():
|
||
if v is None or k in ("device_id", "locale", "capabilities", "protocol"):
|
||
continue
|
||
if k == "extras" and isinstance(v, dict) and len(v) == 0:
|
||
continue
|
||
client[k] = v
|
||
client["protocol"] = {"dialog_result": CLOUD_VOICE_DIALOG_V1}
|
||
return client
|
||
|
||
|
||
def _transient_ws_exc(exc: BaseException) -> bool:
|
||
"""可通过对端已关、网络抖动等通过重连重发 turn 恢复的异常。"""
|
||
import websocket as _websocket # noqa: PLC0415
|
||
|
||
if isinstance(
|
||
exc,
|
||
(
|
||
BrokenPipeError,
|
||
ConnectionResetError,
|
||
ConnectionAbortedError,
|
||
),
|
||
):
|
||
return True
|
||
if isinstance(
|
||
exc,
|
||
(
|
||
_websocket.WebSocketConnectionClosedException,
|
||
_websocket.WebSocketTimeoutException,
|
||
),
|
||
):
|
||
return True
|
||
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (
|
||
32,
|
||
104,
|
||
110,
|
||
): # EPIPE, ECONNRESET, ETIMEDOUT
|
||
return True
|
||
return False
|
||
|
||
|
||
def _merge_tts_pcm_chunks(
|
||
chunk_entries: list[tuple[int | None, int, bytes]],
|
||
) -> bytes:
|
||
"""按 seq 升序拼接;无 seq 时按到达顺序。chunk_entries: (seq|None, arrival, pcm)。"""
|
||
if not chunk_entries:
|
||
return b""
|
||
if all(s is not None for s, _, _ in chunk_entries):
|
||
ordered = sorted(chunk_entries, key=lambda x: (x[0], x[1]))
|
||
seqs = [x[0] for x in ordered]
|
||
for a, b in zip(seqs, seqs[1:]):
|
||
if b != a + 1:
|
||
logger.warning("TTS seq 不连续(仍按序拼接): %s → %s", a, b)
|
||
break
|
||
return b"".join(x[2] for x in ordered)
|
||
return b"".join(x[2] for x in sorted(chunk_entries, key=lambda x: x[1]))
|
||
|
||
|
||
class CloudVoiceError(RuntimeError):
|
||
"""云端返回 error 消息或协议不符合预期。"""
|
||
|
||
def __init__(self, message: str, *, code: str | None = None, retryable: bool = False):
|
||
super().__init__(message)
|
||
self.code = code
|
||
self.retryable = retryable
|
||
|
||
|
||
class CloudVoiceClient:
|
||
"""连接 ws://…/v1/voice/session;session 为 pcm_asr_uplink,含 run_turn_audio / run_turn / tts.synthesize。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
server_url: str,
|
||
auth_token: str,
|
||
device_id: str,
|
||
recv_timeout: float = 120.0,
|
||
session_client_extensions: dict[str, Any] | None = None,
|
||
) -> None:
|
||
self.server_url = server_url.strip()
|
||
self.auth_token = auth_token.strip()
|
||
self.device_id = (device_id or "drone-001").strip()
|
||
self.recv_timeout = float(recv_timeout)
|
||
self._session_client_extensions: dict[str, Any] = dict(
|
||
session_client_extensions or {}
|
||
)
|
||
self._transport_profile: str = TRANSPORT_PCM_ASR_UPLINK
|
||
self._ws: Any = None
|
||
self._session_id: str | None = None
|
||
self._lock = threading.Lock()
|
||
|
||
@property
|
||
def connected(self) -> bool:
|
||
with self._lock:
|
||
return self._ws is not None
|
||
|
||
def close(self) -> None:
|
||
with self._lock:
|
||
self._close_nolock()
|
||
|
||
def _close_nolock(self) -> None:
|
||
if self._ws is None:
|
||
self._session_id = None
|
||
return
|
||
try:
|
||
if self._session_id:
|
||
try:
|
||
self._ws.send(
|
||
json.dumps(
|
||
{
|
||
"type": "session.end",
|
||
"proto_version": _CLOUD_PROTO,
|
||
"session_id": self._session_id,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
except Exception: # noqa: BLE001
|
||
pass
|
||
finally:
|
||
try:
|
||
self._ws.close()
|
||
except Exception: # noqa: BLE001
|
||
pass
|
||
self._ws = None
|
||
self._session_id = None
|
||
|
||
def connect(self) -> None:
|
||
"""建立 WSS,发送 session.start,等待 session.ready。"""
|
||
with self._lock:
|
||
self._connect_nolock()
|
||
|
||
def _connect_nolock(self) -> None:
|
||
import websocket # websocket-client
|
||
|
||
self._close_nolock()
|
||
hdr = [f"Authorization: Bearer {self.auth_token}"]
|
||
try:
|
||
self._ws = websocket.create_connection(
|
||
self.server_url,
|
||
header=hdr,
|
||
timeout=self.recv_timeout,
|
||
)
|
||
self._ws.settimeout(self.recv_timeout)
|
||
self._session_id = str(uuid.uuid4())
|
||
client_payload = _merge_session_client(
|
||
self.device_id,
|
||
session_client_extensions=self._session_client_extensions,
|
||
)
|
||
if self._session_client_extensions:
|
||
logger.info(
|
||
"session.start 已附加 client 扩展键: %s",
|
||
sorted(self._session_client_extensions.keys()),
|
||
)
|
||
start = {
|
||
"type": "session.start",
|
||
"proto_version": _CLOUD_PROTO,
|
||
"transport_profile": self._transport_profile,
|
||
"session_id": self._session_id,
|
||
"auth_token": self.auth_token,
|
||
"client": client_payload,
|
||
}
|
||
self._ws.send(json.dumps(start, ensure_ascii=False))
|
||
raw = self._ws.recv()
|
||
if isinstance(raw, bytes):
|
||
raise CloudVoiceError("session.ready 期望 JSON 文本帧,收到二进制")
|
||
data = json.loads(raw)
|
||
if data.get("type") != "session.ready":
|
||
raise CloudVoiceError(
|
||
f"期望 session.ready,收到: {data.get('type')!r}",
|
||
code="INVALID_MESSAGE",
|
||
)
|
||
logger.info("云端会话已就绪 session_id=%s", self._session_id)
|
||
except Exception:
|
||
self._close_nolock()
|
||
raise
|
||
|
||
def ensure_connected(self) -> None:
|
||
with self._lock:
|
||
if self._ws is None:
|
||
self._connect_nolock()
|
||
|
||
def _execute_turn_nolock(self, t: str) -> dict[str, Any]:
|
||
"""已持锁且 _ws 已连接:发送 turn.text 并收齐本轮帧。"""
|
||
import websocket # websocket-client
|
||
|
||
ws = self._ws
|
||
if ws is None:
|
||
raise CloudVoiceError("WebSocket 未连接")
|
||
|
||
turn_id = str(uuid.uuid4())
|
||
turn_msg = {
|
||
"type": "turn.text",
|
||
"proto_version": _CLOUD_PROTO,
|
||
"transport_profile": self._transport_profile,
|
||
"turn_id": turn_id,
|
||
"text": t,
|
||
"is_final": True,
|
||
"source": "device_stt",
|
||
}
|
||
try:
|
||
ws.send(json.dumps(turn_msg, ensure_ascii=False))
|
||
except Exception as e:
|
||
if _transient_ws_exc(e):
|
||
raise
|
||
raise CloudVoiceError(f"发送 turn.text 失败: {e}", code="INTERNAL") from e
|
||
logger.debug("→ turn.text turn_id=%s", turn_id)
|
||
|
||
expecting_binary = False
|
||
_pending_tts_seq: int | None = None
|
||
pcm_entries: list[tuple[int | None, int, bytes]] = []
|
||
_pcm_arrival = 0
|
||
llm_stream_parts: list[str] = []
|
||
dialog: dict[str, Any] | None = None
|
||
metrics: dict[str, Any] = {}
|
||
sample_rate_hz = 24000
|
||
|
||
while True:
|
||
try:
|
||
msg = ws.recv()
|
||
except websocket.WebSocketConnectionClosedException as e:
|
||
raise CloudVoiceError(
|
||
f"连接已断开: {e}",
|
||
code="DISCONNECTED",
|
||
retryable=True,
|
||
) from e
|
||
except Exception as e:
|
||
if _transient_ws_exc(e):
|
||
raise
|
||
raise
|
||
|
||
if isinstance(msg, bytes):
|
||
if expecting_binary:
|
||
expecting_binary = False
|
||
else:
|
||
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
|
||
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
|
||
_pcm_arrival += 1
|
||
_pending_tts_seq = None
|
||
continue
|
||
|
||
if not isinstance(msg, str):
|
||
raise CloudVoiceError(
|
||
f"期望文本帧为 str,实际为 {type(msg).__name__}",
|
||
code="INVALID_MESSAGE",
|
||
)
|
||
text_frame = msg.strip()
|
||
if not text_frame:
|
||
logger.debug("跳过空 WebSocket 文本帧")
|
||
continue
|
||
try:
|
||
data = json.loads(text_frame)
|
||
except json.JSONDecodeError as e:
|
||
head = text_frame[:200].replace("\n", "\\n")
|
||
raise CloudVoiceError(
|
||
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
|
||
code="INVALID_MESSAGE",
|
||
) from e
|
||
mtype = data.get("type")
|
||
|
||
if mtype == "asr.partial":
|
||
logger.debug("← asr.partial(机端不参与状态跳转)")
|
||
continue
|
||
|
||
if mtype == "llm.text_delta":
|
||
if data.get("turn_id") != turn_id:
|
||
logger.debug(
|
||
"llm.text_delta turn_id 与当前不一致,忽略 type=%s",
|
||
mtype,
|
||
)
|
||
continue
|
||
raw_d = data.get("delta")
|
||
delta = "" if raw_d is None else str(raw_d)
|
||
llm_stream_parts.append(delta)
|
||
_print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
)
|
||
if _print_stream:
|
||
print(delta, end="", flush=True)
|
||
if data.get("done"):
|
||
print(flush=True)
|
||
logger.debug(
|
||
"← llm.text_delta done=%s delta_chars=%s",
|
||
data.get("done"),
|
||
len(delta),
|
||
)
|
||
continue
|
||
|
||
if mtype == "tts_audio_chunk":
|
||
_pending_tts_seq = None
|
||
if data.get("turn_id") != turn_id:
|
||
logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制")
|
||
else:
|
||
try:
|
||
sample_rate_hz = int(
|
||
data.get("sample_rate_hz") or sample_rate_hz
|
||
)
|
||
except (TypeError, ValueError):
|
||
pass
|
||
_s = data.get("seq")
|
||
try:
|
||
if _s is not None:
|
||
_pending_tts_seq = int(_s)
|
||
except (TypeError, ValueError):
|
||
_pending_tts_seq = None
|
||
if data.get("is_final"):
|
||
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
|
||
expecting_binary = True
|
||
continue
|
||
|
||
if mtype == "dialog_result":
|
||
if data.get("turn_id") != turn_id:
|
||
raise CloudVoiceError(
|
||
"dialog_result turn_id 不匹配", code="INVALID_MESSAGE"
|
||
)
|
||
dialog = data
|
||
logger.info(
|
||
"← dialog_result routing=%s", data.get("routing")
|
||
)
|
||
continue
|
||
|
||
if mtype == "turn.complete":
|
||
if data.get("turn_id") != turn_id:
|
||
raise CloudVoiceError(
|
||
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
|
||
)
|
||
metrics = data.get("metrics") or {}
|
||
break
|
||
|
||
if mtype == "error":
|
||
code = str(data.get("code") or "INTERNAL")
|
||
raise CloudVoiceError(
|
||
data.get("message") or code,
|
||
code=code,
|
||
retryable=bool(data.get("retryable")),
|
||
)
|
||
|
||
logger.debug("忽略服务端消息 type=%s", mtype)
|
||
|
||
if dialog is None:
|
||
raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE")
|
||
|
||
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
|
||
pcm = (
|
||
np.frombuffer(full_pcm, dtype=np.int16).copy()
|
||
if full_pcm
|
||
else np.array([], dtype=np.int16)
|
||
)
|
||
if pcm.size > 0:
|
||
mx = int(np.max(np.abs(pcm)))
|
||
if mx == 0:
|
||
logger.warning(
|
||
"云端 TTS 已收齐二进制总长 %s 字节(≈%s 个 s16 采样),但全为 0x00,"
|
||
"属于服务端发出的静音占位或未写入合成结果;机端无法通过重采样/扬声器修复。"
|
||
"请在服务端对同一次 synthesize 写 WAV 核对非零采样,并确认 WS 先发 tts_audio_chunk JSON、"
|
||
"再发 raw PCM 帧、且未把 JSON/base64 误当 binary 发出。",
|
||
len(full_pcm),
|
||
pcm.size,
|
||
)
|
||
if os.environ.get("ROCKET_CLOUD_PCM_HEX", "").strip().lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
):
|
||
head = full_pcm[:64]
|
||
logger.warning(
|
||
"ROCKET_CLOUD_PCM_HEX: 前 %s 字节 hex=%s",
|
||
len(head),
|
||
head.hex(),
|
||
)
|
||
|
||
llm_stream_text = "".join(llm_stream_parts)
|
||
return {
|
||
"protocol": dialog.get("protocol"),
|
||
"routing": dialog.get("routing"),
|
||
"flight_intent": dialog.get("flight_intent"),
|
||
"confirm": dialog.get("confirm"),
|
||
"chat_reply": dialog.get("chat_reply"),
|
||
"user_input": dialog.get("user_input"),
|
||
"pcm": pcm,
|
||
"sample_rate_hz": sample_rate_hz,
|
||
"metrics": metrics,
|
||
"llm_stream_text": llm_stream_text,
|
||
}
|
||
|
||
def _execute_turn_audio_nolock(
|
||
self, pcm_int16: np.ndarray, sample_rate_hz: int
|
||
) -> dict[str, Any]:
|
||
"""发送 turn.audio.start → 多条 turn.audio.chunk(pcm_base64 文本帧)→ turn.audio.end;禁止 binary 上发 PCM。"""
|
||
import websocket # websocket-client
|
||
|
||
ws = self._ws
|
||
if ws is None:
|
||
raise CloudVoiceError("WebSocket 未连接")
|
||
|
||
pcm_int16 = np.asarray(pcm_int16, dtype=np.int16).reshape(-1)
|
||
if pcm_int16.size == 0:
|
||
raise CloudVoiceError("turn.audio PCM 为空")
|
||
|
||
pcm_mx = int(np.max(np.abs(pcm_int16)))
|
||
pcm_rms = float(np.sqrt(np.mean(pcm_int16.astype(np.float64) ** 2)))
|
||
dur_sec = float(pcm_int16.size) / max(1, int(sample_rate_hz))
|
||
logger.info(
|
||
"turn.audio 上行: samples=%s sr_hz=%s dur≈%.2fs abs_max=%s rms=%.1f dtype=int16",
|
||
pcm_int16.size,
|
||
int(sample_rate_hz),
|
||
dur_sec,
|
||
pcm_mx,
|
||
pcm_rms,
|
||
)
|
||
if pcm_mx == 0:
|
||
logger.warning(
|
||
"turn.audio 上行波形全零,云端 ASR 通常会判无有效语音(请查麦/切段/VAD 是否误交静音)"
|
||
)
|
||
elif pcm_mx < 200:
|
||
logger.warning(
|
||
"turn.audio 上行幅值极小 abs_max=%s(仍发送);若云端反复无识别请调 AGC/VAD/麦增益",
|
||
pcm_mx,
|
||
)
|
||
|
||
turn_id = str(uuid.uuid4())
|
||
start = {
|
||
"type": "turn.audio.start",
|
||
"proto_version": _CLOUD_PROTO,
|
||
"transport_profile": self._transport_profile,
|
||
"turn_id": turn_id,
|
||
"sample_rate_hz": int(sample_rate_hz),
|
||
"codec": "pcm_s16le",
|
||
"channels": 1,
|
||
}
|
||
raw = pcm_int16.tobytes()
|
||
try:
|
||
ws.send(json.dumps(start, ensure_ascii=False))
|
||
try:
|
||
raw_chunk = int(os.environ.get("ROCKET_CLOUD_AUDIO_CHUNK_BYTES", "8192"))
|
||
except ValueError:
|
||
raw_chunk = 8192
|
||
raw_chunk = max(2048, min(256 * 1024, raw_chunk))
|
||
n_chunks = 0
|
||
for i in range(0, len(raw), raw_chunk):
|
||
piece = raw[i : i + raw_chunk]
|
||
chunk_msg = {
|
||
"type": "turn.audio.chunk",
|
||
"proto_version": _CLOUD_PROTO,
|
||
"transport_profile": self._transport_profile,
|
||
"turn_id": turn_id,
|
||
"pcm_base64": base64.b64encode(piece).decode("ascii"),
|
||
}
|
||
ws.send(json.dumps(chunk_msg, ensure_ascii=False))
|
||
n_chunks += 1
|
||
end = {
|
||
"type": "turn.audio.end",
|
||
"proto_version": _CLOUD_PROTO,
|
||
"transport_profile": self._transport_profile,
|
||
"turn_id": turn_id,
|
||
}
|
||
ws.send(json.dumps(end, ensure_ascii=False))
|
||
except Exception as e:
|
||
if _transient_ws_exc(e):
|
||
raise
|
||
raise CloudVoiceError(f"发送 turn.audio 失败: {e}", code="INTERNAL") from e
|
||
logger.debug(
|
||
"→ turn.audio start/%s chunk(s)/end turn_id=%s samples=%s",
|
||
n_chunks,
|
||
turn_id,
|
||
pcm_int16.size,
|
||
)
|
||
|
||
expecting_binary = False
|
||
_pending_tts_seq: int | None = None
|
||
pcm_entries: list[tuple[int | None, int, bytes]] = []
|
||
_pcm_arrival = 0
|
||
llm_stream_parts: list[str] = []
|
||
dialog: dict[str, Any] | None = None
|
||
metrics: dict[str, Any] = {}
|
||
out_sr = 24000
|
||
|
||
while True:
|
||
try:
|
||
msg = ws.recv()
|
||
except websocket.WebSocketConnectionClosedException as e:
|
||
raise CloudVoiceError(
|
||
f"连接已断开: {e}",
|
||
code="DISCONNECTED",
|
||
retryable=True,
|
||
) from e
|
||
except Exception as e:
|
||
if _transient_ws_exc(e):
|
||
raise
|
||
raise
|
||
|
||
if isinstance(msg, bytes):
|
||
if expecting_binary:
|
||
expecting_binary = False
|
||
else:
|
||
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
|
||
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
|
||
_pcm_arrival += 1
|
||
_pending_tts_seq = None
|
||
continue
|
||
|
||
if not isinstance(msg, str):
|
||
raise CloudVoiceError(
|
||
f"期望文本帧为 str,实际为 {type(msg).__name__}",
|
||
code="INVALID_MESSAGE",
|
||
)
|
||
text_frame = msg.strip()
|
||
if not text_frame:
|
||
logger.debug("跳过空 WebSocket 文本帧")
|
||
continue
|
||
try:
|
||
data = json.loads(text_frame)
|
||
except json.JSONDecodeError as e:
|
||
head = text_frame[:200].replace("\n", "\\n")
|
||
raise CloudVoiceError(
|
||
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
|
||
code="INVALID_MESSAGE",
|
||
) from e
|
||
mtype = data.get("type")
|
||
|
||
if mtype == "asr.partial":
|
||
logger.debug("← asr.partial(机端不参与状态跳转)")
|
||
continue
|
||
|
||
if mtype == "llm.text_delta":
|
||
if data.get("turn_id") != turn_id:
|
||
logger.debug(
|
||
"llm.text_delta turn_id 与当前不一致,忽略 type=%s",
|
||
mtype,
|
||
)
|
||
continue
|
||
raw_d = data.get("delta")
|
||
delta = "" if raw_d is None else str(raw_d)
|
||
llm_stream_parts.append(delta)
|
||
_print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in (
|
||
"1",
|
||
"true",
|
||
"yes",
|
||
)
|
||
if _print_stream:
|
||
print(delta, end="", flush=True)
|
||
if data.get("done"):
|
||
print(flush=True)
|
||
logger.debug(
|
||
"← llm.text_delta done=%s delta_chars=%s",
|
||
data.get("done"),
|
||
len(delta),
|
||
)
|
||
continue
|
||
|
||
if mtype == "tts_audio_chunk":
|
||
_pending_tts_seq = None
|
||
if data.get("turn_id") != turn_id:
|
||
logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制")
|
||
else:
|
||
try:
|
||
out_sr = int(data.get("sample_rate_hz") or out_sr)
|
||
except (TypeError, ValueError):
|
||
pass
|
||
_s = data.get("seq")
|
||
try:
|
||
if _s is not None:
|
||
_pending_tts_seq = int(_s)
|
||
except (TypeError, ValueError):
|
||
_pending_tts_seq = None
|
||
if data.get("is_final"):
|
||
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
|
||
expecting_binary = True
|
||
continue
|
||
|
||
if mtype == "dialog_result":
|
||
if data.get("turn_id") != turn_id:
|
||
raise CloudVoiceError(
|
||
"dialog_result turn_id 不匹配", code="INVALID_MESSAGE"
|
||
)
|
||
dialog = data
|
||
logger.info(
|
||
"← dialog_result routing=%s", data.get("routing")
|
||
)
|
||
continue
|
||
|
||
if mtype == "turn.complete":
|
||
if data.get("turn_id") != turn_id:
|
||
raise CloudVoiceError(
|
||
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
|
||
)
|
||
metrics = data.get("metrics") or {}
|
||
break
|
||
|
||
if mtype == "error":
|
||
code = str(data.get("code") or "INTERNAL")
|
||
raise CloudVoiceError(
|
||
data.get("message") or code,
|
||
code=code,
|
||
retryable=bool(data.get("retryable")),
|
||
)
|
||
|
||
logger.debug("忽略服务端消息 type=%s", mtype)
|
||
|
||
if dialog is None:
|
||
raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE")
|
||
|
||
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
|
||
out_pcm = (
|
||
np.frombuffer(full_pcm, dtype=np.int16).copy()
|
||
if full_pcm
|
||
else np.array([], dtype=np.int16)
|
||
)
|
||
if out_pcm.size > 0:
|
||
mx = int(np.max(np.abs(out_pcm)))
|
||
if mx == 0:
|
||
logger.warning(
|
||
"云端 TTS 已收齐但全零采样,请核对服务端。",
|
||
)
|
||
|
||
llm_stream_text = "".join(llm_stream_parts)
|
||
return {
|
||
"protocol": dialog.get("protocol"),
|
||
"routing": dialog.get("routing"),
|
||
"flight_intent": dialog.get("flight_intent"),
|
||
"confirm": dialog.get("confirm"),
|
||
"chat_reply": dialog.get("chat_reply"),
|
||
"user_input": dialog.get("user_input"),
|
||
"pcm": out_pcm,
|
||
"sample_rate_hz": out_sr,
|
||
"metrics": metrics,
|
||
"llm_stream_text": llm_stream_text,
|
||
}
|
||
|
||
def run_turn_audio(
|
||
self, pcm_int16: np.ndarray, sample_rate_hz: int
|
||
) -> dict[str, Any]:
|
||
"""上行一轮麦克风 PCM:chunk 均为含 pcm_base64 的文本 JSON;收齐 dialog_result + TTS + turn.complete。"""
|
||
try:
|
||
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
|
||
except ValueError:
|
||
raw_attempts = 3
|
||
attempts = max(1, raw_attempts)
|
||
try:
|
||
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
|
||
except ValueError:
|
||
delay = 0.35
|
||
delay = max(0.0, delay)
|
||
|
||
for attempt in range(attempts):
|
||
with self._lock:
|
||
try:
|
||
if self._ws is None:
|
||
self._connect_nolock()
|
||
return self._execute_turn_audio_nolock(pcm_int16, sample_rate_hz)
|
||
except CloudVoiceError as e:
|
||
retry = bool(e.retryable) or e.code == "DISCONNECTED"
|
||
if retry and attempt < attempts - 1:
|
||
logger.warning(
|
||
"turn.audio 可恢复错误,重连重试 (%s/%s): %s",
|
||
attempt + 1,
|
||
attempts,
|
||
e,
|
||
)
|
||
self._close_nolock()
|
||
if delay:
|
||
time.sleep(delay)
|
||
continue
|
||
raise
|
||
except Exception as e:
|
||
if _transient_ws_exc(e) and attempt < attempts - 1:
|
||
logger.warning(
|
||
"turn.audio WebSocket 瞬断,重连重试 (%s/%s): %s",
|
||
attempt + 1,
|
||
attempts,
|
||
e,
|
||
)
|
||
self._close_nolock()
|
||
if delay:
|
||
time.sleep(delay)
|
||
continue
|
||
raise
|
||
|
||
raise CloudVoiceError("run_turn_audio 未执行", code="INTERNAL")
|
||
|
||
def _execute_tts_synthesize_nolock(self, text: str) -> dict[str, Any]:
|
||
"""已持锁且 _ws 已连接:发送 tts.synthesize,仅收 tts_audio_chunk* 与 turn.complete(无 dialog_result)。"""
|
||
import websocket # websocket-client
|
||
|
||
ws = self._ws
|
||
if ws is None:
|
||
raise CloudVoiceError("WebSocket 未连接")
|
||
|
||
turn_id = str(uuid.uuid4())
|
||
synth_msg = {
|
||
"type": "tts.synthesize",
|
||
"proto_version": _CLOUD_PROTO,
|
||
"transport_profile": self._transport_profile,
|
||
"turn_id": turn_id,
|
||
"text": text,
|
||
}
|
||
try:
|
||
ws.send(json.dumps(synth_msg, ensure_ascii=False))
|
||
except Exception as e:
|
||
if _transient_ws_exc(e):
|
||
raise
|
||
raise CloudVoiceError(f"发送 tts.synthesize 失败: {e}", code="INTERNAL") from e
|
||
logger.debug("→ tts.synthesize turn_id=%s", turn_id)
|
||
|
||
expecting_binary = False
|
||
_pending_tts_seq: int | None = None
|
||
pcm_entries: list[tuple[int | None, int, bytes]] = []
|
||
_pcm_arrival = 0
|
||
metrics: dict[str, Any] = {}
|
||
sample_rate_hz = 24000
|
||
|
||
while True:
|
||
try:
|
||
msg = ws.recv()
|
||
except websocket.WebSocketConnectionClosedException as e:
|
||
raise CloudVoiceError(
|
||
f"连接已断开: {e}",
|
||
code="DISCONNECTED",
|
||
retryable=True,
|
||
) from e
|
||
except Exception as e:
|
||
if _transient_ws_exc(e):
|
||
raise
|
||
raise
|
||
|
||
if isinstance(msg, bytes):
|
||
if expecting_binary:
|
||
expecting_binary = False
|
||
else:
|
||
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
|
||
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
|
||
_pcm_arrival += 1
|
||
_pending_tts_seq = None
|
||
continue
|
||
|
||
if not isinstance(msg, str):
|
||
raise CloudVoiceError(
|
||
f"期望文本帧为 str,实际为 {type(msg).__name__}",
|
||
code="INVALID_MESSAGE",
|
||
)
|
||
text_frame = msg.strip()
|
||
if not text_frame:
|
||
logger.debug("跳过空 WebSocket 文本帧")
|
||
continue
|
||
try:
|
||
data = json.loads(text_frame)
|
||
except json.JSONDecodeError as e:
|
||
head = text_frame[:200].replace("\n", "\\n")
|
||
raise CloudVoiceError(
|
||
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
|
||
code="INVALID_MESSAGE",
|
||
) from e
|
||
mtype = data.get("type")
|
||
|
||
if mtype == "asr.partial":
|
||
logger.debug("← asr.partial(tts 轮次,忽略)")
|
||
continue
|
||
|
||
if mtype == "llm.text_delta":
|
||
if data.get("turn_id") != turn_id:
|
||
logger.debug(
|
||
"llm.text_delta turn_id 与当前 tts 不一致,忽略",
|
||
)
|
||
continue
|
||
|
||
if mtype == "tts_audio_chunk":
|
||
_pending_tts_seq = None
|
||
if data.get("turn_id") != turn_id:
|
||
logger.warning(
|
||
"tts_audio_chunk turn_id 与 tts.synthesize 不一致,仍消费后续二进制",
|
||
)
|
||
else:
|
||
try:
|
||
sample_rate_hz = int(
|
||
data.get("sample_rate_hz") or sample_rate_hz
|
||
)
|
||
except (TypeError, ValueError):
|
||
pass
|
||
_s = data.get("seq")
|
||
try:
|
||
if _s is not None:
|
||
_pending_tts_seq = int(_s)
|
||
except (TypeError, ValueError):
|
||
_pending_tts_seq = None
|
||
if data.get("is_final"):
|
||
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
|
||
expecting_binary = True
|
||
continue
|
||
|
||
if mtype == "dialog_result":
|
||
logger.debug("tts.synthesize 收到 dialog_result(非预期),忽略")
|
||
continue
|
||
|
||
if mtype == "turn.complete":
|
||
if data.get("turn_id") != turn_id:
|
||
raise CloudVoiceError(
|
||
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
|
||
)
|
||
metrics = data.get("metrics") or {}
|
||
break
|
||
|
||
if mtype == "error":
|
||
code = str(data.get("code") or "INTERNAL")
|
||
raise CloudVoiceError(
|
||
data.get("message") or code,
|
||
code=code,
|
||
retryable=bool(data.get("retryable")),
|
||
)
|
||
|
||
logger.debug("忽略服务端消息 type=%s", mtype)
|
||
|
||
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
|
||
pcm = (
|
||
np.frombuffer(full_pcm, dtype=np.int16).copy()
|
||
if full_pcm
|
||
else np.array([], dtype=np.int16)
|
||
)
|
||
if pcm.size > 0:
|
||
mx = int(np.max(np.abs(pcm)))
|
||
if mx == 0:
|
||
logger.warning(
|
||
"tts.synthesize 收齐 PCM 但全零(服务端静音占位);总长 %s 字节",
|
||
len(full_pcm),
|
||
)
|
||
|
||
return {
|
||
"pcm": pcm,
|
||
"sample_rate_hz": sample_rate_hz,
|
||
"metrics": metrics,
|
||
}
|
||
|
||
def run_tts_synthesize(self, text: str) -> dict[str, Any]:
|
||
"""
|
||
发送 tts.synthesize,收齐 TTS 块与 turn.complete(无 dialog_result)。
|
||
与 run_turn 共用连接,互斥由服务端排队;重试策略同 ROCKET_CLOUD_TURN_RETRIES。
|
||
"""
|
||
t = (text or "").strip()
|
||
if not t:
|
||
raise CloudVoiceError("tts.synthesize text 不能为空")
|
||
|
||
try:
|
||
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
|
||
except ValueError:
|
||
raw_attempts = 3
|
||
attempts = max(1, raw_attempts)
|
||
try:
|
||
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
|
||
except ValueError:
|
||
delay = 0.35
|
||
delay = max(0.0, delay)
|
||
|
||
for attempt in range(attempts):
|
||
with self._lock:
|
||
try:
|
||
if self._ws is None:
|
||
self._connect_nolock()
|
||
return self._execute_tts_synthesize_nolock(t)
|
||
except CloudVoiceError as e:
|
||
retry = bool(e.retryable) or e.code == "DISCONNECTED"
|
||
if retry and attempt < attempts - 1:
|
||
logger.warning(
|
||
"tts.synthesize 可恢复错误,将重连并重试 (%s/%s): %s",
|
||
attempt + 1,
|
||
attempts,
|
||
e,
|
||
)
|
||
self._close_nolock()
|
||
if delay:
|
||
time.sleep(delay)
|
||
continue
|
||
raise
|
||
except Exception as e:
|
||
if _transient_ws_exc(e) and attempt < attempts - 1:
|
||
logger.warning(
|
||
"tts.synthesize WebSocket 瞬断,重连并重试 (%s/%s): %s",
|
||
attempt + 1,
|
||
attempts,
|
||
e,
|
||
)
|
||
self._close_nolock()
|
||
if delay:
|
||
time.sleep(delay)
|
||
continue
|
||
raise
|
||
|
||
raise CloudVoiceError("run_tts_synthesize 未执行", code="INTERNAL")
|
||
|
||
def run_turn(self, text: str) -> dict[str, Any]:
|
||
"""
|
||
发送一轮用户文本,收齐 dialog_result、TTS 块、turn.complete。
|
||
|
||
支持流式下行:可先于 dialog_result 收到 tts_audio_chunk+PCM 与 llm.text_delta;
|
||
飞控与最终文案仍以 dialog_result 为准。
|
||
|
||
若中间因对端已关 TCP、ping/pong Broken pipe 等断开,会自动关连接、
|
||
重连 session 并重发本轮(次数由 ROCKET_CLOUD_TURN_RETRIES 控制,默认 3)。
|
||
|
||
Returns:
|
||
dict: routing, flight_intent, chat_reply, user_input, pcm, sample_rate_hz,
|
||
metrics, llm_stream_text(llm.text_delta 拼接,可选调试/UI)
|
||
"""
|
||
t = (text or "").strip()
|
||
if not t:
|
||
raise CloudVoiceError("turn.text 不能为空")
|
||
|
||
try:
|
||
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
|
||
except ValueError:
|
||
raw_attempts = 3
|
||
attempts = max(1, raw_attempts)
|
||
try:
|
||
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
|
||
except ValueError:
|
||
delay = 0.35
|
||
delay = max(0.0, delay)
|
||
|
||
for attempt in range(attempts):
|
||
with self._lock:
|
||
try:
|
||
if self._ws is None:
|
||
self._connect_nolock()
|
||
return self._execute_turn_nolock(t)
|
||
except CloudVoiceError as e:
|
||
retry = bool(e.retryable) or e.code == "DISCONNECTED"
|
||
if retry and attempt < attempts - 1:
|
||
logger.warning(
|
||
"云端回合可恢复错误,将重连并重试 (%s/%s): %s",
|
||
attempt + 1,
|
||
attempts,
|
||
e,
|
||
)
|
||
self._close_nolock()
|
||
if delay:
|
||
time.sleep(delay)
|
||
continue
|
||
raise
|
||
except Exception as e:
|
||
if _transient_ws_exc(e) and attempt < attempts - 1:
|
||
logger.warning(
|
||
"云端 WebSocket 瞬断(如对端先关、PONG 写失败),"
|
||
"重连并重发 turn (%s/%s): %s",
|
||
attempt + 1,
|
||
attempts,
|
||
e,
|
||
)
|
||
self._close_nolock()
|
||
if delay:
|
||
time.sleep(delay)
|
||
continue
|
||
raise
|
||
|
||
raise CloudVoiceError("run_turn 未执行", code="INTERNAL")
|