""" 协议验证模块 """ from typing import Tuple, Optional from app.protocols.models import * from app.config import settings def validate_session_start(data: dict) -> Tuple[bool, Optional[SessionStartMessage], Optional[str]]: """ 验证 session.start 消息 Returns: (是否有效, 解析后的消息, 错误信息) """ try: # 检查必需字段 if data.get("proto_version") != PROTO_VERSION: return False, None, f"不支持的协议版本: {data.get('proto_version')}" profile = data.get("transport_profile") if profile not in (TRANSPORT_PROFILE, PCM_ASR_TRANSPORT_PROFILE): return False, None, f"不支持的传输配置: {profile}" # 解析消息 msg = SessionStartMessage(**data) # 鉴权检查 if settings.BEARER_TOKEN and msg.auth_token != settings.BEARER_TOKEN: return False, None, "鉴权失败: token 无效" return True, msg, None except Exception as e: return False, None, f"消息格式错误: {str(e)}" def validate_turn_text(data: dict) -> Tuple[bool, Optional[TurnTextMessage], Optional[str]]: """ 验证 turn.text 消息 Returns: (是否有效, 解析后的消息, 错误信息) """ try: if data.get("proto_version") != PROTO_VERSION: return False, None, f"不支持的协议版本: {data.get('proto_version')}" msg = TurnTextMessage(**data) if not msg.text or not msg.text.strip(): return False, None, "文本不能为空" return True, msg, None except Exception as e: return False, None, f"消息格式错误: {str(e)}" def validate_tts_synthesize( data: dict, ) -> Tuple[bool, Optional[TtsSynthesizeMessage], Optional[str]]: """验证 tts.synthesize(与 turn 共用 turn_id / tts_audio_chunk 关联)。""" try: if data.get("proto_version") != PROTO_VERSION: return False, None, f"不支持的协议版本: {data.get('proto_version')}" msg = TtsSynthesizeMessage(**data) if not msg.text or not msg.text.strip(): return False, None, "文本不能为空" return True, msg, None except Exception as e: return False, None, f"消息格式错误: {str(e)}" def validate_not_audio_profile(data: dict) -> Tuple[bool, Optional[str]]: """ text_uplink 下禁止的旧版音频类型(未使用 pcm_asr_uplink 协商时由业务层处理)。 """ msg_type = data.get("type") if msg_type in ["turn.audio_chunk", "turn.audio_end"]: return True, ErrorCode.INVALID_MESSAGE return False, None def validate_turn_audio_start( data: dict, ) -> Tuple[bool, Optional[TurnAudioStartMessage], Optional[str]]: try: if data.get("proto_version") != PROTO_VERSION: return False, None, f"不支持的协议版本: {data.get('proto_version')}" if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE: return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink" msg = TurnAudioStartMessage(**data) return True, msg, None except Exception as e: return False, None, f"消息格式错误: {str(e)}" def validate_turn_audio_chunk( data: dict, ) -> Tuple[bool, Optional[TurnAudioChunkMessage], Optional[str]]: try: if data.get("proto_version") != PROTO_VERSION: return False, None, f"不支持的协议版本: {data.get('proto_version')}" if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE: return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink" msg = TurnAudioChunkMessage(**data) if not msg.pcm_base64 or not msg.pcm_base64.strip(): return False, None, "pcm_base64 不能为空" return True, msg, None except Exception as e: return False, None, f"消息格式错误: {str(e)}" def validate_turn_audio_end( data: dict, ) -> Tuple[bool, Optional[TurnAudioEndMessage], Optional[str]]: try: if data.get("proto_version") != PROTO_VERSION: return False, None, f"不支持的协议版本: {data.get('proto_version')}" if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE: return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink" msg = TurnAudioEndMessage(**data) return True, msg, None except Exception as e: return False, None, f"消息格式错误: {str(e)}"