2026-04-14 10:08:41 +08:00

481 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
协议模型定义 - Cloud Voice Protocol v1.0 (text_uplink)
所有消息都遵循协议文档 F 节规范
"""
from __future__ import annotations
from pydantic import BaseModel, Field, ConfigDict, model_validator
from typing import Optional, List, Dict, Any
import uuid
from app.config import settings
from enum import Enum
from datetime import datetime
# ==================== 协议常量 ====================
PROTO_VERSION = "1.0"
TRANSPORT_PROFILE = "text_uplink"
PCM_ASR_TRANSPORT_PROFILE = "pcm_asr_uplink"
DIALOG_RESULT_PROTOCOL_V1 = "cloud_voice_dialog_v1"
# ==================== 枚举定义 ====================
class RoutingType(str, Enum):
"""路由类型"""
FLIGHT_INTENT = "flight_intent"
CHITCHAT = "chitchat"
ERROR = "error"
class FlightActionType(str, Enum):
"""飞控动作类型(与 FLIGHT_INTENT_SCHEMA v1 §3 一致)"""
TAKEOFF = "takeoff"
LAND = "land"
RETURN_HOME = "return_home"
HOVER = "hover"
HOLD = "hold"
GOTO = "goto"
WAIT = "wait"
class SourceType(str, Enum):
"""文本来源类型"""
DEVICE_STT = "device_stt"
DEBUG_KEYBOARD = "debug_keyboard"
TEXT_ONLY = "text_only"
CLOUD_FUN_ASR = "cloud_fun_asr"
class ErrorCode(str, Enum):
"""错误码枚举"""
UNAUTHORIZED = "UNAUTHORIZED"
INVALID_MESSAGE = "INVALID_MESSAGE"
LLM_FAILED = "LLM_FAILED"
LLM_TIMEOUT = "LLM_TIMEOUT"
TTS_FAILED = "TTS_FAILED"
ASR_FAILED = "ASR_FAILED"
RATE_LIMIT = "RATE_LIMIT"
INTERNAL = "INTERNAL"
# ==================== 客户端 -> 服务端消息 ====================
class ClientCapabilities(BaseModel):
"""客户端能力"""
playback_sample_rate_hz: int = 24000
prefer_tts_codec: str = "pcm_s16le"
class ClientProtocolDecl(BaseModel):
"""session.start.client.protocol — 见 CLOUD_VOICE_DIALOG_v1.md §8"""
model_config = ConfigDict(extra="forbid")
dialog_result: Optional[str] = None
class Px4VehicleProfile(BaseModel):
"""
机载 PX4 上下文:由客户端在 session.start 上报,供云端 LLM 按真实机型消歧义。
均可选;未提供时模型按 unknown 处理并在 summary 中保守说明假设。
"""
# 与 PX4 airframe / MAV_TYPE 对齐的粗分类(客户端可二选一或同时给)
vehicle_class: str = Field(
default="unknown",
description=(
"multicopter | fixed_wing | vtol_standard | vtol_tailsitter | "
"rover | boat | submarine | other | unknown"
),
)
mav_type: Optional[int] = Field(
default=None,
description="MAVLink HEARTBEAT.mav_type0255与 ArduPilot 共用枚举值",
)
px4_version: Optional[str] = Field(
default=None, description="PX4 固件版本,如 1.14.3"
)
airframe_id: Optional[str] = Field(
default=None, description="PX4 机型 id / QGC 显示的 airframe 名,便于日志"
)
# 控制与任务能力(机端已知时建议带上)
default_setpoint_frame: str = Field(
default="local_ned",
description="默认相对位移 framelocal_ned | body_ned 等",
)
offboard_capable: bool = Field(
default=False, description="Companion 能否稳定进入并维持 Offboard"
)
mission_capable: bool = Field(default=True, description="是否可执行 Mission / 航点任务")
rtl_available: bool = Field(default=True, description="RTL / 返航是否可用(含 Home 参数已配置)")
home_position_valid: bool = Field(
default=False, description="Home 是否已记入飞控(影响 RTL 语义)"
)
# 当前 NAV_STATE / 用户可读模式名MAVROS / px4_msgs 可映射字符串)
current_nav_state: Optional[str] = Field(
default=None, description="如 POSITION_MODE、OFFBOARD、MISSION、AUTO.LAND"
)
# 运行包线口语里的「高一点」可结合默认高度尺度单位_meter
cruise_alt_m_agl: Optional[float] = Field(
default=None, description="典型巡航相对高度_m缺省时仅靠用户口述"
)
extras: Dict[str, Any] = Field(
default_factory=dict,
description="扩展键值(如 estimator、GPS 状态),原样进入 LLM 上下文",
)
class ClientInfo(BaseModel):
"""客户端信息"""
device_id: str
locale: str = "zh-CN"
capabilities: ClientCapabilities = Field(default_factory=ClientCapabilities)
px4: Optional[Px4VehicleProfile] = Field(
default=None, description="PX4 载具与能力上下文;强烈建议填写 vehicle_class"
)
protocol: Optional[ClientProtocolDecl] = Field(
default=None,
description='声明 dialog_result 形状dialog_result=="cloud_voice_dialog_v1" 启用 §3',
)
class SessionStartMessage(BaseModel):
"""session.start - 客户端发起会话"""
type: str = "session.start"
proto_version: str = PROTO_VERSION
transport_profile: str # text_uplink | pcm_asr_uplink
session_id: str
auth_token: Optional[str] = None
client: ClientInfo
class TurnTextMessage(BaseModel):
"""turn.text - 客户端发送文本"""
type: str = "turn.text"
proto_version: str = PROTO_VERSION
transport_profile: str = TRANSPORT_PROFILE # 兼容旧机端;可与 session 不一致时以 session 为准
turn_id: str
text: str
is_final: bool = True
source: SourceType = SourceType.DEVICE_STT
class TurnAudioStartMessage(BaseModel):
"""turn.audio.start - 开始一轮云端 ASR 上行pcm_asr_uplink"""
type: str = "turn.audio.start"
proto_version: str = PROTO_VERSION
transport_profile: str = PCM_ASR_TRANSPORT_PROFILE
turn_id: str
sample_rate_hz: int = Field(default=16000, description="须与 Fun-ASR 约定一致,默认 16000")
format: str = Field(default="pcm_s16le", description="小端 mono int16")
class TurnAudioChunkMessage(BaseModel):
"""turn.audio.chunk - PCM 分片base64"""
type: str = "turn.audio.chunk"
proto_version: str = PROTO_VERSION
transport_profile: str = PCM_ASR_TRANSPORT_PROFILE
turn_id: str
pcm_base64: str = Field(..., description="raw PCM s16le 字节序列的 base64")
class TurnAudioEndMessage(BaseModel):
"""turn.audio.end - 本轮麦克风推流结束,触发识别收尾并入队 LLM"""
type: str = "turn.audio.end"
proto_version: str = PROTO_VERSION
transport_profile: str = PCM_ASR_TRANSPORT_PROFILE
turn_id: str
class TtsSynthesizeMessage(BaseModel):
"""tts.synthesize - 仅 TTS 播报,无 LLM / dialog_result / 历史"""
type: str = "tts.synthesize"
proto_version: str = PROTO_VERSION
transport_profile: str = TRANSPORT_PROFILE
turn_id: str
text: str
class SessionEndMessage(BaseModel):
"""session.end - 客户端结束会话"""
type: str = "session.end"
proto_version: str = PROTO_VERSION
session_id: str
# ==================== 服务端 -> 客户端消息 ====================
class ServerCapabilities(BaseModel):
"""服务端能力"""
accepts_audio_uplink: bool = False
llm: bool = True
tts_codecs: List[str] = ["pcm_s16le"]
llm_context_turns: int = 4
accepts_px4_vehicle_profile: bool = True
class SessionReadyMessage(BaseModel):
"""session.ready - 服务端确认会话"""
type: str = "session.ready"
proto_version: str = PROTO_VERSION
transport_profile: str # 与 session.start 一致
session_id: str
server_caps: ServerCapabilities = Field(default_factory=ServerCapabilities)
class FlightIntentAction(BaseModel):
"""飞控动作 — 每项仅允许 type + args与 Schema v1 一致)"""
model_config = ConfigDict(extra="forbid")
type: str
args: Dict[str, Any] = Field(default_factory=dict)
class FlightIntentPayload(BaseModel):
"""飞控意图载荷 — 对齐 FLIGHT_INTENT_SCHEMA_v1.md"""
model_config = ConfigDict(extra="forbid")
is_flight_intent: bool = True
version: int = 1
actions: List[FlightIntentAction]
summary: str
trace_id: Optional[str] = Field(default=None, max_length=128)
class UserInput(BaseModel):
"""用户输入信息"""
text: str
language: str = "zh"
is_final: bool = True
source: str = "device_stt"
class TTSHint(BaseModel):
"""TTS 提示"""
speak_summary_or_reply: bool = True
voice_id: str = "default"
class FlightConfirmSpec(BaseModel):
"""routing=flight_intent 时必填 — CLOUD_VOICE_DIALOG_v1.md §3.4"""
model_config = ConfigDict(extra="forbid")
required: bool = Field(description="true首轮禁止执行需确认窗")
timeout_sec: float = Field(ge=1, le=600)
confirm_phrases: List[str] = Field(min_length=1)
cancel_phrases: List[str] = Field(min_length=1)
pending_id: str = Field(min_length=1)
summary_for_user: Optional[str] = None
class DialogResultCloudV1(BaseModel):
"""dialog_resultcloud_voice_dialog_v1— CLOUD_VOICE_DIALOG_v1.md"""
model_config = ConfigDict(extra="forbid")
type: str = "dialog_result"
proto_version: str = PROTO_VERSION
transport_profile: str = TRANSPORT_PROFILE
turn_id: str
protocol: str = DIALOG_RESULT_PROTOCOL_V1
user_input: str
routing: RoutingType
flight_intent: Optional[FlightIntentPayload] = None
confirm: Optional[FlightConfirmSpec] = None
chat_reply: Optional[str] = None
@model_validator(mode="after")
def _routing_shape(self):
if self.routing == RoutingType.CHITCHAT:
if self.flight_intent is not None or self.confirm is not None:
raise ValueError("chitchat 不得携带 flight_intent / confirm")
elif self.routing == RoutingType.FLIGHT_INTENT:
if self.flight_intent is None or self.confirm is None:
raise ValueError("flight_intent 路由必须同时携带 flight_intent 与 confirm")
else:
raise ValueError("DialogResultCloudV1 仅支持 chitchat / flight_intent")
return self
class DialogResultMessage(BaseModel):
"""dialog_result兼容旧机端无 protocol / confirmuser_input 为对象)"""
type: str = "dialog_result"
proto_version: str = PROTO_VERSION
transport_profile: str = TRANSPORT_PROFILE
turn_id: str
user_input: UserInput
routing: RoutingType
flight_intent: Optional[FlightIntentPayload] = None
chat_reply: Optional[str] = None
tts_hint: TTSHint = Field(default_factory=TTSHint)
class TTSAudioChunkMessage(BaseModel):
"""tts_audio_chunk - TTS 音频块元数据"""
type: str = "tts_audio_chunk"
proto_version: str = PROTO_VERSION
transport_profile: str = TRANSPORT_PROFILE
turn_id: str
seq: int = 0
codec: str = "pcm_s16le"
sample_rate_hz: int = 24000
is_final: bool = False
class Metrics(BaseModel):
"""性能指标"""
llm_ms: Optional[int] = None
tts_first_byte_ms: Optional[int] = None
class TurnCompleteMessage(BaseModel):
"""turn.complete - 轮次完成"""
type: str = "turn.complete"
proto_version: str = PROTO_VERSION
transport_profile: str = TRANSPORT_PROFILE
turn_id: str
metrics: Metrics = Field(default_factory=Metrics)
class ErrorMessage(BaseModel):
"""error - 错误消息"""
type: str = "error"
proto_version: str = PROTO_VERSION
transport_profile: str = TRANSPORT_PROFILE
turn_id: Optional[str] = None
code: ErrorCode
message: str
retryable: bool = False
# ==================== 辅助函数 ====================
def create_error_message(
code: ErrorCode,
message: str,
turn_id: Optional[str] = None,
retryable: bool = False,
) -> dict:
"""创建错误消息字典"""
msg = ErrorMessage(
turn_id=turn_id,
code=code,
message=message,
retryable=retryable,
)
return msg.model_dump(exclude_none=True)
def create_asr_partial(
*,
turn_id: str,
text: str,
is_final: bool,
transport_profile: str,
) -> dict:
"""云端 Fun-ASR 中间/分句结果,供端上显示听写或状态。"""
return {
"type": "asr.partial",
"proto_version": PROTO_VERSION,
"transport_profile": transport_profile,
"turn_id": turn_id,
"text": text,
"is_final": is_final,
}
def create_llm_text_delta(
turn_id: str,
delta: str,
*,
done: bool = False,
) -> dict:
"""流式大模型增量文本(可先于 dialog_result 下发,便于端上显示打字效果)。"""
return {
"type": "llm.text_delta",
"proto_version": PROTO_VERSION,
"transport_profile": TRANSPORT_PROFILE,
"turn_id": turn_id,
"delta": delta,
"done": done,
}
def create_dialog_result(
turn_id: str,
user_text: str,
routing: RoutingType,
flight_intent: Optional[dict] = None,
chat_reply: Optional[str] = None,
) -> dict:
"""创建旧版 dialog_result无 protocol / confirmuser_input 为嵌套对象)。"""
intent_payload = None
if flight_intent:
intent_payload = FlightIntentPayload(**flight_intent)
msg = DialogResultMessage(
turn_id=turn_id,
user_input=UserInput(text=user_text),
routing=routing,
flight_intent=intent_payload,
chat_reply=chat_reply,
)
return msg.model_dump(exclude_none=True)
def create_dialog_result_cloud_v1(
turn_id: str,
user_text: str,
routing: RoutingType,
flight_intent: Optional[dict] = None,
chat_reply: Optional[str] = None,
*,
confirm_required: Optional[bool] = None,
confirm_timeout_sec: Optional[float] = None,
) -> dict:
"""
cloud_voice_dialog_v1 — CLOUD_VOICE_DIALOG_v1.md。
flight 分支自动补 FlightIntentPayload 与 FlightConfirmSpec。
"""
req = settings.FLIGHT_CONFIRM_REQUIRED if confirm_required is None else confirm_required
timeout = (
settings.FLIGHT_CONFIRM_TIMEOUT_SEC
if confirm_timeout_sec is None
else confirm_timeout_sec
)
intent_payload = None
confirm_payload = None
if routing == RoutingType.FLIGHT_INTENT:
if not flight_intent:
raise ValueError("flight_intent 路由缺少 flight_intent dict")
intent_payload = FlightIntentPayload(**flight_intent)
summary = (intent_payload.summary or "").strip() or "飞控指令"
confirm_payload = FlightConfirmSpec(
required=req,
timeout_sec=float(timeout),
confirm_phrases=["确认"],
cancel_phrases=["取消"],
pending_id=str(uuid.uuid4()),
summary_for_user=summary,
)
msg = DialogResultCloudV1(
turn_id=turn_id,
user_input=user_text or "",
routing=routing,
flight_intent=intent_payload,
confirm=confirm_payload,
chat_reply=chat_reply,
)
return msg.model_dump(exclude_none=True)