""" WebSocket 会话管理 """ from __future__ import annotations import asyncio import time from typing import Dict, Optional, List from dataclasses import dataclass, field from loguru import logger from fastapi import WebSocket from app.protocols.models import * from app.config import settings @dataclass class SessionContext: """会话上下文 - 维护单个 session 的状态""" session_id: str websocket: WebSocket device_id: str = "" client_info: Optional[ClientInfo] = None # 多轮对话历史 chat_history: List[Dict[str, str]] = field(default_factory=list) turn_count: int = 0 # 状态标记 is_ready: bool = False last_activity: float = field(default_factory=time.time) transport_profile: str = TRANSPORT_PROFILE # Fun-ASR 一轮识别(仅 pcm_asr_uplink);存在时不应再 start 另一轮 active_fun_asr: Optional[Any] = None # turn.text 与 tts.synthesize 互斥,避免交错 TTS / LLM pipeline_lock: asyncio.Lock = field(default_factory=asyncio.Lock) # client.protocol.dialog_result === cloud_voice_dialog_v1 时走 dialog_result v1(见 CLOUD_VOICE_DIALOG_v1.md) dialog_protocol: str = "" def update_activity(self): """更新最后活动时间""" self.last_activity = time.time() def add_to_history(self, role: str, content: str): """添加消息到历史""" self.chat_history.append({"role": role, "content": content}) # 限制历史长度 max_len = settings.LLM_CONTEXT_TURNS * 2 if len(self.chat_history) > max_len: self.chat_history = self.chat_history[-max_len:] def clear_history(self): """清空历史""" self.chat_history.clear() self.turn_count = 0 class SessionManager: """会话管理器 - 管理所有活跃的 WebSocket 连接""" def __init__(self): self._sessions: Dict[str, SessionContext] = {} self._lock = asyncio.Lock() async def create_session( self, session_id: str, websocket: WebSocket, ) -> SessionContext: """ 创建新会话 Args: session_id: 会话 ID websocket: WebSocket 连接 Returns: 会话上下文 """ async with self._lock: # 检查并发数限制 if len(self._sessions) >= settings.MAX_CONCURRENT_SESSIONS: raise Exception( f"达到最大并发会话数限制: {settings.MAX_CONCURRENT_SESSIONS}" ) # 检查是否已存在 if session_id in self._sessions: logger.warning(f"会话已存在: {session_id}, 将替换") await self.close_session(session_id) # 创建新会话 ctx = SessionContext( session_id=session_id, websocket=websocket, ) self._sessions[session_id] = ctx logger.info(f"创建会话: {session_id[:8]}..., 当前活跃: {len(self._sessions)}") return ctx async def close_session(self, session_id: str): """ 关闭会话 Args: session_id: 会话 ID """ async with self._lock: if session_id in self._sessions: ctx = self._sessions.pop(session_id) try: await ctx.websocket.close() except: pass logger.info(f"关闭会话: {session_id[:8]}...") async def get_session(self, session_id: str) -> Optional[SessionContext]: """获取会话上下文""" return self._sessions.get(session_id) async def send_json(self, session_id: str, data: dict): """ 发送 JSON 消息到客户端 Args: session_id: 会话 ID data: 消息数据 """ ctx = await self.get_session(session_id) if not ctx: logger.warning(f"会话不存在: {session_id}") return try: await ctx.websocket.send_json(data) logger.debug(f"[WS 发送] {data.get('type')} -> {session_id[:8]}") except Exception as e: logger.error(f"发送 JSON 失败: {e}") await self.close_session(session_id) async def send_binary(self, session_id: str, data: bytes): """ 发送二进制数据到客户端 Args: session_id: 会话 ID data: 二进制数据 """ ctx = await self.get_session(session_id) if not ctx: return try: await ctx.websocket.send_bytes(data) except Exception as e: logger.error(f"发送二进制数据失败: {e}") await self.close_session(session_id) def active_count(self) -> int: """获取活跃会话数""" return len(self._sessions) # 全局会话管理器实例 session_manager = SessionManager()