131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
"""
|
||
协议验证模块
|
||
"""
|
||
from typing import Tuple, Optional
|
||
from app.protocols.models import *
|
||
from app.config import settings
|
||
|
||
|
||
def validate_session_start(data: dict) -> Tuple[bool, Optional[SessionStartMessage], Optional[str]]:
|
||
"""
|
||
验证 session.start 消息
|
||
|
||
Returns:
|
||
(是否有效, 解析后的消息, 错误信息)
|
||
"""
|
||
try:
|
||
# 检查必需字段
|
||
if data.get("proto_version") != PROTO_VERSION:
|
||
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
|
||
|
||
profile = data.get("transport_profile")
|
||
if profile not in (TRANSPORT_PROFILE, PCM_ASR_TRANSPORT_PROFILE):
|
||
return False, None, f"不支持的传输配置: {profile}"
|
||
|
||
# 解析消息
|
||
msg = SessionStartMessage(**data)
|
||
|
||
# 鉴权检查
|
||
if settings.BEARER_TOKEN and msg.auth_token != settings.BEARER_TOKEN:
|
||
return False, None, "鉴权失败: token 无效"
|
||
|
||
return True, msg, None
|
||
|
||
except Exception as e:
|
||
return False, None, f"消息格式错误: {str(e)}"
|
||
|
||
|
||
def validate_turn_text(data: dict) -> Tuple[bool, Optional[TurnTextMessage], Optional[str]]:
|
||
"""
|
||
验证 turn.text 消息
|
||
|
||
Returns:
|
||
(是否有效, 解析后的消息, 错误信息)
|
||
"""
|
||
try:
|
||
if data.get("proto_version") != PROTO_VERSION:
|
||
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
|
||
|
||
msg = TurnTextMessage(**data)
|
||
|
||
if not msg.text or not msg.text.strip():
|
||
return False, None, "文本不能为空"
|
||
|
||
return True, msg, None
|
||
|
||
except Exception as e:
|
||
return False, None, f"消息格式错误: {str(e)}"
|
||
|
||
|
||
def validate_tts_synthesize(
|
||
data: dict,
|
||
) -> Tuple[bool, Optional[TtsSynthesizeMessage], Optional[str]]:
|
||
"""验证 tts.synthesize(与 turn 共用 turn_id / tts_audio_chunk 关联)。"""
|
||
try:
|
||
if data.get("proto_version") != PROTO_VERSION:
|
||
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
|
||
|
||
msg = TtsSynthesizeMessage(**data)
|
||
|
||
if not msg.text or not msg.text.strip():
|
||
return False, None, "文本不能为空"
|
||
|
||
return True, msg, None
|
||
|
||
except Exception as e:
|
||
return False, None, f"消息格式错误: {str(e)}"
|
||
|
||
|
||
def validate_not_audio_profile(data: dict) -> Tuple[bool, Optional[str]]:
|
||
"""
|
||
text_uplink 下禁止的旧版音频类型(未使用 pcm_asr_uplink 协商时由业务层处理)。
|
||
"""
|
||
msg_type = data.get("type")
|
||
if msg_type in ["turn.audio_chunk", "turn.audio_end"]:
|
||
return True, ErrorCode.INVALID_MESSAGE
|
||
return False, None
|
||
|
||
|
||
def validate_turn_audio_start(
|
||
data: dict,
|
||
) -> Tuple[bool, Optional[TurnAudioStartMessage], Optional[str]]:
|
||
try:
|
||
if data.get("proto_version") != PROTO_VERSION:
|
||
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
|
||
if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE:
|
||
return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink"
|
||
msg = TurnAudioStartMessage(**data)
|
||
return True, msg, None
|
||
except Exception as e:
|
||
return False, None, f"消息格式错误: {str(e)}"
|
||
|
||
|
||
def validate_turn_audio_chunk(
|
||
data: dict,
|
||
) -> Tuple[bool, Optional[TurnAudioChunkMessage], Optional[str]]:
|
||
try:
|
||
if data.get("proto_version") != PROTO_VERSION:
|
||
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
|
||
if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE:
|
||
return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink"
|
||
msg = TurnAudioChunkMessage(**data)
|
||
if not msg.pcm_base64 or not msg.pcm_base64.strip():
|
||
return False, None, "pcm_base64 不能为空"
|
||
return True, msg, None
|
||
except Exception as e:
|
||
return False, None, f"消息格式错误: {str(e)}"
|
||
|
||
|
||
def validate_turn_audio_end(
|
||
data: dict,
|
||
) -> Tuple[bool, Optional[TurnAudioEndMessage], Optional[str]]:
|
||
try:
|
||
if data.get("proto_version") != PROTO_VERSION:
|
||
return False, None, f"不支持的协议版本: {data.get('proto_version')}"
|
||
if data.get("transport_profile") != PCM_ASR_TRANSPORT_PROFILE:
|
||
return False, None, "turn.audio.* 必须使用 transport_profile=pcm_asr_uplink"
|
||
msg = TurnAudioEndMessage(**data)
|
||
return True, msg, None
|
||
except Exception as e:
|
||
return False, None, f"消息格式错误: {str(e)}"
|