2026-04-14 10:08:41 +08:00

131 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
协议验证模块
"""
from typing import Tuple, Optional
from app.protocols.models import *
from app.config import settings
def validate_session_start(data: dict) -> Tuple[bool, Optional[SessionStartMessage], Optional[str]]:
"""
验证 session.start 消息
Returns:
(是否有效, 解析后的消息, 错误信息)
"""
try:
# 检查必需字段
if data.get("proto_version") != PROTO_VERSION:
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
profile = data.get("transport_profile")
if profile not in (TRANSPORT_PROFILE, PCM_ASR_TRANSPORT_PROFILE):
return False, None, f"不支持的传输配置: {profile}"
# 解析消息
msg = SessionStartMessage(**data)
# 鉴权检查
if settings.BEARER_TOKEN and msg.auth_token != settings.BEARER_TOKEN:
return False, None, "鉴权失败: token 无效"
return True, msg, None
except Exception as e:
return False, None, f"消息格式错误: {str(e)}"
def validate_turn_text(data: dict) -> Tuple[bool, Optional[TurnTextMessage], Optional[str]]:
"""
验证 turn.text 消息
Returns:
(是否有效, 解析后的消息, 错误信息)
"""
try:
if data.get("proto_version") != PROTO_VERSION:
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
msg = TurnTextMessage(**data)
if not msg.text or not msg.text.strip():
return False, None, "文本不能为空"
return True, msg, None
except Exception as e:
return False, None, f"消息格式错误: {str(e)}"
def validate_tts_synthesize(
data: dict,
) -> Tuple[bool, Optional[TtsSynthesizeMessage], Optional[str]]:
"""验证 tts.synthesize与 turn 共用 turn_id / tts_audio_chunk 关联)。"""
try:
if data.get("proto_version") != PROTO_VERSION:
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
msg = TtsSynthesizeMessage(**data)
if not msg.text or not msg.text.strip():
return False, None, "文本不能为空"
return True, msg, None
except Exception as e:
return False, None, f"消息格式错误: {str(e)}"
def validate_not_audio_profile(data: dict) -> Tuple[bool, Optional[str]]:
"""
text_uplink 下禁止的旧版音频类型(未使用 pcm_asr_uplink 协商时由业务层处理)。
"""
msg_type = data.get("type")
if msg_type in ["turn.audio_chunk", "turn.audio_end"]:
return True, ErrorCode.INVALID_MESSAGE
return False, None
def validate_turn_audio_start(
data: dict,
) -> Tuple[bool, Optional[TurnAudioStartMessage], Optional[str]]:
try:
if data.get("proto_version") != PROTO_VERSION:
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE:
return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink"
msg = TurnAudioStartMessage(**data)
return True, msg, None
except Exception as e:
return False, None, f"消息格式错误: {str(e)}"
def validate_turn_audio_chunk(
data: dict,
) -> Tuple[bool, Optional[TurnAudioChunkMessage], Optional[str]]:
try:
if data.get("proto_version") != PROTO_VERSION:
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE:
return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink"
msg = TurnAudioChunkMessage(**data)
if not msg.pcm_base64 or not msg.pcm_base64.strip():
return False, None, "pcm_base64 不能为空"
return True, msg, None
except Exception as e:
return False, None, f"消息格式错误: {str(e)}"
def validate_turn_audio_end(
data: dict,
) -> Tuple[bool, Optional[TurnAudioEndMessage], Optional[str]]:
try:
if data.get("proto_version") != PROTO_VERSION:
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE:
return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink"
msg = TurnAudioEndMessage(**data)
return True, msg, None
except Exception as e:
return False, None, f"消息格式错误: {str(e)}"