2026-04-14 10:08:41 +08:00

872 lines
31 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
WebSocket 消息处理器
处理协议定义的所有消息类型和时序
"""
from __future__ import annotations
import asyncio
import base64
import time
from typing import Optional
import numpy as np
from loguru import logger
from app.protocols.models import *
from app.protocols.validators import *
from app.websocket.session import SessionContext, session_manager
from app.services.llm_service import LLMServiceInterface
from app.services.tts_service import TTSServiceInterface, TTSAudioChunk
from app.services.intent_service import (
parse_flight_intent_reply,
get_tts_text,
allows_incremental_tts,
take_next_speech_segment,
should_recover_failed_flight_output,
)
from app.config import settings
from app.utils.audio import tts_chunk_to_pcm_s16le
from app.providers.dashscope_fun_asr_turn import (
ActiveFunAsrTurn,
FunAsrTurnState,
build_fun_asr_recognition,
)
class MessageHandler:
"""WebSocket 消息处理器"""
def __init__(
self,
llm_service: LLMServiceInterface,
tts_service: TTSServiceInterface,
):
self.llm_service = llm_service
self.tts_service = tts_service
async def handle_message(self, session_id: str, data: dict):
"""
路由消息到对应的处理器
Args:
session_id: 会话 ID
data: 解析后的 JSON 数据
"""
msg_type = data.get("type")
if msg_type == "session.start":
await self._handle_session_start(session_id, data)
elif msg_type == "session.end":
await self._handle_session_end(session_id, data)
elif msg_type == "turn.audio.start":
await self._handle_turn_audio_start(session_id, data)
elif msg_type == "turn.audio.chunk":
await self._handle_turn_audio_chunk(session_id, data)
elif msg_type == "turn.audio.end":
await self._handle_turn_audio_end(session_id, data)
elif msg_type == "turn.text":
await self._handle_turn_text(session_id, data)
elif msg_type == "tts.synthesize":
await self._handle_tts_synthesize(session_id, data)
else:
# text_uplink 下禁止的旧版 turn.audio_chunk / turn.audio_end
is_audio, error_code = validate_not_audio_profile(data)
if is_audio:
await session_manager.send_json(
session_id,
create_error_message(
code=error_code,
message=(
"text_uplink 不支持该音频上行格式"
"(请使用 session.start transport_profile=pcm_asr_uplink 与 turn.audio.*"
),
turn_id=data.get("turn_id"),
retryable=False,
)
)
else:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message=f"未知的消息类型: {msg_type}",
turn_id=data.get("turn_id"),
retryable=False,
)
)
async def _handle_session_start(self, session_id: str, data: dict):
"""处理 session.start"""
valid, msg, error = validate_session_start(data)
if not valid:
logger.warning(f"session.start 验证失败: {error}")
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.UNAUTHORIZED if "鉴权" in error else ErrorCode.INVALID_MESSAGE,
message=error,
retryable=False,
)
)
return
# 获取或创建会话
ctx = await session_manager.get_session(session_id)
if not ctx:
logger.error(f"会话不存在: {session_id}")
return
# 更新会话信息
ctx.device_id = msg.client.device_id
ctx.client_info = msg.client
ctx.is_ready = True
ctx.transport_profile = msg.transport_profile
ctx.active_fun_asr = None
ctx.update_activity()
ctx.dialog_protocol = ""
if msg.client.protocol and msg.client.protocol.dialog_result:
ctx.dialog_protocol = msg.client.protocol.dialog_result.strip()
px4_vc = None
if msg.client.px4:
px4_vc = msg.client.px4.vehicle_class
logger.info(
f"会话已建立: {session_id[:8]}..., "
f"device={msg.client.device_id}, "
f"profile={data.get('transport_profile')}"
+ (f", px4_vehicle_class={px4_vc}" if px4_vc else "")
)
# 返回 session.ready
accepts_audio = ctx.transport_profile == PCM_ASR_TRANSPORT_PROFILE
ready_msg = SessionReadyMessage(
session_id=session_id,
transport_profile=ctx.transport_profile,
server_caps=ServerCapabilities(
llm_context_turns=settings.LLM_CONTEXT_TURNS,
accepts_audio_uplink=accepts_audio,
),
)
await session_manager.send_json(
session_id,
ready_msg.model_dump(exclude_none=True)
)
async def _handle_session_end(self, session_id: str, data: dict):
"""处理 session.end"""
logger.info(f"会话结束: {session_id[:8]}...")
ctx = await session_manager.get_session(session_id)
if ctx and ctx.active_fun_asr:
try:
await asyncio.to_thread(ctx.active_fun_asr.recognition.stop)
except Exception as e:
logger.warning(f"session.end 关闭 Fun-ASR 时: {e}")
ctx.active_fun_asr = None
await session_manager.close_session(session_id)
async def _handle_turn_audio_start(self, session_id: str, data: dict):
"""pcm_asr_uplink开启一轮阿里云 Fun-ASR 实时识别。"""
ctx = await session_manager.get_session(session_id)
if not ctx:
return
if not ctx.is_ready:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="请先发送 session.start",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
if ctx.transport_profile != PCM_ASR_TRANSPORT_PROFILE:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="当前会话不是 pcm_asr_uplink不可使用 turn.audio.*",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
if ctx.active_fun_asr:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="已有进行中的语音识别,请先 turn.audio.end",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
valid, msg, error = validate_turn_audio_start(data)
if not valid:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message=error or "无效请求",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
if msg.sample_rate_hz != settings.ASR_AUDIO_SAMPLE_RATE:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message=(
f"sample_rate_hz 须为 {settings.ASR_AUDIO_SAMPLE_RATE} "
f"(Fun-ASR 与本服务约定)"
),
turn_id=msg.turn_id,
retryable=False,
),
)
return
loop = asyncio.get_running_loop()
st = FunAsrTurnState(turn_id=msg.turn_id)
profile = ctx.transport_profile
async def _partial(text: str, is_final: bool) -> None:
if not text.strip():
return
await session_manager.send_json(
session_id,
create_asr_partial(
turn_id=msg.turn_id,
text=text,
is_final=is_final,
transport_profile=profile,
),
)
async def _asr_err(m: str) -> None:
c = await session_manager.get_session(session_id)
if c:
c.active_fun_asr = None
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.ASR_FAILED,
message=m,
turn_id=msg.turn_id,
retryable=True,
),
)
rec = build_fun_asr_recognition(
loop=loop,
state=st,
on_partial=_partial,
on_error_msg=_asr_err,
)
try:
await asyncio.to_thread(rec.start)
except Exception as e:
logger.error(f"Fun-ASR start 失败: {e}", exc_info=True)
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.ASR_FAILED,
message=f"语音识别启动失败: {str(e)}",
turn_id=msg.turn_id,
retryable=True,
),
)
return
ctx.active_fun_asr = ActiveFunAsrTurn(
turn_id=msg.turn_id,
recognition=rec,
state=st,
)
logger.info(
f"[ASR] Fun-ASR 开始 turn={msg.turn_id[:8]} "
f"model={settings.DASHSCOPE_ASR_MODEL} sr={settings.ASR_AUDIO_SAMPLE_RATE}"
)
async def _handle_turn_audio_chunk(self, session_id: str, data: dict):
ctx = await session_manager.get_session(session_id)
if not ctx or not ctx.is_ready:
return
valid, msg, error = validate_turn_audio_chunk(data)
if not valid:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message=error or "无效请求",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
active = ctx.active_fun_asr
if not active or active.turn_id != msg.turn_id:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="无匹配的 turn.audio.start 或未开始识别",
turn_id=msg.turn_id,
retryable=False,
),
)
return
try:
pcm = base64.b64decode(msg.pcm_base64)
except Exception:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="pcm_base64 解码失败",
turn_id=msg.turn_id,
retryable=False,
),
)
return
if not pcm:
return
try:
await asyncio.to_thread(active.recognition.send_audio_frame, pcm)
except Exception as e:
logger.error(f"Fun-ASR send_audio_frame: {e}", exc_info=True)
ctx.active_fun_asr = None
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.ASR_FAILED,
message=f"上行音频失败: {str(e)}",
turn_id=msg.turn_id,
retryable=True,
),
)
async def _handle_turn_audio_end(self, session_id: str, data: dict):
ctx = await session_manager.get_session(session_id)
if not ctx or not ctx.is_ready:
return
valid, msg, error = validate_turn_audio_end(data)
if not valid:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message=error or "无效请求",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
active = ctx.active_fun_asr
if not active or active.turn_id != msg.turn_id:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="无匹配的 turn.audio.start",
turn_id=msg.turn_id,
retryable=False,
),
)
return
rec = active.recognition
st = active.state
ctx.active_fun_asr = None
try:
await asyncio.to_thread(rec.stop)
except Exception as e:
logger.warning(f"Fun-ASR stop: {e}")
with st.lock:
err = st.error
finals = list(st.final_texts)
last = (st.last_text or "").strip()
if err:
return
user_text = "".join(finals).strip() or last
if not user_text:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.ASR_FAILED,
message="未识别到有效语音内容",
turn_id=msg.turn_id,
retryable=True,
),
)
return
turn_msg = TurnTextMessage(
turn_id=msg.turn_id,
text=user_text,
transport_profile=ctx.transport_profile,
is_final=True,
source=SourceType.CLOUD_FUN_ASR,
)
logger.info(
f"[ASR→LLM] session={session_id[:8]} turn={msg.turn_id[:8]} "
f"识别文本: {user_text[:80]}"
)
async with ctx.pipeline_lock:
await self._run_turn_text_locked(session_id, ctx, turn_msg)
async def _handle_turn_text(self, session_id: str, data: dict):
"""
处理 turn.text - 核心业务逻辑
时序:
1. 验证消息
2. 流式 LLMllm.text_delta+ 闲聊时分句流式 TTS
3. 解析意图,发送 dialog_result
4. 补播剩余 TTS飞控固定话术 / 闲聊尾部等)
5. tts 结束帧、turn.complete
"""
ctx = await session_manager.get_session(session_id)
if not ctx:
return
if not ctx.is_ready:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="请先发送 session.start",
turn_id=data.get("turn_id"),
retryable=False,
)
)
return
if ctx.active_fun_asr:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="语音识别进行中,请先完成 turn.audio.end",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
# 验证消息
valid, msg, error = validate_turn_text(data)
if not valid:
logger.warning(f"turn.text 验证失败: {error}")
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message=error,
turn_id=data.get("turn_id"),
retryable=False,
)
)
return
turn_id = msg.turn_id
user_text = msg.text
logger.info(
f"[回合] session={session_id[:8]} turn={turn_id[:8]} "
f"收到文本: {user_text[:50]}"
)
async with ctx.pipeline_lock:
await self._run_turn_text_locked(session_id, ctx, msg)
async def _run_turn_text_locked(
self,
session_id: str,
ctx: SessionContext,
msg: TurnTextMessage,
) -> None:
turn_id = msg.turn_id
user_text = msg.text
try:
turn_t0 = time.time()
px4_ctx = None
if ctx.client_info and ctx.client_info.px4:
px4_ctx = ctx.client_info.px4.model_dump(
mode="json",
exclude_none=True,
)
messages = await self.llm_service.build_messages(
user_text,
ctx.chat_history,
px4=px4_ctx,
enable_tools=settings.LLM_TOOLS_ENABLED,
)
llm_reply_parts: list[str] = []
tts_carry = ""
tts_seq = 0
llm_t0 = time.time()
tts_first_at: list[Optional[float]] = [None]
# ========== 1. 流式 LLM + 闲聊可分句 TTS若启用则含 tools 闭环)==========
async for delta in self.llm_service.chat_stream_with_tools(
messages,
session_id=session_id,
turn_id=turn_id,
):
llm_reply_parts.append(delta)
full_text = "".join(llm_reply_parts)
await session_manager.send_json(
session_id,
create_llm_text_delta(turn_id, delta, done=False),
)
if allows_incremental_tts(full_text, user_utterance=user_text):
tts_carry += delta
while True:
seg, tts_carry = take_next_speech_segment(tts_carry)
if not seg:
break
line = seg
if len(line) > settings.TTS_MAX_CHARS:
line = line[: settings.TTS_MAX_CHARS] + "..."
tts_seq = await self._stream_tts(
session_id,
turn_id,
line,
seq=tts_seq,
send_stream_end=False,
tts_first_at=tts_first_at,
)
await session_manager.send_json(
session_id,
create_llm_text_delta(turn_id, "", done=True),
)
llm_reply_raw = "".join(llm_reply_parts)
llm_time = time.time() - llm_t0
routing, flight_intent = parse_flight_intent_reply(llm_reply_raw)
if should_recover_failed_flight_output(user_text, llm_reply_raw):
logger.warning(
f"飞控类用户话但模型不可解析为 JSON跳过异常 TTS"
f"raw_len={len(llm_reply_raw)} preview={llm_reply_raw[:200]!r}"
)
routing = "chitchat"
flight_intent = None
llm_reply = "指令未能解析,请再说简短一些。"
else:
llm_reply = llm_reply_raw
chat_reply = llm_reply if routing == "chitchat" else None
logger.info(f"[意图] routing={routing}")
ctx.add_to_history("user", user_text)
ctx.add_to_history("assistant", llm_reply)
ctx.turn_count += 1
if ctx.dialog_protocol == DIALOG_RESULT_PROTOCOL_V1:
dialog_msg = create_dialog_result_cloud_v1(
turn_id=turn_id,
user_text=user_text,
routing=RoutingType(routing),
flight_intent=flight_intent,
chat_reply=chat_reply,
)
else:
dialog_msg = create_dialog_result(
turn_id=turn_id,
user_text=user_text,
routing=RoutingType(routing),
flight_intent=flight_intent,
chat_reply=chat_reply,
)
await session_manager.send_json(session_id, dialog_msg)
# ========== 4. 剩余 TTS ==========
tts_start = time.time()
if routing == "flight_intent":
if ctx.dialog_protocol == DIALOG_RESULT_PROTOCOL_V1:
sm = (
(flight_intent.get("summary") if flight_intent else "")
or ""
).strip() or "飞控指令"
if settings.FLIGHT_CONFIRM_REQUIRED:
tts_line = f"我将执行:{sm}。请回复确认或取消。"
else:
tts_line = f"{sm}"
logger.info(f"[TTS] 飞控(dialog v1),播报: {tts_line[:80]}")
else:
tts_line = "识别到飞控指令,正在下发指令"
logger.info(f"[TTS] 飞控指令,使用固定播报文案: {tts_line}")
tts_seq = await self._stream_tts(
session_id,
turn_id,
tts_line,
seq=tts_seq,
send_stream_end=False,
tts_first_at=tts_first_at,
)
else:
rem = tts_carry.strip()
if not rem and tts_seq == 0:
rem = get_tts_text(routing, flight_intent, chat_reply)
if len(rem) > settings.TTS_MAX_CHARS:
rem = rem[: settings.TTS_MAX_CHARS] + "..."
elif rem and len(rem) > settings.TTS_MAX_CHARS:
rem = rem[: settings.TTS_MAX_CHARS] + "..."
if rem:
tts_seq = await self._stream_tts(
session_id,
turn_id,
rem,
seq=tts_seq,
send_stream_end=False,
tts_first_at=tts_first_at,
)
await self._tts_finalize(session_id, turn_id, tts_seq)
tts_time = time.time() - tts_start
if tts_first_at[0] is not None:
tts_first_byte_ms = int((tts_first_at[0] - turn_t0) * 1000)
else:
tts_first_byte_ms = int(tts_time * 1000)
# ========== 5. turn.complete ==========
complete_msg = TurnCompleteMessage(
turn_id=turn_id,
metrics=Metrics(
llm_ms=int(llm_time * 1000),
tts_first_byte_ms=tts_first_byte_ms,
),
)
await session_manager.send_json(
session_id,
complete_msg.model_dump(exclude_none=True)
)
logger.info(
f"[回合完成] turn={turn_id[:8]} "
f"LLM={llm_time:.2f}s, TTS={tts_time:.2f}s"
)
except Exception as e:
logger.error(f"处理 turn.text 异常: {e}", exc_info=True)
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INTERNAL,
message=f"处理失败: {str(e)}",
turn_id=turn_id,
retryable=True,
)
)
async def _handle_tts_synthesize(self, session_id: str, data: dict):
"""仅 TTS不调用 LLM不写 chat_history不下发 dialog_result。"""
ctx = await session_manager.get_session(session_id)
if not ctx:
return
if not ctx.is_ready:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="请先发送 session.start",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
if ctx.active_fun_asr:
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message="语音识别进行中,不可插入 tts.synthesize",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
valid, msg, error = validate_tts_synthesize(data)
if not valid:
logger.warning(f"tts.synthesize 验证失败: {error}")
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INVALID_MESSAGE,
message=error or "无效请求",
turn_id=data.get("turn_id"),
retryable=False,
),
)
return
turn_id = msg.turn_id
text = msg.text.strip()
if len(text) > settings.TTS_MAX_CHARS:
text = text[: settings.TTS_MAX_CHARS] + "..."
logger.info(
f"[TTS-only] session={session_id[:8]} turn={turn_id[:8]} "
f"chars={len(text)}"
)
async with ctx.pipeline_lock:
try:
turn_t0 = time.time()
tts_first_at: list[Optional[float]] = [None]
tts_seq = await self._stream_tts(
session_id,
turn_id,
text,
seq=0,
send_stream_end=False,
tts_first_at=tts_first_at,
)
await self._tts_finalize(session_id, turn_id, tts_seq)
tts_time = time.time() - turn_t0
if tts_first_at[0] is not None:
tts_first_byte_ms = int(
(tts_first_at[0] - turn_t0) * 1000
)
else:
tts_first_byte_ms = int(tts_time * 1000)
complete_msg = TurnCompleteMessage(
turn_id=turn_id,
metrics=Metrics(
llm_ms=0,
tts_first_byte_ms=tts_first_byte_ms,
),
)
await session_manager.send_json(
session_id,
complete_msg.model_dump(exclude_none=True),
)
logger.info(
f"[TTS-only 完成] turn={turn_id[:8]} TTS={tts_time:.2f}s"
)
except Exception as e:
logger.error(f"处理 tts.synthesize 异常: {e}", exc_info=True)
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.INTERNAL,
message=f"处理失败: {str(e)}",
turn_id=turn_id,
retryable=True,
),
)
async def _stream_tts(
self,
session_id: str,
turn_id: str,
text: str,
*,
seq: int = 0,
send_stream_end: bool = True,
tts_first_at: Optional[list] = None,
) -> int:
"""
流式发送一段文本的 TTS 音频;返回下一个 seq。
send_stream_end=False 时本段不发送 is_final用于同 turn 多句连续播)。
"""
if not (text and text.strip()):
return seq
total_bytes = 0
local_seq = seq
logger.info(f"[TTS] 合成片段: text='{text[:50]}...'")
try:
for audio_chunk in self.tts_service.synthesize(
text.strip(),
sample_rate=settings.TTS_SAMPLE_RATE,
):
pcm_data = tts_chunk_to_pcm_s16le(audio_chunk)
total_bytes += len(pcm_data)
if tts_first_at is not None and tts_first_at[0] is None:
tts_first_at[0] = time.time()
if local_seq == seq:
logger.info(
f"[TTS] 首块音频: shape={audio_chunk.shape}, "
f"dtype={audio_chunk.dtype}, "
f"max={np.max(np.abs(audio_chunk)):.4f}, "
f"bytes={len(pcm_data)}"
)
chunk_meta = TTSAudioChunk(
data=pcm_data,
turn_id=turn_id,
seq=local_seq,
is_final=False,
)
await session_manager.send_json(
session_id,
chunk_meta.to_metadata_dict(),
)
await session_manager.send_binary(session_id, pcm_data)
local_seq += 1
logger.info(
f"[TTS] 片段完成: seq {seq}-{local_seq - 1}, "
f"{total_bytes} bytes"
)
if send_stream_end:
await self._tts_finalize(session_id, turn_id, local_seq)
return local_seq
except Exception as e:
logger.error(f"TTS 流式发送失败: {e}")
await session_manager.send_json(
session_id,
create_error_message(
code=ErrorCode.TTS_FAILED,
message=f"TTS 失败: {str(e)}",
turn_id=turn_id,
retryable=True,
)
)
return local_seq
async def _tts_finalize(
self,
session_id: str,
turn_id: str,
seq: int,
) -> None:
"""turn 级 TTS 结束帧is_final=True"""
if seq == 0:
logger.warning("[TTS] 未生成任何音频数据")
end_meta = TTSAudioChunk(
data=b"",
turn_id=turn_id,
seq=seq,
is_final=True,
)
await session_manager.send_json(
session_id,
end_meta.to_metadata_dict(),
)
logger.debug(f"[TTS] 流结束 seq={seq}")