872 lines
31 KiB
Python
872 lines
31 KiB
Python
"""
|
||
WebSocket 消息处理器
|
||
处理协议定义的所有消息类型和时序
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import time
|
||
from typing import Optional
|
||
import numpy as np
|
||
from loguru import logger
|
||
|
||
from app.protocols.models import *
|
||
from app.protocols.validators import *
|
||
from app.websocket.session import SessionContext, session_manager
|
||
from app.services.llm_service import LLMServiceInterface
|
||
from app.services.tts_service import TTSServiceInterface, TTSAudioChunk
|
||
from app.services.intent_service import (
|
||
parse_flight_intent_reply,
|
||
get_tts_text,
|
||
allows_incremental_tts,
|
||
take_next_speech_segment,
|
||
should_recover_failed_flight_output,
|
||
)
|
||
from app.config import settings
|
||
from app.utils.audio import tts_chunk_to_pcm_s16le
|
||
from app.providers.dashscope_fun_asr_turn import (
|
||
ActiveFunAsrTurn,
|
||
FunAsrTurnState,
|
||
build_fun_asr_recognition,
|
||
)
|
||
|
||
|
||
class MessageHandler:
|
||
"""WebSocket 消息处理器"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_service: LLMServiceInterface,
|
||
tts_service: TTSServiceInterface,
|
||
):
|
||
self.llm_service = llm_service
|
||
self.tts_service = tts_service
|
||
|
||
async def handle_message(self, session_id: str, data: dict):
|
||
"""
|
||
路由消息到对应的处理器
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
data: 解析后的 JSON 数据
|
||
"""
|
||
msg_type = data.get("type")
|
||
|
||
if msg_type == "session.start":
|
||
await self._handle_session_start(session_id, data)
|
||
elif msg_type == "session.end":
|
||
await self._handle_session_end(session_id, data)
|
||
elif msg_type == "turn.audio.start":
|
||
await self._handle_turn_audio_start(session_id, data)
|
||
elif msg_type == "turn.audio.chunk":
|
||
await self._handle_turn_audio_chunk(session_id, data)
|
||
elif msg_type == "turn.audio.end":
|
||
await self._handle_turn_audio_end(session_id, data)
|
||
elif msg_type == "turn.text":
|
||
await self._handle_turn_text(session_id, data)
|
||
elif msg_type == "tts.synthesize":
|
||
await self._handle_tts_synthesize(session_id, data)
|
||
else:
|
||
# text_uplink 下禁止的旧版 turn.audio_chunk / turn.audio_end
|
||
is_audio, error_code = validate_not_audio_profile(data)
|
||
if is_audio:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=error_code,
|
||
message=(
|
||
"text_uplink 不支持该音频上行格式"
|
||
"(请使用 session.start transport_profile=pcm_asr_uplink 与 turn.audio.*)"
|
||
),
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
)
|
||
)
|
||
else:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message=f"未知的消息类型: {msg_type}",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
)
|
||
)
|
||
|
||
async def _handle_session_start(self, session_id: str, data: dict):
|
||
"""处理 session.start"""
|
||
valid, msg, error = validate_session_start(data)
|
||
|
||
if not valid:
|
||
logger.warning(f"session.start 验证失败: {error}")
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.UNAUTHORIZED if "鉴权" in error else ErrorCode.INVALID_MESSAGE,
|
||
message=error,
|
||
retryable=False,
|
||
)
|
||
)
|
||
return
|
||
|
||
# 获取或创建会话
|
||
ctx = await session_manager.get_session(session_id)
|
||
if not ctx:
|
||
logger.error(f"会话不存在: {session_id}")
|
||
return
|
||
|
||
# 更新会话信息
|
||
ctx.device_id = msg.client.device_id
|
||
ctx.client_info = msg.client
|
||
ctx.is_ready = True
|
||
ctx.transport_profile = msg.transport_profile
|
||
ctx.active_fun_asr = None
|
||
ctx.update_activity()
|
||
ctx.dialog_protocol = ""
|
||
if msg.client.protocol and msg.client.protocol.dialog_result:
|
||
ctx.dialog_protocol = msg.client.protocol.dialog_result.strip()
|
||
|
||
px4_vc = None
|
||
if msg.client.px4:
|
||
px4_vc = msg.client.px4.vehicle_class
|
||
logger.info(
|
||
f"会话已建立: {session_id[:8]}..., "
|
||
f"device={msg.client.device_id}, "
|
||
f"profile={data.get('transport_profile')}"
|
||
+ (f", px4_vehicle_class={px4_vc}" if px4_vc else "")
|
||
)
|
||
|
||
# 返回 session.ready
|
||
accepts_audio = ctx.transport_profile == PCM_ASR_TRANSPORT_PROFILE
|
||
ready_msg = SessionReadyMessage(
|
||
session_id=session_id,
|
||
transport_profile=ctx.transport_profile,
|
||
server_caps=ServerCapabilities(
|
||
llm_context_turns=settings.LLM_CONTEXT_TURNS,
|
||
accepts_audio_uplink=accepts_audio,
|
||
),
|
||
)
|
||
|
||
await session_manager.send_json(
|
||
session_id,
|
||
ready_msg.model_dump(exclude_none=True)
|
||
)
|
||
|
||
async def _handle_session_end(self, session_id: str, data: dict):
|
||
"""处理 session.end"""
|
||
logger.info(f"会话结束: {session_id[:8]}...")
|
||
ctx = await session_manager.get_session(session_id)
|
||
if ctx and ctx.active_fun_asr:
|
||
try:
|
||
await asyncio.to_thread(ctx.active_fun_asr.recognition.stop)
|
||
except Exception as e:
|
||
logger.warning(f"session.end 关闭 Fun-ASR 时: {e}")
|
||
ctx.active_fun_asr = None
|
||
await session_manager.close_session(session_id)
|
||
|
||
async def _handle_turn_audio_start(self, session_id: str, data: dict):
|
||
"""pcm_asr_uplink:开启一轮阿里云 Fun-ASR 实时识别。"""
|
||
ctx = await session_manager.get_session(session_id)
|
||
if not ctx:
|
||
return
|
||
if not ctx.is_ready:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="请先发送 session.start",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
if ctx.transport_profile != PCM_ASR_TRANSPORT_PROFILE:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="当前会话不是 pcm_asr_uplink,不可使用 turn.audio.*",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
if ctx.active_fun_asr:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="已有进行中的语音识别,请先 turn.audio.end",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
valid, msg, error = validate_turn_audio_start(data)
|
||
if not valid:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message=error or "无效请求",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
if msg.sample_rate_hz != settings.ASR_AUDIO_SAMPLE_RATE:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message=(
|
||
f"sample_rate_hz 须为 {settings.ASR_AUDIO_SAMPLE_RATE} "
|
||
f"(Fun-ASR 与本服务约定)"
|
||
),
|
||
turn_id=msg.turn_id,
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
|
||
loop = asyncio.get_running_loop()
|
||
st = FunAsrTurnState(turn_id=msg.turn_id)
|
||
profile = ctx.transport_profile
|
||
|
||
async def _partial(text: str, is_final: bool) -> None:
|
||
if not text.strip():
|
||
return
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_asr_partial(
|
||
turn_id=msg.turn_id,
|
||
text=text,
|
||
is_final=is_final,
|
||
transport_profile=profile,
|
||
),
|
||
)
|
||
|
||
async def _asr_err(m: str) -> None:
|
||
c = await session_manager.get_session(session_id)
|
||
if c:
|
||
c.active_fun_asr = None
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.ASR_FAILED,
|
||
message=m,
|
||
turn_id=msg.turn_id,
|
||
retryable=True,
|
||
),
|
||
)
|
||
|
||
rec = build_fun_asr_recognition(
|
||
loop=loop,
|
||
state=st,
|
||
on_partial=_partial,
|
||
on_error_msg=_asr_err,
|
||
)
|
||
try:
|
||
await asyncio.to_thread(rec.start)
|
||
except Exception as e:
|
||
logger.error(f"Fun-ASR start 失败: {e}", exc_info=True)
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.ASR_FAILED,
|
||
message=f"语音识别启动失败: {str(e)}",
|
||
turn_id=msg.turn_id,
|
||
retryable=True,
|
||
),
|
||
)
|
||
return
|
||
|
||
ctx.active_fun_asr = ActiveFunAsrTurn(
|
||
turn_id=msg.turn_id,
|
||
recognition=rec,
|
||
state=st,
|
||
)
|
||
logger.info(
|
||
f"[ASR] Fun-ASR 开始 turn={msg.turn_id[:8]} "
|
||
f"model={settings.DASHSCOPE_ASR_MODEL} sr={settings.ASR_AUDIO_SAMPLE_RATE}"
|
||
)
|
||
|
||
async def _handle_turn_audio_chunk(self, session_id: str, data: dict):
|
||
ctx = await session_manager.get_session(session_id)
|
||
if not ctx or not ctx.is_ready:
|
||
return
|
||
valid, msg, error = validate_turn_audio_chunk(data)
|
||
if not valid:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message=error or "无效请求",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
active = ctx.active_fun_asr
|
||
if not active or active.turn_id != msg.turn_id:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="无匹配的 turn.audio.start 或未开始识别",
|
||
turn_id=msg.turn_id,
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
try:
|
||
pcm = base64.b64decode(msg.pcm_base64)
|
||
except Exception:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="pcm_base64 解码失败",
|
||
turn_id=msg.turn_id,
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
if not pcm:
|
||
return
|
||
try:
|
||
await asyncio.to_thread(active.recognition.send_audio_frame, pcm)
|
||
except Exception as e:
|
||
logger.error(f"Fun-ASR send_audio_frame: {e}", exc_info=True)
|
||
ctx.active_fun_asr = None
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.ASR_FAILED,
|
||
message=f"上行音频失败: {str(e)}",
|
||
turn_id=msg.turn_id,
|
||
retryable=True,
|
||
),
|
||
)
|
||
|
||
async def _handle_turn_audio_end(self, session_id: str, data: dict):
|
||
ctx = await session_manager.get_session(session_id)
|
||
if not ctx or not ctx.is_ready:
|
||
return
|
||
valid, msg, error = validate_turn_audio_end(data)
|
||
if not valid:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message=error or "无效请求",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
active = ctx.active_fun_asr
|
||
if not active or active.turn_id != msg.turn_id:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="无匹配的 turn.audio.start",
|
||
turn_id=msg.turn_id,
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
rec = active.recognition
|
||
st = active.state
|
||
ctx.active_fun_asr = None
|
||
try:
|
||
await asyncio.to_thread(rec.stop)
|
||
except Exception as e:
|
||
logger.warning(f"Fun-ASR stop: {e}")
|
||
|
||
with st.lock:
|
||
err = st.error
|
||
finals = list(st.final_texts)
|
||
last = (st.last_text or "").strip()
|
||
if err:
|
||
return
|
||
|
||
user_text = "".join(finals).strip() or last
|
||
if not user_text:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.ASR_FAILED,
|
||
message="未识别到有效语音内容",
|
||
turn_id=msg.turn_id,
|
||
retryable=True,
|
||
),
|
||
)
|
||
return
|
||
|
||
turn_msg = TurnTextMessage(
|
||
turn_id=msg.turn_id,
|
||
text=user_text,
|
||
transport_profile=ctx.transport_profile,
|
||
is_final=True,
|
||
source=SourceType.CLOUD_FUN_ASR,
|
||
)
|
||
logger.info(
|
||
f"[ASR→LLM] session={session_id[:8]} turn={msg.turn_id[:8]} "
|
||
f"识别文本: {user_text[:80]}"
|
||
)
|
||
async with ctx.pipeline_lock:
|
||
await self._run_turn_text_locked(session_id, ctx, turn_msg)
|
||
|
||
async def _handle_turn_text(self, session_id: str, data: dict):
|
||
"""
|
||
处理 turn.text - 核心业务逻辑
|
||
|
||
时序:
|
||
1. 验证消息
|
||
2. 流式 LLM(llm.text_delta)+ 闲聊时分句流式 TTS
|
||
3. 解析意图,发送 dialog_result
|
||
4. 补播剩余 TTS(飞控固定话术 / 闲聊尾部等)
|
||
5. tts 结束帧、turn.complete
|
||
"""
|
||
ctx = await session_manager.get_session(session_id)
|
||
if not ctx:
|
||
return
|
||
|
||
if not ctx.is_ready:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="请先发送 session.start",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
)
|
||
)
|
||
return
|
||
|
||
if ctx.active_fun_asr:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="语音识别进行中,请先完成 turn.audio.end",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
|
||
# 验证消息
|
||
valid, msg, error = validate_turn_text(data)
|
||
if not valid:
|
||
logger.warning(f"turn.text 验证失败: {error}")
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message=error,
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
)
|
||
)
|
||
return
|
||
|
||
turn_id = msg.turn_id
|
||
user_text = msg.text
|
||
|
||
logger.info(
|
||
f"[回合] session={session_id[:8]} turn={turn_id[:8]} "
|
||
f"收到文本: {user_text[:50]}"
|
||
)
|
||
|
||
async with ctx.pipeline_lock:
|
||
await self._run_turn_text_locked(session_id, ctx, msg)
|
||
|
||
async def _run_turn_text_locked(
|
||
self,
|
||
session_id: str,
|
||
ctx: SessionContext,
|
||
msg: TurnTextMessage,
|
||
) -> None:
|
||
turn_id = msg.turn_id
|
||
user_text = msg.text
|
||
try:
|
||
turn_t0 = time.time()
|
||
px4_ctx = None
|
||
if ctx.client_info and ctx.client_info.px4:
|
||
px4_ctx = ctx.client_info.px4.model_dump(
|
||
mode="json",
|
||
exclude_none=True,
|
||
)
|
||
messages = await self.llm_service.build_messages(
|
||
user_text,
|
||
ctx.chat_history,
|
||
px4=px4_ctx,
|
||
enable_tools=settings.LLM_TOOLS_ENABLED,
|
||
)
|
||
|
||
llm_reply_parts: list[str] = []
|
||
tts_carry = ""
|
||
tts_seq = 0
|
||
llm_t0 = time.time()
|
||
tts_first_at: list[Optional[float]] = [None]
|
||
|
||
# ========== 1. 流式 LLM + 闲聊可分句 TTS(若启用则含 tools 闭环)==========
|
||
async for delta in self.llm_service.chat_stream_with_tools(
|
||
messages,
|
||
session_id=session_id,
|
||
turn_id=turn_id,
|
||
):
|
||
llm_reply_parts.append(delta)
|
||
full_text = "".join(llm_reply_parts)
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_llm_text_delta(turn_id, delta, done=False),
|
||
)
|
||
if allows_incremental_tts(full_text, user_utterance=user_text):
|
||
tts_carry += delta
|
||
while True:
|
||
seg, tts_carry = take_next_speech_segment(tts_carry)
|
||
if not seg:
|
||
break
|
||
line = seg
|
||
if len(line) > settings.TTS_MAX_CHARS:
|
||
line = line[: settings.TTS_MAX_CHARS] + "..."
|
||
tts_seq = await self._stream_tts(
|
||
session_id,
|
||
turn_id,
|
||
line,
|
||
seq=tts_seq,
|
||
send_stream_end=False,
|
||
tts_first_at=tts_first_at,
|
||
)
|
||
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_llm_text_delta(turn_id, "", done=True),
|
||
)
|
||
|
||
llm_reply_raw = "".join(llm_reply_parts)
|
||
llm_time = time.time() - llm_t0
|
||
|
||
routing, flight_intent = parse_flight_intent_reply(llm_reply_raw)
|
||
|
||
if should_recover_failed_flight_output(user_text, llm_reply_raw):
|
||
logger.warning(
|
||
f"飞控类用户话但模型不可解析为 JSON,跳过异常 TTS;"
|
||
f"raw_len={len(llm_reply_raw)} preview={llm_reply_raw[:200]!r}"
|
||
)
|
||
routing = "chitchat"
|
||
flight_intent = None
|
||
llm_reply = "指令未能解析,请再说简短一些。"
|
||
else:
|
||
llm_reply = llm_reply_raw
|
||
|
||
chat_reply = llm_reply if routing == "chitchat" else None
|
||
|
||
logger.info(f"[意图] routing={routing}")
|
||
|
||
ctx.add_to_history("user", user_text)
|
||
ctx.add_to_history("assistant", llm_reply)
|
||
ctx.turn_count += 1
|
||
|
||
if ctx.dialog_protocol == DIALOG_RESULT_PROTOCOL_V1:
|
||
dialog_msg = create_dialog_result_cloud_v1(
|
||
turn_id=turn_id,
|
||
user_text=user_text,
|
||
routing=RoutingType(routing),
|
||
flight_intent=flight_intent,
|
||
chat_reply=chat_reply,
|
||
)
|
||
else:
|
||
dialog_msg = create_dialog_result(
|
||
turn_id=turn_id,
|
||
user_text=user_text,
|
||
routing=RoutingType(routing),
|
||
flight_intent=flight_intent,
|
||
chat_reply=chat_reply,
|
||
)
|
||
await session_manager.send_json(session_id, dialog_msg)
|
||
|
||
# ========== 4. 剩余 TTS ==========
|
||
tts_start = time.time()
|
||
if routing == "flight_intent":
|
||
if ctx.dialog_protocol == DIALOG_RESULT_PROTOCOL_V1:
|
||
sm = (
|
||
(flight_intent.get("summary") if flight_intent else "")
|
||
or ""
|
||
).strip() or "飞控指令"
|
||
if settings.FLIGHT_CONFIRM_REQUIRED:
|
||
tts_line = f"我将执行:{sm}。请回复确认或取消。"
|
||
else:
|
||
tts_line = f"{sm}。"
|
||
logger.info(f"[TTS] 飞控(dialog v1),播报: {tts_line[:80]}…")
|
||
else:
|
||
tts_line = "识别到飞控指令,正在下发指令"
|
||
logger.info(f"[TTS] 飞控指令,使用固定播报文案: {tts_line}")
|
||
tts_seq = await self._stream_tts(
|
||
session_id,
|
||
turn_id,
|
||
tts_line,
|
||
seq=tts_seq,
|
||
send_stream_end=False,
|
||
tts_first_at=tts_first_at,
|
||
)
|
||
else:
|
||
rem = tts_carry.strip()
|
||
if not rem and tts_seq == 0:
|
||
rem = get_tts_text(routing, flight_intent, chat_reply)
|
||
if len(rem) > settings.TTS_MAX_CHARS:
|
||
rem = rem[: settings.TTS_MAX_CHARS] + "..."
|
||
elif rem and len(rem) > settings.TTS_MAX_CHARS:
|
||
rem = rem[: settings.TTS_MAX_CHARS] + "..."
|
||
if rem:
|
||
tts_seq = await self._stream_tts(
|
||
session_id,
|
||
turn_id,
|
||
rem,
|
||
seq=tts_seq,
|
||
send_stream_end=False,
|
||
tts_first_at=tts_first_at,
|
||
)
|
||
|
||
await self._tts_finalize(session_id, turn_id, tts_seq)
|
||
tts_time = time.time() - tts_start
|
||
if tts_first_at[0] is not None:
|
||
tts_first_byte_ms = int((tts_first_at[0] - turn_t0) * 1000)
|
||
else:
|
||
tts_first_byte_ms = int(tts_time * 1000)
|
||
|
||
# ========== 5. turn.complete ==========
|
||
complete_msg = TurnCompleteMessage(
|
||
turn_id=turn_id,
|
||
metrics=Metrics(
|
||
llm_ms=int(llm_time * 1000),
|
||
tts_first_byte_ms=tts_first_byte_ms,
|
||
),
|
||
)
|
||
|
||
await session_manager.send_json(
|
||
session_id,
|
||
complete_msg.model_dump(exclude_none=True)
|
||
)
|
||
|
||
logger.info(
|
||
f"[回合完成] turn={turn_id[:8]} "
|
||
f"LLM={llm_time:.2f}s, TTS={tts_time:.2f}s"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理 turn.text 异常: {e}", exc_info=True)
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INTERNAL,
|
||
message=f"处理失败: {str(e)}",
|
||
turn_id=turn_id,
|
||
retryable=True,
|
||
)
|
||
)
|
||
|
||
async def _handle_tts_synthesize(self, session_id: str, data: dict):
|
||
"""仅 TTS:不调用 LLM,不写 chat_history,不下发 dialog_result。"""
|
||
ctx = await session_manager.get_session(session_id)
|
||
if not ctx:
|
||
return
|
||
|
||
if not ctx.is_ready:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="请先发送 session.start",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
|
||
if ctx.active_fun_asr:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message="语音识别进行中,不可插入 tts.synthesize",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
|
||
valid, msg, error = validate_tts_synthesize(data)
|
||
if not valid:
|
||
logger.warning(f"tts.synthesize 验证失败: {error}")
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INVALID_MESSAGE,
|
||
message=error or "无效请求",
|
||
turn_id=data.get("turn_id"),
|
||
retryable=False,
|
||
),
|
||
)
|
||
return
|
||
|
||
turn_id = msg.turn_id
|
||
text = msg.text.strip()
|
||
if len(text) > settings.TTS_MAX_CHARS:
|
||
text = text[: settings.TTS_MAX_CHARS] + "..."
|
||
|
||
logger.info(
|
||
f"[TTS-only] session={session_id[:8]} turn={turn_id[:8]} "
|
||
f"chars={len(text)}"
|
||
)
|
||
|
||
async with ctx.pipeline_lock:
|
||
try:
|
||
turn_t0 = time.time()
|
||
tts_first_at: list[Optional[float]] = [None]
|
||
tts_seq = await self._stream_tts(
|
||
session_id,
|
||
turn_id,
|
||
text,
|
||
seq=0,
|
||
send_stream_end=False,
|
||
tts_first_at=tts_first_at,
|
||
)
|
||
await self._tts_finalize(session_id, turn_id, tts_seq)
|
||
tts_time = time.time() - turn_t0
|
||
if tts_first_at[0] is not None:
|
||
tts_first_byte_ms = int(
|
||
(tts_first_at[0] - turn_t0) * 1000
|
||
)
|
||
else:
|
||
tts_first_byte_ms = int(tts_time * 1000)
|
||
|
||
complete_msg = TurnCompleteMessage(
|
||
turn_id=turn_id,
|
||
metrics=Metrics(
|
||
llm_ms=0,
|
||
tts_first_byte_ms=tts_first_byte_ms,
|
||
),
|
||
)
|
||
await session_manager.send_json(
|
||
session_id,
|
||
complete_msg.model_dump(exclude_none=True),
|
||
)
|
||
logger.info(
|
||
f"[TTS-only 完成] turn={turn_id[:8]} TTS={tts_time:.2f}s"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"处理 tts.synthesize 异常: {e}", exc_info=True)
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.INTERNAL,
|
||
message=f"处理失败: {str(e)}",
|
||
turn_id=turn_id,
|
||
retryable=True,
|
||
),
|
||
)
|
||
|
||
async def _stream_tts(
|
||
self,
|
||
session_id: str,
|
||
turn_id: str,
|
||
text: str,
|
||
*,
|
||
seq: int = 0,
|
||
send_stream_end: bool = True,
|
||
tts_first_at: Optional[list] = None,
|
||
) -> int:
|
||
"""
|
||
流式发送一段文本的 TTS 音频;返回下一个 seq。
|
||
send_stream_end=False 时本段不发送 is_final(用于同 turn 多句连续播)。
|
||
"""
|
||
if not (text and text.strip()):
|
||
return seq
|
||
|
||
total_bytes = 0
|
||
local_seq = seq
|
||
|
||
logger.info(f"[TTS] 合成片段: text='{text[:50]}...'")
|
||
|
||
try:
|
||
for audio_chunk in self.tts_service.synthesize(
|
||
text.strip(),
|
||
sample_rate=settings.TTS_SAMPLE_RATE,
|
||
):
|
||
pcm_data = tts_chunk_to_pcm_s16le(audio_chunk)
|
||
total_bytes += len(pcm_data)
|
||
|
||
if tts_first_at is not None and tts_first_at[0] is None:
|
||
tts_first_at[0] = time.time()
|
||
|
||
if local_seq == seq:
|
||
logger.info(
|
||
f"[TTS] 首块音频: shape={audio_chunk.shape}, "
|
||
f"dtype={audio_chunk.dtype}, "
|
||
f"max={np.max(np.abs(audio_chunk)):.4f}, "
|
||
f"bytes={len(pcm_data)}"
|
||
)
|
||
|
||
chunk_meta = TTSAudioChunk(
|
||
data=pcm_data,
|
||
turn_id=turn_id,
|
||
seq=local_seq,
|
||
is_final=False,
|
||
)
|
||
await session_manager.send_json(
|
||
session_id,
|
||
chunk_meta.to_metadata_dict(),
|
||
)
|
||
await session_manager.send_binary(session_id, pcm_data)
|
||
local_seq += 1
|
||
|
||
logger.info(
|
||
f"[TTS] 片段完成: seq {seq}-{local_seq - 1}, "
|
||
f"{total_bytes} bytes"
|
||
)
|
||
|
||
if send_stream_end:
|
||
await self._tts_finalize(session_id, turn_id, local_seq)
|
||
|
||
return local_seq
|
||
|
||
except Exception as e:
|
||
logger.error(f"TTS 流式发送失败: {e}")
|
||
await session_manager.send_json(
|
||
session_id,
|
||
create_error_message(
|
||
code=ErrorCode.TTS_FAILED,
|
||
message=f"TTS 失败: {str(e)}",
|
||
turn_id=turn_id,
|
||
retryable=True,
|
||
)
|
||
)
|
||
return local_seq
|
||
|
||
async def _tts_finalize(
|
||
self,
|
||
session_id: str,
|
||
turn_id: str,
|
||
seq: int,
|
||
) -> None:
|
||
"""turn 级 TTS 结束帧(is_final=True)。"""
|
||
if seq == 0:
|
||
logger.warning("[TTS] 未生成任何音频数据")
|
||
end_meta = TTSAudioChunk(
|
||
data=b"",
|
||
turn_id=turn_id,
|
||
seq=seq,
|
||
is_final=True,
|
||
)
|
||
await session_manager.send_json(
|
||
session_id,
|
||
end_meta.to_metadata_dict(),
|
||
)
|
||
logger.debug(f"[TTS] 流结束 seq={seq}")
|