381 lines
11 KiB
Python
381 lines
11 KiB
Python
"""
|
||
意图识别服务 - 解析 LLM 回复,判断飞控意图或闲聊
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import math
|
||
import re
|
||
from typing import Tuple, Optional, Dict, Any
|
||
from loguru import logger
|
||
|
||
from app.config import settings
|
||
|
||
# 句末切 TTS 用(强切分)
|
||
_SPEECH_END_CHARS = frozenset("。!?!?;\n\r")
|
||
# 过长无句末时弱切分
|
||
_SPEECH_WEAK_CHARS = frozenset(",,、")
|
||
|
||
# FLIGHT_INTENT_SCHEMA_v1.md §3.7(与 docs 同步)
|
||
_FLIGHT_INTENT_TOP_LEVEL = frozenset(
|
||
{"is_flight_intent", "version", "actions", "summary", "trace_id"}
|
||
)
|
||
_ACTION_TYPES = frozenset(
|
||
{"takeoff", "land", "return_home", "hover", "hold", "goto", "wait"}
|
||
)
|
||
_GOTO_FRAMES = frozenset({"local_ned", "body_ned"})
|
||
_REL_ALT_MAX_M = 500.0
|
||
_GOTO_DISPLACEMENT_MAX_M = 10_000.0
|
||
_WAIT_SECONDS_MAX = 3600.0
|
||
|
||
|
||
def _is_json_number(v: Any) -> bool:
|
||
return isinstance(v, (int, float)) and not isinstance(v, bool)
|
||
|
||
|
||
def parse_flight_intent_reply(raw: str) -> Tuple[str, Optional[Dict[str, Any]]]:
|
||
"""
|
||
解析 LLM 回复,判断是飞控意图还是闲聊
|
||
|
||
Args:
|
||
raw: LLM 原始回复
|
||
|
||
Returns:
|
||
(routing 类型, 飞控意图 dict 或 None)
|
||
routing: "flight_intent" 或 "chitchat"
|
||
"""
|
||
text = raw.strip()
|
||
|
||
# 尝试提取 JSON
|
||
json_str = _extract_json(text)
|
||
|
||
if json_str:
|
||
try:
|
||
obj = json.loads(json_str)
|
||
if isinstance(obj, dict) and obj.get("is_flight_intent") is True:
|
||
# 验证飞控意图格式
|
||
if _validate_flight_intent(obj):
|
||
return "flight_intent", obj
|
||
except json.JSONDecodeError as e:
|
||
logger.debug(f"JSON 解析失败: {e}")
|
||
|
||
# 默认闲聊
|
||
return "chitchat", None
|
||
|
||
|
||
def user_text_suggests_flight_control(text: str) -> bool:
|
||
"""用户话里是否明显像飞控口令(用于禁用流式闲聊 TTS,等整段模型输出再解析)。"""
|
||
t = (text or "").strip()
|
||
if not t:
|
||
return False
|
||
tl = t.lower()
|
||
if any(x in tl for x in ("px4", "mavros", "offboard", "rtl", "mission")):
|
||
return True
|
||
keywords = (
|
||
"返航",
|
||
"起飞",
|
||
"降落",
|
||
"悬停",
|
||
"航线",
|
||
"航点",
|
||
"高度",
|
||
"速度",
|
||
"前进",
|
||
"后退",
|
||
"往前",
|
||
"往后",
|
||
"向左",
|
||
"向右",
|
||
"上升",
|
||
"下降",
|
||
"升高",
|
||
"降低",
|
||
"定点",
|
||
"盘旋",
|
||
"米",
|
||
"飞",
|
||
)
|
||
return any(k in t for k in keywords)
|
||
|
||
|
||
def _assistant_stream_unsafe_for_tts(buf: str) -> bool:
|
||
"""
|
||
模型流式内容是否明显不是给人听的闲聊(例如乱数字、//、极高密度 / 与数字)。
|
||
用于避免 Kokoro 朗读垃圾并拖长推理时间。
|
||
"""
|
||
if not buf or len(buf.strip()) < 10:
|
||
return False
|
||
s = buf.lstrip()
|
||
if s.startswith("//"):
|
||
return True
|
||
cjk = sum(1 for c in s if "\u4e00" <= c <= "\u9fff")
|
||
if cjk == 0 and len(s) >= 12:
|
||
noisy = sum(1 for c in s if c in "/0123456789\n\r\t ")
|
||
if noisy / len(s) > 0.48:
|
||
return True
|
||
return False
|
||
|
||
|
||
def allows_incremental_tts(
|
||
assistant_buffer: str,
|
||
*,
|
||
user_utterance: str = "",
|
||
) -> bool:
|
||
"""
|
||
是否允许对当前助手输出做「边生成边分句 TTS」。
|
||
- 飞控 JSON 以 '{' 开头:不允许(避免读半段 JSON)。
|
||
- 用户本轮明显在发飞控口令:不允许(等整段输出,避免模型胡写数字时被当闲聊播掉)。
|
||
- 流式内容像异常 token 乱流:不允许。
|
||
"""
|
||
if assistant_buffer.lstrip().startswith("{"):
|
||
return False
|
||
if user_text_suggests_flight_control(user_utterance):
|
||
return False
|
||
if _assistant_stream_unsafe_for_tts(assistant_buffer):
|
||
return False
|
||
return True
|
||
|
||
|
||
def should_recover_failed_flight_output(user_text: str, llm_reply: str) -> bool:
|
||
"""
|
||
用户明显在要飞控,但模型未给出可解析的 flight_intent(归为闲聊且输出像乱流或过长渣输出)。
|
||
此时应改用简短提示语,禁止把整段模型垃圾送进 TTS。
|
||
"""
|
||
if not user_text_suggests_flight_control(user_text):
|
||
return False
|
||
if _extract_json(llm_reply):
|
||
# 能像 JSON 抽解析的交给 parse_flight_intent_reply
|
||
return False
|
||
if _assistant_stream_unsafe_for_tts(llm_reply):
|
||
return True
|
||
if len(llm_reply.strip()) > 600:
|
||
return True
|
||
return False
|
||
|
||
|
||
def take_next_speech_segment(
|
||
carry: str,
|
||
min_chars: int = 2,
|
||
soft_flush_len: int | None = None,
|
||
) -> Tuple[Optional[str], str]:
|
||
"""
|
||
从累积文本取出下一段可播报片段。遇句末标点优先切分;
|
||
可选在较短逗号处提前切(首包加速);过长无句末则按 soft_flush_len 弱切。
|
||
"""
|
||
if not carry:
|
||
return None, carry
|
||
|
||
sfl = (
|
||
soft_flush_len
|
||
if soft_flush_len is not None
|
||
else int(getattr(settings, "TTS_STREAM_SOFT_FLUSH_LEN", 40))
|
||
)
|
||
|
||
for i, ch in enumerate(carry):
|
||
if ch in _SPEECH_END_CHARS:
|
||
segment = carry[: i + 1].strip()
|
||
rest = carry[i + 1 :]
|
||
if len(segment) >= min_chars:
|
||
return segment, rest
|
||
|
||
if getattr(settings, "TTS_STREAM_EARLY_WEAK_CUT", True):
|
||
scan_end = min(
|
||
len(carry),
|
||
int(getattr(settings, "TTS_STREAM_EARLY_WEAK_SCAN_CAP", 48)),
|
||
)
|
||
min_seg = int(getattr(settings, "TTS_STREAM_EARLY_WEAK_MIN_SEGMENT", 6))
|
||
for i, ch in enumerate(carry[:scan_end]):
|
||
if i + 1 < min_seg:
|
||
continue
|
||
if ch in _SPEECH_WEAK_CHARS:
|
||
segment = carry[: i + 1].strip()
|
||
rest = carry[i + 1 :]
|
||
if len(segment) >= min_chars:
|
||
return segment, rest
|
||
|
||
if len(carry) >= sfl:
|
||
cut = sfl
|
||
for i, ch in enumerate(carry[:sfl]):
|
||
if ch in _SPEECH_WEAK_CHARS and i >= 8:
|
||
cut = i + 1
|
||
break
|
||
segment = carry[:cut].strip()
|
||
rest = carry[cut:]
|
||
if segment:
|
||
return segment, rest
|
||
|
||
return None, carry
|
||
|
||
|
||
def _extract_json(text: str) -> Optional[str]:
|
||
"""
|
||
从文本中提取 JSON 字符串
|
||
|
||
处理情况:
|
||
1. 纯 JSON
|
||
2. Markdown 代码块包裹的 JSON
|
||
3. 文本中嵌入的 JSON
|
||
"""
|
||
# 去除 Markdown 代码块
|
||
m = re.match(r"^```(?:json)?\s*\n?(.*)\n?```\s*$", text, re.DOTALL | re.IGNORECASE)
|
||
if m:
|
||
text = m.group(1).strip()
|
||
|
||
# 查找第一个平衡的 JSON 对象
|
||
start = text.find("{")
|
||
if start < 0:
|
||
return None
|
||
|
||
depth = 0
|
||
for i in range(start, len(text)):
|
||
if text[i] == "{":
|
||
depth += 1
|
||
elif text[i] == "}":
|
||
depth -= 1
|
||
if depth == 0:
|
||
return text[start:i+1]
|
||
|
||
return None
|
||
|
||
|
||
def _validate_flight_action(atype: str, args: Any) -> bool:
|
||
"""L2/L3:单步 action 的 args 与白名单。"""
|
||
if not isinstance(args, dict):
|
||
return False
|
||
|
||
def _only_keys(d: Dict[str, Any], allowed: frozenset) -> bool:
|
||
return frozenset(d.keys()) <= allowed
|
||
|
||
if atype == "takeoff":
|
||
if not _only_keys(args, frozenset({"relative_altitude_m"})):
|
||
return False
|
||
if "relative_altitude_m" in args:
|
||
h = args["relative_altitude_m"]
|
||
if h is None or not _is_json_number(h):
|
||
return False
|
||
if not (0 < float(h) <= _REL_ALT_MAX_M):
|
||
return False
|
||
return True
|
||
|
||
if atype in ("land", "return_home", "hover", "hold"):
|
||
return args == {}
|
||
|
||
if atype == "goto":
|
||
if "frame" not in args:
|
||
return False
|
||
if not _only_keys(args, frozenset({"frame", "x", "y", "z"})):
|
||
return False
|
||
if args["frame"] not in _GOTO_FRAMES:
|
||
return False
|
||
vecs = []
|
||
for k in ("x", "y", "z"):
|
||
if k not in args:
|
||
continue
|
||
v = args[k]
|
||
if v is None:
|
||
continue
|
||
if not _is_json_number(v):
|
||
return False
|
||
vecs.append(float(v))
|
||
if vecs:
|
||
if math.sqrt(sum(x * x for x in vecs)) > _GOTO_DISPLACEMENT_MAX_M:
|
||
return False
|
||
return True
|
||
|
||
if atype == "wait":
|
||
if not _only_keys(args, frozenset({"seconds"})):
|
||
return False
|
||
if "seconds" not in args:
|
||
return False
|
||
s = args["seconds"]
|
||
if not _is_json_number(s):
|
||
return False
|
||
sf = float(s)
|
||
if not (0 < sf <= _WAIT_SECONDS_MAX):
|
||
return False
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def _validate_flight_intent(obj: Dict[str, Any]) -> bool:
|
||
"""
|
||
校验 flight_intent(L1–L3),对齐 FLIGHT_INTENT_SCHEMA_v1.md。
|
||
"""
|
||
try:
|
||
if not isinstance(obj, dict):
|
||
return False
|
||
|
||
if frozenset(obj.keys()) - _FLIGHT_INTENT_TOP_LEVEL:
|
||
logger.debug("flight_intent: 存在非法顶层字段")
|
||
return False
|
||
|
||
if obj.get("is_flight_intent") is not True:
|
||
return False
|
||
|
||
if obj.get("version") != 1:
|
||
logger.debug("flight_intent: version 须为 1")
|
||
return False
|
||
|
||
summary = obj.get("summary")
|
||
if not isinstance(summary, str) or not summary.strip():
|
||
logger.debug("flight_intent: summary 须为非空字符串")
|
||
return False
|
||
|
||
tid = obj.get("trace_id")
|
||
if tid is not None:
|
||
if not isinstance(tid, str) or len(tid) > 128:
|
||
logger.debug("flight_intent: trace_id 非法")
|
||
return False
|
||
|
||
actions = obj.get("actions")
|
||
if not isinstance(actions, list) or len(actions) < 1:
|
||
logger.debug("flight_intent: actions 须为非空数组")
|
||
return False
|
||
|
||
for i, action in enumerate(actions):
|
||
if not isinstance(action, dict):
|
||
return False
|
||
if set(action.keys()) != {"type", "args"}:
|
||
logger.debug(f"flight_intent: action[{i}] 仅允许 type、args(见 Schema §2)")
|
||
return False
|
||
atype = action["type"]
|
||
if not isinstance(atype, str) or atype not in _ACTION_TYPES:
|
||
logger.debug(f"flight_intent: 非法 action.type 索引 {i}={atype!r}")
|
||
return False
|
||
if not _validate_flight_action(atype, action["args"]):
|
||
logger.debug(f"flight_intent: action[{i}] args 校验失败 type={atype}")
|
||
return False
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"飞控意图验证异常: {e}")
|
||
return False
|
||
|
||
|
||
def get_tts_text(
|
||
routing: str,
|
||
flight_intent: Optional[Dict[str, Any]],
|
||
chat_reply: Optional[str],
|
||
) -> str:
|
||
"""
|
||
获取用于 TTS 播报的文本
|
||
|
||
Args:
|
||
routing: 路由类型
|
||
flight_intent: 飞控意图
|
||
chat_reply: 闲聊回复
|
||
|
||
Returns:
|
||
用于 TTS 的文本
|
||
"""
|
||
if routing == "flight_intent" and flight_intent:
|
||
# 飞控意图:优先使用 summary
|
||
return flight_intent.get("summary", "收到")
|
||
elif routing == "chitchat" and chat_reply:
|
||
# 闲聊:使用 chat_reply
|
||
return chat_reply
|
||
else:
|
||
return "收到"
|