DroneMind/voice_drone/core/cloud_voice_client.py
2026-04-14 09:54:26 +08:00

1000 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
云端语音 WebSocket 客户端:会话 `session.start.transport_profile` 固定为 pcm_asr_uplink。
- 主路径:`turn.audio.start` → 若干 `turn.audio.chunk`(每条仅文本 JSON含 `pcm_base64`)→ `turn.audio.end`**禁止**用 WebSocket binary 上发 PCM与 Starlette receive 语义一致)。
- 辅助:`run_turn` 发 `turn.text`(如同句快路径仅有文本);`run_tts_synthesize` 仅 TTS。
- `asr.partial` 仅调试展示,不参与机端状态机。
文档:`docs/CLOUD_VOICE_SESSION_SCHEME_v1.md``docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md`。
"""
from __future__ import annotations
import base64
import json
import os
import threading
import time
import uuid
from typing import Any
import numpy as np
from voice_drone.core.cloud_dialog_v1 import CLOUD_VOICE_DIALOG_V1
from voice_drone.logging_ import get_logger
logger = get_logger("voice_drone.cloud_voice")
_CLOUD_PROTO = "1.0"
TRANSPORT_PCM_ASR_UPLINK = "pcm_asr_uplink"
def _merge_session_client(
device_id: str,
*,
session_client_extensions: dict[str, Any] | None,
) -> dict[str, Any]:
"""session.start 的 clientcapabilities 与设备信息 + 可选 PX4 等扩展(不覆盖 device_id/locale"""
client: dict[str, Any] = {
"device_id": device_id,
"locale": "zh-CN",
"capabilities": {
"playback_sample_rate_hz": 24000,
"prefer_tts_codec": "pcm_s16le",
},
}
ext = session_client_extensions or {}
for k, v in ext.items():
if v is None or k in ("device_id", "locale", "capabilities", "protocol"):
continue
if k == "extras" and isinstance(v, dict) and len(v) == 0:
continue
client[k] = v
client["protocol"] = {"dialog_result": CLOUD_VOICE_DIALOG_V1}
return client
def _transient_ws_exc(exc: BaseException) -> bool:
"""可通过对端已关、网络抖动等通过重连重发 turn 恢复的异常。"""
import websocket as _websocket # noqa: PLC0415
if isinstance(
exc,
(
BrokenPipeError,
ConnectionResetError,
ConnectionAbortedError,
),
):
return True
if isinstance(
exc,
(
_websocket.WebSocketConnectionClosedException,
_websocket.WebSocketTimeoutException,
),
):
return True
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (
32,
104,
110,
): # EPIPE, ECONNRESET, ETIMEDOUT
return True
return False
def _merge_tts_pcm_chunks(
chunk_entries: list[tuple[int | None, int, bytes]],
) -> bytes:
"""按 seq 升序拼接;无 seq 时按到达顺序。chunk_entries: (seq|None, arrival, pcm)。"""
if not chunk_entries:
return b""
if all(s is not None for s, _, _ in chunk_entries):
ordered = sorted(chunk_entries, key=lambda x: (x[0], x[1]))
seqs = [x[0] for x in ordered]
for a, b in zip(seqs, seqs[1:]):
if b != a + 1:
logger.warning("TTS seq 不连续(仍按序拼接): %s%s", a, b)
break
return b"".join(x[2] for x in ordered)
return b"".join(x[2] for x in sorted(chunk_entries, key=lambda x: x[1]))
class CloudVoiceError(RuntimeError):
"""云端返回 error 消息或协议不符合预期。"""
def __init__(self, message: str, *, code: str | None = None, retryable: bool = False):
super().__init__(message)
self.code = code
self.retryable = retryable
class CloudVoiceClient:
"""连接 ws://…/v1/voice/sessionsession 为 pcm_asr_uplink含 run_turn_audio / run_turn / tts.synthesize。"""
def __init__(
self,
*,
server_url: str,
auth_token: str,
device_id: str,
recv_timeout: float = 120.0,
session_client_extensions: dict[str, Any] | None = None,
) -> None:
self.server_url = server_url.strip()
self.auth_token = auth_token.strip()
self.device_id = (device_id or "drone-001").strip()
self.recv_timeout = float(recv_timeout)
self._session_client_extensions: dict[str, Any] = dict(
session_client_extensions or {}
)
self._transport_profile: str = TRANSPORT_PCM_ASR_UPLINK
self._ws: Any = None
self._session_id: str | None = None
self._lock = threading.Lock()
@property
def connected(self) -> bool:
with self._lock:
return self._ws is not None
def close(self) -> None:
with self._lock:
self._close_nolock()
def _close_nolock(self) -> None:
if self._ws is None:
self._session_id = None
return
try:
if self._session_id:
try:
self._ws.send(
json.dumps(
{
"type": "session.end",
"proto_version": _CLOUD_PROTO,
"session_id": self._session_id,
},
ensure_ascii=False,
)
)
except Exception: # noqa: BLE001
pass
finally:
try:
self._ws.close()
except Exception: # noqa: BLE001
pass
self._ws = None
self._session_id = None
def connect(self) -> None:
"""建立 WSS发送 session.start等待 session.ready。"""
with self._lock:
self._connect_nolock()
def _connect_nolock(self) -> None:
import websocket # websocket-client
self._close_nolock()
hdr = [f"Authorization: Bearer {self.auth_token}"]
try:
self._ws = websocket.create_connection(
self.server_url,
header=hdr,
timeout=self.recv_timeout,
)
self._ws.settimeout(self.recv_timeout)
self._session_id = str(uuid.uuid4())
client_payload = _merge_session_client(
self.device_id,
session_client_extensions=self._session_client_extensions,
)
if self._session_client_extensions:
logger.info(
"session.start 已附加 client 扩展键: %s",
sorted(self._session_client_extensions.keys()),
)
start = {
"type": "session.start",
"proto_version": _CLOUD_PROTO,
"transport_profile": self._transport_profile,
"session_id": self._session_id,
"auth_token": self.auth_token,
"client": client_payload,
}
self._ws.send(json.dumps(start, ensure_ascii=False))
raw = self._ws.recv()
if isinstance(raw, bytes):
raise CloudVoiceError("session.ready 期望 JSON 文本帧,收到二进制")
data = json.loads(raw)
if data.get("type") != "session.ready":
raise CloudVoiceError(
f"期望 session.ready收到: {data.get('type')!r}",
code="INVALID_MESSAGE",
)
logger.info("云端会话已就绪 session_id=%s", self._session_id)
except Exception:
self._close_nolock()
raise
def ensure_connected(self) -> None:
with self._lock:
if self._ws is None:
self._connect_nolock()
def _execute_turn_nolock(self, t: str) -> dict[str, Any]:
"""已持锁且 _ws 已连接:发送 turn.text 并收齐本轮帧。"""
import websocket # websocket-client
ws = self._ws
if ws is None:
raise CloudVoiceError("WebSocket 未连接")
turn_id = str(uuid.uuid4())
turn_msg = {
"type": "turn.text",
"proto_version": _CLOUD_PROTO,
"transport_profile": self._transport_profile,
"turn_id": turn_id,
"text": t,
"is_final": True,
"source": "device_stt",
}
try:
ws.send(json.dumps(turn_msg, ensure_ascii=False))
except Exception as e:
if _transient_ws_exc(e):
raise
raise CloudVoiceError(f"发送 turn.text 失败: {e}", code="INTERNAL") from e
logger.debug("→ turn.text turn_id=%s", turn_id)
expecting_binary = False
_pending_tts_seq: int | None = None
pcm_entries: list[tuple[int | None, int, bytes]] = []
_pcm_arrival = 0
llm_stream_parts: list[str] = []
dialog: dict[str, Any] | None = None
metrics: dict[str, Any] = {}
sample_rate_hz = 24000
while True:
try:
msg = ws.recv()
except websocket.WebSocketConnectionClosedException as e:
raise CloudVoiceError(
f"连接已断开: {e}",
code="DISCONNECTED",
retryable=True,
) from e
except Exception as e:
if _transient_ws_exc(e):
raise
raise
if isinstance(msg, bytes):
if expecting_binary:
expecting_binary = False
else:
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
_pcm_arrival += 1
_pending_tts_seq = None
continue
if not isinstance(msg, str):
raise CloudVoiceError(
f"期望文本帧为 str实际为 {type(msg).__name__}",
code="INVALID_MESSAGE",
)
text_frame = msg.strip()
if not text_frame:
logger.debug("跳过空 WebSocket 文本帧")
continue
try:
data = json.loads(text_frame)
except json.JSONDecodeError as e:
head = text_frame[:200].replace("\n", "\\n")
raise CloudVoiceError(
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
code="INVALID_MESSAGE",
) from e
mtype = data.get("type")
if mtype == "asr.partial":
logger.debug("← asr.partial机端不参与状态跳转")
continue
if mtype == "llm.text_delta":
if data.get("turn_id") != turn_id:
logger.debug(
"llm.text_delta turn_id 与当前不一致,忽略 type=%s",
mtype,
)
continue
raw_d = data.get("delta")
delta = "" if raw_d is None else str(raw_d)
llm_stream_parts.append(delta)
_print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in (
"1",
"true",
"yes",
)
if _print_stream:
print(delta, end="", flush=True)
if data.get("done"):
print(flush=True)
logger.debug(
"← llm.text_delta done=%s delta_chars=%s",
data.get("done"),
len(delta),
)
continue
if mtype == "tts_audio_chunk":
_pending_tts_seq = None
if data.get("turn_id") != turn_id:
logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制")
else:
try:
sample_rate_hz = int(
data.get("sample_rate_hz") or sample_rate_hz
)
except (TypeError, ValueError):
pass
_s = data.get("seq")
try:
if _s is not None:
_pending_tts_seq = int(_s)
except (TypeError, ValueError):
_pending_tts_seq = None
if data.get("is_final"):
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
expecting_binary = True
continue
if mtype == "dialog_result":
if data.get("turn_id") != turn_id:
raise CloudVoiceError(
"dialog_result turn_id 不匹配", code="INVALID_MESSAGE"
)
dialog = data
logger.info(
"← dialog_result routing=%s", data.get("routing")
)
continue
if mtype == "turn.complete":
if data.get("turn_id") != turn_id:
raise CloudVoiceError(
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
)
metrics = data.get("metrics") or {}
break
if mtype == "error":
code = str(data.get("code") or "INTERNAL")
raise CloudVoiceError(
data.get("message") or code,
code=code,
retryable=bool(data.get("retryable")),
)
logger.debug("忽略服务端消息 type=%s", mtype)
if dialog is None:
raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE")
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
pcm = (
np.frombuffer(full_pcm, dtype=np.int16).copy()
if full_pcm
else np.array([], dtype=np.int16)
)
if pcm.size > 0:
mx = int(np.max(np.abs(pcm)))
if mx == 0:
logger.warning(
"云端 TTS 已收齐二进制总长 %s 字节(≈%s 个 s16 采样),但全为 0x00"
"属于服务端发出的静音占位或未写入合成结果;机端无法通过重采样/扬声器修复。"
"请在服务端对同一次 synthesize 写 WAV 核对非零采样,并确认 WS 先发 tts_audio_chunk JSON、"
"再发 raw PCM 帧、且未把 JSON/base64 误当 binary 发出。",
len(full_pcm),
pcm.size,
)
if os.environ.get("ROCKET_CLOUD_PCM_HEX", "").strip().lower() in (
"1",
"true",
"yes",
):
head = full_pcm[:64]
logger.warning(
"ROCKET_CLOUD_PCM_HEX: 前 %s 字节 hex=%s",
len(head),
head.hex(),
)
llm_stream_text = "".join(llm_stream_parts)
return {
"protocol": dialog.get("protocol"),
"routing": dialog.get("routing"),
"flight_intent": dialog.get("flight_intent"),
"confirm": dialog.get("confirm"),
"chat_reply": dialog.get("chat_reply"),
"user_input": dialog.get("user_input"),
"pcm": pcm,
"sample_rate_hz": sample_rate_hz,
"metrics": metrics,
"llm_stream_text": llm_stream_text,
}
def _execute_turn_audio_nolock(
self, pcm_int16: np.ndarray, sample_rate_hz: int
) -> dict[str, Any]:
"""发送 turn.audio.start → 多条 turn.audio.chunkpcm_base64 文本帧)→ turn.audio.end禁止 binary 上发 PCM。"""
import websocket # websocket-client
ws = self._ws
if ws is None:
raise CloudVoiceError("WebSocket 未连接")
pcm_int16 = np.asarray(pcm_int16, dtype=np.int16).reshape(-1)
if pcm_int16.size == 0:
raise CloudVoiceError("turn.audio PCM 为空")
pcm_mx = int(np.max(np.abs(pcm_int16)))
pcm_rms = float(np.sqrt(np.mean(pcm_int16.astype(np.float64) ** 2)))
dur_sec = float(pcm_int16.size) / max(1, int(sample_rate_hz))
logger.info(
"turn.audio 上行: samples=%s sr_hz=%s dur≈%.2fs abs_max=%s rms=%.1f dtype=int16",
pcm_int16.size,
int(sample_rate_hz),
dur_sec,
pcm_mx,
pcm_rms,
)
if pcm_mx == 0:
logger.warning(
"turn.audio 上行波形全零,云端 ASR 通常会判无有效语音(请查麦/切段/VAD 是否误交静音)"
)
elif pcm_mx < 200:
logger.warning(
"turn.audio 上行幅值极小 abs_max=%s(仍发送);若云端反复无识别请调 AGC/VAD/麦增益",
pcm_mx,
)
turn_id = str(uuid.uuid4())
start = {
"type": "turn.audio.start",
"proto_version": _CLOUD_PROTO,
"transport_profile": self._transport_profile,
"turn_id": turn_id,
"sample_rate_hz": int(sample_rate_hz),
"codec": "pcm_s16le",
"channels": 1,
}
raw = pcm_int16.tobytes()
try:
ws.send(json.dumps(start, ensure_ascii=False))
try:
raw_chunk = int(os.environ.get("ROCKET_CLOUD_AUDIO_CHUNK_BYTES", "8192"))
except ValueError:
raw_chunk = 8192
raw_chunk = max(2048, min(256 * 1024, raw_chunk))
n_chunks = 0
for i in range(0, len(raw), raw_chunk):
piece = raw[i : i + raw_chunk]
chunk_msg = {
"type": "turn.audio.chunk",
"proto_version": _CLOUD_PROTO,
"transport_profile": self._transport_profile,
"turn_id": turn_id,
"pcm_base64": base64.b64encode(piece).decode("ascii"),
}
ws.send(json.dumps(chunk_msg, ensure_ascii=False))
n_chunks += 1
end = {
"type": "turn.audio.end",
"proto_version": _CLOUD_PROTO,
"transport_profile": self._transport_profile,
"turn_id": turn_id,
}
ws.send(json.dumps(end, ensure_ascii=False))
except Exception as e:
if _transient_ws_exc(e):
raise
raise CloudVoiceError(f"发送 turn.audio 失败: {e}", code="INTERNAL") from e
logger.debug(
"→ turn.audio start/%s chunk(s)/end turn_id=%s samples=%s",
n_chunks,
turn_id,
pcm_int16.size,
)
expecting_binary = False
_pending_tts_seq: int | None = None
pcm_entries: list[tuple[int | None, int, bytes]] = []
_pcm_arrival = 0
llm_stream_parts: list[str] = []
dialog: dict[str, Any] | None = None
metrics: dict[str, Any] = {}
out_sr = 24000
while True:
try:
msg = ws.recv()
except websocket.WebSocketConnectionClosedException as e:
raise CloudVoiceError(
f"连接已断开: {e}",
code="DISCONNECTED",
retryable=True,
) from e
except Exception as e:
if _transient_ws_exc(e):
raise
raise
if isinstance(msg, bytes):
if expecting_binary:
expecting_binary = False
else:
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
_pcm_arrival += 1
_pending_tts_seq = None
continue
if not isinstance(msg, str):
raise CloudVoiceError(
f"期望文本帧为 str实际为 {type(msg).__name__}",
code="INVALID_MESSAGE",
)
text_frame = msg.strip()
if not text_frame:
logger.debug("跳过空 WebSocket 文本帧")
continue
try:
data = json.loads(text_frame)
except json.JSONDecodeError as e:
head = text_frame[:200].replace("\n", "\\n")
raise CloudVoiceError(
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
code="INVALID_MESSAGE",
) from e
mtype = data.get("type")
if mtype == "asr.partial":
logger.debug("← asr.partial机端不参与状态跳转")
continue
if mtype == "llm.text_delta":
if data.get("turn_id") != turn_id:
logger.debug(
"llm.text_delta turn_id 与当前不一致,忽略 type=%s",
mtype,
)
continue
raw_d = data.get("delta")
delta = "" if raw_d is None else str(raw_d)
llm_stream_parts.append(delta)
_print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in (
"1",
"true",
"yes",
)
if _print_stream:
print(delta, end="", flush=True)
if data.get("done"):
print(flush=True)
logger.debug(
"← llm.text_delta done=%s delta_chars=%s",
data.get("done"),
len(delta),
)
continue
if mtype == "tts_audio_chunk":
_pending_tts_seq = None
if data.get("turn_id") != turn_id:
logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制")
else:
try:
out_sr = int(data.get("sample_rate_hz") or out_sr)
except (TypeError, ValueError):
pass
_s = data.get("seq")
try:
if _s is not None:
_pending_tts_seq = int(_s)
except (TypeError, ValueError):
_pending_tts_seq = None
if data.get("is_final"):
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
expecting_binary = True
continue
if mtype == "dialog_result":
if data.get("turn_id") != turn_id:
raise CloudVoiceError(
"dialog_result turn_id 不匹配", code="INVALID_MESSAGE"
)
dialog = data
logger.info(
"← dialog_result routing=%s", data.get("routing")
)
continue
if mtype == "turn.complete":
if data.get("turn_id") != turn_id:
raise CloudVoiceError(
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
)
metrics = data.get("metrics") or {}
break
if mtype == "error":
code = str(data.get("code") or "INTERNAL")
raise CloudVoiceError(
data.get("message") or code,
code=code,
retryable=bool(data.get("retryable")),
)
logger.debug("忽略服务端消息 type=%s", mtype)
if dialog is None:
raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE")
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
out_pcm = (
np.frombuffer(full_pcm, dtype=np.int16).copy()
if full_pcm
else np.array([], dtype=np.int16)
)
if out_pcm.size > 0:
mx = int(np.max(np.abs(out_pcm)))
if mx == 0:
logger.warning(
"云端 TTS 已收齐但全零采样,请核对服务端。",
)
llm_stream_text = "".join(llm_stream_parts)
return {
"protocol": dialog.get("protocol"),
"routing": dialog.get("routing"),
"flight_intent": dialog.get("flight_intent"),
"confirm": dialog.get("confirm"),
"chat_reply": dialog.get("chat_reply"),
"user_input": dialog.get("user_input"),
"pcm": out_pcm,
"sample_rate_hz": out_sr,
"metrics": metrics,
"llm_stream_text": llm_stream_text,
}
def run_turn_audio(
self, pcm_int16: np.ndarray, sample_rate_hz: int
) -> dict[str, Any]:
"""上行一轮麦克风 PCMchunk 均为含 pcm_base64 的文本 JSON收齐 dialog_result + TTS + turn.complete。"""
try:
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
except ValueError:
raw_attempts = 3
attempts = max(1, raw_attempts)
try:
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
except ValueError:
delay = 0.35
delay = max(0.0, delay)
for attempt in range(attempts):
with self._lock:
try:
if self._ws is None:
self._connect_nolock()
return self._execute_turn_audio_nolock(pcm_int16, sample_rate_hz)
except CloudVoiceError as e:
retry = bool(e.retryable) or e.code == "DISCONNECTED"
if retry and attempt < attempts - 1:
logger.warning(
"turn.audio 可恢复错误,重连重试 (%s/%s): %s",
attempt + 1,
attempts,
e,
)
self._close_nolock()
if delay:
time.sleep(delay)
continue
raise
except Exception as e:
if _transient_ws_exc(e) and attempt < attempts - 1:
logger.warning(
"turn.audio WebSocket 瞬断,重连重试 (%s/%s): %s",
attempt + 1,
attempts,
e,
)
self._close_nolock()
if delay:
time.sleep(delay)
continue
raise
raise CloudVoiceError("run_turn_audio 未执行", code="INTERNAL")
def _execute_tts_synthesize_nolock(self, text: str) -> dict[str, Any]:
"""已持锁且 _ws 已连接:发送 tts.synthesize仅收 tts_audio_chunk* 与 turn.complete无 dialog_result"""
import websocket # websocket-client
ws = self._ws
if ws is None:
raise CloudVoiceError("WebSocket 未连接")
turn_id = str(uuid.uuid4())
synth_msg = {
"type": "tts.synthesize",
"proto_version": _CLOUD_PROTO,
"transport_profile": self._transport_profile,
"turn_id": turn_id,
"text": text,
}
try:
ws.send(json.dumps(synth_msg, ensure_ascii=False))
except Exception as e:
if _transient_ws_exc(e):
raise
raise CloudVoiceError(f"发送 tts.synthesize 失败: {e}", code="INTERNAL") from e
logger.debug("→ tts.synthesize turn_id=%s", turn_id)
expecting_binary = False
_pending_tts_seq: int | None = None
pcm_entries: list[tuple[int | None, int, bytes]] = []
_pcm_arrival = 0
metrics: dict[str, Any] = {}
sample_rate_hz = 24000
while True:
try:
msg = ws.recv()
except websocket.WebSocketConnectionClosedException as e:
raise CloudVoiceError(
f"连接已断开: {e}",
code="DISCONNECTED",
retryable=True,
) from e
except Exception as e:
if _transient_ws_exc(e):
raise
raise
if isinstance(msg, bytes):
if expecting_binary:
expecting_binary = False
else:
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
_pcm_arrival += 1
_pending_tts_seq = None
continue
if not isinstance(msg, str):
raise CloudVoiceError(
f"期望文本帧为 str实际为 {type(msg).__name__}",
code="INVALID_MESSAGE",
)
text_frame = msg.strip()
if not text_frame:
logger.debug("跳过空 WebSocket 文本帧")
continue
try:
data = json.loads(text_frame)
except json.JSONDecodeError as e:
head = text_frame[:200].replace("\n", "\\n")
raise CloudVoiceError(
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
code="INVALID_MESSAGE",
) from e
mtype = data.get("type")
if mtype == "asr.partial":
logger.debug("← asr.partialtts 轮次,忽略)")
continue
if mtype == "llm.text_delta":
if data.get("turn_id") != turn_id:
logger.debug(
"llm.text_delta turn_id 与当前 tts 不一致,忽略",
)
continue
if mtype == "tts_audio_chunk":
_pending_tts_seq = None
if data.get("turn_id") != turn_id:
logger.warning(
"tts_audio_chunk turn_id 与 tts.synthesize 不一致,仍消费后续二进制",
)
else:
try:
sample_rate_hz = int(
data.get("sample_rate_hz") or sample_rate_hz
)
except (TypeError, ValueError):
pass
_s = data.get("seq")
try:
if _s is not None:
_pending_tts_seq = int(_s)
except (TypeError, ValueError):
_pending_tts_seq = None
if data.get("is_final"):
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
expecting_binary = True
continue
if mtype == "dialog_result":
logger.debug("tts.synthesize 收到 dialog_result非预期忽略")
continue
if mtype == "turn.complete":
if data.get("turn_id") != turn_id:
raise CloudVoiceError(
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
)
metrics = data.get("metrics") or {}
break
if mtype == "error":
code = str(data.get("code") or "INTERNAL")
raise CloudVoiceError(
data.get("message") or code,
code=code,
retryable=bool(data.get("retryable")),
)
logger.debug("忽略服务端消息 type=%s", mtype)
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
pcm = (
np.frombuffer(full_pcm, dtype=np.int16).copy()
if full_pcm
else np.array([], dtype=np.int16)
)
if pcm.size > 0:
mx = int(np.max(np.abs(pcm)))
if mx == 0:
logger.warning(
"tts.synthesize 收齐 PCM 但全零(服务端静音占位);总长 %s 字节",
len(full_pcm),
)
return {
"pcm": pcm,
"sample_rate_hz": sample_rate_hz,
"metrics": metrics,
}
def run_tts_synthesize(self, text: str) -> dict[str, Any]:
"""
发送 tts.synthesize收齐 TTS 块与 turn.complete无 dialog_result
与 run_turn 共用连接,互斥由服务端排队;重试策略同 ROCKET_CLOUD_TURN_RETRIES。
"""
t = (text or "").strip()
if not t:
raise CloudVoiceError("tts.synthesize text 不能为空")
try:
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
except ValueError:
raw_attempts = 3
attempts = max(1, raw_attempts)
try:
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
except ValueError:
delay = 0.35
delay = max(0.0, delay)
for attempt in range(attempts):
with self._lock:
try:
if self._ws is None:
self._connect_nolock()
return self._execute_tts_synthesize_nolock(t)
except CloudVoiceError as e:
retry = bool(e.retryable) or e.code == "DISCONNECTED"
if retry and attempt < attempts - 1:
logger.warning(
"tts.synthesize 可恢复错误,将重连并重试 (%s/%s): %s",
attempt + 1,
attempts,
e,
)
self._close_nolock()
if delay:
time.sleep(delay)
continue
raise
except Exception as e:
if _transient_ws_exc(e) and attempt < attempts - 1:
logger.warning(
"tts.synthesize WebSocket 瞬断,重连并重试 (%s/%s): %s",
attempt + 1,
attempts,
e,
)
self._close_nolock()
if delay:
time.sleep(delay)
continue
raise
raise CloudVoiceError("run_tts_synthesize 未执行", code="INTERNAL")
def run_turn(self, text: str) -> dict[str, Any]:
"""
发送一轮用户文本,收齐 dialog_result、TTS 块、turn.complete。
支持流式下行:可先于 dialog_result 收到 tts_audio_chunk+PCM 与 llm.text_delta
飞控与最终文案仍以 dialog_result 为准。
若中间因对端已关 TCP、ping/pong Broken pipe 等断开,会自动关连接、
重连 session 并重发本轮(次数由 ROCKET_CLOUD_TURN_RETRIES 控制,默认 3
Returns:
dict: routing, flight_intent, chat_reply, user_input, pcm, sample_rate_hz,
metrics, llm_stream_textllm.text_delta 拼接,可选调试/UI
"""
t = (text or "").strip()
if not t:
raise CloudVoiceError("turn.text 不能为空")
try:
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
except ValueError:
raw_attempts = 3
attempts = max(1, raw_attempts)
try:
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
except ValueError:
delay = 0.35
delay = max(0.0, delay)
for attempt in range(attempts):
with self._lock:
try:
if self._ws is None:
self._connect_nolock()
return self._execute_turn_nolock(t)
except CloudVoiceError as e:
retry = bool(e.retryable) or e.code == "DISCONNECTED"
if retry and attempt < attempts - 1:
logger.warning(
"云端回合可恢复错误,将重连并重试 (%s/%s): %s",
attempt + 1,
attempts,
e,
)
self._close_nolock()
if delay:
time.sleep(delay)
continue
raise
except Exception as e:
if _transient_ws_exc(e) and attempt < attempts - 1:
logger.warning(
"云端 WebSocket 瞬断如对端先关、PONG 写失败),"
"重连并重发 turn (%s/%s): %s",
attempt + 1,
attempts,
e,
)
self._close_nolock()
if delay:
time.sleep(delay)
continue
raise
raise CloudVoiceError("run_turn 未执行", code="INTERNAL")