169 lines
5.1 KiB
Python
169 lines
5.1 KiB
Python
"""
|
||
WebSocket 会话管理
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import time
|
||
from typing import Dict, Optional, List
|
||
from dataclasses import dataclass, field
|
||
from loguru import logger
|
||
from fastapi import WebSocket
|
||
|
||
from app.protocols.models import *
|
||
from app.config import settings
|
||
|
||
|
||
@dataclass
|
||
class SessionContext:
|
||
"""会话上下文 - 维护单个 session 的状态"""
|
||
|
||
session_id: str
|
||
websocket: WebSocket
|
||
device_id: str = ""
|
||
client_info: Optional[ClientInfo] = None
|
||
|
||
# 多轮对话历史
|
||
chat_history: List[Dict[str, str]] = field(default_factory=list)
|
||
turn_count: int = 0
|
||
|
||
# 状态标记
|
||
is_ready: bool = False
|
||
last_activity: float = field(default_factory=time.time)
|
||
transport_profile: str = TRANSPORT_PROFILE
|
||
# Fun-ASR 一轮识别(仅 pcm_asr_uplink);存在时不应再 start 另一轮
|
||
active_fun_asr: Optional[Any] = None
|
||
# turn.text 与 tts.synthesize 互斥,避免交错 TTS / LLM
|
||
pipeline_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||
# client.protocol.dialog_result === cloud_voice_dialog_v1 时走 dialog_result v1(见 CLOUD_VOICE_DIALOG_v1.md)
|
||
dialog_protocol: str = ""
|
||
|
||
def update_activity(self):
|
||
"""更新最后活动时间"""
|
||
self.last_activity = time.time()
|
||
|
||
def add_to_history(self, role: str, content: str):
|
||
"""添加消息到历史"""
|
||
self.chat_history.append({"role": role, "content": content})
|
||
|
||
# 限制历史长度
|
||
max_len = settings.LLM_CONTEXT_TURNS * 2
|
||
if len(self.chat_history) > max_len:
|
||
self.chat_history = self.chat_history[-max_len:]
|
||
|
||
def clear_history(self):
|
||
"""清空历史"""
|
||
self.chat_history.clear()
|
||
self.turn_count = 0
|
||
|
||
|
||
class SessionManager:
|
||
"""会话管理器 - 管理所有活跃的 WebSocket 连接"""
|
||
|
||
def __init__(self):
|
||
self._sessions: Dict[str, SessionContext] = {}
|
||
self._lock = asyncio.Lock()
|
||
|
||
async def create_session(
|
||
self,
|
||
session_id: str,
|
||
websocket: WebSocket,
|
||
) -> SessionContext:
|
||
"""
|
||
创建新会话
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
websocket: WebSocket 连接
|
||
|
||
Returns:
|
||
会话上下文
|
||
"""
|
||
async with self._lock:
|
||
# 检查并发数限制
|
||
if len(self._sessions) >= settings.MAX_CONCURRENT_SESSIONS:
|
||
raise Exception(
|
||
f"达到最大并发会话数限制: {settings.MAX_CONCURRENT_SESSIONS}"
|
||
)
|
||
|
||
# 检查是否已存在
|
||
if session_id in self._sessions:
|
||
logger.warning(f"会话已存在: {session_id}, 将替换")
|
||
await self.close_session(session_id)
|
||
|
||
# 创建新会话
|
||
ctx = SessionContext(
|
||
session_id=session_id,
|
||
websocket=websocket,
|
||
)
|
||
self._sessions[session_id] = ctx
|
||
|
||
logger.info(f"创建会话: {session_id[:8]}..., 当前活跃: {len(self._sessions)}")
|
||
return ctx
|
||
|
||
async def close_session(self, session_id: str):
|
||
"""
|
||
关闭会话
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
"""
|
||
async with self._lock:
|
||
if session_id in self._sessions:
|
||
ctx = self._sessions.pop(session_id)
|
||
try:
|
||
await ctx.websocket.close()
|
||
except:
|
||
pass
|
||
logger.info(f"关闭会话: {session_id[:8]}...")
|
||
|
||
async def get_session(self, session_id: str) -> Optional[SessionContext]:
|
||
"""获取会话上下文"""
|
||
return self._sessions.get(session_id)
|
||
|
||
async def send_json(self, session_id: str, data: dict):
|
||
"""
|
||
发送 JSON 消息到客户端
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
data: 消息数据
|
||
"""
|
||
ctx = await self.get_session(session_id)
|
||
if not ctx:
|
||
logger.warning(f"会话不存在: {session_id}")
|
||
return
|
||
|
||
try:
|
||
await ctx.websocket.send_json(data)
|
||
logger.debug(f"[WS 发送] {data.get('type')} -> {session_id[:8]}")
|
||
except Exception as e:
|
||
logger.error(f"发送 JSON 失败: {e}")
|
||
await self.close_session(session_id)
|
||
|
||
async def send_binary(self, session_id: str, data: bytes):
|
||
"""
|
||
发送二进制数据到客户端
|
||
|
||
Args:
|
||
session_id: 会话 ID
|
||
data: 二进制数据
|
||
"""
|
||
ctx = await self.get_session(session_id)
|
||
if not ctx:
|
||
return
|
||
|
||
try:
|
||
await ctx.websocket.send_bytes(data)
|
||
except Exception as e:
|
||
logger.error(f"发送二进制数据失败: {e}")
|
||
await self.close_session(session_id)
|
||
|
||
def active_count(self) -> int:
|
||
"""获取活跃会话数"""
|
||
return len(self._sessions)
|
||
|
||
|
||
# 全局会话管理器实例
|
||
session_manager = SessionManager()
|