2026-04-14 10:08:41 +08:00

169 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
WebSocket 会话管理
"""
from __future__ import annotations
import asyncio
import time
from typing import Dict, Optional, List
from dataclasses import dataclass, field
from loguru import logger
from fastapi import WebSocket
from app.protocols.models import *
from app.config import settings
@dataclass
class SessionContext:
"""会话上下文 - 维护单个 session 的状态"""
session_id: str
websocket: WebSocket
device_id: str = ""
client_info: Optional[ClientInfo] = None
# 多轮对话历史
chat_history: List[Dict[str, str]] = field(default_factory=list)
turn_count: int = 0
# 状态标记
is_ready: bool = False
last_activity: float = field(default_factory=time.time)
transport_profile: str = TRANSPORT_PROFILE
# Fun-ASR 一轮识别(仅 pcm_asr_uplink存在时不应再 start 另一轮
active_fun_asr: Optional[Any] = None
# turn.text 与 tts.synthesize 互斥,避免交错 TTS / LLM
pipeline_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
# client.protocol.dialog_result === cloud_voice_dialog_v1 时走 dialog_result v1见 CLOUD_VOICE_DIALOG_v1.md
dialog_protocol: str = ""
def update_activity(self):
"""更新最后活动时间"""
self.last_activity = time.time()
def add_to_history(self, role: str, content: str):
"""添加消息到历史"""
self.chat_history.append({"role": role, "content": content})
# 限制历史长度
max_len = settings.LLM_CONTEXT_TURNS * 2
if len(self.chat_history) > max_len:
self.chat_history = self.chat_history[-max_len:]
def clear_history(self):
"""清空历史"""
self.chat_history.clear()
self.turn_count = 0
class SessionManager:
"""会话管理器 - 管理所有活跃的 WebSocket 连接"""
def __init__(self):
self._sessions: Dict[str, SessionContext] = {}
self._lock = asyncio.Lock()
async def create_session(
self,
session_id: str,
websocket: WebSocket,
) -> SessionContext:
"""
创建新会话
Args:
session_id: 会话 ID
websocket: WebSocket 连接
Returns:
会话上下文
"""
async with self._lock:
# 检查并发数限制
if len(self._sessions) >= settings.MAX_CONCURRENT_SESSIONS:
raise Exception(
f"达到最大并发会话数限制: {settings.MAX_CONCURRENT_SESSIONS}"
)
# 检查是否已存在
if session_id in self._sessions:
logger.warning(f"会话已存在: {session_id}, 将替换")
await self.close_session(session_id)
# 创建新会话
ctx = SessionContext(
session_id=session_id,
websocket=websocket,
)
self._sessions[session_id] = ctx
logger.info(f"创建会话: {session_id[:8]}..., 当前活跃: {len(self._sessions)}")
return ctx
async def close_session(self, session_id: str):
"""
关闭会话
Args:
session_id: 会话 ID
"""
async with self._lock:
if session_id in self._sessions:
ctx = self._sessions.pop(session_id)
try:
await ctx.websocket.close()
except:
pass
logger.info(f"关闭会话: {session_id[:8]}...")
async def get_session(self, session_id: str) -> Optional[SessionContext]:
"""获取会话上下文"""
return self._sessions.get(session_id)
async def send_json(self, session_id: str, data: dict):
"""
发送 JSON 消息到客户端
Args:
session_id: 会话 ID
data: 消息数据
"""
ctx = await self.get_session(session_id)
if not ctx:
logger.warning(f"会话不存在: {session_id}")
return
try:
await ctx.websocket.send_json(data)
logger.debug(f"[WS 发送] {data.get('type')} -> {session_id[:8]}")
except Exception as e:
logger.error(f"发送 JSON 失败: {e}")
await self.close_session(session_id)
async def send_binary(self, session_id: str, data: bytes):
"""
发送二进制数据到客户端
Args:
session_id: 会话 ID
data: 二进制数据
"""
ctx = await self.get_session(session_id)
if not ctx:
return
try:
await ctx.websocket.send_bytes(data)
except Exception as e:
logger.error(f"发送二进制数据失败: {e}")
await self.close_session(session_id)
def active_count(self) -> int:
"""获取活跃会话数"""
return len(self._sessions)
# 全局会话管理器实例
session_manager = SessionManager()