300 lines
11 KiB
Python
300 lines
11 KiB
Python
"""
|
||
FastAPI 主应用入口
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
from contextlib import asynccontextmanager
|
||
from typing import Optional
|
||
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||
|
||
from loguru import logger
|
||
from starlette.websockets import WebSocketState
|
||
|
||
from app.config import settings
|
||
from app.websocket.session import session_manager
|
||
from app.websocket.handler import MessageHandler
|
||
from app.services.llm_service import LLMServiceInterface
|
||
from app.services.tts_service import TTSServiceInterface
|
||
from app.utils.logger import setup_logger
|
||
|
||
|
||
# 全局服务实例
|
||
llm_service: Optional[LLMServiceInterface] = None
|
||
tts_service: Optional[TTSServiceInterface] = None
|
||
message_handler: Optional[MessageHandler] = None
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理 - 启动和关闭"""
|
||
# ========== 启动 ==========
|
||
logger.info("=" * 60)
|
||
logger.info("云端无人机语音服务启动中...")
|
||
logger.info("=" * 60)
|
||
|
||
global llm_service, tts_service, message_handler
|
||
|
||
# 初始化日志
|
||
setup_logger()
|
||
|
||
# 初始化 LLM 服务
|
||
logger.info(f"初始化 LLM 服务: {settings.LLM_PROVIDER}")
|
||
if settings.LLM_PROVIDER == "dashscope":
|
||
from app.providers.dashscope_llm import DashScopeLLMService
|
||
llm_service = DashScopeLLMService()
|
||
else:
|
||
raise ValueError(f"不支持的 LLM 提供者: {settings.LLM_PROVIDER}")
|
||
|
||
if not await llm_service.initialize():
|
||
logger.error("LLM 服务初始化失败")
|
||
raise RuntimeError("LLM 服务初始化失败")
|
||
|
||
# Fun-ASR 与 LLM 共用 DASHSCOPE_API_KEY;实时语音识别走独立 WebSocket 域
|
||
import dashscope
|
||
|
||
dashscope.base_websocket_api_url = settings.DASHSCOPE_WEBSOCKET_URL
|
||
logger.info(f"DashScope Fun-ASR WebSocket: {settings.DASHSCOPE_WEBSOCKET_URL}")
|
||
|
||
# 初始化 TTS 服务
|
||
logger.info(f"初始化 TTS 服务: {settings.TTS_PROVIDER}")
|
||
if settings.TTS_PROVIDER == "kokoro":
|
||
from app.providers.kokoro_tts import KokoroTTSService
|
||
tts_service = KokoroTTSService()
|
||
elif settings.TTS_PROVIDER == "piper":
|
||
from app.providers.piper_tts import PiperTTSService
|
||
tts_service = PiperTTSService()
|
||
else:
|
||
raise ValueError(f"不支持的 TTS 提供者: {settings.TTS_PROVIDER}")
|
||
|
||
if not await tts_service.initialize():
|
||
logger.error("TTS 服务初始化失败")
|
||
raise RuntimeError("TTS 服务初始化失败")
|
||
|
||
# 初始化消息处理器
|
||
message_handler = MessageHandler(
|
||
llm_service=llm_service,
|
||
tts_service=tts_service,
|
||
)
|
||
|
||
logger.info("所有服务初始化完成")
|
||
logger.info(f"监听地址: ws://{settings.WS_HOST}:{settings.WS_PORT}{settings.WS_PATH}")
|
||
logger.info(f"最大并发会话数: {settings.MAX_CONCURRENT_SESSIONS}")
|
||
|
||
yield
|
||
|
||
# ========== 关闭 ==========
|
||
logger.info("服务关闭中...")
|
||
|
||
if llm_service:
|
||
await llm_service.shutdown()
|
||
|
||
if tts_service:
|
||
await tts_service.shutdown()
|
||
|
||
logger.info("服务已关闭")
|
||
|
||
|
||
# 创建 FastAPI 应用
|
||
app = FastAPI(
|
||
title="云端无人机语音服务",
|
||
description="Cloud Voice Protocol v1.0 (text_uplink + pcm_asr_uplink)",
|
||
version="1.0.0",
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# 配置日志系统 - 确保 loguru 和 uvicorn 兼容
|
||
import logging
|
||
import sys
|
||
from loguru import logger
|
||
|
||
# 拦截 uvicorn 日志,转发到 loguru
|
||
class InterceptHandler(logging.Handler):
|
||
def emit(self, record):
|
||
try:
|
||
level = logger.level(record.levelname).name
|
||
except ValueError:
|
||
level = record.levelno
|
||
frame, depth = logging.currentframe(), 2
|
||
while frame.f_code.co_filename == logging.__file__:
|
||
frame = frame.f_back
|
||
depth += 1
|
||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||
level, record.getMessage()
|
||
)
|
||
|
||
# 配置 uvicorn 日志使用 loguru
|
||
logging.getLogger("uvicorn").handlers = [InterceptHandler()]
|
||
logging.getLogger("uvicorn.access").handlers = [InterceptHandler()]
|
||
logging.getLogger("uvicorn.error").handlers = [InterceptHandler()]
|
||
|
||
# 配置 loguru 输出到标准错误
|
||
logger.remove()
|
||
logger.add(
|
||
sys.stderr,
|
||
level="DEBUG",
|
||
format=(
|
||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||
"<level>{level: <8}</level> | "
|
||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
||
"<level>{message}</level>"
|
||
),
|
||
colorize=True,
|
||
)
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查端点"""
|
||
return {
|
||
"status": "ok",
|
||
"active_sessions": session_manager.active_count(),
|
||
"llm_provider": settings.LLM_PROVIDER,
|
||
"tts_provider": settings.TTS_PROVIDER,
|
||
"dashscope_asr_model": settings.DASHSCOPE_ASR_MODEL,
|
||
"asr_sample_rate_hz": settings.ASR_AUDIO_SAMPLE_RATE,
|
||
}
|
||
|
||
|
||
@app.websocket(settings.WS_PATH)
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
"""
|
||
WebSocket 端点 - 处理语音对话请求
|
||
|
||
路径: /v1/voice/session
|
||
"""
|
||
# 接受连接
|
||
await websocket.accept()
|
||
|
||
# 从查询参数或后续消息中获取 session_id
|
||
# 这里我们先接受连接,session_id 在 session.start 消息中获取
|
||
session_id = None
|
||
|
||
try:
|
||
while True:
|
||
# 会话可能在 TTS 发送失败时被 close_session 关闭,避免对已断开 socket 调用 receive
|
||
if websocket.client_state != WebSocketState.CONNECTED:
|
||
logger.info(
|
||
f"WebSocket 已不可用,结束接收循环: "
|
||
f"{session_id[:8] if session_id else 'unknown'}"
|
||
)
|
||
break
|
||
# 接收消息(必须用 receive():若机端误发 binary PCM,receive_text() 会 KeyError 并断连,
|
||
# 机端 websocket-client 可能表现为 recv 得到空 str)
|
||
incoming = await websocket.receive()
|
||
if incoming["type"] == "websocket.disconnect":
|
||
break
|
||
if incoming["type"] != "websocket.receive":
|
||
continue
|
||
|
||
# ASGI:binary 与 text 二选一;用「非空 bytes」判断 binary,
|
||
# 避免仅用 key 存在性误判(少数实现会对文本帧带 bytes: null)。
|
||
payload_bytes = incoming.get("bytes")
|
||
payload_text = incoming.get("text")
|
||
if payload_bytes:
|
||
n = len(payload_bytes)
|
||
logger.warning(
|
||
f"收到二进制上行帧 ({n} bytes):本服务只接受文本 JSON。"
|
||
f"音频请用 turn.audio.chunk.pcm_base64,勿直接发 WebSocket binary。"
|
||
)
|
||
bin_err = {
|
||
"type": "error",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"code": "INVALID_MESSAGE",
|
||
"message": (
|
||
"上行请使用文本帧 JSON:turn.audio.chunk 内字段 pcm_base64(UTF-8 文本帧);"
|
||
"勿对原始 PCM 发送 WebSocket binary,否则服务端会断开。"
|
||
),
|
||
"retryable": False,
|
||
}
|
||
if session_id:
|
||
await session_manager.send_json(session_id, bin_err)
|
||
else:
|
||
await websocket.send_json(bin_err)
|
||
continue
|
||
|
||
if payload_text is None:
|
||
logger.debug(
|
||
"websocket.receive 无 text 且无有效 bytes,跳过: keys=%s",
|
||
list(incoming.keys()),
|
||
)
|
||
continue
|
||
raw_data = payload_text.strip()
|
||
if not raw_data:
|
||
logger.debug("收到空文本帧,忽略")
|
||
continue
|
||
|
||
try:
|
||
data = json.loads(raw_data)
|
||
except json.JSONDecodeError:
|
||
logger.warning("收到无效的 JSON 数据")
|
||
if session_id:
|
||
await session_manager.send_json(
|
||
session_id,
|
||
{
|
||
"type": "error",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"code": "INVALID_MESSAGE",
|
||
"message": "无效的 JSON 格式",
|
||
"retryable": False,
|
||
}
|
||
)
|
||
continue
|
||
|
||
# 如果是 session.start,提取 session_id
|
||
if data.get("type") == "session.start":
|
||
session_id = data.get("session_id")
|
||
|
||
if not session_id:
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"code": "INVALID_MESSAGE",
|
||
"message": "缺少 session_id",
|
||
"retryable": False,
|
||
})
|
||
continue
|
||
|
||
# 创建会话
|
||
await session_manager.create_session(session_id, websocket)
|
||
|
||
# 如果没有 session_id,拒绝处理
|
||
if not session_id:
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"proto_version": "1.0",
|
||
"transport_profile": "text_uplink",
|
||
"code": "INVALID_MESSAGE",
|
||
"message": "请先发送 session.start",
|
||
"retryable": False,
|
||
})
|
||
continue
|
||
|
||
# 处理消息
|
||
if message_handler:
|
||
await message_handler.handle_message(session_id, data)
|
||
else:
|
||
logger.error("消息处理器未初始化")
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info(f"WebSocket 断开连接: {session_id[:8] if session_id else 'unknown'}")
|
||
except RuntimeError as e:
|
||
# Starlette:连接已关闭或非 CONNECTED 时 receive/send 会抛此错误(与 WebSocketDisconnect 不同)
|
||
if "WebSocket is not connected" in str(e):
|
||
logger.info(
|
||
f"WebSocket 已关闭: {session_id[:8] if session_id else 'unknown'}"
|
||
)
|
||
else:
|
||
logger.error(f"WebSocket 处理异常: {e}", exc_info=True)
|
||
except Exception as e:
|
||
logger.error(f"WebSocket 处理异常: {e}", exc_info=True)
|
||
finally:
|
||
# 清理会话
|
||
if session_id:
|
||
await session_manager.close_session(session_id)
|