""" WebSocket 消息处理器 处理协议定义的所有消息类型和时序 """ from __future__ import annotations import asyncio import base64 import time from typing import Optional import numpy as np from loguru import logger from app.protocols.models import * from app.protocols.validators import * from app.websocket.session import SessionContext, session_manager from app.services.llm_service import LLMServiceInterface from app.services.tts_service import TTSServiceInterface, TTSAudioChunk from app.services.intent_service import ( parse_flight_intent_reply, get_tts_text, allows_incremental_tts, take_next_speech_segment, should_recover_failed_flight_output, ) from app.config import settings from app.utils.audio import tts_chunk_to_pcm_s16le from app.providers.dashscope_fun_asr_turn import ( ActiveFunAsrTurn, FunAsrTurnState, build_fun_asr_recognition, ) class MessageHandler: """WebSocket 消息处理器""" def __init__( self, llm_service: LLMServiceInterface, tts_service: TTSServiceInterface, ): self.llm_service = llm_service self.tts_service = tts_service async def handle_message(self, session_id: str, data: dict): """ 路由消息到对应的处理器 Args: session_id: 会话 ID data: 解析后的 JSON 数据 """ msg_type = data.get("type") if msg_type == "session.start": await self._handle_session_start(session_id, data) elif msg_type == "session.end": await self._handle_session_end(session_id, data) elif msg_type == "turn.audio.start": await self._handle_turn_audio_start(session_id, data) elif msg_type == "turn.audio.chunk": await self._handle_turn_audio_chunk(session_id, data) elif msg_type == "turn.audio.end": await self._handle_turn_audio_end(session_id, data) elif msg_type == "turn.text": await self._handle_turn_text(session_id, data) elif msg_type == "tts.synthesize": await self._handle_tts_synthesize(session_id, data) else: # text_uplink 下禁止的旧版 turn.audio_chunk / turn.audio_end is_audio, error_code = validate_not_audio_profile(data) if is_audio: await session_manager.send_json( session_id, create_error_message( code=error_code, message=( "text_uplink 不支持该音频上行格式" "(请使用 session.start transport_profile=pcm_asr_uplink 与 turn.audio.*)" ), turn_id=data.get("turn_id"), retryable=False, ) ) else: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message=f"未知的消息类型: {msg_type}", turn_id=data.get("turn_id"), retryable=False, ) ) async def _handle_session_start(self, session_id: str, data: dict): """处理 session.start""" valid, msg, error = validate_session_start(data) if not valid: logger.warning(f"session.start 验证失败: {error}") await session_manager.send_json( session_id, create_error_message( code=ErrorCode.UNAUTHORIZED if "鉴权" in error else ErrorCode.INVALID_MESSAGE, message=error, retryable=False, ) ) return # 获取或创建会话 ctx = await session_manager.get_session(session_id) if not ctx: logger.error(f"会话不存在: {session_id}") return # 更新会话信息 ctx.device_id = msg.client.device_id ctx.client_info = msg.client ctx.is_ready = True ctx.transport_profile = msg.transport_profile ctx.active_fun_asr = None ctx.update_activity() ctx.dialog_protocol = "" if msg.client.protocol and msg.client.protocol.dialog_result: ctx.dialog_protocol = msg.client.protocol.dialog_result.strip() px4_vc = None if msg.client.px4: px4_vc = msg.client.px4.vehicle_class logger.info( f"会话已建立: {session_id[:8]}..., " f"device={msg.client.device_id}, " f"profile={data.get('transport_profile')}" + (f", px4_vehicle_class={px4_vc}" if px4_vc else "") ) # 返回 session.ready accepts_audio = ctx.transport_profile == PCM_ASR_TRANSPORT_PROFILE ready_msg = SessionReadyMessage( session_id=session_id, transport_profile=ctx.transport_profile, server_caps=ServerCapabilities( llm_context_turns=settings.LLM_CONTEXT_TURNS, accepts_audio_uplink=accepts_audio, ), ) await session_manager.send_json( session_id, ready_msg.model_dump(exclude_none=True) ) async def _handle_session_end(self, session_id: str, data: dict): """处理 session.end""" logger.info(f"会话结束: {session_id[:8]}...") ctx = await session_manager.get_session(session_id) if ctx and ctx.active_fun_asr: try: await asyncio.to_thread(ctx.active_fun_asr.recognition.stop) except Exception as e: logger.warning(f"session.end 关闭 Fun-ASR 时: {e}") ctx.active_fun_asr = None await session_manager.close_session(session_id) async def _handle_turn_audio_start(self, session_id: str, data: dict): """pcm_asr_uplink:开启一轮阿里云 Fun-ASR 实时识别。""" ctx = await session_manager.get_session(session_id) if not ctx: return if not ctx.is_ready: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="请先发送 session.start", turn_id=data.get("turn_id"), retryable=False, ), ) return if ctx.transport_profile != PCM_ASR_TRANSPORT_PROFILE: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="当前会话不是 pcm_asr_uplink,不可使用 turn.audio.*", turn_id=data.get("turn_id"), retryable=False, ), ) return if ctx.active_fun_asr: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="已有进行中的语音识别,请先 turn.audio.end", turn_id=data.get("turn_id"), retryable=False, ), ) return valid, msg, error = validate_turn_audio_start(data) if not valid: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message=error or "无效请求", turn_id=data.get("turn_id"), retryable=False, ), ) return if msg.sample_rate_hz != settings.ASR_AUDIO_SAMPLE_RATE: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message=( f"sample_rate_hz 须为 {settings.ASR_AUDIO_SAMPLE_RATE} " f"(Fun-ASR 与本服务约定)" ), turn_id=msg.turn_id, retryable=False, ), ) return loop = asyncio.get_running_loop() st = FunAsrTurnState(turn_id=msg.turn_id) profile = ctx.transport_profile async def _partial(text: str, is_final: bool) -> None: if not text.strip(): return await session_manager.send_json( session_id, create_asr_partial( turn_id=msg.turn_id, text=text, is_final=is_final, transport_profile=profile, ), ) async def _asr_err(m: str) -> None: c = await session_manager.get_session(session_id) if c: c.active_fun_asr = None await session_manager.send_json( session_id, create_error_message( code=ErrorCode.ASR_FAILED, message=m, turn_id=msg.turn_id, retryable=True, ), ) rec = build_fun_asr_recognition( loop=loop, state=st, on_partial=_partial, on_error_msg=_asr_err, ) try: await asyncio.to_thread(rec.start) except Exception as e: logger.error(f"Fun-ASR start 失败: {e}", exc_info=True) await session_manager.send_json( session_id, create_error_message( code=ErrorCode.ASR_FAILED, message=f"语音识别启动失败: {str(e)}", turn_id=msg.turn_id, retryable=True, ), ) return ctx.active_fun_asr = ActiveFunAsrTurn( turn_id=msg.turn_id, recognition=rec, state=st, ) logger.info( f"[ASR] Fun-ASR 开始 turn={msg.turn_id[:8]} " f"model={settings.DASHSCOPE_ASR_MODEL} sr={settings.ASR_AUDIO_SAMPLE_RATE}" ) async def _handle_turn_audio_chunk(self, session_id: str, data: dict): ctx = await session_manager.get_session(session_id) if not ctx or not ctx.is_ready: return valid, msg, error = validate_turn_audio_chunk(data) if not valid: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message=error or "无效请求", turn_id=data.get("turn_id"), retryable=False, ), ) return active = ctx.active_fun_asr if not active or active.turn_id != msg.turn_id: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="无匹配的 turn.audio.start 或未开始识别", turn_id=msg.turn_id, retryable=False, ), ) return try: pcm = base64.b64decode(msg.pcm_base64) except Exception: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="pcm_base64 解码失败", turn_id=msg.turn_id, retryable=False, ), ) return if not pcm: return try: await asyncio.to_thread(active.recognition.send_audio_frame, pcm) except Exception as e: logger.error(f"Fun-ASR send_audio_frame: {e}", exc_info=True) ctx.active_fun_asr = None await session_manager.send_json( session_id, create_error_message( code=ErrorCode.ASR_FAILED, message=f"上行音频失败: {str(e)}", turn_id=msg.turn_id, retryable=True, ), ) async def _handle_turn_audio_end(self, session_id: str, data: dict): ctx = await session_manager.get_session(session_id) if not ctx or not ctx.is_ready: return valid, msg, error = validate_turn_audio_end(data) if not valid: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message=error or "无效请求", turn_id=data.get("turn_id"), retryable=False, ), ) return active = ctx.active_fun_asr if not active or active.turn_id != msg.turn_id: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="无匹配的 turn.audio.start", turn_id=msg.turn_id, retryable=False, ), ) return rec = active.recognition st = active.state ctx.active_fun_asr = None try: await asyncio.to_thread(rec.stop) except Exception as e: logger.warning(f"Fun-ASR stop: {e}") with st.lock: err = st.error finals = list(st.final_texts) last = (st.last_text or "").strip() if err: return user_text = "".join(finals).strip() or last if not user_text: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.ASR_FAILED, message="未识别到有效语音内容", turn_id=msg.turn_id, retryable=True, ), ) return turn_msg = TurnTextMessage( turn_id=msg.turn_id, text=user_text, transport_profile=ctx.transport_profile, is_final=True, source=SourceType.CLOUD_FUN_ASR, ) logger.info( f"[ASR→LLM] session={session_id[:8]} turn={msg.turn_id[:8]} " f"识别文本: {user_text[:80]}" ) async with ctx.pipeline_lock: await self._run_turn_text_locked(session_id, ctx, turn_msg) async def _handle_turn_text(self, session_id: str, data: dict): """ 处理 turn.text - 核心业务逻辑 时序: 1. 验证消息 2. 流式 LLM(llm.text_delta)+ 闲聊时分句流式 TTS 3. 解析意图,发送 dialog_result 4. 补播剩余 TTS(飞控固定话术 / 闲聊尾部等) 5. tts 结束帧、turn.complete """ ctx = await session_manager.get_session(session_id) if not ctx: return if not ctx.is_ready: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="请先发送 session.start", turn_id=data.get("turn_id"), retryable=False, ) ) return if ctx.active_fun_asr: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="语音识别进行中,请先完成 turn.audio.end", turn_id=data.get("turn_id"), retryable=False, ), ) return # 验证消息 valid, msg, error = validate_turn_text(data) if not valid: logger.warning(f"turn.text 验证失败: {error}") await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message=error, turn_id=data.get("turn_id"), retryable=False, ) ) return turn_id = msg.turn_id user_text = msg.text logger.info( f"[回合] session={session_id[:8]} turn={turn_id[:8]} " f"收到文本: {user_text[:50]}" ) async with ctx.pipeline_lock: await self._run_turn_text_locked(session_id, ctx, msg) async def _run_turn_text_locked( self, session_id: str, ctx: SessionContext, msg: TurnTextMessage, ) -> None: turn_id = msg.turn_id user_text = msg.text try: turn_t0 = time.time() px4_ctx = None if ctx.client_info and ctx.client_info.px4: px4_ctx = ctx.client_info.px4.model_dump( mode="json", exclude_none=True, ) messages = await self.llm_service.build_messages( user_text, ctx.chat_history, px4=px4_ctx, enable_tools=settings.LLM_TOOLS_ENABLED, ) llm_reply_parts: list[str] = [] tts_carry = "" tts_seq = 0 llm_t0 = time.time() tts_first_at: list[Optional[float]] = [None] # ========== 1. 流式 LLM + 闲聊可分句 TTS(若启用则含 tools 闭环)========== async for delta in self.llm_service.chat_stream_with_tools( messages, session_id=session_id, turn_id=turn_id, ): llm_reply_parts.append(delta) full_text = "".join(llm_reply_parts) await session_manager.send_json( session_id, create_llm_text_delta(turn_id, delta, done=False), ) if allows_incremental_tts(full_text, user_utterance=user_text): tts_carry += delta while True: seg, tts_carry = take_next_speech_segment(tts_carry) if not seg: break line = seg if len(line) > settings.TTS_MAX_CHARS: line = line[: settings.TTS_MAX_CHARS] + "..." tts_seq = await self._stream_tts( session_id, turn_id, line, seq=tts_seq, send_stream_end=False, tts_first_at=tts_first_at, ) await session_manager.send_json( session_id, create_llm_text_delta(turn_id, "", done=True), ) llm_reply_raw = "".join(llm_reply_parts) llm_time = time.time() - llm_t0 routing, flight_intent = parse_flight_intent_reply(llm_reply_raw) if should_recover_failed_flight_output(user_text, llm_reply_raw): logger.warning( f"飞控类用户话但模型不可解析为 JSON,跳过异常 TTS;" f"raw_len={len(llm_reply_raw)} preview={llm_reply_raw[:200]!r}" ) routing = "chitchat" flight_intent = None llm_reply = "指令未能解析,请再说简短一些。" else: llm_reply = llm_reply_raw chat_reply = llm_reply if routing == "chitchat" else None logger.info(f"[意图] routing={routing}") ctx.add_to_history("user", user_text) ctx.add_to_history("assistant", llm_reply) ctx.turn_count += 1 if ctx.dialog_protocol == DIALOG_RESULT_PROTOCOL_V1: dialog_msg = create_dialog_result_cloud_v1( turn_id=turn_id, user_text=user_text, routing=RoutingType(routing), flight_intent=flight_intent, chat_reply=chat_reply, ) else: dialog_msg = create_dialog_result( turn_id=turn_id, user_text=user_text, routing=RoutingType(routing), flight_intent=flight_intent, chat_reply=chat_reply, ) await session_manager.send_json(session_id, dialog_msg) # ========== 4. 剩余 TTS ========== tts_start = time.time() if routing == "flight_intent": if ctx.dialog_protocol == DIALOG_RESULT_PROTOCOL_V1: sm = ( (flight_intent.get("summary") if flight_intent else "") or "" ).strip() or "飞控指令" if settings.FLIGHT_CONFIRM_REQUIRED: tts_line = f"我将执行:{sm}。请回复确认或取消。" else: tts_line = f"{sm}。" logger.info(f"[TTS] 飞控(dialog v1),播报: {tts_line[:80]}…") else: tts_line = "识别到飞控指令,正在下发指令" logger.info(f"[TTS] 飞控指令,使用固定播报文案: {tts_line}") tts_seq = await self._stream_tts( session_id, turn_id, tts_line, seq=tts_seq, send_stream_end=False, tts_first_at=tts_first_at, ) else: rem = tts_carry.strip() if not rem and tts_seq == 0: rem = get_tts_text(routing, flight_intent, chat_reply) if len(rem) > settings.TTS_MAX_CHARS: rem = rem[: settings.TTS_MAX_CHARS] + "..." elif rem and len(rem) > settings.TTS_MAX_CHARS: rem = rem[: settings.TTS_MAX_CHARS] + "..." if rem: tts_seq = await self._stream_tts( session_id, turn_id, rem, seq=tts_seq, send_stream_end=False, tts_first_at=tts_first_at, ) await self._tts_finalize(session_id, turn_id, tts_seq) tts_time = time.time() - tts_start if tts_first_at[0] is not None: tts_first_byte_ms = int((tts_first_at[0] - turn_t0) * 1000) else: tts_first_byte_ms = int(tts_time * 1000) # ========== 5. turn.complete ========== complete_msg = TurnCompleteMessage( turn_id=turn_id, metrics=Metrics( llm_ms=int(llm_time * 1000), tts_first_byte_ms=tts_first_byte_ms, ), ) await session_manager.send_json( session_id, complete_msg.model_dump(exclude_none=True) ) logger.info( f"[回合完成] turn={turn_id[:8]} " f"LLM={llm_time:.2f}s, TTS={tts_time:.2f}s" ) except Exception as e: logger.error(f"处理 turn.text 异常: {e}", exc_info=True) await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INTERNAL, message=f"处理失败: {str(e)}", turn_id=turn_id, retryable=True, ) ) async def _handle_tts_synthesize(self, session_id: str, data: dict): """仅 TTS:不调用 LLM,不写 chat_history,不下发 dialog_result。""" ctx = await session_manager.get_session(session_id) if not ctx: return if not ctx.is_ready: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="请先发送 session.start", turn_id=data.get("turn_id"), retryable=False, ), ) return if ctx.active_fun_asr: await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message="语音识别进行中,不可插入 tts.synthesize", turn_id=data.get("turn_id"), retryable=False, ), ) return valid, msg, error = validate_tts_synthesize(data) if not valid: logger.warning(f"tts.synthesize 验证失败: {error}") await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INVALID_MESSAGE, message=error or "无效请求", turn_id=data.get("turn_id"), retryable=False, ), ) return turn_id = msg.turn_id text = msg.text.strip() if len(text) > settings.TTS_MAX_CHARS: text = text[: settings.TTS_MAX_CHARS] + "..." logger.info( f"[TTS-only] session={session_id[:8]} turn={turn_id[:8]} " f"chars={len(text)}" ) async with ctx.pipeline_lock: try: turn_t0 = time.time() tts_first_at: list[Optional[float]] = [None] tts_seq = await self._stream_tts( session_id, turn_id, text, seq=0, send_stream_end=False, tts_first_at=tts_first_at, ) await self._tts_finalize(session_id, turn_id, tts_seq) tts_time = time.time() - turn_t0 if tts_first_at[0] is not None: tts_first_byte_ms = int( (tts_first_at[0] - turn_t0) * 1000 ) else: tts_first_byte_ms = int(tts_time * 1000) complete_msg = TurnCompleteMessage( turn_id=turn_id, metrics=Metrics( llm_ms=0, tts_first_byte_ms=tts_first_byte_ms, ), ) await session_manager.send_json( session_id, complete_msg.model_dump(exclude_none=True), ) logger.info( f"[TTS-only 完成] turn={turn_id[:8]} TTS={tts_time:.2f}s" ) except Exception as e: logger.error(f"处理 tts.synthesize 异常: {e}", exc_info=True) await session_manager.send_json( session_id, create_error_message( code=ErrorCode.INTERNAL, message=f"处理失败: {str(e)}", turn_id=turn_id, retryable=True, ), ) async def _stream_tts( self, session_id: str, turn_id: str, text: str, *, seq: int = 0, send_stream_end: bool = True, tts_first_at: Optional[list] = None, ) -> int: """ 流式发送一段文本的 TTS 音频;返回下一个 seq。 send_stream_end=False 时本段不发送 is_final(用于同 turn 多句连续播)。 """ if not (text and text.strip()): return seq total_bytes = 0 local_seq = seq logger.info(f"[TTS] 合成片段: text='{text[:50]}...'") try: for audio_chunk in self.tts_service.synthesize( text.strip(), sample_rate=settings.TTS_SAMPLE_RATE, ): pcm_data = tts_chunk_to_pcm_s16le(audio_chunk) total_bytes += len(pcm_data) if tts_first_at is not None and tts_first_at[0] is None: tts_first_at[0] = time.time() if local_seq == seq: logger.info( f"[TTS] 首块音频: shape={audio_chunk.shape}, " f"dtype={audio_chunk.dtype}, " f"max={np.max(np.abs(audio_chunk)):.4f}, " f"bytes={len(pcm_data)}" ) chunk_meta = TTSAudioChunk( data=pcm_data, turn_id=turn_id, seq=local_seq, is_final=False, ) await session_manager.send_json( session_id, chunk_meta.to_metadata_dict(), ) await session_manager.send_binary(session_id, pcm_data) local_seq += 1 logger.info( f"[TTS] 片段完成: seq {seq}-{local_seq - 1}, " f"{total_bytes} bytes" ) if send_stream_end: await self._tts_finalize(session_id, turn_id, local_seq) return local_seq except Exception as e: logger.error(f"TTS 流式发送失败: {e}") await session_manager.send_json( session_id, create_error_message( code=ErrorCode.TTS_FAILED, message=f"TTS 失败: {str(e)}", turn_id=turn_id, retryable=True, ) ) return local_seq async def _tts_finalize( self, session_id: str, turn_id: str, seq: int, ) -> None: """turn 级 TTS 结束帧(is_final=True)。""" if seq == 0: logger.warning("[TTS] 未生成任何音频数据") end_meta = TTSAudioChunk( data=b"", turn_id=turn_id, seq=seq, is_final=True, ) await session_manager.send_json( session_id, end_meta.to_metadata_dict(), ) logger.debug(f"[TTS] 流结束 seq={seq}")