481 lines
15 KiB
Python
481 lines
15 KiB
Python
"""
|
||
协议模型定义 - Cloud Voice Protocol v1.0 (text_uplink)
|
||
所有消息都遵循协议文档 F 节规范
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||
from typing import Optional, List, Dict, Any
|
||
import uuid
|
||
|
||
from app.config import settings
|
||
from enum import Enum
|
||
from datetime import datetime
|
||
|
||
|
||
# ==================== 协议常量 ====================
|
||
PROTO_VERSION = "1.0"
|
||
TRANSPORT_PROFILE = "text_uplink"
|
||
PCM_ASR_TRANSPORT_PROFILE = "pcm_asr_uplink"
|
||
DIALOG_RESULT_PROTOCOL_V1 = "cloud_voice_dialog_v1"
|
||
|
||
|
||
# ==================== 枚举定义 ====================
|
||
class RoutingType(str, Enum):
|
||
"""路由类型"""
|
||
FLIGHT_INTENT = "flight_intent"
|
||
CHITCHAT = "chitchat"
|
||
ERROR = "error"
|
||
|
||
|
||
class FlightActionType(str, Enum):
|
||
"""飞控动作类型(与 FLIGHT_INTENT_SCHEMA v1 §3 一致)"""
|
||
|
||
TAKEOFF = "takeoff"
|
||
LAND = "land"
|
||
RETURN_HOME = "return_home"
|
||
HOVER = "hover"
|
||
HOLD = "hold"
|
||
GOTO = "goto"
|
||
WAIT = "wait"
|
||
|
||
|
||
class SourceType(str, Enum):
|
||
"""文本来源类型"""
|
||
DEVICE_STT = "device_stt"
|
||
DEBUG_KEYBOARD = "debug_keyboard"
|
||
TEXT_ONLY = "text_only"
|
||
CLOUD_FUN_ASR = "cloud_fun_asr"
|
||
|
||
|
||
class ErrorCode(str, Enum):
|
||
"""错误码枚举"""
|
||
UNAUTHORIZED = "UNAUTHORIZED"
|
||
INVALID_MESSAGE = "INVALID_MESSAGE"
|
||
LLM_FAILED = "LLM_FAILED"
|
||
LLM_TIMEOUT = "LLM_TIMEOUT"
|
||
TTS_FAILED = "TTS_FAILED"
|
||
ASR_FAILED = "ASR_FAILED"
|
||
RATE_LIMIT = "RATE_LIMIT"
|
||
INTERNAL = "INTERNAL"
|
||
|
||
|
||
# ==================== 客户端 -> 服务端消息 ====================
|
||
class ClientCapabilities(BaseModel):
|
||
"""客户端能力"""
|
||
playback_sample_rate_hz: int = 24000
|
||
prefer_tts_codec: str = "pcm_s16le"
|
||
|
||
|
||
class ClientProtocolDecl(BaseModel):
|
||
"""session.start.client.protocol — 见 CLOUD_VOICE_DIALOG_v1.md §8"""
|
||
|
||
model_config = ConfigDict(extra="forbid")
|
||
|
||
dialog_result: Optional[str] = None
|
||
|
||
|
||
class Px4VehicleProfile(BaseModel):
|
||
"""
|
||
机载 PX4 上下文:由客户端在 session.start 上报,供云端 LLM 按真实机型消歧义。
|
||
均可选;未提供时模型按 unknown 处理并在 summary 中保守说明假设。
|
||
"""
|
||
|
||
# 与 PX4 airframe / MAV_TYPE 对齐的粗分类(客户端可二选一或同时给)
|
||
vehicle_class: str = Field(
|
||
default="unknown",
|
||
description=(
|
||
"multicopter | fixed_wing | vtol_standard | vtol_tailsitter | "
|
||
"rover | boat | submarine | other | unknown"
|
||
),
|
||
)
|
||
mav_type: Optional[int] = Field(
|
||
default=None,
|
||
description="MAVLink HEARTBEAT.mav_type(0–255),与 ArduPilot 共用枚举值",
|
||
)
|
||
px4_version: Optional[str] = Field(
|
||
default=None, description="PX4 固件版本,如 1.14.3"
|
||
)
|
||
airframe_id: Optional[str] = Field(
|
||
default=None, description="PX4 机型 id / QGC 显示的 airframe 名,便于日志"
|
||
)
|
||
|
||
# 控制与任务能力(机端已知时建议带上)
|
||
default_setpoint_frame: str = Field(
|
||
default="local_ned",
|
||
description="默认相对位移 frame:local_ned | body_ned 等",
|
||
)
|
||
offboard_capable: bool = Field(
|
||
default=False, description="Companion 能否稳定进入并维持 Offboard"
|
||
)
|
||
mission_capable: bool = Field(default=True, description="是否可执行 Mission / 航点任务")
|
||
rtl_available: bool = Field(default=True, description="RTL / 返航是否可用(含 Home 参数已配置)")
|
||
home_position_valid: bool = Field(
|
||
default=False, description="Home 是否已记入飞控(影响 RTL 语义)"
|
||
)
|
||
|
||
# 当前 NAV_STATE / 用户可读模式名(MAVROS / px4_msgs 可映射字符串)
|
||
current_nav_state: Optional[str] = Field(
|
||
default=None, description="如 POSITION_MODE、OFFBOARD、MISSION、AUTO.LAND"
|
||
)
|
||
|
||
# 运行包线(口语里的「高一点」可结合默认高度尺度;单位_meter)
|
||
cruise_alt_m_agl: Optional[float] = Field(
|
||
default=None, description="典型巡航相对高度_m,缺省时仅靠用户口述"
|
||
)
|
||
|
||
extras: Dict[str, Any] = Field(
|
||
default_factory=dict,
|
||
description="扩展键值(如 estimator、GPS 状态),原样进入 LLM 上下文",
|
||
)
|
||
|
||
|
||
class ClientInfo(BaseModel):
|
||
"""客户端信息"""
|
||
device_id: str
|
||
locale: str = "zh-CN"
|
||
capabilities: ClientCapabilities = Field(default_factory=ClientCapabilities)
|
||
px4: Optional[Px4VehicleProfile] = Field(
|
||
default=None, description="PX4 载具与能力上下文;强烈建议填写 vehicle_class"
|
||
)
|
||
protocol: Optional[ClientProtocolDecl] = Field(
|
||
default=None,
|
||
description='声明 dialog_result 形状;dialog_result=="cloud_voice_dialog_v1" 启用 §3',
|
||
)
|
||
|
||
|
||
class SessionStartMessage(BaseModel):
|
||
"""session.start - 客户端发起会话"""
|
||
type: str = "session.start"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str # text_uplink | pcm_asr_uplink
|
||
session_id: str
|
||
auth_token: Optional[str] = None
|
||
client: ClientInfo
|
||
|
||
|
||
class TurnTextMessage(BaseModel):
|
||
"""turn.text - 客户端发送文本"""
|
||
type: str = "turn.text"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = TRANSPORT_PROFILE # 兼容旧机端;可与 session 不一致时以 session 为准
|
||
turn_id: str
|
||
text: str
|
||
is_final: bool = True
|
||
source: SourceType = SourceType.DEVICE_STT
|
||
|
||
|
||
class TurnAudioStartMessage(BaseModel):
|
||
"""turn.audio.start - 开始一轮云端 ASR 上行(pcm_asr_uplink)"""
|
||
|
||
type: str = "turn.audio.start"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = PCM_ASR_TRANSPORT_PROFILE
|
||
turn_id: str
|
||
sample_rate_hz: int = Field(default=16000, description="须与 Fun-ASR 约定一致,默认 16000")
|
||
format: str = Field(default="pcm_s16le", description="小端 mono int16")
|
||
|
||
|
||
class TurnAudioChunkMessage(BaseModel):
|
||
"""turn.audio.chunk - PCM 分片(base64)"""
|
||
|
||
type: str = "turn.audio.chunk"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = PCM_ASR_TRANSPORT_PROFILE
|
||
turn_id: str
|
||
pcm_base64: str = Field(..., description="raw PCM s16le 字节序列的 base64")
|
||
|
||
|
||
class TurnAudioEndMessage(BaseModel):
|
||
"""turn.audio.end - 本轮麦克风推流结束,触发识别收尾并入队 LLM"""
|
||
|
||
type: str = "turn.audio.end"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = PCM_ASR_TRANSPORT_PROFILE
|
||
turn_id: str
|
||
|
||
|
||
class TtsSynthesizeMessage(BaseModel):
|
||
"""tts.synthesize - 仅 TTS 播报,无 LLM / dialog_result / 历史"""
|
||
type: str = "tts.synthesize"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = TRANSPORT_PROFILE
|
||
turn_id: str
|
||
text: str
|
||
|
||
|
||
class SessionEndMessage(BaseModel):
|
||
"""session.end - 客户端结束会话"""
|
||
type: str = "session.end"
|
||
proto_version: str = PROTO_VERSION
|
||
session_id: str
|
||
|
||
|
||
# ==================== 服务端 -> 客户端消息 ====================
|
||
class ServerCapabilities(BaseModel):
|
||
"""服务端能力"""
|
||
accepts_audio_uplink: bool = False
|
||
llm: bool = True
|
||
tts_codecs: List[str] = ["pcm_s16le"]
|
||
llm_context_turns: int = 4
|
||
accepts_px4_vehicle_profile: bool = True
|
||
|
||
|
||
class SessionReadyMessage(BaseModel):
|
||
"""session.ready - 服务端确认会话"""
|
||
type: str = "session.ready"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str # 与 session.start 一致
|
||
session_id: str
|
||
server_caps: ServerCapabilities = Field(default_factory=ServerCapabilities)
|
||
|
||
|
||
class FlightIntentAction(BaseModel):
|
||
"""飞控动作 — 每项仅允许 type + args(与 Schema v1 一致)"""
|
||
|
||
model_config = ConfigDict(extra="forbid")
|
||
|
||
type: str
|
||
args: Dict[str, Any] = Field(default_factory=dict)
|
||
|
||
|
||
class FlightIntentPayload(BaseModel):
|
||
"""飞控意图载荷 — 对齐 FLIGHT_INTENT_SCHEMA_v1.md"""
|
||
|
||
model_config = ConfigDict(extra="forbid")
|
||
|
||
is_flight_intent: bool = True
|
||
version: int = 1
|
||
actions: List[FlightIntentAction]
|
||
summary: str
|
||
trace_id: Optional[str] = Field(default=None, max_length=128)
|
||
|
||
|
||
class UserInput(BaseModel):
|
||
"""用户输入信息"""
|
||
text: str
|
||
language: str = "zh"
|
||
is_final: bool = True
|
||
source: str = "device_stt"
|
||
|
||
|
||
class TTSHint(BaseModel):
|
||
"""TTS 提示"""
|
||
speak_summary_or_reply: bool = True
|
||
voice_id: str = "default"
|
||
|
||
|
||
class FlightConfirmSpec(BaseModel):
|
||
"""routing=flight_intent 时必填 — CLOUD_VOICE_DIALOG_v1.md §3.4"""
|
||
|
||
model_config = ConfigDict(extra="forbid")
|
||
|
||
required: bool = Field(description="true:首轮禁止执行,需确认窗")
|
||
timeout_sec: float = Field(ge=1, le=600)
|
||
confirm_phrases: List[str] = Field(min_length=1)
|
||
cancel_phrases: List[str] = Field(min_length=1)
|
||
pending_id: str = Field(min_length=1)
|
||
summary_for_user: Optional[str] = None
|
||
|
||
|
||
class DialogResultCloudV1(BaseModel):
|
||
"""dialog_result(cloud_voice_dialog_v1)— CLOUD_VOICE_DIALOG_v1.md"""
|
||
|
||
model_config = ConfigDict(extra="forbid")
|
||
|
||
type: str = "dialog_result"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = TRANSPORT_PROFILE
|
||
turn_id: str
|
||
protocol: str = DIALOG_RESULT_PROTOCOL_V1
|
||
user_input: str
|
||
routing: RoutingType
|
||
flight_intent: Optional[FlightIntentPayload] = None
|
||
confirm: Optional[FlightConfirmSpec] = None
|
||
chat_reply: Optional[str] = None
|
||
|
||
@model_validator(mode="after")
|
||
def _routing_shape(self):
|
||
if self.routing == RoutingType.CHITCHAT:
|
||
if self.flight_intent is not None or self.confirm is not None:
|
||
raise ValueError("chitchat 不得携带 flight_intent / confirm")
|
||
elif self.routing == RoutingType.FLIGHT_INTENT:
|
||
if self.flight_intent is None or self.confirm is None:
|
||
raise ValueError("flight_intent 路由必须同时携带 flight_intent 与 confirm")
|
||
else:
|
||
raise ValueError("DialogResultCloudV1 仅支持 chitchat / flight_intent")
|
||
return self
|
||
|
||
|
||
class DialogResultMessage(BaseModel):
|
||
"""dialog_result(兼容旧机端:无 protocol / confirm,user_input 为对象)"""
|
||
type: str = "dialog_result"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = TRANSPORT_PROFILE
|
||
turn_id: str
|
||
user_input: UserInput
|
||
routing: RoutingType
|
||
flight_intent: Optional[FlightIntentPayload] = None
|
||
chat_reply: Optional[str] = None
|
||
tts_hint: TTSHint = Field(default_factory=TTSHint)
|
||
|
||
|
||
class TTSAudioChunkMessage(BaseModel):
|
||
"""tts_audio_chunk - TTS 音频块元数据"""
|
||
type: str = "tts_audio_chunk"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = TRANSPORT_PROFILE
|
||
turn_id: str
|
||
seq: int = 0
|
||
codec: str = "pcm_s16le"
|
||
sample_rate_hz: int = 24000
|
||
is_final: bool = False
|
||
|
||
|
||
class Metrics(BaseModel):
|
||
"""性能指标"""
|
||
llm_ms: Optional[int] = None
|
||
tts_first_byte_ms: Optional[int] = None
|
||
|
||
|
||
class TurnCompleteMessage(BaseModel):
|
||
"""turn.complete - 轮次完成"""
|
||
type: str = "turn.complete"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = TRANSPORT_PROFILE
|
||
turn_id: str
|
||
metrics: Metrics = Field(default_factory=Metrics)
|
||
|
||
|
||
class ErrorMessage(BaseModel):
|
||
"""error - 错误消息"""
|
||
type: str = "error"
|
||
proto_version: str = PROTO_VERSION
|
||
transport_profile: str = TRANSPORT_PROFILE
|
||
turn_id: Optional[str] = None
|
||
code: ErrorCode
|
||
message: str
|
||
retryable: bool = False
|
||
|
||
|
||
# ==================== 辅助函数 ====================
|
||
def create_error_message(
|
||
code: ErrorCode,
|
||
message: str,
|
||
turn_id: Optional[str] = None,
|
||
retryable: bool = False,
|
||
) -> dict:
|
||
"""创建错误消息字典"""
|
||
msg = ErrorMessage(
|
||
turn_id=turn_id,
|
||
code=code,
|
||
message=message,
|
||
retryable=retryable,
|
||
)
|
||
return msg.model_dump(exclude_none=True)
|
||
|
||
|
||
def create_asr_partial(
|
||
*,
|
||
turn_id: str,
|
||
text: str,
|
||
is_final: bool,
|
||
transport_profile: str,
|
||
) -> dict:
|
||
"""云端 Fun-ASR 中间/分句结果,供端上显示听写或状态。"""
|
||
return {
|
||
"type": "asr.partial",
|
||
"proto_version": PROTO_VERSION,
|
||
"transport_profile": transport_profile,
|
||
"turn_id": turn_id,
|
||
"text": text,
|
||
"is_final": is_final,
|
||
}
|
||
|
||
|
||
def create_llm_text_delta(
|
||
turn_id: str,
|
||
delta: str,
|
||
*,
|
||
done: bool = False,
|
||
) -> dict:
|
||
"""流式大模型增量文本(可先于 dialog_result 下发,便于端上显示打字效果)。"""
|
||
return {
|
||
"type": "llm.text_delta",
|
||
"proto_version": PROTO_VERSION,
|
||
"transport_profile": TRANSPORT_PROFILE,
|
||
"turn_id": turn_id,
|
||
"delta": delta,
|
||
"done": done,
|
||
}
|
||
|
||
|
||
def create_dialog_result(
|
||
turn_id: str,
|
||
user_text: str,
|
||
routing: RoutingType,
|
||
flight_intent: Optional[dict] = None,
|
||
chat_reply: Optional[str] = None,
|
||
) -> dict:
|
||
"""创建旧版 dialog_result(无 protocol / confirm,user_input 为嵌套对象)。"""
|
||
intent_payload = None
|
||
if flight_intent:
|
||
intent_payload = FlightIntentPayload(**flight_intent)
|
||
|
||
msg = DialogResultMessage(
|
||
turn_id=turn_id,
|
||
user_input=UserInput(text=user_text),
|
||
routing=routing,
|
||
flight_intent=intent_payload,
|
||
chat_reply=chat_reply,
|
||
)
|
||
return msg.model_dump(exclude_none=True)
|
||
|
||
|
||
def create_dialog_result_cloud_v1(
|
||
turn_id: str,
|
||
user_text: str,
|
||
routing: RoutingType,
|
||
flight_intent: Optional[dict] = None,
|
||
chat_reply: Optional[str] = None,
|
||
*,
|
||
confirm_required: Optional[bool] = None,
|
||
confirm_timeout_sec: Optional[float] = None,
|
||
) -> dict:
|
||
"""
|
||
cloud_voice_dialog_v1 — CLOUD_VOICE_DIALOG_v1.md。
|
||
flight 分支自动补 FlightIntentPayload 与 FlightConfirmSpec。
|
||
"""
|
||
req = settings.FLIGHT_CONFIRM_REQUIRED if confirm_required is None else confirm_required
|
||
timeout = (
|
||
settings.FLIGHT_CONFIRM_TIMEOUT_SEC
|
||
if confirm_timeout_sec is None
|
||
else confirm_timeout_sec
|
||
)
|
||
|
||
intent_payload = None
|
||
confirm_payload = None
|
||
if routing == RoutingType.FLIGHT_INTENT:
|
||
if not flight_intent:
|
||
raise ValueError("flight_intent 路由缺少 flight_intent dict")
|
||
intent_payload = FlightIntentPayload(**flight_intent)
|
||
summary = (intent_payload.summary or "").strip() or "飞控指令"
|
||
confirm_payload = FlightConfirmSpec(
|
||
required=req,
|
||
timeout_sec=float(timeout),
|
||
confirm_phrases=["确认"],
|
||
cancel_phrases=["取消"],
|
||
pending_id=str(uuid.uuid4()),
|
||
summary_for_user=summary,
|
||
)
|
||
|
||
msg = DialogResultCloudV1(
|
||
turn_id=turn_id,
|
||
user_input=user_text or "",
|
||
routing=routing,
|
||
flight_intent=intent_payload,
|
||
confirm=confirm_payload,
|
||
chat_reply=chat_reply,
|
||
)
|
||
return msg.model_dump(exclude_none=True)
|