commit 157a34fe877f66674698326686de7440952fdc29 Author: LKTX Date: Tue Apr 14 09:54:26 2026 +0800 Initial commit: voice drone assistant Made-with: Cursor diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9e1f1de --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +.Python +.venv/ +venv/ +.env +*.egg-info/ +.eggs/ +dist/ +build/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# 大模型与缓存(见 models/README.txt,请单独拷贝或从原仓库同步) +models/ +cache/ + +# OS +.DS_Store +Thumbs.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..7bb5043 --- /dev/null +++ b/README.md @@ -0,0 +1,94 @@ +# voice_drone_assistant + +从原仓库抽离的**独立可运行**子工程:麦克风采集 → VAD 切段 → **SenseVoice STT** → **唤醒词** →(关键词起飞 / **Qwen + Kokoro 对话播报**)。 + +**部署与外场启动(推荐先读):[docs/DEPLOYMENT_AND_OPERATIONS.md](docs/DEPLOYMENT_AND_OPERATIONS.md)** +**日常配置索引:[docs/PROJECT_GUIDE.md](docs/PROJECT_GUIDE.md)** · 云端协议:[docs/llmcon.md](docs/llmcon.md) + +## 目录结构 + +| 路径 | 说明 | +|------|------| +| `main.py` | 启动入口 | +| `with_system_alsa.sh` | Conda 下建议包一层启动,修正 ALSA/PortAudio | +| `voice_drone/core/` | 音频、VAD、STT、TTS、预处理、唤醒、配置、识别器主流程 | +| `voice_drone/main_app.py` | 唤醒流程 + LLM 流式 + 起飞脚本联动(原 `rocket_drone_audio.py`) | +| `voice_drone/config/` | `system.yaml`、`wake_word.yaml`、`keywords.yaml`、`command_.yaml` | +| `voice_drone/logging_/` | 彩色日志 | +| `voice_drone/tools/` | YAML 加载等 | +| `scripts/` | PX4 offboard、`generate_wake_greeting_wav.py` | +| `assets/tts_cache/` | 唤醒问候 WAV 缓存 | +| `models/` | **需自备或软链**,见 `models/README.txt` | + +## 环境准备 + +1. Python 3.10+(与原项目一致即可),安装依赖: + + ```bash + pip install -r requirements.txt + ``` + +2. 模型:将 STT / TTS /(可选)Silero VAD 放到 `models/`,或按 `models/README.txt` 从原仓库 `src/models` 创建符号链接。 + +3. 大模型:默认查找 `cache/qwen25-1.5b-gguf/qwen2.5-1.5b-instruct-q4_k_m.gguf`,或通过环境变量 `ROCKET_LLM_GGUF` 指定 GGUF 路径。 + +## 运行 + +在 **`voice_drone_assistant` 根目录** 执行: + +```bash +bash with_system_alsa.sh python main.py +``` + +常用参数与环境变量与原 `rocket_drone_audio.py` 相同(如 `ROCKET_LLM_STREAM`、`ROCKET_INPUT_DEVICE_INDEX`、`--input-index`、`ROCKET_ENERGY_VAD` 等),说明见 `voice_drone/main_app.py` 文件头注释。 + +也可直接跑模块: + +```bash +bash with_system_alsa.sh python -m voice_drone.main_app +``` + +## 为什么不默认带上原仓库的 models? + +- **ONNX / GGUF 体积大**(动辄数百 MB~数 GB),放进 Git 或重复拷贝会加重仓库和同步成本。 +- 抽离时只保证 **代码与配置自给**;权重文件用 **本机拷贝 / U 盘 / 另一台预先 `bundle`** 更灵活。 + +若你本机仍摆着原仓库 `rocket_drone_audio`,且 `voice_drone_assistant` 在其子目录下,代码里有个**临时便利**:`models/...` 找不到时会尝试 **上一级 `src/models/...`**,所以在开发机上可以不改目录也能跑。 +**这只在「子目录 + 上层仍有原仓库」时有效**,把 `voice_drone_assistant` **单独拷到另一台香橙派后,上层没有原仓库,必须在本目录自备 `models/`(和可选 `cache/`)**。 + +## 拷到另一台香橙派要做什么? + +1. **整目录复制**(建议先在本机执行下面脚本打全模型,再打包 `voice_drone_assistant`): + + ```bash + cd /path/to/voice_drone_assistant + bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio + ``` + + 会把 `SenseVoiceSmall`、`Kokoro-82M-v1.1-zh-ONNX`(及存在的 `SileroVad`)复制到本目录 `models/`;可按提示选择是否复制 Qwen GGUF。 + +2. **新机器上 Python 依赖**:另一台是**全新系统**时,需要再装一次(或整体迁移同一个 conda/env): + + ```bash + cd voice_drone_assistant + pip install -r requirements.txt + ``` + + 二进制/系统库层面若仍用 conda + PortAudio,建议继续 **`bash with_system_alsa.sh python main.py`**。 + +3. **大模型路径**:若未打包 `cache/`,在新机器设环境变量或放入默认路径,例如: + + ```bash + export ROCKET_LLM_GGUF=/path/to/qwen2.5-1.5b-instruct-q4_k_m.gguf + ``` + +综上:**工程可独立**,但必须带上 **`models/` + 已装依赖 +(可选)GGUF**;**`pip install` 每台新环境通常要做一次**,除非你把整个 conda env 目录一起迁移。 + +## 与原仓库关系 + +- 本目录为**代码与配置的复制 + 包名调整**(`src.*` → `voice_drone.*`),默认不把大体积 `models/`、`cache/` 放进版本库。 +- 原仓库 `rocket_drone_audio` 仍可继续使用;开发阶段两者可并存,部署到单机时只带走 `voice_drone_assistant`(+ `bundle` 后的模型)即可。 + +## 未纳入本工程的模块 + +PX4 电机演示、独立录音脚本、Socket 试飞控协议服务端、ChatTTS 转换脚本等均留在原仓库,以减小篇幅;本工程仍通过 `SocketClient` 预留配置项(`TakeoffPrintRecognizer` 使用 `auto_connect_socket=False`,不依赖外置试飞控 Socket)。 diff --git a/assets/tts_cache/wake_greeting.wav b/assets/tts_cache/wake_greeting.wav new file mode 100644 index 0000000..e28f224 Binary files /dev/null and b/assets/tts_cache/wake_greeting.wav differ diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md b/docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md new file mode 100644 index 0000000..04b70e1 --- /dev/null +++ b/docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md @@ -0,0 +1,179 @@ +# 云端语音 · `dialog_result` 与飞控二次确认(v1) + +供 **云端服务** 与 **机端 voice_drone_assistant** 同步实现。**尚无线上存量**:本文即 **`dialog_result` 的飞机位约定**,服务端可按 v1 直接改结构,无需迁就旧字段。 + +--- + +## 1. 目标 + +1. **`routing=chitchat`**:只走闲聊与对应 TTS,**不**下发可执行飞控负载。 +2. **`routing=flight_intent`**:携 **`flight_intent`(v1)** + **`confirm`**;机端是否立刻执行仅由 **`confirm.required`** 决定,并支持 **确认 / 取消 / 超时** 交互。 +3. **ASR**:飞控句是否改用云端识别见 **附录 A**;与 `confirm` 独立。 + +--- + +## 2. 术语 + +| 术语 | 含义 | +|------|------| +| **首轮** | 用户说一句;本轮 WS 收到 `dialog_result` 为止。 | +| **确认窗** | `confirm.required=true` 时,机端播完本轮 PCM 后 **仅收口令** 的时段,时长 **`confirm.timeout_sec`**。 | +| **`flight_intent`** | 见 `FLIGHT_INTENT_SCHEMA_v1.md`。 | + +--- + +## 3. `dialog_result` 形状(云端 → 机端) + +### 3.1 公共顶层(每轮必带) + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `turn_id` | string | 是 | 与现有一致,关联本 turn。 | +| **`protocol`** | string | 是 | 固定 **`cloud_voice_dialog_v1`**,便于机端强校验、排障。 | +| `routing` | string | 是 | **`chitchat`** \| **`flight_intent`** | +| `user_input` | string | 建议 | 本回合用于生成回复的用户文本(可为云端 STT 结果)。 | + +### 3.2 `routing=chitchat` + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `chat_reply` | string | 是 | 闲聊文本(与 TTS 语义一致或由服务端定义)。 | +| `flight_intent` | — | **禁止** | 不得出现。 | +| `confirm` | — | **禁止** | 不得出现。 | + +### 3.3 `routing=flight_intent` + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `flight_intent` | object | 是 | v1:`is_flight_intent`、`version`、`actions`、`summary` 等。 | +| **`confirm`** | object | 是 | 见 §3.4;**每轮飞控必带**,机端拒收缺字段报文。 | + +### 3.4 `confirm` 对象(`routing=flight_intent` 时必填) + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| **`required`** | bool | 是 | `true`:进入确认窗,**首轮禁止**执行飞控;`false`:首轮允许按机端执行开关立即执行(调试/免确认策略)。 | +| **`timeout_sec`** | number | 是 | 确认窗秒数;建议默认 **10**。 | +| **`confirm_phrases`** | string[] | 是 | 非空;与口播一致,推荐 **`["确认"]`**。 | +| **`cancel_phrases`** | string[] | 是 | 非空;推荐 **`["取消"]`**。 | +| **`pending_id`** | string | 是 | 本轮待定意图 ID(建议 UUID);日志、可选第二轮遥测(附录 B)。 | +| **`summary_for_user`** | string | 建议 | 与口播语义一致,供日志/本地 TTS 兜底;**最终以本轮 PCM 为准**。 | + +--- + +## 4. 播报(理解与提示) + +- **TTS**:仍用 **`tts_audio_chunk` + PCM**;内容示例:复述理解 + **「请回复确认或取消」**;服务端在 `confirm_*_phrases` 中与口播保持一致(推荐 **`确认` / `取消`**)。 +- 机端 **须** 在 **本轮 PCM 播放结束**(或播放管线给出「可收听下一句」)后再进入确认窗,避免抢话。 + +--- + +## 5. 机端短语匹配(确认窗内) + +对用户 **一句** STT 规范化后,与 `confirm_phrases` / `cancel_phrases` 比对(机端实现见 `match_phrase_list`): + +1. **取消优先**:若命中 `cancel_phrases` 任一 → 取消本轮。 +2. **确认**:否则若命中 `confirm_phrases` 任一 → 执行 **`flight_intent`**。 +3. **规则要点**:**全等**(去尾标点)算命中;或对 **很短** 的句子(长度 ≤ 短语长+2)允许 **子串** 命中,以便「好的确认」类说法;**整句复述**云端长提示(如「请回复确认或取消」)不会因同时含「确认」「取消」子串而误匹配。 +4. **未命中**:可静候超时(v1 建议确认窗内 **可多句** 直至超时,由机端实现决定)。 +4. **超时 / 取消** 固定中文播报见下表(机端本地 TTS,降低时延): + +| 事件 | 文案 | +|------|------| +| 超时 | `未收到确认指令,请重新下发指令` | +| 取消 | `已取消指令,请重新唤醒后下发指令` | +| 确认并执行 | `开始执行飞控指令` | + +若产品强制云端音色,见 **附录 C**。 + +--- + +## 6. 机端执行条件(归纳) + +| 条件 | 行为 | +|------|------| +| `routing=chitchat` | 不执行飞控。 | +| `routing=flight_intent` 且 `confirm.required=false` 且机端已开执行开关 | 首轮校验通过后 **可立即** 执行。 | +| `routing=flight_intent` 且 `confirm.required=true` | **仅**在确认窗内命中确认短语后执行;**首轮绝不**执行。 | + +--- + +## 7. 机端状态机(摘要) + +```mermaid +stateDiagram-v2 + [*] --> Idle + Idle --> Chitchat: routing=chitchat + Idle --> ExecNow: routing=flight_intent 且 confirm.required=false + Idle --> ConfirmWin: routing=flight_intent 且 confirm.required=true + + ConfirmWin --> ExecIntent: 命中 confirm_phrases + ConfirmWin --> SayCancel: 命中 cancel_phrases + ConfirmWin --> SayTimeout: timeout_sec + + ExecNow --> Idle + ExecIntent --> Idle + SayCancel --> Idle + SayTimeout --> Idle + Chitchat --> Idle +``` + +--- + +## 8. 会话握手 + +**`session.start`**(或等价)的 `client` **须** 带: + +```json +{ + "protocol": { + "dialog_result": "cloud_voice_dialog_v1" + } +} +``` + +服务端仅对声明该协议的客户端下发 §3 结构;机端若未声明,服务端可拒绝或返显式错误码(由服务端定义)。 + +--- + +## 9. 安全说明 + +二次确认减轻 **错词误飞**,不替代 **急停、遥控介入、场地规范**。 +TTS 若为「请回复确认或取消」,服务端请在 `confirm_phrases` / `cancel_phrases` 中下发 **`确认`**、**`取消`**(与口播一致);**听与判均在机端**,云端无需再收一轮确认消息。 + +--- + +## 附录 A:云端 ASR(可选) + +服务端可将飞控相关 utterance 改为 **云端 STT** 结果填入 `user_input`,与 `flight_intent` 解析同源;**执行仍以 `flight_intent` + `confirm` 为准**。 + +--- + +## 附录 B:第二轮 `turn`(可选遥测) + +用户确认后机端可再发一轮文本(ASR 原文),payload 可带 `pending_id`、`phase: confirm_ack`;**执行成功与否不依赖**该轮响应。 + +--- + +## 附录 C:超时/取消走云端 TTS(可选) + +若 `confirm.play_server_tts_on_timeout` 为真(服务端与机端扩展字段),则由云端推 PCM;**易增延迟**,v1 默认 **关**,以 §5 本地播报为准。 + +--- + +## 文档关系 + +| 文档 | 关系 | +|------|------| +| `FLIGHT_INTENT_SCHEMA_v1.md` | `flight_intent` 体 | +| `DEPLOYMENT_AND_OPERATIONS.md` | 部署 | + +**版本**:`cloud_voice_dialog_v1`(本文);后续 breaking 变更递增 `cloud_voice_dialog_v2` 等。 + +--- + +## 机端实现状态(voice_drone_assistant) + +- **`CloudVoiceClient`**:`session.start.client` 已带 `protocol.dialog_result: cloud_voice_dialog_v1`;`run_turn` 返回含 `protocol`、`confirm`。 +- **`main_app.TakeoffPrintRecognizer`**:解析 `confirm`;`required=true` 且已开 `ROCKET_CLOUD_EXECUTE_FLIGHT` 时,播完本轮 PCM 后进入 **`FLIGHT_CONFIRM_LISTEN`**,本地匹配短语 / 超时文案见 **`voice_drone.core.cloud_dialog_v1`**。 +- **服务端未升级前**:若缺 `protocol` 或 `confirm`,机端 **不执行** 飞控(仍播 TTS)。 diff --git a/docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md b/docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md new file mode 100644 index 0000000..3f8c12d --- /dev/null +++ b/docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md @@ -0,0 +1,55 @@ +# PCM ASR 上行协议 v1(机端实现摘要) + +与 `CloudVoiceClient`(`voice_drone/core/cloud_voice_client.py`)及 voicellmcloud 的 `pcm_asr_uplink` **session.start.transport_profile** 对齐。 + +## 上行:仅文本 WebSocket 帧 + +**禁止**用 WebSocket **binary** 发送用户 PCM。对端(Starlette)`receive_text()` 与 `receive_bytes()` 分流;binary 上发会导致对端异常,客户端可能表现为空文本等异常。 + +用户音频只出现在 **`turn.audio.chunk` 的 JSON 字段 `pcm_base64`** 中(标准 Base64,内容为 little-endian **pcm_s16le** 原始字节)。 + +## session.start + +- `transport_profile`: **`pcm_asr_uplink`** +- 其余与会话通用字段相同(`client`、`auth_token`、`session_id` 等)。 + +## 单轮上行(一个 `turn_id`) + +1. **文本 JSON**:`turn.audio.start` + - `type`: `"turn.audio.start"` + - `proto_version`: `"1.0"` + - `transport_profile`: `"pcm_asr_uplink"` + - `turn_id`: UUID 字符串 + - `sample_rate_hz`: 整数(机端一般为 **16000**,与采集一致) + - `codec`: `"pcm_s16le"` + - `channels`: **1** + +2. **文本 JSON**(可多条):`turn.audio.chunk` + - `type`: `"turn.audio.chunk"` + - `proto_version`、`transport_profile`、`turn_id` 与 start 一致 + - `pcm_base64`: 本段 PCM 原始字节的 Base64(不传 WebSocket binary) + + 每段原始字节长度由环境变量 **`ROCKET_CLOUD_AUDIO_CHUNK_BYTES`** 控制(默认 8192,对 **解码前** 的 PCM 字节数做钳制)。 + +3. **文本 JSON**:`turn.audio.end` + - `type`: `"turn.audio.end"` + - `proto_version`、`transport_profile`、`turn_id` 与 start 一致。 + +**并发**:同一 WebSocket 会话内,**勿**在收到上一轮的 `turn.complete` 之前再发新一轮 `turn.audio.start`。 + +## 下行(与 turn.text 同形态) + +- 可选:`asr.partial` — 机端仅日志/UI,**不参与状态机**。 +- `llm.text_delta`(可选) +- `tts_audio_chunk`(JSON)后随 **binary PCM**(TTS 下行仍可为 binary,与上行约定无关) +- `dialog_result` +- `turn.complete` + +机端对 **空文本帧** 会忽略并继续读(与云端「空文本忽略」一致)。 + +机端须 **收齐 `turn.complete` 且按序拼完该轮 TTS 二进制** 后再视为播报结束,再按产品规则分支(闲聊再滴声 / 飞控确认窗等)。 + +## 参考 + +- 会话产品流:[`CLOUD_VOICE_SESSION_SCHEME_v1.md`](./CLOUD_VOICE_SESSION_SCHEME_v1.md) +- 飞控确认:`CLOUD_VOICE_FLIGHT_CONFIRM_v1.md` diff --git a/docs/CLOUD_VOICE_SESSION_SCHEME_v1.md b/docs/CLOUD_VOICE_SESSION_SCHEME_v1.md new file mode 100644 index 0000000..f5aa5a9 --- /dev/null +++ b/docs/CLOUD_VOICE_SESSION_SCHEME_v1.md @@ -0,0 +1,163 @@ +# 语音助手会话方案 v1(服务端 + 机端对齐) + +本文档描述 **「唤醒 → 问候 → 滴声开录 → 断句上云 → 提示音 → 云端理解与 TTS → 分支循环/待机」** 的端到端方案,供 **服务端(voicellmcloud)** 与 **机端(本仓库 voice_drone_assistant)** 分工落地。 + +**v1 明确不做**:播报中途 **抢话 / 打断 TTS(barge-in)**;播放 TTS 时机端 **关麦或不处理用户语音**。 + +**关联协议**: + +- 音频上行与 Fun-ASR:[CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md](./CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md) +- 未唤醒不上云:由服务端/产品文档 `CLOUD_VOICE_CLIENT_WAKE_GATE_v1` 约定(机端须本地门禁后再建联上云) +- 总接口:以 voicellmcloud 仓库 `API_SPECIFICATION` 为准 +- 飞控确认窗:[CLOUD_VOICE_FLIGHT_CONFIRM_v1.md](./CLOUD_VOICE_FLIGHT_CONFIRM_v1.md) + +--- + +## 1. 产品流程(用户视角) + +1. 用户说唤醒词(如「无人机」,由机端配置),**仅本地处理,不上云 ASR**。 +2. 机端播放问候语(如「你好,有什么事儿吗」)— 可用本地 TTS 或 `tts.synthesize`。 +3. 机端 **滴一声**,表示 **开始收音**;同时启动 **5 秒静默超时** 计时(见 §4)。 +4. 用户说话;机端 **VAD/端点检测** 得到 **一整句** 后: + - 播放 **极短断句提示音**(表示「已截句、将上云」); + - **提示音播放期间闭麦或做回声隔离**,提示音结束后 **短消抖**(建议 **150~300 ms**)再恢复采集逻辑,避免把提示音当成用户语音。 +5. 将该句 **PCM** 以 `turn.audio.*` 发云端;云端 **Fun-ASR → LLM → `dialog_result` → TTS**;机端播完 **全部** `tts_audio_chunk` 及收到 **`turn.complete`** 后,视为本轮播报结束(见 §3 服务端)。 +6. **分支**: + - **`routing === flight_intent`**:进入 **飞控子状态机**(口头确认/取消/超时),**不使用** §4 的「闲聊滴声后 5s」规则覆盖确认窗;超时以 **`dialog_result.confirm.timeout_sec`** 及 [CLOUD_VOICE_FLIGHT_CONFIRM_v1.md](./CLOUD_VOICE_FLIGHT_CONFIRM_v1.md) 为准。 + - **`routing === chitchat`**:本轮结束后 **再滴一声**,进入下一轮 **步骤 4**(同一 WebSocket 会话内 **新 `turn_id`** 再起一轮 `turn.audio.*`)。 +7. 若在 **步骤 3 的滴声之后** 连续 **5 s** 内未检测到有效语音(见 §4),机端播 **超时提示音**,**不再收音**,回到 **待机**(仅唤醒)。 + +--- + +## 2. 机端状态机(规范性) + +| 状态 | 含义 | 开麦 | 上云 ASR | 备注 | +|------|------|------|----------|------| +| `STANDBY` | 仅监听唤醒 | 按现网 VAD | **禁止** `turn.audio.*` | 本地 STT + 唤醒词 | +| `GREETING` | 播问候 | 可关麦或忽略输入 | 否 | 避免问候进识别 | +| `PROMPT_LISTEN` | 已滴「开始录」,等用户一句 | 开 | 否 | **5s 超时**在此状态监控 | +| `SEGMENT_END` | 已断句,播短提示音 | **闭麦/屏蔽** | 否 | 消抖后再转 `UPLOADING` | +| `UPLOADING` | 发送 `turn.audio.*` | 否 | **是** | 一轮一个 `turn_id` | +| `PLAYING_CLOUD_TTS` | 播云端 TTS | **关麦**(v1 无抢话) | 否 | 至 `turn.complete` + PCM 播完 | +| `FLIGHT_CONFIRM` | 飞控确认窗 | 按飞控文档 | **可** `turn.text` 或按产品另定 | **独立超时**,不共用 5s | +| `CHITCHAT_TAIL` | 闲聊结束,将再滴一声 | — | 否 | 回到 `PROMPT_LISTEN` | + +**并发**:同一时刻仅允许 **一路** `turn.audio.start`~`end`;须等 `turn.complete` 后再开下一轮(与现有 `pipeline_lock` 一致)。 + +**唤醒前**:须满足未唤醒不上云的产品/协议约定。 + +--- + +## 3. 服务端(voicellmcloud)职责 — v1 **无新消息类型** + +### 3.1 单轮行为(不变) + +对每个完整 `turn.audio.*` 或 `turn.text`: + +1. Fun-ASR(仅 `pcm_asr_uplink` + 音频轮)→ 文本 +2. LLM 流式 → `dialog_result`(`routing` / `flight_intent` / `chat_reply` 等) +3. `tts_audio_chunk*` → `turn.complete` + +服务端 **不** 下发「请再滴一声」「进入待机」类机端 UX 信令;这些由机端根据 **`routing` + `turn.complete`** **固定规则** 驱动。 + +### 3.2 机端判定「播报完成」 + +须同时满足: + +- 收到该轮 **`turn.complete`** +- 已按序播完该轮关联的 **binary PCM**(`tts_audio_chunk` 与现实现一致) + +然后机端再执行 §1 步骤 6 的分支。 + +### 3.3 可选下行 + +- **`asr.partial`**:机端 **不得** 用于驱动状态跳转;仅可 UI 展示。 +- **错误**:`error` / `ASR_FAILED` 等 → 机端播简短失败提示后,建议 **回 `STANDBY` 或回到 `PROMPT_LISTEN`**(产品定)。 + +--- + +## 4. 5 秒静默超时(闲聊路径) + +| 项 | 约定 | +|----|------| +| **起算点** | 「**开始收音**」的 **滴声播放结束** 时刻(或滴声后固定 **50~100 ms** 偏移,避免与滴声能量重叠)。 | +| **「无说话」** | 麦克 **RMS / VAD** 低于阈值,持续累计 ≥ **5 s**(建议可配置,默认 5)。 | +| **期间若开始说话** | 清零超时;**断句上云**后本超时在下一轮「滴声」后重新起算。 | +| **触发动作** | 播 **超时提示音** → 进入 **`STANDBY`**(不再滴声、不上云)。 | +| **不适用** | **`FLIGHT_CONFIRM`** 整段;确认窗用 **服务端给的 `timeout_sec`**。 | + +**机端配置**(`system.yaml` `cloud_voice`):`listen_silence_timeout_sec`、`post_cue_mic_mute_ms`、`segment_cue_duration_ms`;环境变量见 `main_app.py` 头部说明。 + +--- + +## 5. 断句后提示音(工程) + +| 项 | 约定 | +|----|------| +| 目的 | 用户感知「已截句,可等待播报」 | +| 实现 | 机端本地短 WAV / 蜂鸣;时长建议 **≤ 200 ms** | +| 回声 | **SEGMENT_END** 阶段闭麦或硬件 AEC;结束后 **≥ 150 ms** 再进入 `UPLOADING` | +| 与云端 | **无需** 上传该提示音 | + +--- + +## 6. 时序简图(闲聊多轮) + +```mermaid +sequenceDiagram + participant U as 用户 + participant D as 机端 + participant S as 服务端 + + U->>D: 唤醒词(本地) + D->>D: GREETING 播问候 + D->>D: 滴声 → PROMPT_LISTEN(起 5s 定时) + U->>D: 一句语音 + D->>D: VAD 断句 → 短提示音 → UPLOADING + D->>S: turn.audio.start/chunk/end + S->>D: asr.partial(可选) + S->>D: dialog_result + TTS + turn.complete + D->>D: PLAYING_CLOUD_TTS(关麦) + alt chitchat + D->>D: 再滴声 → PROMPT_LISTEN + else flight_intent + D->>D: FLIGHT_CONFIRM(独立超时) + end +``` + +--- + +## 7. 配置建议(机端) + +| 键 | 默认值 | 说明 | +|----|--------|------| +| `listen_silence_timeout_sec` | `5` | 滴声后起算 | +| `post_cue_mic_mute_ms` | `150`~`300` | 断句提示音后再采集 | +| `cue_tone_duration_ms` / `segment_cue_duration_ms` | `≤200` | 断句提示 | +| `flight_confirm_handling` | 遵循飞控文档 | 禁用闲聊 5s 覆盖 | + +--- + +## 8. 机端开发自检 + +- [ ] `STANDBY` 下无 `turn.audio.start`。 +- [ ] `PLAYING_CLOUD_TTS` 与 `SEGMENT_END` 提示音阶段 **不开麦**(v1)。 +- [ ] 每轮新 `turn_id`;不并行两轮音频上行。 +- [ ] `flight_intent` 后进入 `FLIGHT_CONFIRM`,**不**误用 5s 闲聊超时。 +- [ ] `chitchat` 在 TTS 完成后 **再滴** 再 `PROMPT_LISTEN`。 + +--- + +## 9. 非目标(v1) + +- 播报中抢话、打断 TTS、实时 re-prompt。 +- 服务端驱动「滴声/待机」(均由机端规则实现)。 +- 连续免唤醒「直接说指令」跨多轮(若需另开 v2)。 + +--- + +## 10. 修订记录 + +| 版本 | 日期 | 说明 | +|------|------|------| +| v1 | 2026-04-07 | 首版:小爱类会话 + 双端分工;不含 barge-in | diff --git a/docs/DEPLOYMENT_AND_OPERATIONS.md b/docs/DEPLOYMENT_AND_OPERATIONS.md new file mode 100644 index 0000000..1879e5a --- /dev/null +++ b/docs/DEPLOYMENT_AND_OPERATIONS.md @@ -0,0 +1,288 @@ +# 部署与运维手册(项目总结) + +本文面向 **生产/外场部署**:说明 **voice_drone_assistant** 是什么、与 **云端** / **ROS 伴飞桥** / **PX4** 如何衔接,以及 **推荐启动顺序**、**环境变量**与**常见问题**。 +协议细节见 [`llmcon.md`](llmcon.md),通用配置索引见 [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md),伴飞桥行为见 [`FLIGHT_BRIDGE_ROS1.md`](FLIGHT_BRIDGE_ROS1.md),`flight_intent` 字段见 [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md)。 + +--- + +## 1. 项目总结 + +### 1.1 定位 + +**voice_drone_assistant** 是板端 **语音无人机助手**:麦克风 → 降噪/VAD → **SenseVoice STT** → **唤醒词** → 用户一句指令 → **云端 WebSocket**(LLM + TTS)或 **本地 Qwen + Kokoro** → 若服务端返回 **`flight_intent`**,可在本机 **校验后执行**(TCP Socket 旧路径,或 **ROS 伴飞桥** 推荐路径)。 + +### 1.2 推荐数据流(方案一:云 → 语音程序 → ROS 桥) + +```mermaid +flowchart LR + subgraph cloud [云端] + WS[WebSocket LLM+TTS] + end + subgraph board [机载 香橙派等] + MIC[麦克风] + MAIN[main.py TakeoffPrintRecognizer] + ROSPUB[子进程 publish JSON] + BRIDGE[flight_bridge ros1_node] + MAV[MAVROS] + end + FCU[PX4 飞控] + + MIC --> MAIN + MAIN <-->|WSS pcm_asr_uplink + flight_intent| WS + MAIN -->|ROCKET_FLIGHT_INTENT_ROS_BRIDGE| ROSPUB + ROSPUB -->|std_msgs/String /input| BRIDGE + BRIDGE --> MAV --> FCU +``` + +- **不**把 ROS 直接暴露给公网:云端只连板子的 **WSS/WS**;飞控由 **本机 MAVROS + 伴飞桥** 执行。 +- **TCP Socket**(`system.yaml` → `socket_server`)是另一条试飞控通道,与云端 **无关**;未起 Socket 服务端时仅会重连日志,不影响 ROS 方案。 + +### 1.3 目录与核心入口(仓库根 = `voice_drone_assistant/`) + +| 路径 | 说明 | +|------|------| +| `main.py` | 语音助手入口 | +| `with_system_alsa.sh` | 建议包装启动,修正 Conda 与系统 ALSA | +| `voice_drone/main_app.py` | 唤醒、云端/本地 LLM、TTS、`flight_intent` 执行策略 | +| `voice_drone/flight_bridge/ros1_node.py` | ROS1 订阅 `/input`,执行 `flight_intent` | +| `voice_drone/flight_bridge/ros1_mavros_executor.py` | MAVROS:offboard / AUTO.LAND / RTL | +| `voice_drone/tools/publish_flight_intent_ros_once.py` | 单次向 ROS 发布 JSON(主程序 ROS 桥会子进程调用) | +| `scripts/run_flight_bridge_with_mavros.sh` | 一键:roscore(可选)+ MAVROS + 伴飞桥 | +| `scripts/run_flight_intent_bridge_ros1.sh` | 仅伴飞桥(须已有 roscore + MAVROS) | +| `voice_drone/config/system.yaml` | 音频、STT、TTS、云端、`assistant` 等 | +| `requirements.txt` | Python 依赖;**rospy** 来自 `apt` 的 ROS Noetic,见文件内注释 | + +--- + +## 2. 环境与依赖 + +### 2.1 硬件与系统(典型) + +- ARM64 板卡(如 RK3588)、ES8388 等音频编解码器、USB/内置麦克风。 +- Ubuntu 20.04 + **ROS Noetic**(伴飞桥 / MAVROS 路径);同机运行语音进程与 `ros1_node`。 +- 飞控串口(如 `/dev/ttyACM0`)与 MAVROS `fcu_url` 一致。 + +### 2.2 Python + +- Python 3.10+(与原仓库一致即可)。 +- 在 **`voice_drone_assistant`** 根目录: + + ```bash + pip install -r requirements.txt + ``` + +### 2.3 ROS / MAVROS(伴飞桥方案必选) + +```bash +sudo apt install ros-noetic-ros-base ros-noetic-mavros ros-noetic-mavros-extras +# 按官方文档执行 mavros 地理库安装(如有) +``` + +- 语音主程序的 **ROS 桥**子进程会 `source /opt/ros/noetic/setup.bash` 并 **prepend** `PYTHONPATH`,**不要**在未 source ROS 的 shell 里把 `PYTHONPATH` 设成「只有工程根」,否则会找不到 `rospy`(参见 `main_app` 中 `_publish_flight_intent_to_ros_bridge`)。 + +### 2.4 模型与权重 + +- STT / TTS /(可选)VAD 放入 `models/`,或 `bash scripts/bundle_for_device.sh` 从原仓库打包。 +- 本地 LLM:GGUF 默认路径或 `ROCKET_LLM_GGUF`;**纯云端对话**时可弱化本地模型,但回退/混合模式仍需。 + +--- + +## 3. 部署拓扑 + +### 3.1 单机一体化(常见) + +同一台香橙派上同时运行: + +1. **roscore**(若尚无 master,由 `run_flight_bridge_with_mavros.sh` 拉起)。 +2. **MAVROS**(`px4.launch`,串口连 PX4)。 +3. **伴飞桥** `python3 -m voice_drone.flight_bridge.ros1_node`(订阅 **`/input`**)。 +4. **语音** `bash with_system_alsa.sh python main.py`。 + +`ROS_MASTER_URI` / `ROS_HOSTNAME`:一键脚本内默认 `http://127.0.0.1:11311` 与 `127.0.0.1`;**新开调试终端** 执行 `rostopic`/`rosservice` 前须自行 `source /opt/ros/noetic/setup.bash` 并 export **同一** `ROS_MASTER_URI`(见下文「常见问题」)。 + +### 3.2 网络 + +- 板子能访问 **云端 WebSocket**(`ROCKET_CLOUD_WS_URL`)。 +- PX4 + 遥控 + 安全开关等按外场规范配置;本文不替代安全检校清单。 + +--- + +## 4. 启动顺序(推荐) + +### 4.1 终端 A:飞控栈 + 伴飞桥 + +在 **`voice_drone_assistant`** 根目录: + +```bash +cd /path/to/voice_drone_assistant +bash scripts/run_flight_bridge_with_mavros.sh /dev/ttyACM0 921600 +``` + +脚本会: + +- 设置 `ROS_MASTER_URI`、`ROS_HOSTNAME`(未预设时默认为本机 master); +- 如无 master 则启动 **roscore**; +- 启动 **MAVROS** 并等待 `/mavros/state` **connected**; +- 前台启动伴飞桥,日志中应出现:`flight_intent_bridge 就绪:订阅 /input`。 + +**仅桥(已有 MAVROS 时)**: + +```bash +source /opt/ros/noetic/setup.bash +export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}" +export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}" +bash scripts/run_flight_intent_bridge_ros1.sh +``` + +### 4.2 终端 B:语音助手 + 云端 + 执行飞控 + +```bash +cd /path/to/voice_drone_assistant + +export ROCKET_CLOUD_VOICE=1 +export ROCKET_CLOUD_WS_URL='ws://<云主机>:8766/v1/voice/session' +export ROCKET_CLOUD_AUTH_TOKEN='' +export ROCKET_CLOUD_DEVICE_ID='drone-001' # 可选 + +# 云端返回 flight_intent 时是否在机端执行 +export ROCKET_CLOUD_EXECUTE_FLIGHT=1 +# 走 ROS 伴飞桥(与 Socket/offboard 序列互斥,勿双开重复执行) +export ROCKET_FLIGHT_INTENT_ROS_BRIDGE=1 +# 可选:ROCKET_FLIGHT_BRIDGE_TOPIC=/input ROCKET_FLIGHT_BRIDGE_WAIT_SUB=2 + +# 默认关闭本地「起飞演示」口令直起 offboard;需要时再设为 1 +# export ROCKET_LOCAL_KEYWORD_TAKEOFF=1 + +bash with_system_alsa.sh python main.py +``` + +成功时日志类似:`[飞控-ROS桥] 已发布至 /input`;伴飞桥端出现 `执行 flight_intent:steps=...`。 + +### 4.3 配置写进 YAML(可选) + +- 云端:`system.yaml` → `cloud_voice`(`enabled`、`server_url`、`auth_token` 等)。 +- 本地口令起飞:`assistant.local_keyword_takeoff_enabled`(默认 `false`);环境变量 `ROCKET_LOCAL_KEYWORD_TAKEOFF` **非空时优先生效**。 + +--- + +## 5. 环境变量速查(飞控与云端) + +| 变量 | 含义 | +|------|------| +| `ROCKET_CLOUD_VOICE` | `1`:对话走云端 WebSocket | +| `ROCKET_CLOUD_WS_URL` | 云端会话地址 | +| `ROCKET_CLOUD_AUTH_TOKEN` | WS 鉴权 | +| `ROCKET_CLOUD_DEVICE_ID` | 设备 ID(可选) | +| `ROCKET_CLOUD_EXECUTE_FLIGHT` | `1`:云端 `flight_intent` 在机端执行 | +| `ROCKET_FLIGHT_INTENT_ROS_BRIDGE` | `1`:执行方式为 **发布到 ROS `/input`**,不跑机内 Socket+offboard 序列 | +| `ROCKET_FLIGHT_BRIDGE_TOPIC` | 默认 `/input` | +| `ROCKET_FLIGHT_BRIDGE_SETUP` | 子进程内 source ROS 的命令,默认 `source /opt/ros/noetic/setup.bash` | +| `ROCKET_FLIGHT_BRIDGE_WAIT_SUB` | 发布前等待订阅者的秒数,默认 `2`;`0` 即尽可能快发 | +| `ROCKET_LOCAL_KEYWORD_TAKEOFF` | 非空时:`1/true/yes` 开启 **`keywords.yaml` takeoff → 本地 offboard** | +| `ROCKET_CLOUD_PX4_CONTEXT_FILE` | 覆盖 `cloud_voice.px4_context_file`,合并进 session.start | + +更多调试变量见 **`voice_drone/main_app.py` 文件头注释** 与 [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md) 第 5 节。 + +--- + +## 6. 联调与自测 + +### 6.1 仅测 ROS 链(无语音) + +终端已 `source /opt/ros/noetic/setup.bash` 且与 master 一致: + +```bash +rostopic pub -1 /input std_msgs/String \ + "data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"测\"}'" +``` + +注意:`std_msgs/String` 在命令行里只能写 **`data: '...json...'`**,不能把 JSON 放在消息顶层。 + +### 6.2 确认话题与 master + +```bash +source /opt/ros/noetic/setup.bash +export ROS_MASTER_URI=http://127.0.0.1:11311 +rosnode list +rostopic info /input +rosservice list | grep set_mode +``` + +若 `Unable to communicate with master!`:当前 shell 未连上正在运行的 **roscore**(或未 export 正确 `ROS_MASTER_URI`)。 + +--- + +## 7. 常见问题(摘录) + +| 现象 | 可能原因 | 处理 | +|------|----------|------| +| `ModuleNotFoundError: rospy` | 子进程未继承 ROS 的 `PYTHONPATH` | 已修复为 `PYTHONPATH=<根>:$PYTHONPATH`;确保 `ROCKET_FLIGHT_BRIDGE_SETUP` 能 source Noetic | +| 语音端「已发布」但桥无日志 | 曾用相对 `input`,与全局 `/input` 不一致 | 伴飞桥默认已改为订阅 **`/input`**;重启桥 | +| `set_mode unavailable` / land 失败 | OFFBOARD 断流、MAVROS 异常等 | 伴飞桥降落逻辑已带持续 setpoint + 重连 proxy;仍失败则查 `rosservice`、`/mavros/state`、链路 | +| takeoff 超时 | 未进 OFFBOARD、未解锁、定位未就绪 | 查地面站、`/mavros/state`、适当增大 `~takeoff_timeout_sec`(ROS 私有参数) | +| ALSA underrun | 播放与采集竞争 | 板端常见;可调缓冲区/设备或 `recognizer.ack_pause_mic_for_playback` | + +--- + +## 8. 安全与运维建议 + +- 外场前在 **SITL 或系留** 环境验证完整 **`flight_intent`** 序列。 +- 云端 token、WS URL 勿提交到公开仓库;用环境变量或本机 **overlay** 配置注入。 +- 升级伴飞桥或 MAVROS 后清日志重试一遍 **`/input`** 手发 JSON。 + +--- + +## 9. 迁移到另一台香橙派:是否只拷贝 `voice_drone_assistant` 即可? + +**结论:目录是「代码 + 配置」的核心载体,但仅靠「整文件夹 scp 过去」通常不够;新板必须再装系统级依赖、模型与(可选)ROS,并按现场改配置。** + +### 9.1 拷贝目录本身会带上什么 + +| 已包含 | 说明 | +|--------|------| +| 全部 Python 源码、`voice_drone/config/*.yaml` 默认配置 | 可直接改 YAML / 环境变量适配新环境 | +| `scripts/`、`with_system_alsa.sh`、`docs/` | 启动与说明在包内 | + +### 9.2 新板必须单独准备(不随目录自动存在) + +| 项 | 说明 | +|----|------| +| **Ubuntu + 音频/ALSA** | 与当前开发板同代或自行适配;录音设备索引可能变化,需重选或设 `ROCKET_INPUT_DEVICE_INDEX` | +| **`pip install -r requirements.txt`** | 每台新 Python 环境执行一次(或整体迁移同一 conda 目录) | +| **`models/`** | STT/TTS/VAD 体积大,**务必**在本机先 `bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio` 或手工拷入,见 `models/README.txt` | +| **`cache/` GGUF** | 纯云端可不强依赖;若需本地 Qwen 回退,拷贝或设 `ROCKET_LLM_GGUF` | +| **ROS Noetic + MAVROS** | **apt** 安装;伴飞桥方案 **必选**;`rospy` **不要**指望只靠 pip | +| **云端连通** | 新板 IP/防火墙能访问 `ROCKET_CLOUD_WS_URL`;token 用环境变量注入 | +| **`dialout` 等权限** | 访问 `/dev/ttyACM0` 的用户加入 `dialout`,否则 MAVROS 无串口 | +| **`system.yaml` 现场差异** | `socket_server` IP、可选 `tts.output_device`、若麦索引固定可写 `audio.input_device_index` | + +### 9.3 推荐迁移流程(简表) + +1. 在旧机或 CI:**bundle 模型** → 打包整个 `voice_drone_assistant`(含 `models/`,按需含 `cache/`)。 +2. 新香橙派:解压到任意路径,安装 **`requirements.txt`**、**ROS+MAVROS**、系统音频工具。 +3. 用 **`with_system_alsa.sh python main.py`** 试麦与 STT;再按本文 **§4** 双终端起 **桥 + 语音**。 +4. 首次外场前做一次 **`rostopic pub /input`** 手发 JSON(见 **§6**)。 + +### 9.4 常见误区 + +- **只拉 Git、不拷 `models/`**:STT/TTS 启动即失败。 +- **新板 Noetic 未装却开 `ROCKET_FLIGHT_INTENT_ROS_BRIDGE`**:发布子进程仍可能报错。 +- **假设麦克风设备号一定相同**:Orange Pi 刷机或换内核后常变,以首次启动日志为准。 + +--- + +## 10. 文档索引 + +| 文档 | 用途 | +|------|------| +| [`README.md`](../README.md) | 仓库简介、模型、`bundle` | +| [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md) | 配置项与日常用法索引 | +| **本文** | 部署拓扑、启动顺序、环境变量、联调、**§9 迁移清单** | +| [`FLIGHT_BRIDGE_ROS1.md`](FLIGHT_BRIDGE_ROS1.md) | 伴飞桥参数、PX4 行为、`rostopic pub` 注意 | +| [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md) | JSON 协议 | +| [`llmcon.md`](llmcon.md) | 云端协议 | +| [`CLOUD_VOICE_FLIGHT_CONFIRM_v1.md`](CLOUD_VOICE_FLIGHT_CONFIRM_v1.md) | **飞控口头二次确认**(闲聊不变、确认/取消/超时)云端与机端字段约定 | + +--- + +*文档版本与仓库同步;若行为与代码不一致,以当前 `main_app.py`、`flight_bridge/*.py` 为准。* diff --git a/docs/FLIGHT_BRIDGE_ROS1.md b/docs/FLIGHT_BRIDGE_ROS1.md new file mode 100644 index 0000000..0d9b615 --- /dev/null +++ b/docs/FLIGHT_BRIDGE_ROS1.md @@ -0,0 +1,88 @@ +# Flight Intent 伴飞桥(ROS 1 + MAVROS) + +本目录代码与语音助手 **`main.py` 独立进程**:在 **MAVROS 已连接 PX4** 的前提下,订阅一条 JSON,按 **`FLIGHT_INTENT_SCHEMA_v1.md`** 顺序执行。 + +## 依赖 + +- Ubuntu / 设备上已装 **ROS Noetic**、`mavros`(与 `scripts/run_px4_offboard_one_terminal.sh` 一致) +- Python 能 import:`rospy`、`std_msgs`、`geometry_msgs`、`mavros_msgs` +- 本仓库根目录 **`voice_drone_assistant`** 需在 `PYTHONPATH`(启动脚本已设置) + +## 启动 + +**推荐(不会单独敲 roslaunch 时用)**:一键拉起 roscore(若尚无)→ MAVROS → 伴飞桥: + +```bash +cd voice_drone_assistant +bash scripts/run_flight_bridge_with_mavros.sh +bash scripts/run_flight_bridge_with_mavros.sh /dev/ttyACM0 921600 +``` + +**已有 MAVROS 时**只启桥: + +```bash +cd voice_drone_assistant +bash scripts/run_flight_intent_bridge_ros1.sh +``` + +默认节点名(含 anonymous 后缀):`/flight_intent_mavros_bridge_<...>`,默认订阅 **全局 `/input`**(`std_msgs/String`,内容为 JSON)。私有参数 `~input_topic` 可改(例如专用名时填入完整话题)。桥启动日志会打印实际订阅名。 + +## JSON 格式 + +- **完整** `flight_intent`(与云端相同顶层字段),或 +- **最小**:`{"actions":[...], "summary":"任意非空"}`(节点内会补 `is_flight_intent/version`) + +校验失败会打 `rospy.logerr`,不执行。 + +## 行为映射(首版) + +| `type` | 行为(简述) | +|--------|----------------| +| `takeoff` | Offboard 位姿:当前点预热 setpoint → `OFFBOARD` + arm → `z0 + Δz`(Δz 来自 `relative_altitude_m` 或参数 `~default_takeoff_relative_m`,与 `px4_ctrl_offboard_demo` 同号约定) | +| `hover` / `hold` | 当前位姿 hold 约 1s(持续发 setpoint) | +| `wait` | `rospy.Duration`,Offboard 时顺带维持当前点 setpoint | +| `goto` | `local_ned` / `body_ned` 增量 → 目标 NED 点,到位容差见 `~goto_position_tolerance` | +| `land` | `AUTO.LAND` | +| `return_home` | `AUTO.RTL` | + +**注意**:真机前请 SITL 验证;不同 PX4/机型的 `custom_mode` 字符串若不一致需在 `ros1_mavros_executor.py` 中调整。 + +## 参数(私有命名空间) + +| 参数 | 默认 | 含义 | +|------|------|------| +| `~input_topic` | `/input` | 订阅话题(建议绝对路径;勿再用相对名 `input`,否则与 `/input` 对不上) | +| `~default_takeoff_relative_m` | `0.5` | `takeoff` 无 `relative_altitude_m` 时 | +| `~takeoff_timeout_sec` | `15` | | +| `~goto_position_tolerance` | `0.15` | m | +| `~goto_timeout_sec` | `60` | | +| `~land_timeout_sec` | `45` | land/rtl 等待 disarm 超时 | +| `~offboard_pre_stream_count` | `80` | 与 demo 类似 | + +## 与语音程序的关系 + +- **`main.py`**:仍可用 Socket / 本地 offboard **演示脚本**(产品过渡期)。 +- **桥**:适合作为 **长期** MAVROS 执行端;后续可把语音侧改为 **向本节点 `input` 发布 JSON**(`rospy` 或 `roslibpy` 等),而不再直接起 bash demo。 + +## 与语音侧联调:`rostopic pub`(注意 `data:`) + +`std_msgs/String` 的 YAML 只有字段 **`data`**,JSON 必须写在 **`data: '...'`** 里;不能把 JSON 直接当消息顶层(否则会 `ERROR: No field name [is_flight_intent]` / `Args are: [data]`)。 + +终端 A:已跑 `run_flight_bridge_with_mavros.sh`(或 MAVROS + 桥)。 +终端 B: + +```bash +source /opt/ros/noetic/setup.bash +# 订阅名以桥日志为准,常见为全局 /input +rostopic pub -1 /input std_msgs/String \ + "data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"联调降落\"}'" +``` + +**悬停 2 秒再降**(须已在空中时再试): + +```bash +rostopic pub -1 /input std_msgs/String \ + "data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"hover\",\"args\":{}},{\"type\":\"wait\",\"args\":{\"seconds\":2}},{\"type\":\"land\",\"args\":{}}],\"summary\":\"测\"}'" +``` + +在语音进程内集成时,建议 **不要在 asyncio 里直接调 rospy**,用 **独立桥进程 + topic** 或 **UNIX socket 转发到桥**。 diff --git a/docs/FLIGHT_INTENT_IMPLEMENTATION_PLAN.md b/docs/FLIGHT_INTENT_IMPLEMENTATION_PLAN.md new file mode 100644 index 0000000..eba4320 --- /dev/null +++ b/docs/FLIGHT_INTENT_IMPLEMENTATION_PLAN.md @@ -0,0 +1,113 @@ +# Flight Intent v1 + 伴飞桥 — 实施计划 + +本文档与 [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md) 配套,描述从协议闭环到 ROS/PX4 可控的**分阶段交付**。顺序建议按阶段 0→4;各阶段内任务可并行处已标注。 + +--- + +## 目标与验收标准 + +| 维度 | 验收标准 | +|------|-----------| +| **协议** | 云端下发的 `flight_intent` 满足 v1:含 `wait`、`takeoff` 可选高度、`trace_id`;L1–L3 校验可自动化 | +| **语音客户端** | 能解析并记录完整 `actions`;在 `ROCKET_CLOUD_EXECUTE_FLIGHT=1` 时通过 Socket/桥 执行或与桥约定本地执行 `wait` | +| **桥** | 顺序执行 `actions`,每步有超时/失败策略;可对接 MAVROS(或既定 ROS 2 栈)驱动 PX4 | +| **安全** | 执行前 L4 门禁、执行中可中断、急停路径明确 | +| **回归** | SITL 或台架可重复跑通「起飞 → 悬停 → wait → 降落」等示例 | + +--- + +## 阶段 0:对齐与基线(约 0.5~1 天) + +- [ ] 全员精读 `FLIGHT_INTENT_SCHEMA_v1.md`,冻结 **v1 白名单**(`type` / `args` 键)。 +- [ ] 确认伴飞侧技术选型:**ROS 2 + MAVROS**(或 `px4_ros_com`)与默认 **AUTO vs Offboard** 策略(写入桥 YAML,不写进 JSON)。 +- [ ] 盘点现有 **Socket 服务**:是否即「桥」或仅转发;是否需新进程 `flight_intent_bridge`。 +- [ ] 建立 **trace_id** 在日志中的格式(云端 / 语音 / 桥统一)。 + +**产出**:架构一页纸(谁消费 WebSocket、谁连 PX4)、桥配置模板路径约定。 + +--- + +## 阶段 1:协议与云端(可与阶段 2 并行,约 2~4 天) + +- [ ] **Schema 校验**:服务端对 `flight_intent` 做 L1–L3(必要时 L4 占位);非法则 `routing=error` 或产品协议兜底。 +- [ ] **LLM 提示词**:只允许 §3.7 中 `type` 与允许键;强调 **时长必须用 `wait`**,禁止用 `summary` 控机。 +- [ ] **示例与回归用例**:固定 JSON golden(§7.1~§7.3 + 边界:首步 `wait`、`seconds` 超界、多余 `args` 键)。 +- [ ] **可选 `trace_id`**:服务端生成或在 bundle 层透传。 + +**产出**:校验测试集、提示词 MR、发布说明(对客户端可见的字段变更)。 + +--- + +## 阶段 2:语音客户端(`voice_drone_assistant`)(约 3~5 天) + +可与阶段 1、3 部分并行。 + +- [x] **Pydantic**:`voice_drone/core/flight_intent.py`(v2)按 v1 文档收紧动作与 `args`。 +- [x] **`parse_flight_intent_dict`**:等价 L1–L3 + 首步禁止 `wait`;白名单、`goto.frame`、`wait.seconds`、`takeoff.relative_altitude_m`。 +- [x] **`main_app`**:`ROCKET_CLOUD_EXECUTE_FLIGHT=1` 时在后台线程 **`_run_cloud_flight_intent_sequence`** 顺序执行;`wait` 用 `time.sleep`;`goto` **单轴** 映射 Socket `Command`;`return_home` 已入 `Command`;**含 `takeoff` 的序列**在 offboard 完成后继续后续步(不再丢失)。 +- [x] **日志**:序列开始时打印 `trace_id`;`takeoff` 打相对高度提示(offboard 是否消费须自行接参数)。 +- [x] **单测**:`tests/test_flight_intent.py`(无完整依赖时 goto 用例自动 skip)。 + +**产出**:MR 合并后,本地无 PX4 也能跑通解析与 mock 执行。 + +--- + +## 阶段 3:伴飞桥 + ROS/PX4(约 5~10 天,视现网复用程度) + +- [x] **进程边界(首版)**:独立 ROS1 节点,订阅 `std_msgs/String` JSON;见 **`docs/FLIGHT_BRIDGE_ROS1.md`**、`scripts/run_flight_intent_bridge_ros1.sh`。 +- [x] **执行器(首版)**:`voice_drone/flight_bridge/ros1_mavros_executor.py` 单线程顺序执行;`takeoff/goto` 带超时;`land/rtl` 等待 disarm 超时。 +- [x] **翻译实现(首版 / MAVROS)**: + - `takeoff` / `hover` / `wait` / `goto`:`/mavros/setpoint_raw/local`(Offboard)+ `set_mode` / `arming`。 + - `land` / `return_home`:`AUTO.LAND` / `AUTO.RTL`。 +- [ ] **安全**:L4(电量、围栏、急停 topic);`wait` 中异常策略。 +- [ ] **回执**:result topic / 与 `main.py` 的 topic 串联。 +- [ ] **ROS2 / 仅 TCP 无 ROS**:按需另起接口。 + +**产出(当前)**:ROS1 桥可 `rostopic pub` 联调;**待** launch、与语音侧发布 JSON、SITL CI。 + +--- + +## 阶段 4:联调、硬化与发布(约 3~7 天) + +- [ ] **端到端**:真机或 SITL:语音 → 云 → 客户端 → 桥 → PX4,带 `trace_id` 串 log。 +- [ ] **压测与失败注入**:断 WebSocket、桥崩溃重启、Offboard 丢失等(预期行为写进运维文档)。 +- [ ] **配置与门禁**:默认关闭实飞执行;仅生产镜像打开;参数与围栏双人复核。 +- [ ] **文档**:更新 `PROJECT_GUIDE.md` 中「飞控路径」链接到本文与 SCHEMA。 + +**产出**:发布 checklist、已知限制列表(如某机型仅支持 AUTO 等)。 + +--- + +## 依赖与风险 + +| 风险 | 缓解 | +|------|------| +| Socket 协议与 `Command` 无法表达多步 | **推荐**由桥消费**完整** `flight_intent` JSON,客户端只负责下发一份;少步经 Socket 逐条 | +| Offboard 与 AUTO 混用冲突 | 桥配置单一「主策略」;`goto` 仅在 Offboard 就绪时接受 | +| LLM 仍产出非法 JSON | L2 硬拒绝 + 提示词回归 + golden 测试 | +| 排期膨胀 | 先交付 **AUTO 模式族 + wait + land**,再迭代复杂 `goto` | + +--- + +## 建议里程碑(日历为估算) + +| 里程碑 | 内容 | +|--------|------| +| **M1** | 阶段 0–1 完成:云校验 + 提示词 + golden | +| **M2** | 阶段 2 完成:客户端 strict 模型 + `wait` + 执行路径单一数据源 | +| **M3** | 阶段 3 完成:桥 + SITL 跑通 §7.2 | +| **M4** | 阶段 4:联调签字 + 生产策略 | + +--- + +## 文档索引 + +| 文档 | 用途 | +|------|------| +| [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md) | 字段、校验、桥分层、ROS 参考 | +| [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md) | 仓库总览与运行方式 | +| 本文 | 任务拆解、顺序、验收 | + +--- + +**版本**:2026-04-07;随 SCHEMA v1 修订同步更新本计划中的阶段勾选与工期估算。 diff --git a/docs/FLIGHT_INTENT_SCHEMA_v1.md b/docs/FLIGHT_INTENT_SCHEMA_v1.md new file mode 100644 index 0000000..9a0a9ae --- /dev/null +++ b/docs/FLIGHT_INTENT_SCHEMA_v1.md @@ -0,0 +1,372 @@ +# 云端高层飞控意图 JSON 规范 v1(完整版) + +> **定位**:定义 WebSocket `dialog_result.flight_intent` 中的**语义对象**与**伴飞侧执行约定**,不是 MAVLink 二进制帧。 +> **目标**: +> +> 1. **协议**:客户端与云端可 **100% 按字段表 strict 解析**;**禁止**用自然语言或 `summary` 驱动机控。 +> 2. **桥(companion)**:按本文执行 **有序 `actions`**、校验、排队与安全门,再译为 PX4 可接受的模式 / 指令 / Offboard setpoint。 +> 3. **ROS**:为 MAVROS(或等价 ROS 2 封装)提供**参考映射表**;具体 topic/service 名称以装机软件栈为准。 + +**协议(bundle)**:`proto_version: "1.0"`,会话 `transport_profile: "pcm_asr_uplink"`(与机端 CloudVoiceClient 一致)。 + +--- + +## 1. 顶层对象 `flight_intent` + +当 `routing === "flight_intent"` 时,`flight_intent` **必须非 null**,且为下表 JSON 对象(键顺序任意)。 + +| 字段 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `is_flight_intent` | `boolean` | 是 | **必须为 `true`** | +| `version` | `integer` | 是 | **Schema 版本,本文档固定为 `1`** | +| `actions` | `array` | 是 | **有序**动作列表,按**时间先后**执行(见 §5、§10) | +| `summary` | `string` | 是 | 一句人类可读中文摘要(播报/日志);**不参与机控解析** | +| `trace_id` | `string` | 否 | **端到端追踪 ID**(建议 UUID 或雪花 ID);用于桥、ROS 节点与日志关联;长度建议 ≤ 128 | + +**禁止字段**:除上表外,顶层不得出现其它键(便于 `strict` 解析)。扩展须 **递增 `version`**(见 §8)。 + +**字符编码**:UTF-8。 + +--- + +## 2. `actions[]` 通用形式 + +每个元素: + +```json +{ "type": "", "args": { ... } } +``` + +| 字段 | 类型 | 必填 | +|------|------|------| +| `type` | `string` | 是,取值限于 §3 枚举 | +| `args` | `object` | 是,允许为空对象 `{}`;**仅允许**各小节表中列出的键 | + +--- + +## 3. `ActionType` 与白名单 `args` + +以下 **`type` 仅允许小写**。未列出的值在 v1 **非法**。 + +### 3.1 `takeoff` + +| args 键 | 类型 | 必填 | 说明 | +|---------|------|------|------| +| `relative_altitude_m` | `number` | 否 | 相对起飞点(或飞控定义的 TAKEOFF 参考)的**目标高度**,单位 **米**,须 **> 0**。省略则 **完全由机端/飞控缺省参数** 决定(与旧版「空 args」语义一致) | + +```json +{ "type": "takeoff", "args": {} } +``` + +```json +{ "type": "takeoff", "args": { "relative_altitude_m": 5 } } +``` + +**桥 / ROS 映射提示**:PX4 `TAKEOFF` 模式、`MAV_CMD_NAV_TAKEOFF`;相对高度常与 `MIS_TAKEOFF_ALT` 或命令参数结合,**以装机参数为准**。 + +--- + +### 3.2 `land` + +| args 键 | 类型 | 必填 | 说明 | +|---------|------|------|------| +| — | — | — | v1 无参;降落行为由飞控 `AUTO.LAND` 等策略决定 | + +```json +{ "type": "land", "args": {} } +``` + +**桥 / ROS 映射提示**:`AUTO.LAND`、`MAV_CMD_NAV_LAND`;固定翼 / 多旋翼路径不同,由机端处理。 + +--- + +### 3.3 `return_home` + +```json +{ "type": "return_home", "args": {} } +``` + +语义:**返航至 Home 并按飞控策略降落或盘旋**(与 PX4 **RTL** 概义一致)。 + +--- + +### 3.4 `hover` 与 `hold` + +二者在 v1 **语义等价**:**在当前位置附近保持**(多旋翼常见为位置保持;固定翼可能映射为 Loiter,由机端按机型解释)。 + +| args 键 | 类型 | 必填 | 说明 | +|---------|------|------|------| +| — | — | — | 仅 `{}` 合法;表示进入保持,**不隐含**持续时间(时长用 §3.6 `wait`) | + +```json +{ "type": "hover", "args": {} } +``` + +```json +{ "type": "hold", "args": {} } +``` + +**约定**:同一 `actions` 序列建议只选 `hover` 或 `hold` 一种命名;解析端可映射到同一 PX4 行为。 + +**互操作(非规范首选)**:个别上游可能错误输出 `"args": { "duration": 3 }`(秒)。**伴飞客户端**(如本仓库 `flight_intent.py`)可在校验时将其**折叠**为:`hover`(无 `duration`)+ `wait`,与上表典型组合等价;**新开发的上游仍应只产 `wait`**。 + +**典型组合**(「悬停 3 秒后降落」): + +```json +[ + { "type": "takeoff", "args": {} }, + { "type": "hover", "args": {} }, + { "type": "wait", "args": { "seconds": 3 } }, + { "type": "land", "args": {} } +] +``` + +--- + +### 3.5 `goto` — 相对/局部位移 + +| args 键 | 类型 | 必填 | 说明 | +|---------|------|------|------| +| `frame` | `string` | 是 | 坐标系,取值见 §4 | +| `x` | `number` \| `null` | 否 | **米**;`null` 或 **省略** 表示该轴**无位移意图**(机端保持) | +| `y` | `number` \| `null` | 否 | 同上 | +| `z` | `number` \| `null` | 否 | 同上 | + +```json +{ + "type": "goto", + "args": { + "frame": "local_ned", + "x": 100, + "y": 0, + "z": 0 + } +} +``` + +**语义**:在 `frame` 下相对当前位置的**增量**。v1 **不**定义绝对经纬度航点;若未来需要,应 **v2** 增加 `goto_global` / `waypoint`。 + +**口语映射示例**:「向前飞 10 米」可建模为 `frame: "body_ned"`, `x: 10`(前为 x+,与 §4 一致)。 + +--- + +### 3.6 `wait` — 纯时间等待 + +**不包含**模式切换;桥在**本地计时**后继续下一步。用于「悬停多久」「停顿再执行」等。 + +| args 键 | 类型 | 必填 | 说明 | +|---------|------|------|------| +| `seconds` | `number` | 是 | 等待秒数,须满足 **0 < seconds ≤ 3600**(上限可防止 LLM 写极大值;产品可改小) | + +```json +{ "type": "wait", "args": { "seconds": 3 } } +``` + +**安全**:等待期间桥须持续监测遥测(失联、低电量、姿态异常等),**可中断**序列并转入 `RTL` / `LAND` / `HOLD`(策略见 §10.4)。 + +--- + +### 3.7 v1 动作类型一览 + +| `type` | 必填 `args` 键 | 备注 | +|--------|----------------|------| +| `takeoff` | 无 | 可选 `relative_altitude_m` | +| `land` | 无 | | +| `return_home` | 无 | | +| `hover` | 无 | | +| `hold` | 无 | 与 `hover` 等价 | +| `goto` | `frame` | 可选 x/y/z | +| `wait` | `seconds` | | + +--- + +## 4. `frame`(仅 `goto`) + +| 取值 | 含义 | +|------|------| +| `local_ned` | **局部 NED**:北(x)-东(y)-地(z),单位 m;**向下为 z 正**(与 PX4 `LOCAL_NED` 常见用法一致) | +| `body_ned` | **机体系**:**前(x)-右(y)-下(z)**,单位 m;桥或 ROS 侧需转换到 NED / setpoint | + +**v1 仅此两值**;其它字符串 **L2 非法**。 + +--- + +## 5. 序列语义与组合规则 + +- **`actions` 有序**:严格对应口语**时间顺序**(先执行索引 0,再 1,…)。 +- **空列表**:**不允许**(至少 1 个元素)。 +- **`wait`**:不改变飞控模式;若需「边悬停边等」,应先 `hover`/`hold` 再 `wait`(或在上一步已进入位置模式的假定下仅 `wait`,由机端策略定义;**推荐**显式 `hover` 再 `wait`)。 +- **首步**:首元素为 `wait` **不推荐**(飞机未起飞则等待无控飞意义);服务端可做 **L4 警告或拒绝**。 +- **`takeoff` 后出现 `goto`**:桥应确保已有位置估计/GPS 等前置条件,否则拒绝并回报原因。 +- **重复动作**:不禁止连续多个 `goto` / `wait`;机端可合并或排队。 + +--- + +## 6. 校验分级(服务端 + 桥建议共用) + +| 级别 | 内容 | +|------|------| +| **L1 结构** | JSON 可解析;`is_flight_intent===true`;`version===1`;`actions` 为非空数组;`summary` 为非空字符串;`trace_id` 若存在则为 string。 | +| **L2 枚举** | 每个 `action.type` ∈ §3.7;`goto` 含合法 `frame`;各 `args` **仅含**该 `type` 允许的键(无多余键)。 | +| **L3 数值** | `relative_altitude_m` 若存在则 **> 0** 且建议 capped(如 ≤ 500);`wait.seconds` 在 **(0, 3600]**;`goto` 的 x/y/z 为有限 number 或 null;位移模长可设上限(如 10e3 m)。 | +| **L4 语义** | 结合 `session.start.client.px4`(机型、是否支持 Offboard、地理围栏等);禁止不合法序列(如无定位时 `goto`);不通过则 `error` 或带工程约定 `warnings`(v2 可标准化 `warnings` 数组)。 | + +--- + +## 7. 完整示例 + +### 7.1 起飞 → 北飞 100m → 悬停 + +```json +{ + "is_flight_intent": true, + "version": 1, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + "actions": [ + { "type": "takeoff", "args": {} }, + { + "type": "goto", + "args": { "frame": "local_ned", "x": 100, "y": 0, "z": 0 } + }, + { "type": "hover", "args": {} } + ], + "summary": "起飞后向北飞约100米并悬停" +} +``` + +### 7.2 起飞 → 悬停 3 秒 → 降落 + +```json +{ + "is_flight_intent": true, + "version": 1, + "actions": [ + { "type": "takeoff", "args": { "relative_altitude_m": 3 } }, + { "type": "hover", "args": {} }, + { "type": "wait", "args": { "seconds": 3 } }, + { "type": "land", "args": {} } + ], + "summary": "起飞至约3米高,悬停3秒后降落" +} +``` + +### 7.3 返航 + +```json +{ + "is_flight_intent": true, + "version": 1, + "actions": [ + { "type": "return_home", "args": {} } + ], + "summary": "返航至 Home" +} +``` + +--- + +## 8. 演进与扩展 + +- **新增 `type`、新 `frame`、顶层字段**:须 **递增 `version`**(如 `2`)并附迁移说明。 +- **严禁**:在 `flight_intent` 内增加自由文本字段用于机动解释(仅 `summary` 可读)。 +- **调试**:可在外层 bundle(非 `flight_intent` 体内)附加 `schema: "cloud_voice.flight_intent@1"`,由工程约定。 + +--- + +## 9. JSON → PX4 责任边界(摘要) + +| JSON `type` | 机端典型职责(PX4 侧,非规范强制) | +|-------------|--------------------------------------| +| `return_home` | RTL / `MAV_CMD_NAV_RETURN_TO_LAUNCH` 等 | +| `takeoff` | TAKEOFF / `MAV_CMD_NAV_TAKEOFF`,高度来自 args 或参数 | +| `land` | LAND 模式 / `MAV_CMD_NAV_LAND` | +| `goto` | Offboard 轨迹、外部跟踪或 Mission 航点(**桥根据策略选一**) | +| `hover` / `hold` | LOITER / HOLD / 位置保持 setpoint | +| `wait` | 仅伴飞计时;**不发**模式切换 MAV 命令(除非实现为「保持当前模式下的阻塞」) | + +**不重样规定**:MAVLink messageId、发送频率、Offboard 心跳、EKF 就绪条件由 **companion + PX4 装机** 保证。 + +--- + +## 10. 伴飞桥(Bridge)设计要点 + +本节约定:**桥** = 运行在伴飞计算机上的进程(可与语音同源或独立),负责消费 `flight_intent`(或等价 JSON),**绝不**把原始 LLM 文本直接发给 PX4。 + +### 10.1 逻辑分层 + +1. **接入**:WebSocket 回调 → 解析 `flight_intent`;或订阅 ROS Topic `flight_intent/json`;或 TCP 接收与本文相同的 JSON。 +2. **校验**:至少 L1–L3;有 px4 上下文时做 L4。 +3. **执行器**:对 `actions` **单线程顺序**执行;内部每步调用 **翻译器**(见 §10.3)。 +4. **遥测与安全**:每步前置检查(模式、解锁、定位、电量、围栏);执行中 watchdog;可打断队列。 +5. **回执(建议)**:ROS 发布 `flight_intent/result` 或写日志/Socket:success / rejected / aborted + `trace_id` + 步号 + 原因码。 + +### 10.2 与语音客户端的关系(本仓库) + +- 语音侧可将 **`flight_intent` 映射** 为现有 `Command`(`command` + `params` + `sequence_id` + `timestamp`)经 **Socket** 发到桥;或由桥 **直接订阅云端结果**(二选一,避免双源)。 +- **`wait`**:若 Socket 协议暂无对应 `Command`,桥在本地对「已解析的 `actions` 列表」执行 `wait`,**不必**经 Socket 转发计时。 +- **扩展 `Command`**:若希望所有步骤可经 Socket 观测,可增加 `command: "noop"` + `params.duration` 仅作日志,但 **推荐** 桥本地处理 `wait`。 + +### 10.3 翻译器(`type` → 行为) + +实现为代码表 + 机型分支,示例: + +| `type` | 桥内典型步骤(抽象) | +|--------|----------------------| +| `takeoff` | 检查 arming 策略 → 发送起飞命令/切 TAKEOFF → 等待「达到 hover 可接受高度」或超时 | +| `land` | 切 LAND 或发 NAV_LAND → 监测直到 disarm 或超时 | +| `return_home` | 切 RTL | +| `hover`/`hold` | 切 AUTO.LOITER 或发位置保持 setpoint(Offboard 路径则发零速/当前位 setpoint) | +| `goto` | 按 `frame` 解算目标 → Offboard 轨迹或上传迷你 mission → 等待到达容差或超时 | +| `wait` | `sleep(seconds)` + 可中断环形检查遥测 | + +每步应定义 **超时** 与 **失败策略**(中止整段序列 / 仅跳过一步)。 + +### 10.4 安全与中断 + +- **急停 / 人机优先级**:本地硬件或 ROS `/emergency_hold` 等应能 **清空队列** 并进入安全模式。 +- **云断连**:不要求中断已在执行的序列(产品可配置「断连即 RTL」)。 +- **`wait` 期间**:持续判据;触发阈值则 **中止等待** 并执行安全动作。 + +--- + +## 11. ROS / MAVROS 实施参考 + +以下为方便对接 **ROS 2 + MAVROS**(或 `px4_ros_com`)的**参考映射**;实际包名、话题名、QoS 以你方 `mavros` 版本与 launch 为准。 + +### 11.1 常用接口类型 + +| 目的 | 常见 ROS 2 形态 | 说明 | +|------|------------------|------| +| 模式切换 | `mavros_msgs/srv/VehicleCmd` 或 SetMode 等价服务 | 切 `AUTO.TAKEOFF`, `AUTO.LAND`, `AUTO.LOITER`, `AUTO.RTL` 等 | +| 解锁/上锁 | `cmd/arming` 服务或 VehicleCommand | 桥策略决定是否自动 arm | +| Offboard 轨迹 | `trajectory_setpoint`、`offboard_control_mode`(PX4 官方 ROS 2 示例) | 用于 `goto` / `hover` 的 setpoint 路径 | +| 状态反馈 | `vehicle_status`、`local_position`、电池 topic | L4 与每步完成判定 | +| 长航指令 | `Mission`、`CMD` 接口 | 复杂航迹可选用 mission 上传 | + +### 11.2 JSON → ROS 责任划分建议 + +- **桥节点**订阅或接收 `flight_intent`,执行 §10.3,并调用 **MAVROS / px4_ros_com** 客户端。 +- **飞控仿真**:同一套 `flight_intent` 可在 SITL 上回放,便于 CI。 +- **单飞控单 writer**:同一时刻建议只有一个节点向 Offboard 端口写 setpoint,避免竞争。 + +### 11.3 与 PX4 模式的关系(概念) + +- **AUTO 模式族**(TAKEOFF / LAND / LOITER / RTL):适合 `takeoff`、`land`、`return_home`、部分 `hover`。 +- **Offboard**:适合连续 `goto`、精细悬停;桥需负责 **先切 Offboard 再发 setpoint**,并满足 PX4 Offboard 丢包监测。 +- 具体选 AUTO 还是 Offboard 由 **桥配置**(YAML)决定,**不写入** `flight_intent` JSON(保持云侧与机型解耦)。 + +--- + +## 12. 与当前仓库实现的对齐清单 + +| 项 | 建议 | +|----|------| +| Pydantic | `FlightIntentPayload` / `FlightIntentAction`:收紧 `type` Literal;`args` 按 §3 分类型或 discriminated union | +| 云端校验 | `_validate_flight_intent`:L2 白名单 + `goto.frame` + `wait.seconds` + `takeoff.relative_altitude_m` | +| LLM 提示词 | 仅允许 §3.7 中 `type` 与各 `args` 键;**必须**用 `wait` 表达明确停顿时长 | +| `main_app` | `land`/`hover` 已有 Socket 映射;`goto`/`return_home`/`takeoff`/`wait` 需在桥或 Socket 侧补全 | +| `Command` | 可扩展 `Literal` 与 `CommandParams`,或与桥约定「语音只发 Socket,复杂序列由桥执行」 | + +--- + +**文档版本**:2026-04-07(修订:增加 `wait`、`takeoff` 可选高度、`trace_id`、桥与 ROS 章节)。与 **`flight_intent.version === 1`** 对应。 diff --git a/docs/PROJECT_GUIDE.md b/docs/PROJECT_GUIDE.md new file mode 100644 index 0000000..7aa68fc --- /dev/null +++ b/docs/PROJECT_GUIDE.md @@ -0,0 +1,148 @@ +# voice_drone_assistant — 项目说明与配置指南 + +面向部署与二次开发:**目录结构**、**配置文件用法**、**启动与日常操作**、**与云端/飞控的关系**。**外场统一部署与双终端启动顺序**见 **`docs/DEPLOYMENT_AND_OPERATIONS.md`**;协议细节以 `docs/llmcon.md` 为准。 + +--- + +## 1. 项目做什么 + +- **麦克风** → 预处理(降噪/AGC)→ **VAD 切段** → **SenseVoice STT** → **唤醒词** +- **关键词起飞(offboard 演示,默认关闭)**:`system.yaml` → **`assistant.local_keyword_takeoff_enabled`** 或 **`ROCKET_LOCAL_KEYWORD_TAKEOFF=1`** 开启后,`keywords.yaml` 里 **`takeoff` 词表**(如「起飞演示」)→ 提示音 + offboard 脚本;飞控主路径推荐 **云端 `flight_intent` + ROS 伴飞桥** +- **其它语音**:本地 **Qwen + Kokoro**,或 **云端 WebSocket**(LLM + TTS 上云,见 `cloud_voice`) +- 可选通过 **TCP Socket** 下发结构化飞控命令(`VoiceCommandRecognizer` 路径;`TakeoffPrintRecognizer` 默认不在启动时连 Socket,飞控多为云端 JSON + 可选 `ROCKET_CLOUD_EXECUTE_FLIGHT`) + +--- + +## 2. 目录结构(仓库根 = `voice_drone_assistant/`) + +| 路径 | 说明 | +|------|------| +| `main.py` | 程序入口(会 `chdir` 到本目录并跑 `voice_drone.main_app`) | +| `with_system_alsa.sh` | 在 Conda/残缺 ALSA 环境下修正 `LD_LIBRARY_PATH`,建议始终包一层启动 | +| `requirements.txt` | Python 依赖(含 `websocket-client` 等) | +| **`voice_drone/main_app.py`** | 主流程:唤醒、问候/快路径、关麦、LLM/云端、TTS、offboard | +| **`voice_drone/core/`** | 音频采集、预处理、VAD、STT、TTS、Socket、云端 WS、唤醒、命令、文本预处理、配置加载 | +| **`voice_drone/flight_bridge/`** | 伴飞桥(ROS1+MAVROS):`flight_intent` → 飞控;说明见 **`docs/FLIGHT_BRIDGE_ROS1.md`** | +| **`voice_drone/config/`** | 各类 YAML,见下文「配置文件」 | +| **`voice_drone/logging_/`** | 日志与彩色输出 | +| **`voice_drone/tools/`** | `config_loader` 等工具 | +| **`docs/`** | `PROJECT_GUIDE.md`(本文)、`llmcon.md`(云端协议)、`clientguide.md`(联调与示例) | +| **`scripts/`** | `run_px4_offboard_one_terminal.sh`;**伴飞桥** `run_flight_bridge_with_mavros.sh`(含 MAVROS)、`run_flight_intent_bridge_ros1.sh`(仅桥);另有 `generate_wake_greeting_wav.py`、`bundle_for_device.sh` | +| **`assets/tts_cache/`** | 唤醒问候等预生成 WAV(可自动生成) | +| **`models/`** | STT/TTS/VAD ONNX 等(需自备或 bundle,见 `models/README.txt`) | + +--- + +## 3. 配置文件一览 + +配置由 `voice_drone/core/configuration.py` 在进程启动时读入;主文件为 **`voice_drone/config/system.yaml`**(路径相对 **`voice_drone_assistant` 根目录**)。 + +| 文件 | 作用 | +|------|------| +| **`system.yaml`** | **总控**:`audio`(采样、设备、AGC、降噪)、`vad`、`stt`、`tts`、`cloud_voice`、`socket_server`、`text_preprocessor`、`recognizer`(VAD 能量门槛、尾静音、问候/TTS 关麦等) | +| **`wake_word.yaml`** | 唤醒词主词、变体、模糊/部分匹配策略 | +| **`keywords.yaml`** | 命令关键词与同义词(供文本预处理映射到 `Command`) | +| **`command_.yaml`** | 各飞行动作默认 `distance/speed/duration`(与 `Command` 联动) | +| **`cloud_voice_px4_context.yaml`** | 云端 **`session.start.client` 扩展**:`vehicle_class`、`mav_type`、`default_setpoint_frame`、`extras` 等,供服务端 LLM 生成 PX4 相关指令;路径在 `system.yaml` → `cloud_voice.px4_context_file`,也可用环境变量 **`ROCKET_CLOUD_PX4_CONTEXT_FILE`** 覆盖 | + +修改 YAML 后需**重启** `main.py` 生效(`SYSTEM_CLOUD_VOICE_PX4_CONTEXT` 等在 import 时加载一次)。 + +--- + +## 4. `system.yaml` 常用区块(索引) + +- **`audio`**:采样率、`frame_size`、`input_device_index`(`null` 则枚举设备)、`prefer_stereo_capture`(ES8388 等)、`noise_reduce`、`agc*`、`agc_release_alpha` +- **`vad`**:Silero 用阈值、`end_frame` 等(能量 VAD 时部分由 `recognizer` 覆盖) +- **`stt`**:SenseVoice 模型路径、ORT 线程等 +- **`tts`**:Kokoro 目录、音色 `voice`、`speed`、`output_device`、`playback_*` +- **`cloud_voice`**:`enabled`、`server_url`、`auth_token`、`device_id`、`timeout`、`fallback_to_local`、`px4_context_file` +- **`socket_server`**:试飞控 TCP 地址、`reconnect_interval`、`max_retries`(`-1` 为断线持续重连直至成功) +- **`recognizer`**:`trailing_silence_seconds`、`vad_backend`(`energy`/`silero`)、`energy_vad_*`、`energy_vad_utt_peak_decay`、`energy_vad_end_peak_ratio`、`pre_speech_max_seconds`、`ack_pause_mic_for_playback`、应答 TTS 等 + +更细的参数含义以各 YAML 内注释为准。 + +--- + +## 5. 系统使用方式 + +### 5.1 推荐启动命令 + +在 **`voice_drone_assistant` 根目录**: + +```bash +bash with_system_alsa.sh python main.py +``` + +或使用模块方式: + +```bash +bash with_system_alsa.sh python -m voice_drone.main_app +``` + +录音设备:**首次**可交互选择;非交互时可设 `ROCKET_INPUT_DEVICE_INDEX` 或使用 `main.py --input-index N` / `--non-interactive`(详见 `main_app` 内 `argparse` 与文件头注释)。 + +### 5.2 典型工作流(默认 `TakeoffPrintRecognizer`) + +1. 说唤醒词(如「无人机」);若**同句带指令**,会**跳过问候与滴声**,直接关麦处理:命中 `keywords.yaml` 的 **takeoff** 则 offboard,否则走 LLM/云端。 +2. 若**只唤醒**,则问候(或缓存 WAV)+ 可选滴声 → 再说**一句**指令。 +3. 云端模式:指令以文本上云,TTS 多为服务端 PCM;本地模式:Qwen 推理 + Kokoro 播报。 + +### 5.3 云端语音(可选) + +- `system.yaml` 里 `cloud_voice.enabled: true`,或环境变量 **`ROCKET_CLOUD_VOICE=1`** +- **`ROCKET_CLOUD_WS_URL`**、`ROCKET_CLOUD_AUTH_TOKEN`、可选 **`ROCKET_CLOUD_DEVICE_ID`**(可覆盖 yaml) +- PX4 语境:见 `cloud_voice_px4_context.yaml` / **`ROCKET_CLOUD_PX4_CONTEXT_FILE`** +- 协议与消息类型: **`docs/llmcon.md`** +- 飞控 JSON 是否机端执行: **`ROCKET_CLOUD_EXECUTE_FLIGHT=1`**;走 ROS 伴飞桥时再设 **`ROCKET_FLIGHT_INTENT_ROS_BRIDGE=1`**(详见 **`docs/DEPLOYMENT_AND_OPERATIONS.md`**) + +### 5.4 本地大模型与 TTS + +- GGUF:`cache/` 默认路径或 **`ROCKET_LLM_GGUF`** +- 关闭对话:**`ROCKET_LLM_DISABLE=1`** +- 流式输出:**`ROCKET_LLM_STREAM=0`** 可改为整段生成后再播(调试) +- 详细列表见 **`voice_drone/main_app.py` 文件头部注释**。 + +### 5.5 其它实用环境变量(摘录) + +| 变量 | 说明 | +|------|------| +| `ROCKET_ENERGY_VAD` | `1` 时使用能量 VAD(板载麦常见) | +| `ROCKET_PRINT_STT` / `ROCKET_PRINT_VAD` | 终端打印 STT/VAD 诊断 | +| `ROCKET_CLOUD_TURN_RETRIES` | 云端 WS 单轮失败重连重试次数(默认 3) | +| `ROCKET_PRINT_LLM_STREAM` | 云端流式字 `llm.text_delta` 打印到终端 | +| `ROCKET_WAKE_PROMPT_BEEP` | `0` 关闭问候后滴声 | +| `ROCKET_MIC_RESTART_SETTLE_MS` | 播完 TTS 恢复麦克风后的等待毫秒 | + +--- + +## 6. 相关文档与代码入口 + +| 文档 | 内容 | +|------|------| +| **`README.md`** | 简版说明、bundle 到香橙派、与原仓库关系 | +| **`docs/DEPLOYMENT_AND_OPERATIONS.md`** | **部署与外场启动**:拓扑、`ROS_MASTER_URI`、双终端启动顺序、环境变量速查、联调 | +| **`docs/PROJECT_GUIDE.md`** | 本文:目录、配置、使用方式总览 | +| **`docs/FLIGHT_BRIDGE_ROS1.md`** | ROS1 伴飞桥、MAVROS、`/input`、`rostopic pub` | +| **`docs/llmcon.md`** | 云端 WebSocket 消息类型与客户端约定 | +| **`docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md`** | 云端 **`dialog_result` v1**(`protocol=cloud_voice_dialog_v1`,闲聊/飞控分流 + `confirm`) | + +| 能力 | 主要代码 | +|------|-----------| +| 音频采集/AGC | `voice_drone/core/audio.py` | +| 能量/Silero VAD | `voice_drone/core/recognizer.py`、`voice_drone/core/vad.py` | +| STT | `voice_drone/core/stt.py` | +| 本地 LLM 提示词 | `voice_drone/core/qwen_intent_chat.py`(`FLIGHT_INTENT_CHAT_SYSTEM`) | +| 云端会话 | `voice_drone/core/cloud_voice_client.py` | +| 主流程 | `voice_drone/main_app.py` | +| 配置聚合 | `voice_drone/core/configuration.py` | + +--- + +## 7. 版本与维护 + +- 配置项会随功能迭代增加;若与运行日志或 `llmcon` 不一致,以**当前仓库 YAML + 代码**为准。 +- 新增仅与云端相关的字段时,请同时通知服务端解析 **`session.start.client`**(含 PX4 扩展块)。 + +--- + +*文档版本:与仓库同步维护;更新日期见 Git 提交。* diff --git a/main.py b/main.py new file mode 100644 index 0000000..0b9327f --- /dev/null +++ b/main.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +"""入口:请在工程根目录 voice_drone_assistant 下运行。 + + bash with_system_alsa.sh python main.py + +或使用包方式:python -m voice_drone.main_app(需先 cd 到本目录)。""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) +try: + os.chdir(ROOT) +except OSError: + pass + +from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio + +fix_ld_path_for_portaudio() + +if __name__ == "__main__": + from voice_drone.main_app import main + + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..99b5925 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,25 @@ +# 语音助手最小依赖(从大工程 requirements 精简,若 import 报错再补包) +numpy>=1.21.0 +scipy>=1.10.0 +onnx>=1.14.0 +onnxruntime>=1.16.0 +librosa>=0.10.0 +soundfile>=0.12.0 +pyaudio>=0.2.14 +noisereduce>=2.0.0 +sounddevice>=0.4.6 +pyyaml>=5.4.0 +jieba>=0.42.1 +pypinyin>=0.50.0 +opencc-python-reimplemented>=0.1.7 +cn2an>=0.5.0 +misaki[zh]>=0.8.2 +pydantic>=2.4.0 +# 大模型(Qwen GGUF) +llama-cpp-python>=0.2.0 +# 云端语音 WebSocket(websocket-client) +websocket-client>=1.6.0 + +# ROS1:publish_flight_intent_ros_once / flight_bridge 需 rospy(std_msgs)。Noetic 一般 apt 安装,勿用 pip 覆盖: +# sudo apt install ros-noetic-ros-base # 或桌面版,须含 rospy +# 使用 conda/venv 时请在启动前 source /opt/ros/noetic/setup.bash,且保持 PYTHONPATH 含上述 dist-packages(主程序 ROS 桥子进程已把工程根 prepend 到 $PYTHONPATH)。 diff --git a/scripts/bundle_for_device.sh b/scripts/bundle_for_device.sh new file mode 100644 index 0000000..15c7c54 --- /dev/null +++ b/scripts/bundle_for_device.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# 在本机把「原仓库」里的 models(及可选 GGUF 缓存)拷进本子工程,便于整包 scp/rsync 到另一台香橙派。 +# +# 用法(在 voice_drone_assistant 根目录): +# bash scripts/bundle_for_device.sh +# bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio +# +# 默认上一级目录为原仓库(即 voice_drone_assistant 仍放在 rocket_drone_audio 子目录时的布局)。 +set -euo pipefail +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +SRC="${1:-$ROOT/..}" +M="$ROOT/models" + +if [[ ! -d "$SRC/src/models" ]]; then + echo "未找到 $SRC/src/models ,请传入正确的原仓库根路径:" >&2 + echo " bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio" >&2 + exit 1 +fi + +mkdir -p "$M" +echo "源目录: $SRC" +echo "目标 models: $M" + +copy_dir() { + local name="$1" + if [[ -d "$SRC/src/models/$name" ]]; then + echo " 复制 models/$name ..." + rm -rf "$M/$name" + cp -a "$SRC/src/models/$name" "$M/" + else + echo " 跳过(不存在): src/models/$name" + fi +} + +copy_dir "SenseVoiceSmall" +copy_dir "Kokoro-82M-v1.1-zh-ONNX" +copy_dir "SileroVad" + +# 可选:大模型 GGUF(体积大,按需) +if [[ -d "$SRC/cache/qwen25-1.5b-gguf" ]]; then + read -r -p "是否复制 Qwen GGUF 到 cache/?(可能数百 MB~数 GB)[y/N] " ans + if [[ "${ans:-}" =~ ^[yY] ]]; then + mkdir -p "$ROOT/cache/qwen25-1.5b-gguf" + cp -a "$SRC/cache/qwen25-1.5b-gguf/"* "$ROOT/cache/qwen25-1.5b-gguf/" 2>/dev/null || true + echo " 已复制 cache/qwen25-1.5b-gguf" + fi +else + echo " 未找到 $SRC/cache/qwen25-1.5b-gguf ,大模型请在新机器上再下载或单独拷贝" +fi + +echo +echo "完成。可将整个目录打包拷贝到另一台设备:" +echo " $ROOT" +echo "新设备上请执行: pip install -r requirements.txt(或使用相同 conda 环境)" diff --git a/scripts/generate_wake_greeting_wav.py b/scripts/generate_wake_greeting_wav.py new file mode 100644 index 0000000..64bca7b --- /dev/null +++ b/scripts/generate_wake_greeting_wav.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +"""生成 assets/tts_cache/wake_greeting.wav。须与 voice_drone.main_app._WAKE_GREETING 一致。 + +用法(在 voice_drone_assistant 根目录): + python scripts/generate_wake_greeting_wav.py +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +_ROOT = Path(__file__).resolve().parents[1] +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) +try: + os.chdir(_ROOT) +except OSError: + pass + +from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio + +fix_ld_path_for_portaudio() + +_WAKE_TEXT = "你好,我在呢" + + +def main() -> None: + out_dir = _ROOT / "assets" / "tts_cache" + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / "wake_greeting.wav" + + from voice_drone.core.tts import KokoroOnnxTTS + + tts = KokoroOnnxTTS() + tts.synthesize_to_file(_WAKE_TEXT, str(out_path)) + print(f"已写入: {out_path}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_flight_bridge_with_mavros.sh b/scripts/run_flight_bridge_with_mavros.sh new file mode 100644 index 0000000..c20a1c6 --- /dev/null +++ b/scripts/run_flight_bridge_with_mavros.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +# 一键:roscore(若还没有)→ MAVROS(串口连 PX4)→ flight_intent 伴飞桥(前台,直到 Ctrl+C) +# +# 适合「不想自己敲 roslaunch」时直接跑;用法与 run_px4_offboard_one_terminal.sh 前两参一致。 +# +# 在 voice_drone_assistant 根目录: +# bash scripts/run_flight_bridge_with_mavros.sh +# bash scripts/run_flight_bridge_with_mavros.sh /dev/ttyACM0 921600 +# +# 飞控连上后,另开终端发 JSON(std_msgs/String 必须用 YAML 字段 data,见 docs/FLIGHT_BRIDGE_ROS1.md): +# source /opt/ros/noetic/setup.bash +# rostopic pub -1 /input std_msgs/String \ +# "data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"降\"}'" +# +# 环境变量(可选): +# ROS_MASTER_URI 默认 http://127.0.0.1:11311 +# ROS_HOSTNAME 默认 127.0.0.1 +# BRIDGE_PYTHON 若不设则 python3(与 mavros 同机即可,不必 conda) +# OFFBOARD_PYTHON 未设 BRIDGE_PYTHON 时也试 yanshi 环境(与 offboard 脚本习惯一致) +# +# Ctrl+C:结束伴飞桥,并停止本脚本拉起的 MAVROS;若 roscore 由本脚本启动则一并结束。 + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +cd "$ROOT" + +if [[ ! -f /opt/ros/noetic/setup.bash ]]; then + echo "未找到 /opt/ros/noetic/setup.bash(伴飞桥当前仅支持 ROS1 Noetic)" >&2 + exit 2 +fi + +# shellcheck source=/dev/null +source /opt/ros/noetic/setup.bash +export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}" +export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}" + +DEV="${1:-/dev/ttyACM0}" +BAUD="${2:-921600}" +FCU_URL="${DEV}:${BAUD}" + +ROSCORE_PID="" +MAVROS_PID="" +WE_STARTED_ROSCORE=0 + +master_ok() { + timeout 3 rosnode list &>/dev/null +} + +stop_mavros() { + pkill -f '/opt/ros/noetic/lib/mavros/mavros_node' 2>/dev/null || true +} + +kill_children() { + if [[ -n "${MAVROS_PID:-}" ]] && kill -0 "$MAVROS_PID" 2>/dev/null; then + kill "$MAVROS_PID" 2>/dev/null || true + wait "$MAVROS_PID" 2>/dev/null || true + fi + if [[ "${WE_STARTED_ROSCORE}" -eq 1 ]] && [[ -n "${ROSCORE_PID:-}" ]] && kill -0 "$ROSCORE_PID" 2>/dev/null; then + kill "$ROSCORE_PID" 2>/dev/null || true + wait "$ROSCORE_PID" 2>/dev/null || true + fi +} + +trap 'kill_children; exit 130' INT +trap 'kill_children; exit 143' TERM + +resolve_python() { + if [[ -n "${BRIDGE_PYTHON:-}" ]] && [[ -x "$BRIDGE_PYTHON" ]]; then + echo "$BRIDGE_PYTHON" + return + fi + if [[ -n "${OFFBOARD_PYTHON:-}" ]] && [[ -x "$OFFBOARD_PYTHON" ]]; then + echo "$OFFBOARD_PYTHON" + return + fi + for cand in \ + "${HOME}/miniconda3/envs/yanshi/bin/python" \ + "${HOME}/anaconda3/envs/yanshi/bin/python" \ + "${HOME}/mambaforge/envs/yanshi/bin/python"; do + if [[ -x "$cand" ]]; then + echo "$cand" + return + fi + done + command -v python3 +} + +echo "===== flight_intent 伴飞桥(含 MAVROS)fcu_url=${FCU_URL} =====" + +if ! master_ok; then + echo "[1/3] 启动 roscore …" + roscore > /tmp/roscore_flight_bridge.log 2>&1 & + ROSCORE_PID=$! + WE_STARTED_ROSCORE=1 + for _ in $(seq 1 50); do + master_ok && break + sleep 0.2 + done + if ! master_ok; then + echo "roscore 未起来: tail -40 /tmp/roscore_flight_bridge.log" >&2 + kill_children + exit 1 + fi + echo "roscore 已就绪 (pid=$ROSCORE_PID)" +else + echo "[1/3] 已有 ROS master,跳过 roscore" +fi + +echo "[2/3] 启动 MAVROS …" +stop_mavros +sleep 1 +roslaunch mavros px4.launch fcu_url:="$FCU_URL" > /tmp/mavros_flight_bridge.log 2>&1 & +MAVROS_PID=$! + +echo "等待飞控 connected=true(超时 60s)…" +connected=0 +for _ in $(seq 1 60); do + if timeout 4 rostopic echo /mavros/state -n 1 2>/dev/null | grep -qE 'connected: [Tt]rue'; then + connected=1 + break + fi + sleep 1 +done +if [[ "$connected" -ne 1 ]]; then + echo "仍未连上飞控。检查串口/波特率: tail -60 /tmp/mavros_flight_bridge.log" >&2 + echo "可重试: bash $0 ${DEV} 57600" >&2 + kill_children + exit 1 +fi +echo "MAVROS 已连接飞控" + +PY="$(resolve_python)" +export PYTHONPATH="${PYTHONPATH:-}:${ROOT}" + +echo "[3/3] 启动伴飞桥(Python: $PY)…" +echo " Ctrl+C 退出并清理本脚本启动的 MAVROS$([[ "$WE_STARTED_ROSCORE" -eq 1 ]] && echo +roscore)。" +echo " 发意图: 另开终端 source noetic 后 rostopic pub ... 见脚本头注释。" +set +e +"$PY" -m voice_drone.flight_bridge.ros1_node +BRIDGE_EXIT=$? +set -e + +kill_children +exit "$BRIDGE_EXIT" diff --git a/scripts/run_flight_intent_bridge_ros1.sh b/scripts/run_flight_intent_bridge_ros1.sh new file mode 100644 index 0000000..89d442f --- /dev/null +++ b/scripts/run_flight_intent_bridge_ros1.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# 在已有 ROS master + MAVROS(已连飞控)的前提下启动 flight_intent 伴飞桥节点。 +# 若希望一键连 MAVROS:用同目录 run_flight_bridge_with_mavros.sh +# +# 用法(在 voice_drone_assistant 根目录): +# bash scripts/run_flight_intent_bridge_ros1.sh +# bash scripts/run_flight_intent_bridge_ros1.sh my_bridge # 节点名前缀(anonymous 仍会加后缀) +# +# 另开终端发意图(示例降落,默认订阅全局 /input): +# rostopic pub -1 /input std_msgs/String \ +# "{data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"降\"}'}" +# +# 若改了 ~input_topic:rosnode info <节点名> 查看订阅话题 +# +# 环境变量(与 run_flight_bridge_with_mavros.sh 一致,未设置时给默认值): +# ROS_MASTER_URI 默认 http://127.0.0.1:11311 +# ROS_HOSTNAME 默认 127.0.0.1 +# 注意:这些只在「本脚本进程」里生效;另开终端调试 rostopic/rosservice 时须自行 source noetic 并 export 相同 URI,或与跑 roscore 的机器一致。 + +set -euo pipefail +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT" + +if [[ ! -f /opt/ros/noetic/setup.bash ]]; then + echo "未找到 /opt/ros/noetic/setup.bash(当前桥仅支持 ROS1 Noetic)" >&2 + exit 2 +fi + +# shellcheck source=/dev/null +source /opt/ros/noetic/setup.bash + +export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}" +export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}" + +export PYTHONPATH="${PYTHONPATH}:${ROOT}" + +exec python3 -m voice_drone.flight_bridge.ros1_node "$@" diff --git a/scripts/run_px4_offboard_one_terminal.sh b/scripts/run_px4_offboard_one_terminal.sh new file mode 100644 index 0000000..611ef06 --- /dev/null +++ b/scripts/run_px4_offboard_one_terminal.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash +# 单终端一键:按需 roscore → MAVROS 串口(MAVLink) → px4_ctrl_offboard_demo.py +# +# 用法(在仓库根目录): +# bash scripts/run_px4_offboard_one_terminal.sh +# bash scripts/run_px4_offboard_one_terminal.sh /dev/ttyACM0 921600 +# bash scripts/run_px4_offboard_one_terminal.sh /dev/ttyACM0 921600 120 +# 第三参 DEMO_MAX_SEC:仅限制「px4_ctrl_offboard_demo.py」运行时长(秒),到时发 SIGTERM 结束 +# demo,随后脚本清理 MAVROS / 可选 roscore 并退出;0 或不写表示不限制。 +# 也可用环境变量 OFFBOARD_DEMO_MAX_SEC(第三参优先)。 +# +# 环境变量(可选): +# ROS_MASTER_URI 默认 http://127.0.0.1:11311 +# ROS_HOSTNAME 默认 127.0.0.1 +# OFFBOARD_PYTHON 默认优先用 conda env「yanshi」里的 python,否则 python3 +# OFFBOARD_DEMO_MAX_SEC 同第三参;整条脚本从头限时请用: timeout 600 bash 本脚本 ... +# ROCKET_OFFBOARD_DEMO_PY 显式指定 px4_ctrl_offboard_demo.py 路径(缺省时先 $ROOT/src/ 再 $ROOT/../src/) +# +# 退出:Ctrl+C 会结束 demo,并停止本脚本拉起的 MAVROS;若 roscore 由本脚本启动,会一并结束。 + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +cd "$ROOT" + +resolve_demo_py() { + if [[ -n "${ROCKET_OFFBOARD_DEMO_PY:-}" ]] && [[ -f "${ROCKET_OFFBOARD_DEMO_PY}" ]]; then + echo "$(cd "$(dirname "${ROCKET_OFFBOARD_DEMO_PY}")" && pwd)/$(basename "${ROCKET_OFFBOARD_DEMO_PY}")" + return + fi + local c1="$ROOT/src/px4_ctrl_offboard_demo.py" + if [[ -f "$c1" ]]; then + echo "$c1" + return + fi + local c2="$ROOT/../src/px4_ctrl_offboard_demo.py" + if [[ -f "$c2" ]]; then + echo "$(cd "$(dirname "$c2")" && pwd)/px4_ctrl_offboard_demo.py" + return + fi + echo "错误: 找不到 px4_ctrl_offboard_demo.py。已尝试:" >&2 + echo " ROCKET_OFFBOARD_DEMO_PY=${ROCKET_OFFBOARD_DEMO_PY:-(未设置)}" >&2 + echo " $c1" >&2 + echo " $c2" >&2 + exit 3 +} + +DEMO_PY="$(resolve_demo_py)" + +source /opt/ros/noetic/setup.bash +export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}" +export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}" + +DEV="${1:-/dev/ttyACM0}" +BAUD="${2:-921600}" +DEMO_MAX_SEC="${3:-${OFFBOARD_DEMO_MAX_SEC:-0}}" +FCU_URL="${DEV}:${BAUD}" + +if ! [[ "$DEMO_MAX_SEC" =~ ^[0-9]+$ ]]; then + echo "错误: 第三参(demo 最长运行秒数)须为非负整数,当前=${DEMO_MAX_SEC}" + exit 2 +fi + +ROSCORE_PID="" +MAVROS_PID="" +WE_STARTED_ROSCORE=0 + +master_ok() { + timeout 3 rosnode list &>/dev/null +} + +stop_mavros() { + pkill -f '/opt/ros/noetic/lib/mavros/mavros_node' 2>/dev/null || true +} + +kill_children() { + if [[ -n "${MAVROS_PID:-}" ]] && kill -0 "$MAVROS_PID" 2>/dev/null; then + kill "$MAVROS_PID" 2>/dev/null || true + wait "$MAVROS_PID" 2>/dev/null || true + fi + if [[ "${WE_STARTED_ROSCORE}" -eq 1 ]] && [[ -n "${ROSCORE_PID:-}" ]] && kill -0 "$ROSCORE_PID" 2>/dev/null; then + kill "$ROSCORE_PID" 2>/dev/null || true + wait "$ROSCORE_PID" 2>/dev/null || true + fi +} + +trap 'kill_children; exit 130' INT +trap 'kill_children; exit 143' TERM + +resolve_python() { + if [[ -n "${OFFBOARD_PYTHON:-}" ]] && [[ -x "$OFFBOARD_PYTHON" ]]; then + echo "$OFFBOARD_PYTHON" + return + fi + for cand in \ + "${HOME}/miniconda3/envs/yanshi/bin/python" \ + "${HOME}/anaconda3/envs/yanshi/bin/python" \ + "${HOME}/mambaforge/envs/yanshi/bin/python"; do + if [[ -x "$cand" ]]; then + echo "$cand" + return + fi + done + command -v python3 +} + +if ! master_ok; then + echo "[1/3] 启动 roscore ..." + roscore > /tmp/roscore_offboard_one_terminal.log 2>&1 & + ROSCORE_PID=$! + WE_STARTED_ROSCORE=1 + for _ in $(seq 1 50); do + master_ok && break + sleep 0.2 + done + if ! master_ok; then + echo "roscore 未起来,请查看: tail -40 /tmp/roscore_offboard_one_terminal.log" + kill_children + exit 1 + fi + echo "roscore 已就绪 (pid=$ROSCORE_PID)" +else + echo "[1/3] 已有 ROS master,跳过 roscore" +fi + +echo "[2/3] 启动 MAVROS fcu_url=${FCU_URL}" +stop_mavros +sleep 1 +roslaunch mavros px4.launch fcu_url:="$FCU_URL" > /tmp/mavros_offboard_one_terminal.log 2>&1 & +MAVROS_PID=$! + +echo "等待飞控连接 (connected=true),超时 60s ..." +connected=0 +for _ in $(seq 1 60); do + if timeout 4 rostopic echo /mavros/state -n 1 2>/dev/null | grep -qE 'connected: [Tt]rue'; then + connected=1 + break + fi + sleep 1 +done +if [[ "$connected" -ne 1 ]]; then + echo "仍未连上飞控。请检查串口与波特率,日志: tail -60 /tmp/mavros_offboard_one_terminal.log" + echo "可重试: bash $0 ${DEV} 57600" + kill_children + exit 1 +fi +echo "MAVROS 已连接飞控" + +PY="$(resolve_python)" +echo "[3/3] 使用 Python: $PY" +if [[ "$DEMO_MAX_SEC" -gt 0 ]]; then + echo "运行 $DEMO_PY,${DEMO_MAX_SEC}s 后自动结束 demo 并退出本脚本(随后清理子进程)" +else + echo "运行 $DEMO_PY (Ctrl+C 结束并将停止本脚本启动的 MAVROS)" +fi +set +e +if [[ "$DEMO_MAX_SEC" -gt 0 ]]; then + timeout --signal=TERM --kill-after=8 "$DEMO_MAX_SEC" \ + "$PY" "$DEMO_PY" + DEMO_EXIT=$? + if [[ "$DEMO_EXIT" -eq 124 ]]; then + echo "[timeout] 已到 ${DEMO_MAX_SEC}s,已终止 px4_ctrl_offboard_demo.py(退出码 124)" + fi +else + "$PY" "$DEMO_PY" + DEMO_EXIT=$? +fi +set -e + +kill_children +exit "$DEMO_EXIT" diff --git a/tests/test_cloud_dialog_v1.py b/tests/test_cloud_dialog_v1.py new file mode 100644 index 0000000..2fa0198 --- /dev/null +++ b/tests/test_cloud_dialog_v1.py @@ -0,0 +1,44 @@ +"""cloud_voice_dialog_v1 短语与 confirm 解析。""" +from voice_drone.core.cloud_dialog_v1 import ( + match_phrase_list, + normalize_phrase_text, + parse_confirm_dict, +) + + +def test_normalize(): + assert normalize_phrase_text(" a b ") == "a b" + + +def test_match_phrases(): + assert match_phrase_list(normalize_phrase_text("确认"), ["确认"]) + assert match_phrase_list(normalize_phrase_text("取消"), ["取消"]) + assert match_phrase_list(normalize_phrase_text("好的确认"), ["确认"]) + # 复述长提示:不应仅因子串「取消」命中 + long_prompt = normalize_phrase_text("请回复确认或取消") + assert match_phrase_list(long_prompt, ["取消"]) is False + assert match_phrase_list(long_prompt, ["确认"]) is False + + +def test_parse_confirm_ok(): + d = parse_confirm_dict( + { + "required": True, + "timeout_sec": 10, + "confirm_phrases": ["确认"], + "cancel_phrases": ["取消"], + "pending_id": "p1", + } + ) + assert d is not None + assert d["required"] is True + assert d["timeout_sec"] == 10.0 + assert "确认" in d["confirm_phrases"] + + +def test_parse_confirm_reject_bad_required(): + assert parse_confirm_dict({"required": "yes"}) is None + + +def test_cancel_priority_concept(): + assert match_phrase_list(normalize_phrase_text("取消"), ["取消"]) diff --git a/tests/test_flight_intent.py b/tests/test_flight_intent.py new file mode 100644 index 0000000..dfc9b7d --- /dev/null +++ b/tests/test_flight_intent.py @@ -0,0 +1,160 @@ +"""flight_intent v1 校验与 goto→Command 映射。""" + +from __future__ import annotations + +import pytest + +from voice_drone.core.flight_intent import ( + goto_action_to_command, + parse_flight_intent_dict, +) + + +def _command_stack_available() -> bool: + try: + import yaml # noqa: F401 + + from voice_drone.core.command import Command # noqa: F401 + + return True + except Exception: + return False + + +needs_command_stack = pytest.mark.skipif( + not _command_stack_available(), + reason="需要 pyyaml 与工程配置以加载 Command", +) + + +def test_parse_minimal_ok(): + v, err = parse_flight_intent_dict( + { + "is_flight_intent": True, + "version": 1, + "actions": [{"type": "land", "args": {}}], + "summary": "降落", + } + ) + assert err == [] + assert v is not None + assert v.actions[0].type == "land" + + +def test_hover_duration_cloud_legacy_expands_to_wait(): + """云端误将停顿时长写在 hover.args.duration 时,客户端规范化为 hover + wait。""" + v, err = parse_flight_intent_dict( + { + "is_flight_intent": True, + "version": 1, + "actions": [ + {"type": "takeoff", "args": {}}, + {"type": "hover", "args": {"duration": 3}}, + {"type": "land", "args": {}}, + ], + "summary": "test", + } + ) + assert err == [] + assert v is not None + assert len(v.actions) == 4 + assert v.actions[0].type == "takeoff" + assert v.actions[1].type == "hover" + assert v.actions[2].type == "wait" + assert float(v.actions[2].args.seconds) == 3.0 + assert v.actions[3].type == "land" + + +def test_wait_after_hover_ok(): + v, err = parse_flight_intent_dict( + { + "is_flight_intent": True, + "version": 1, + "actions": [ + {"type": "takeoff", "args": {}}, + {"type": "hover", "args": {}}, + {"type": "wait", "args": {"seconds": 2.5}}, + {"type": "land", "args": {}}, + ], + "summary": "test", + } + ) + assert err == [] + assert v is not None + assert len(v.actions) == 4 + assert v.actions[2].type == "wait" + assert v.trace_id is None + + +def test_first_wait_rejected(): + v, err = parse_flight_intent_dict( + { + "is_flight_intent": True, + "version": 1, + "actions": [ + {"type": "wait", "args": {"seconds": 1}}, + {"type": "land", "args": {}}, + ], + "summary": "x", + } + ) + assert v is None + assert err + + +def test_extra_top_key_rejected(): + v, err = parse_flight_intent_dict( + { + "is_flight_intent": True, + "version": 1, + "actions": [{"type": "land", "args": {}}], + "summary": "x", + "foo": 1, + } + ) + assert v is None + assert err + + +@needs_command_stack +def test_goto_body_forward(): + v, err = parse_flight_intent_dict( + { + "is_flight_intent": True, + "version": 1, + "actions": [ + { + "type": "goto", + "args": {"frame": "body_ned", "x": 3}, + } + ], + "summary": "s", + } + ) + assert v is not None and not err + cmd, reason = goto_action_to_command(v.actions[0], sequence_id=7) + assert reason is None + assert cmd is not None + assert cmd.command == "forward" + assert cmd.sequence_id == 7 + + +@needs_command_stack +def test_goto_multi_axis_no_command(): + v, err = parse_flight_intent_dict( + { + "is_flight_intent": True, + "version": 1, + "actions": [ + { + "type": "goto", + "args": {"frame": "body_ned", "x": 1, "y": 1}, + } + ], + "summary": "s", + } + ) + assert v is not None and not err + cmd, reason = goto_action_to_command(v.actions[0], 1) + assert cmd is None + assert reason and "multi-axis" in reason diff --git a/voice_drone/__init__.py b/voice_drone/__init__.py new file mode 100644 index 0000000..a1e3446 --- /dev/null +++ b/voice_drone/__init__.py @@ -0,0 +1,3 @@ +"""语音无人机助手:采集 → VAD → STT → 唤醒 → LLM/起飞 → Kokoro 播报。""" + +__version__ = "0.1.0" diff --git a/voice_drone/config/cloud_voice_px4_context.yaml b/voice_drone/config/cloud_voice_px4_context.yaml new file mode 100644 index 0000000..bd56c4a --- /dev/null +++ b/voice_drone/config/cloud_voice_px4_context.yaml @@ -0,0 +1,25 @@ +# 云端 session.start → client 扩展字段(PX4 / MAVLink 语境) +# 与服务端约定:原样进入 LLM;vehicle_class 与 mav_type 应一致,冲突时以 mav_type 为准。 +# +# 修改后重启 main.py;可用环境变量 ROCKET_CLOUD_PX4_CONTEXT_FILE 覆盖本文件路径。 + +# 机体类别(自由文本,与 mav_type 语义对齐即可,例如 multicopter / fixed_wing) +vehicle_class: multicopter + +# MAV_TYPE 枚举整数值,参见 https://mavlink.io/en/messages/common.html#MAV_TYPE +mav_type: 2 + +# 口语「前/右/上」未说明坐标系时,云端生成 goto 的默认系: +# local_ned — 北东地(与常见 Offboard 位置设定一致) +# body_ned / body — 机体系前右下(仅当机端按体轴解析相对位移时填写) +default_setpoint_frame: local_ned + +# 以下为可选运行时状态(可由 MAVROS / ROS2 写入同文件,或由机载进程覆写后再连云端) +# home_position_valid: true +# offboard_capable: true +# current_nav_state: "OFFBOARD" + +# 任意短键名 JSON,进入 LLM;适合「电池低」「室内无 GPS」「仅限 mission」等 +extras: + platform: companion_orangepi + # indoor_no_gps: true diff --git a/voice_drone/config/command_.yaml b/voice_drone/config/command_.yaml new file mode 100644 index 0000000..fb2e841 --- /dev/null +++ b/voice_drone/config/command_.yaml @@ -0,0 +1,59 @@ +# 命令配置文件 用于命令生成和填充默认值 +control_params: + # 起飞默认参数 + takeoff: + distance: 0.5 + speed: 0.5 + duration: 1 + # 降落默认参数 + land: + distance: 0 + speed: 0 + duration: 0 + # 跟随默认参数 + follow: + distance: 0.5 + speed: 0.5 + duration: 2 + # 向前默认参数 + forward: + distance: 0.5 + speed: 0.5 + duration: 1 + # 向后默认参数 + backward: + distance: 0.5 + speed: 0.5 + duration: 1 + # 向左默认参数 + left: + distance: 0.5 + speed: 0.5 + duration: 1 + # 向右默认参数 + right: + distance: 0.5 + speed: 0.5 + duration: 1 + # 向上默认参数 + up: + distance: 0.5 + speed: 0.5 + duration: 1 + # 向下默认参数 + down: + distance: 0.5 + speed: 0.5 + duration: 1 + # 悬停默认参数 + hover: + distance: 0 + speed: 0 + duration: 5 + # 返航(Socket 协议与 land/hover 同类占位参数) + return_home: + distance: 0 + speed: 0 + duration: 0 + + diff --git a/voice_drone/config/keywords.yaml b/voice_drone/config/keywords.yaml new file mode 100644 index 0000000..7e0b61d --- /dev/null +++ b/voice_drone/config/keywords.yaml @@ -0,0 +1,71 @@ +keywords: + # takeoff 仅用于「一键 offboard 演示」唤醒路径;用「起飞演示」避免句子里单独出现「起飞」误触(如「起飞,悬停再降落」) + takeoff: + - "起飞演示" + - "演示起飞" + + land: + - "立刻降落" + - "紧急降落" + - "降落" + - "落地" + - "着陆" + + follow: + - "跟随" + - "跟着我" + - "跟我飞" + - "跟随模式" + + hover: + - "马上停下" + - "立刻停下" + - "悬停" + - "停下" + - "停止" + - "停" + + forward: + - "向前飞" + - "往前飞" + - "向前" + - "往前" + - "前面飞" + - "前进" + + backward: + - "向后飞" + - "往后飞" + - "向后" + - "往后" + - "后退" + + left: + - "向左飞" + - "往左飞" + - "向左" + - "往左" + - "左移" + + right: + - "向右飞" + - "往右飞" + - "向右" + - "往右" + - "右移" + + up: + - "向上飞" + - "往上飞" + - "向上" + - "往上" + - "上升" + - "升高" + + down: + - "向下飞" + - "往下飞" + - "向下" + - "往下" + - "下降" + - "降低" \ No newline at end of file diff --git a/voice_drone/config/system.yaml b/voice_drone/config/system.yaml new file mode 100644 index 0000000..a72719f --- /dev/null +++ b/voice_drone/config/system.yaml @@ -0,0 +1,246 @@ +# ***********音频采集配置 ************* +audio: + sample_rate: 16000 + channels: 1 + sample_width: 2 + frame_size: 1024 + audio_format: "wav" + # 麦克风:仅使用 input_device_index(PyAudio 整数)。null = 自动尝试默认输入 + 所有 in>0 设备。 + # 运行 python src/rocket_drone_audio.py 会默认打印 arecord -l、PyAudio 列表与 hw 映射并交互选择。 + # 自动化可加 --non-interactive 并在此写入整数索引;或传 --input-index / ROCKET_INPUT_DEVICE_INDEX。 + input_device_index: null + # 以下字段已不再参与选麦逻辑(可删或保留作备忘) + input_strict_selection: false + input_hw_card_device: null + input_device_name_match: null + # 设备报告双声道、但 channels 为 1 时,用立体声打开再下混为 mono(Orange Pi ES8388 等需开启) + prefer_stereo_capture: true + # ES8388 常需 48k 才能打开;打开后会在采集里重采样回 sample_rate + audio_open_try_rates: [16000, 48000, 44100] + # conda 下 PyAudio 常链到残缺 ALSA 插件:请在 shell 用 bash with_system_alsa.sh python … 启动 + # ES8388 大声说话仍只有 RMS≈30:多为 ALSA「Left/Right Channel」采集=0%,先执行 scripts/es8388_capture_up.sh 再 sudo alsactl store + + # 高性能优化配置 + high_performance_mode: true # 启用高性能模式(多线程+异步处理) + use_callback_mode: true # 使用回调模式(非阻塞) + buffer_queue_size: 10 # 音频缓冲队列大小 + processing_threads: 2 # 处理线程数(采集和处理并行) + batch_processing: true # 启用批处理优化 + + # 降噪配置 + noise_reduce: true + noise_reduction_method: "lightweight" # "lightweight" (轻量级) 或 "noisereduce" (完整版) + noise_sample_duration_ms: 500 # 噪声样本收集时长(毫秒) + noise_reduction_cutoff_hz: 80.0 # 轻量级降噪高通滤波截止频率(Hz) + + # 自动增益控制配置 + agc: true + agc_method: "incremental" # "incremental" (增量) 或 "standard" (标准) + agc_target_db: -20.0 # AGC 目标音量(dB) + # 过小(如 0.1)时在短时强噪声后易把波形压到 int16 近 0,能量 VAD 长期收不到音;建议 0.25~0.5 + agc_gain_min: 0.25 # AGC 最小增益倍数 + agc_gain_max: 10.0 # AGC 最大增益倍数 + agc_rms_threshold: 1e-6 # AGC RMS 阈值(避免除零) + agc_smoothing_alpha: 0.1 # 压低增益时的平滑系数(0-1,越小越慢) + agc_release_alpha: 0.45 # 需要抬增益时(巨响/小声后恢复)用更大系数,由 audio.py 读取 + +# ***********语音活动检测配置 ************* +vad: + # 略降低门槛,避免板载麦/经 AGC 后仍达不到 0.65 + threshold: 0.45 + start_frame: 2 + # 句尾连续静音块数(每块时长≈audio.frame_size/sample_rate);recognizer.trailing_silence_seconds 优先覆盖此项与 energy_vad_end_chunks + end_frame: 10 + min_silence_duration_s: 0.5 + max_silence_duration_s: 30 + model_path: "models/SileroVad/silero_vad.onnx" + +# ***********语音识别配置 ************* +stt: + # 模型路径配置 + model_dir: "models/SenseVoiceSmall" # 模型目录 + model_path: "models/SenseVoiceSmall/model.int8.onnx" # 直接指定模型路径(优先级高于 model_dir) + prefer_int8: true # 是否优先使用 INT8 量化模型 + warmup_file: "" # 预热音频文件 + + # 音频预处理配置 + sample_rate: 16000 # 采样率(与 audio 配置保持一致) + n_mels: 80 # Mel 滤波器数量 + frame_length_ms: 25 # 帧长度(毫秒) + frame_shift_ms: 10 # 帧移(毫秒) + log_eps: 1e-10 # log 计算时的极小值(避免 log(0)) + + # ARM 设备优化配置 + arm_optimization: + enabled: true # 是否启用 ARM 优化 + max_threads: 4 # 最大线程数(RK3588 使用 4 个大核) + + # CTC 解码配置 + ctc_decode: + blank_id: 0 # 空白 token ID + + # 语言和文本规范化配置(默认值,实际从模型元数据读取) + language: + zh_id: 3 # 中文语言ID(默认值) + text_norm: + with_itn_id: 14 # 使用 ITN 的 ID(默认值) + without_itn_id: 15 # 不使用 ITN 的 ID(默认值) + + # 后处理配置 + postprocess: + special_tokens: # 需要移除的特殊 token + - "<|zh|>" + - "<|NEUTRAL|>" + - "<|Speech|>" + - "<|woitn|>" + - "<|withitn|>" + +# ***********文本转语音配置 ************* +tts: + # Kokoro ONNX 中文模型目录 + model_dir: "models/Kokoro-82M-v1.1-zh-ONNX" + + # ONNX 子目录中的模型文件名 + # 清晰度优先: model_fp16.onnx / model.onnx > model_q4f16.onnx > int8/uint8 + # 速度(板端 CPU):可试 model_q4f16.onnx;uint8 有时更慢(取决于 ORT/CPU) + # 若仓库中仅有量化版,可改回 model_uint8.onnx 或 model_q4f16.onnx + model_name: "model_uint8.onnx" + + # 语音风格(对应 voices 目录下的 *.bin, 这里不写扩展名) + # 女声常用 zf_001 / zf_002;可换 zm_* 男声。以本机 voices 目录实际文件为准。 + voice: "zm_009" + + # 语速系数(1.0 最自然; >1 易显赶、含糊; <1 更稳但略慢) + speed: 1.15 + # 输出采样率(与 Kokoro 模型保持一致, 官方为 24000Hz) + sample_rate: 24000 + + # sounddevice 播放输出设备(命令应答语音走扬声器时请指定,避免播到虚拟声卡/耳机) + # null:使用系统默认输出设备 + # 整数:设备索引(启动日志会列出「sounddevice 输出设备列表」供对照) + # 字符串:设备名称子串匹配(不区分大小写),例如 "扬声器"、"Speakers"、"Realtek" + # 香橙派走 HDMI 出声时 PortAudio 名常含 rockchip-hdmi0,可设 output_device: "rockchip-hdmi0"(与 ROCKET_TTS_DEVICE 一致) + output_device: null + + # 播放前重采样到该输出设备的 default_samplerate(Windows/WASAPI 下 24000Hz 常无声,强烈建议 true) + playback_resample_to_device_native: true + + # 播放前将峰值压到约 0.92,减轻削波导致的爆音/杂音 + playback_peak_normalize: true + # 播放音量增益(波形幅度乘法,1.0 不变;1.3~1.8 更响,过大可能削波失真) + playback_gain: 2.5 + # 首尾淡入淡出(毫秒),减轻驱动/缓冲区切换时的爆音与「咔哒」声 + playback_edge_fade_ms: 8 + # sounddevice OutputStream 延迟: low / medium / high(high 易积缓冲,部分机器听感发闷或拖尾) + playback_output_latency: low + +# ***********云端语音(LLM + TTS 上云,见 clientguide.md)************* +cloud_voice: + enabled: false + server_url: "ws://192.168.0.186:8766/v1/voice/session" + auth_token: "drone-voice-cloud-token-2024" + device_id: "drone-001" + timeout: 120 + # PROMPT_LISTEN:麦克 RMS 持续低于 recognizer.energy_vad_rms_low 的累计秒数 ≥ 此值则超时(非滴声后固定墙上时钟);消抖/提示音见下 + listen_silence_timeout_sec: 5 + post_cue_mic_mute_ms: 200 + segment_cue_duration_ms: 120 + # 问候语 / 本地 LLM 文案 / 飞控确认超时等字符串是否走 WebSocket tts.synthesize(见 docs/API.md §3.3);失败回退 Kokoro + remote_tts_for_local: true + # 云端失败时是否回退本地 Qwen + Kokoro(需本地模型) + fallback_to_local: true + # PX4/MAV 语境:合并进 WebSocket session.start 的 client,供服务端 LLM;也可用 ROCKET_CLOUD_PX4_CONTEXT_FILE 覆盖路径 + px4_context_file: "voice_drone/config/cloud_voice_px4_context.yaml" + +# ***********socket服务器配置 ************* +socket_server: + # deployed + host: "192.168.43.200" + port: 6666 + + #local + # host: "127.0.0.1" + # port: 8888 + connection_timeout: 5.0 + send_timeout: 2.0 + reconnect_interval: 3.0 + # -1:断线后持续重连并发送直到成功(仅打 warning,不当作一次性致命错误);正整数:最多尝试次数 + max_retries: -1 + +# ***********文本预处理配置 ************* +text_preprocessor: + # 功能开关 + enable_traditional_to_simplified: true # 启用繁简转换 + enable_segmentation: true # 启用分词 + enable_correction: true # 启用纠错 + enable_number_extraction: true # 启用数字提取 + enable_keyword_detection: true # 启用关键词检测 + + # 性能配置 + lru_cache_size: 512 # LRU缓存大小(分词结果缓存) + +# ***********识别器流程配置 ************* +recognizer: + # 句尾连续静音达到该秒数后才切段送 STT,减少句中停顿被切开、识别半句。按 audio.frame_size 与 sample_rate 换算块数, + # 并同时设置 Silero 的 vad.end_frame 与 energy 的 energy_vad_end_chunks。不配置则分别用 yaml 中上述两项。 + trailing_silence_seconds: 1.5 + # VAD 后端:energy(默认,无需 Silero ONNX)或 silero(需 models/SileroVad/silero_vad.onnx) + vad_backend: energy + energy_vad_rms_high: 8000 # int16 块 RMS,连续达到 start_chunks 块判为开始说话 + energy_vad_rms_low: 5000 # 连续 end_chunks 块低于此判为结束(高噪底时单独不够) + # 相对峰值判停:首字很响时若仍用 0.88,句中略轻易误判「已说完」。可改为 0.75~0.82;设 0 则关闭相对判据(只靠 rms_low + 尾静音) + energy_vad_end_peak_ratio: 0.80 + # 每音频块后对句内峰值衰减系数(0.95~0.999),与 end_peak_ratio 配合减少「长句说到一半被切」 + energy_vad_utt_peak_decay: 0.988 + energy_vad_start_chunks: 4 + energy_vad_end_chunks: 15 + # 预缓冲:在检测到语音开始时,把开始前的一小段音频也拼进语音段 + # 用于避免 VAD 起始判定稍慢导致“丢字/丢开头” + pre_speech_max_seconds: 0.8 + # Socket 命令发送成功后,是否用 TTS 语音回复(需 Kokoro 模型与 sounddevice) + ack_tts_enabled: true + # 若配置了 ack_tts_phrases 且非空:仅下列命令会播报,且每次从对应列表随机选一句。 + # 若未配置或为空 {}:回退为全局 ack_tts_text,所有成功命令均播报同一句(并可预缓存波形)。 + # 阻塞预加载时会对所有「不重复」的备选句逐条合成并缓存;句数越多启动阶段越久,但播报为低延迟播放缓存。 + ack_tts_phrases: + takeoff: + - "收到,正在控制无人机起飞" + - "明白,正在准备起飞" + - "懂你意思, 这就开始起飞" + land: + - "收到,正在控制无人机降落" + - "明白,马上降落" + - "懂你意思, 这就开始降落" + follow: + - "好的,我将持续跟随你,但请你不要移动太快" + - "主人,我已经开始跟随模式,请不要突然离开我的视线" + - "我已经开启了跟随模式" + hover: + - "我已悬停" + - "我已经停下脚步了" + - "放心,我现在开始不会乱动了" + # 未使用 ack_tts_phrases 时的全局固定应答(旧行为) + ack_tts_text: "收到!执行命令!" + # 应答波形磁盘缓存:文案与 tts 配置未变时从 cache/ack_tts_pcm/ 读取,跳过后续启动时的逐条合成(可明显加快二次启动) + ack_tts_disk_cache: true + # 启动时预加载 Kokoro(首次加载约需十余秒) + ack_tts_prewarm: true + # true:启动阶段阻塞到 TTS 就绪后再进入监听(命令成功后可马上播报);false:后台加载(首条命令可能仍要等十余秒) + ack_tts_prewarm_blocking: true + # 播报应答前暂时停止 PyAudio 麦克风(Windows 上输入/输出同时占用时常导致扬声器无声;单独跑 tts 脚本时无此问题) + # 香橙派 ES8388:暂停/恢复采集后出现「VAD RMS≈0、像没电平」时,可改为 false(避免 stop/start 采集),或加大 ROCKET_MIC_RESTART_SETTLE_MS。 + ack_pause_mic_for_playback: true + +# ***********主程序 main_app(TakeoffPrintRecognizer)************* +assistant: + # 「keywords.yaml 里 takeoff 词 → 本地 offboard + WAV」捷径;默认关,飞控走云端 flight_intent / ROS 桥即可。 + # 若需恢复口令起飞:此处改为 true,或启动前 export ROCKET_LOCAL_KEYWORD_TAKEOFF=1(非空环境变量优先于本项)。 + local_keyword_takeoff_enabled: false + +# ***********日志配置 ************* +logging: + level: "INFO" + debug: false + + + diff --git a/voice_drone/config/wake_word.yaml b/voice_drone/config/wake_word.yaml new file mode 100644 index 0000000..5fff531 --- /dev/null +++ b/voice_drone/config/wake_word.yaml @@ -0,0 +1,71 @@ +# 唤醒词配置 +wake_word: + # 主唤醒词(标准形式) + primary: "无人机" + + # 唤醒词变体映射(支持同音字、拼音、错别字等) + variants: + # 标准形式 + - "无人机" + - "无 人 机" # 带空格 + - "无人机," + - "无人机 。" + - "无人机。" + - "喂 无人机" + - "喂,无人机" + - "嗨 无人机" + - "嘿 无人机" + - "哈喽 无人机" + + # 常见误识别 / 近音 + - "五人机" + - "吾人机" + - "无人鸡" + - "无认机" + - "无任机" + + # 拼音变体(小写) + - "wu ren ji" + - "wurenj" + - "wu renji" + - "wuren ji" + - "wu ren, ji" + - "wu ren ji." + - "hey wu ren ji" + - "hi wu ren ji" + - "hello wu ren ji" + - "ok wu ren ji" + + # 拼音变体(大小写混合) + - "Wu Ren Ji" + - "WU REN JI" + - "Wu ren ji" + - "Hey Wu Ren Ji" + - "Hi Wu Ren Ji" + + # 短说 / 部分匹配(易与日常用语冲突时可删去「无人」单独一项) + - "无人" + # 勿单独使用「人机」:易在「牛人机学庭」等误识别里子串命中,导致假唤醒 + - "hey 无人" + - "嗨 无人" + - "喂 无人" + + # 匹配模式配置 + matching: + # 是否启用模糊匹配(同音字、拼音) + enable_fuzzy: true + + # 是否启用部分匹配(代码侧:主词长度足够时取前半段;短词依赖上面 variants) + enable_partial: true + + # 是否忽略大小写 + ignore_case: true + + # 是否忽略空格 + ignore_spaces: true + + # 最小匹配长度(字符数,用于部分匹配) + min_match_length: 2 + + # 相似度阈值(0-1,用于模糊匹配) + similarity_threshold: 0.7 diff --git a/voice_drone/core/audio.py b/voice_drone/core/audio.py new file mode 100644 index 0000000..3f6d3ba --- /dev/null +++ b/voice_drone/core/audio.py @@ -0,0 +1,714 @@ +""" +音频采集模块 - 优化版本 + +输入设备(麦克风)选择(已简化): +- 若 system.yaml 中 audio.input_device_index 为整数:只尝试该 PyAudio 索引(无则启动失败并列设备)。 +- 若为 null:依次尝试系统默认输入、所有 maxInputChannels>0 的设备。 + rocket_drone_audio 启动时可交互选择并写入 input_device_index(见 src.core.mic_device_select)。 +""" +from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio + +fix_ld_path_for_portaudio() + +import re +import pyaudio +import numpy as np +import queue +import threading +from typing import List, Optional, Tuple +from voice_drone.core.configuration import SYSTEM_AUDIO_CONFIG +from voice_drone.logging_ import get_logger + +logger = get_logger("audio.capture.optimized") + + +class AudioCaptureOptimized: + """ + 优化版音频采集器 + + 使用回调模式 + 队列,实现非阻塞音频采集 + """ + + def __init__(self): + """ + 初始化音频采集器 + """ + # 确保数值类型正确(从 YAML 读取可能是字符串) + self.sample_rate = int(SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000)) + self.channels = int(SYSTEM_AUDIO_CONFIG.get("channels", 1)) + self.chunk_size = int(SYSTEM_AUDIO_CONFIG.get("frame_size", 1024)) + self.sample_width = int(SYSTEM_AUDIO_CONFIG.get("sample_width", 2)) + + # 高性能模式配置 + self.buffer_queue_size = int(SYSTEM_AUDIO_CONFIG.get("buffer_queue_size", 10)) + self._prefer_stereo_capture = bool( + SYSTEM_AUDIO_CONFIG.get("prefer_stereo_capture", True) + ) + raw_idx = SYSTEM_AUDIO_CONFIG.get("input_device_index", None) + self._input_device_index_cfg: Optional[int] = ( + int(raw_idx) if raw_idx is not None and str(raw_idx).strip() != "" else None + ) + + tr = SYSTEM_AUDIO_CONFIG.get("audio_open_try_rates") + if tr: + raw_rates: List[int] = [int(x) for x in tr if x is not None] + else: + raw_rates = [self.sample_rate, 48000, 44100, 32000] + seen_r: set[int] = set() + self._open_try_rates: List[int] = [] + for r in raw_rates: + if r not in seen_r: + seen_r.add(r) + self._open_try_rates.append(r) + + # 逻辑通道(送给 VAD/STT 的 mono);_pa_channels 为 PortAudio 实际打开的通道数 + self._pa_channels = self.channels + self._stereo_downmix = False + self._pa_open_sample_rate: int = self.sample_rate + + self.audio = pyaudio.PyAudio() + self.format = self.audio.get_format_from_width(self.sample_width) + + # 使用队列缓冲音频数据(非阻塞) + self.audio_queue = queue.Queue(maxsize=self.buffer_queue_size) + self.stream: Optional[pyaudio.Stream] = None + + logger.info( + f"优化版音频采集器初始化成功: " + f"采样率={self.sample_rate}Hz, " + f"块大小={self.chunk_size}, " + f"使用回调模式+队列缓冲" + ) + + def _device_hw_tuple_in_name(self, dev_name: str) -> Optional[Tuple[int, int]]: + m = re.search(r"\(hw:(\d+),\s*(\d+)\)", dev_name) + if not m: + return None + return int(m.group(1)), int(m.group(2)) + + def _ordered_input_candidates(self) -> Tuple[List[int], List[int]]: + preferred: List[int] = [] + seen: set[int] = set() + + def add(idx: Optional[int]) -> None: + if idx is None: + return + ii = int(idx) + if ii in seen: + return + seen.add(ii) + preferred.append(ii) + + # 配置了整数索引:只打开该设备(与交互选择 / CLI 写入一致) + if self._input_device_index_cfg is not None: + add(self._input_device_index_cfg) + return preferred, [] + + try: + add(int(self.audio.get_default_input_device_info()["index"])) + except Exception: + pass + for i in range(self.audio.get_device_count()): + try: + inf = self.audio.get_device_info_by_index(i) + if int(inf.get("maxInputChannels", 0)) > 0: + add(i) + except Exception: + continue + + fallback: List[int] = [] + for i in range(self.audio.get_device_count()): + if i in seen: + continue + try: + self.audio.get_device_info_by_index(i) + except Exception: + continue + fallback.append(i) + return preferred, fallback + + def _channel_plan(self, max_in: int, dev_name: str) -> List[Tuple[int, bool]]: + ch = self.channels + pref = self._prefer_stereo_capture + if ch != 1: + return [(ch, False)] + if max_in <= 0: + logger.warning( + "设备 %s 报告 maxInputChannels=%s,将尝试 mono / stereo", + dev_name or "?", + max_in, + ) + return [(1, False), (2, True)] + if max_in == 1: + return [(1, False)] + if pref: + return [(2, True), (1, False)] + return [(1, False)] + + def _frames_per_buffer_for_rate(self, pa_rate: int) -> int: + if pa_rate <= 0: + pa_rate = self.sample_rate + return max(128, int(round(self.chunk_size * pa_rate / self.sample_rate))) + + @staticmethod + def _resample_linear_int16( + x: np.ndarray, sr_in: int, sr_out: int + ) -> np.ndarray: + if sr_in == sr_out or x.size == 0: + return x + n_out = max(1, int(round(x.size * (sr_out / sr_in)))) + t_in = np.arange(x.size, dtype=np.float64) + t_out = np.linspace(0.0, float(x.size - 1), n_out, dtype=np.float64) + y = np.interp(t_out, t_in, x.astype(np.float32)) + return np.clip(np.round(y), -32768, 32767).astype(np.int16) + + def _try_open_on_device(self, input_device_index: int) -> bool: + try: + dev = self.audio.get_device_info_by_index(input_device_index) + except Exception: + return False + max_in = int(dev.get("maxInputChannels", 0)) + dev_name = str(dev.get("name", "")) + if ( + max_in <= 0 + and self._input_device_index_cfg is not None + and int(input_device_index) == int(self._input_device_index_cfg) + and self._prefer_stereo_capture + and self.channels == 1 + ): + max_in = 2 + logger.warning( + "设备 %s 上报 maxInputChannels=0,假定 2 通道以尝试 ES8388 立体声采集", + input_device_index, + ) + plan = self._channel_plan(max_in, dev_name) + hw_t = self._device_hw_tuple_in_name(dev_name) + + for pa_ch, stereo_dm in plan: + self._pa_channels = pa_ch + self._stereo_downmix = stereo_dm + if stereo_dm and pa_ch == 2: + logger.info( + "输入按立体声打开并下混 mono(%s index=%s)", + dev_name, + input_device_index, + ) + for rate in self._open_try_rates: + fpb = self._frames_per_buffer_for_rate(int(rate)) + try: + self.stream = self.audio.open( + format=self.format, + channels=self._pa_channels, + rate=int(rate), + input=True, + input_device_index=input_device_index, + frames_per_buffer=fpb, + stream_callback=self._audio_callback, + start=False, + ) + self.stream.start_stream() + self._pa_open_sample_rate = int(rate) + extra = ( + f" hw=card{hw_t[0]}dev{hw_t[1]}" if hw_t else "" + ) + if self._pa_open_sample_rate != self.sample_rate: + logger.warning( + "输入实际 %s Hz,将重采样为 %s Hz 供 VAD/STT", + self._pa_open_sample_rate, + self.sample_rate, + ) + logger.info( + "音频流启动成功 index=%s name=%r PA_ch=%s PA_rate=%s 逻辑rate=%s%s", + input_device_index, + dev_name, + self._pa_channels, + self._pa_open_sample_rate, + self.sample_rate, + extra, + ) + return True + except Exception as e: + if self.stream is not None: + try: + self.stream.close() + except Exception: + pass + self.stream = None + logger.warning( + "打开失败 index=%s ch=%s rate=%s: %s", + input_device_index, + pa_ch, + rate, + e, + ) + return False + + def _audio_callback(self, in_data, frame_count, time_info, status): + """ + 音频回调函数(非阻塞) + """ + if status: + logger.warning(f"音频流状态: {status}") + + # 将数据放入队列(非阻塞) + try: + self.audio_queue.put(in_data, block=False) + except queue.Full: + logger.warning("音频队列已满,丢弃数据块") + + return (None, pyaudio.paContinue) + + def _log_input_devices_for_user(self) -> None: + """列出 PortAudio 全部设备(含 in_ch=0),便于选 --input-index / 核对子串。""" + n_dev = self.audio.get_device_count() + if n_dev <= 0: + print( + "[audio] PyAudio get_device_count()=0,多为 ALSA/PortAudio 未初始化;" + "请用 bash with_system_alsa.sh python … 启动。", + flush=True, + ) + logger.error("PyAudio 枚举不到任何设备") + return + lines: List[str] = [] + for i in range(n_dev): + try: + inf = self.audio.get_device_info_by_index(i) + mic = int(inf.get("maxInputChannels", 0)) + outc = int(inf.get("maxOutputChannels", 0)) + name = str(inf.get("name", "?")) + mark = " <- 可录音" if mic > 0 else "" + lines.append(f" [{i}] in={mic} out={outc} {name}{mark}") + except Exception: + continue + msg = "\n".join(lines) + logger.error("PortAudio 设备列表:\n%s", msg) + print( + "[audio] PortAudio 设备列表(in>0 才可作输入;若板载显示 in=0 仍可用 probe 试采):\n" + + msg, + flush=True, + ) + + def start_stream(self) -> None: + """启动音频流(回调模式)""" + if self.stream is not None: + return + + preferred, fallback = self._ordered_input_candidates() + to_try: List[int] = preferred + fallback + if not to_try: + print( + "[audio] 无任何输入候选。请检查 PortAudio/ALSA(建议:bash with_system_alsa.sh python …)。", + flush=True, + ) + if self._input_device_index_cfg is not None: + logger.error( + "已配置 input_device_index=%s 但无效或不可打开", + self._input_device_index_cfg, + ) + print( + f"[audio] 当前配置的 PyAudio 索引 {self._input_device_index_cfg} 不可用," + "请改 system.yaml 或重新运行交互选设备。", + flush=True, + ) + self._log_input_devices_for_user() + raise OSError("未找到任何 PyAudio 输入候选设备") + + for input_device_index in to_try: + if self._try_open_on_device(input_device_index): + return + + logger.error("启动音频流失败:全部候选设备无法打开") + self._log_input_devices_for_user() + raise OSError("启动音频流失败:全部候选设备无法打开") + + def stop_stream(self) -> None: + """停止音频流""" + if self.stream is None: + return + + try: + self.stream.stop_stream() + self.stream.close() + self.stream = None + self._pa_open_sample_rate = self.sample_rate + # 清空队列 + while not self.audio_queue.empty(): + try: + self.audio_queue.get_nowait() + except queue.Empty: + break + logger.info("音频流已停止") + except Exception as e: + logger.error(f"停止音频流失败: {e}") + + def read_chunk(self, timeout: float = 0.1) -> Optional[bytes]: + """ + 读取一个音频块(非阻塞) + + Args: + timeout: 超时时间(秒) + + Returns: + 音频数据(bytes),如果超时则返回 None + """ + if self.stream is None: + return None + + try: + return self.audio_queue.get(timeout=timeout) + except queue.Empty: + return None + + def read_chunk_numpy(self, timeout: float = 0.1) -> Optional[np.ndarray]: + """读取一个音频块并转换为 numpy 数组(非阻塞)""" + data = self.read_chunk(timeout) + if data is None: + return None + + sample_size = self._pa_channels * self.sample_width + if len(data) % sample_size != 0: + aligned_len = (len(data) // sample_size) * sample_size + if aligned_len == 0: + return None + data = data[:aligned_len] + + mono = np.frombuffer(data, dtype=" float: + """ + 更新 RMS 值 + + Args: + sample: 新的采样值 + + Returns: + 当前 RMS 值 + """ + if self.count < self.window_size: + # 填充阶段 + self.buffer[self.count] = sample + self.sum_sq += sample * sample + self.count += 1 + if self.count == 0: + return 0.0 + return np.sqrt(self.sum_sq / self.count) + else: + # 滑动窗口阶段 + old_sq = self.buffer[self.idx] * self.buffer[self.idx] + self.sum_sq = self.sum_sq - old_sq + sample * sample + self.buffer[self.idx] = sample + self.idx = (self.idx + 1) % self.window_size + return np.sqrt(self.sum_sq / self.window_size) + + def update_batch(self, samples: np.ndarray) -> float: + """ + 批量更新 RMS 值 + + Args: + samples: 采样数组 + + Returns: + 当前 RMS 值 + """ + for sample in samples: + self.update(sample) + return np.sqrt(self.sum_sq / min(self.count, self.window_size)) + + def reset(self): + """重置计算器""" + self.buffer.fill(0.0) + self.sum_sq = 0.0 + self.idx = 0 + self.count = 0 + + +class LightweightNoiseReduction: + """ + 轻量级降噪算法 + + 使用简单的高通滤波 + 谱减法,性能比 noisereduce 快 10-20 倍 + """ + + def __init__(self, sample_rate: int = 16000, cutoff: float = 80.0): + """ + Args: + sample_rate: 采样率 + cutoff: 高通滤波截止频率(Hz) + """ + self.sample_rate = sample_rate + self.cutoff = cutoff + + # 简单的 IIR 高通滤波器系数(一阶 Butterworth) + # H(z) = (1 - z^-1) / (1 - 0.99*z^-1) + self.alpha = np.exp(-2.0 * np.pi * cutoff / sample_rate) + self.prev_input = 0.0 + self.prev_output = 0.0 + + def process(self, audio: np.ndarray) -> np.ndarray: + """ + 处理音频(高通滤波) + + Args: + audio: 音频数据(float32,范围 [-1, 1]) + + Returns: + 处理后的音频 + """ + if audio.dtype != np.float32: + audio = audio.astype(np.float32) + + # 简单的一阶高通滤波 + output = np.zeros_like(audio) + for i in range(len(audio)): + output[i] = self.alpha * (self.prev_output + audio[i] - self.prev_input) + self.prev_input = audio[i] + self.prev_output = output[i] + + return output + + def reset(self): + """重置滤波器状态""" + self.prev_input = 0.0 + self.prev_output = 0.0 + + +class AudioPreprocessorOptimized: + """ + 优化版音频预处理器 + + 性能优化: + 1. 轻量级降噪(替代 noisereduce) + 2. 增量 AGC 计算 + 3. 减少类型转换 + """ + + def __init__(self, enable_noise_reduction: Optional[bool] = None, + enable_agc: Optional[bool] = None): + """ + 初始化音频预处理器 + """ + # 从配置读取 + if enable_noise_reduction is None: + enable_noise_reduction = SYSTEM_AUDIO_CONFIG.get("noise_reduce", True) + if enable_agc is None: + enable_agc = SYSTEM_AUDIO_CONFIG.get("agc", True) + + self.enable_noise_reduction = enable_noise_reduction + self.enable_agc = enable_agc + self.sample_rate = int(SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000)) + + # AGC 参数(确保类型正确,从 YAML 读取可能是字符串) + self.agc_target_db = float(SYSTEM_AUDIO_CONFIG.get("agc_target_db", -20.0)) + self.agc_gain_min = float(SYSTEM_AUDIO_CONFIG.get("agc_gain_min", 0.1)) + self.agc_gain_max = float(SYSTEM_AUDIO_CONFIG.get("agc_gain_max", 10.0)) + self.agc_rms_threshold = float(SYSTEM_AUDIO_CONFIG.get("agc_rms_threshold", 1e-6)) + self._agc_alpha = float(SYSTEM_AUDIO_CONFIG.get("agc_smoothing_alpha", 0.1)) + self._agc_alpha = max(0.02, min(0.95, self._agc_alpha)) + # 当需要抬增益(小声/巨响过后)时用更大系数,避免长时间压在 agc_gain_min + self._agc_release_alpha = float( + SYSTEM_AUDIO_CONFIG.get("agc_release_alpha", 0.45) + ) + self._agc_release_alpha = max(self._agc_alpha, min(0.95, self._agc_release_alpha)) + + # 初始化组件 + if enable_noise_reduction: + self.noise_reducer = LightweightNoiseReduction( + sample_rate=self.sample_rate, + cutoff=80.0 # 可配置 + ) + else: + self.noise_reducer = None + + if enable_agc: + # 使用增量 RMS 计算器 + window_size = int(SYSTEM_AUDIO_CONFIG.get("frame_size", 1024)) + self.rms_calculator = IncrementalRMS(window_size=window_size) + self.current_gain = 1.0 # 缓存当前增益 + else: + self.rms_calculator = None + + logger.info( + f"优化版音频预处理器初始化完成: " + f"降噪={'启用(轻量级)' if enable_noise_reduction else '禁用'}, " + f"自动增益控制={'启用(增量)' if enable_agc else '禁用'}" + ) + + def reset(self) -> None: + """ + 重置高通滤波与 AGC 状态。应在「暂停采集再重新 start_stream」之后调用, + 避免停麦/播 TTS 期间的状态带到新流上(否则易出现恢复后长时间 RMS≈0 或电平怪异)。 + """ + if self.noise_reducer is not None: + self.noise_reducer.reset() + if self.rms_calculator is not None: + self.rms_calculator.reset() + if self.enable_agc: + self.current_gain = 1.0 + + def reset_agc_state(self) -> None: + """ + 每段语音结束或需Recovery时调用:清空 RMS 滑窗并将增益重置为 1。 + 避免短时强噪声把 current_gain 压在 agc_gain_min、滑窗仍含高能量导致后续 RMS≈0。 + """ + if not self.enable_agc or self.rms_calculator is None: + return + self.rms_calculator.reset() + self.current_gain = 1.0 + + def reduce_noise(self, audio_data: np.ndarray) -> np.ndarray: + """ + 轻量级降噪处理 + + Args: + audio_data: 音频数据(int16 或 float32) + + Returns: + 降噪后的音频数据 + """ + if not self.enable_noise_reduction or self.noise_reducer is None: + return audio_data + + # 转换为 float32 + if audio_data.dtype == np.int16: + audio_float = audio_data.astype(np.float32) / 32768.0 + is_int16 = True + else: + audio_float = audio_data.astype(np.float32) + is_int16 = False + + # 轻量级降噪 + reduced = self.noise_reducer.process(audio_float) + + # 转换回原始格式 + if is_int16: + reduced = (reduced * 32768.0).astype(np.int16) + + return reduced + + def automatic_gain_control(self, audio_data: np.ndarray) -> np.ndarray: + """ + 自动增益控制(使用增量 RMS) + + Args: + audio_data: 音频数据(int16 或 float32) + + Returns: + 增益调整后的音频数据 + """ + if not self.enable_agc or self.rms_calculator is None: + return audio_data + + # 转换为 float32 + if audio_data.dtype == np.int16: + audio_float = audio_data.astype(np.float32) / 32768.0 + is_int16 = True + else: + audio_float = audio_data.astype(np.float32) + is_int16 = False + + # 使用增量 RMS 计算 + rms = self.rms_calculator.update_batch(audio_float) + + if rms < self.agc_rms_threshold: + return audio_data + + # 计算增益(可以进一步优化:使用滑动平均) + current_db = 20 * np.log10(rms) + gain_db = self.agc_target_db - current_db + gain_linear = 10 ** (gain_db / 20.0) + gain_linear = np.clip(gain_linear, self.agc_gain_min, self.agc_gain_max) + + # 压低增益用较小 alpha;需要恢复(gain_linear 明显高于当前)时用 release alpha + if gain_linear > self.current_gain * 1.08: + alpha = self._agc_release_alpha + else: + alpha = self._agc_alpha + self.current_gain = alpha * gain_linear + (1 - alpha) * self.current_gain + + # 应用增益 + adjusted = audio_float * self.current_gain + adjusted = np.clip(adjusted, -1.0, 1.0) + + # 转换回原始格式 + if is_int16: + adjusted = (adjusted * 32768.0).astype(np.int16) + + return adjusted + + def process(self, audio_data: np.ndarray) -> np.ndarray: + """ + 完整的预处理流程(优化版) + + Args: + audio_data: 音频数据(numpy array) + + Returns: + 预处理后的音频数据 + """ + processed = audio_data.copy() + + # 降噪 + if self.enable_noise_reduction: + processed = self.reduce_noise(processed) + + # 自动增益控制 + if self.enable_agc: + processed = self.automatic_gain_control(processed) + + return processed + + +# 向后兼容别名(保持API一致性) +AudioCapture = AudioCaptureOptimized +AudioPreprocessor = AudioPreprocessorOptimized + + +# 使用示例 +if __name__ == "__main__": + # 优化版使用 + with AudioCapture() as capture: + preprocessor = AudioPreprocessor() + + for i in range(10): + chunk = capture.read_chunk_numpy(timeout=0.1) + if chunk is not None: + processed = preprocessor.process(chunk) + print(f"处理了 {len(processed)} 个采样点") diff --git a/voice_drone/core/cloud_dialog_v1.py b/voice_drone/core/cloud_dialog_v1.py new file mode 100644 index 0000000..267246c --- /dev/null +++ b/voice_drone/core/cloud_dialog_v1.py @@ -0,0 +1,78 @@ +"""cloud_voice_dialog_v1:dialog_result 约定(见 docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md)。""" + +from __future__ import annotations + +from typing import Any + +CLOUD_VOICE_DIALOG_V1 = "cloud_voice_dialog_v1" + +MSG_CONFIRM_TIMEOUT = "未收到确认指令,请重新下发指令。" +MSG_CANCELLED = "已取消指令,请重新唤醒后下发指令。" +MSG_CONFIRM_EXECUTING = "开始执行飞控指令。" +MSG_PROMPT_LISTEN_TIMEOUT = "未检测到语音,请重新唤醒后再说。" + + +def normalize_phrase_text(s: str) -> str: + """去首尾空白、合并连续空白。""" + return " ".join((s or "").strip().split()) + + +def _strip_tail_punct(s: str) -> str: + return s.rstrip("。!!??,, \t") + + +def match_phrase_list(norm: str, phrases: Any) -> bool: + """ + 命中规则(适配「请回复确认或取消」类长提示 + 只说「确认」「取消」): + - 去尾标点后 **全等** 短语;或 + - 短语为 **子串** 且整句长度 <= len(短语)+2,避免用户复述整段提示时同时含「确认」「取消」而误触。 + """ + if not isinstance(phrases, list) or not norm: + return False + base = _strip_tail_punct(normalize_phrase_text(norm)) + if not base: + return False + for p in phrases: + raw = _strip_tail_punct((p or "").strip()) + if not raw: + continue + if base == raw: + return True + if raw in base and len(base) <= len(raw) + 2: + return True + return False + + +def parse_confirm_dict(raw: Any) -> dict[str, Any] | None: + if not isinstance(raw, dict): + return None + required = raw.get("required") + if not isinstance(required, bool): + return None + try: + timeout_sec = float(raw.get("timeout_sec", 10)) + except (TypeError, ValueError): + timeout_sec = 10.0 + timeout_sec = max(1.0, min(120.0, timeout_sec)) + cp = raw.get("confirm_phrases") + kp = raw.get("cancel_phrases") + if not isinstance(cp, list) or not cp: + return None + if not isinstance(kp, list) or not kp: + return None + pending = raw.get("pending_id") + if pending is not None and not isinstance(pending, str): + pending = str(pending) + cplist = [str(x) for x in cp if str(x).strip()] + kplist = [str(x) for x in kp if str(x).strip()] + if not cplist or not kplist: + return None + return { + "required": required, + "timeout_sec": timeout_sec, + "confirm_phrases": cplist, + "cancel_phrases": kplist, + "pending_id": pending, + "summary_for_user": raw.get("summary_for_user"), + "raw": raw, + } diff --git a/voice_drone/core/cloud_voice_client.py b/voice_drone/core/cloud_voice_client.py new file mode 100644 index 0000000..c7e4fdb --- /dev/null +++ b/voice_drone/core/cloud_voice_client.py @@ -0,0 +1,999 @@ +""" +云端语音 WebSocket 客户端:会话 `session.start.transport_profile` 固定为 pcm_asr_uplink。 + +- 主路径:`turn.audio.start` → 若干 `turn.audio.chunk`(每条仅文本 JSON,含 `pcm_base64`)→ `turn.audio.end`;**禁止**用 WebSocket binary 上发 PCM(与 Starlette receive 语义一致)。 +- 辅助:`run_turn` 发 `turn.text`(如同句快路径仅有文本);`run_tts_synthesize` 仅 TTS。 +- `asr.partial` 仅调试展示,不参与机端状态机。 + +文档:`docs/CLOUD_VOICE_SESSION_SCHEME_v1.md`,`docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md`。 +""" + +from __future__ import annotations + +import base64 +import json +import os +import threading +import time +import uuid +from typing import Any + +import numpy as np + +from voice_drone.core.cloud_dialog_v1 import CLOUD_VOICE_DIALOG_V1 +from voice_drone.logging_ import get_logger + +logger = get_logger("voice_drone.cloud_voice") + +_CLOUD_PROTO = "1.0" +TRANSPORT_PCM_ASR_UPLINK = "pcm_asr_uplink" + + +def _merge_session_client( + device_id: str, + *, + session_client_extensions: dict[str, Any] | None, +) -> dict[str, Any]: + """session.start 的 client:capabilities 与设备信息 + 可选 PX4 等扩展(不覆盖 device_id/locale)。""" + client: dict[str, Any] = { + "device_id": device_id, + "locale": "zh-CN", + "capabilities": { + "playback_sample_rate_hz": 24000, + "prefer_tts_codec": "pcm_s16le", + }, + } + ext = session_client_extensions or {} + for k, v in ext.items(): + if v is None or k in ("device_id", "locale", "capabilities", "protocol"): + continue + if k == "extras" and isinstance(v, dict) and len(v) == 0: + continue + client[k] = v + client["protocol"] = {"dialog_result": CLOUD_VOICE_DIALOG_V1} + return client + + +def _transient_ws_exc(exc: BaseException) -> bool: + """可通过对端已关、网络抖动等通过重连重发 turn 恢复的异常。""" + import websocket as _websocket # noqa: PLC0415 + + if isinstance( + exc, + ( + BrokenPipeError, + ConnectionResetError, + ConnectionAbortedError, + ), + ): + return True + if isinstance( + exc, + ( + _websocket.WebSocketConnectionClosedException, + _websocket.WebSocketTimeoutException, + ), + ): + return True + if isinstance(exc, OSError) and getattr(exc, "errno", None) in ( + 32, + 104, + 110, + ): # EPIPE, ECONNRESET, ETIMEDOUT + return True + return False + + +def _merge_tts_pcm_chunks( + chunk_entries: list[tuple[int | None, int, bytes]], +) -> bytes: + """按 seq 升序拼接;无 seq 时按到达顺序。chunk_entries: (seq|None, arrival, pcm)。""" + if not chunk_entries: + return b"" + if all(s is not None for s, _, _ in chunk_entries): + ordered = sorted(chunk_entries, key=lambda x: (x[0], x[1])) + seqs = [x[0] for x in ordered] + for a, b in zip(seqs, seqs[1:]): + if b != a + 1: + logger.warning("TTS seq 不连续(仍按序拼接): %s → %s", a, b) + break + return b"".join(x[2] for x in ordered) + return b"".join(x[2] for x in sorted(chunk_entries, key=lambda x: x[1])) + + +class CloudVoiceError(RuntimeError): + """云端返回 error 消息或协议不符合预期。""" + + def __init__(self, message: str, *, code: str | None = None, retryable: bool = False): + super().__init__(message) + self.code = code + self.retryable = retryable + + +class CloudVoiceClient: + """连接 ws://…/v1/voice/session;session 为 pcm_asr_uplink,含 run_turn_audio / run_turn / tts.synthesize。""" + + def __init__( + self, + *, + server_url: str, + auth_token: str, + device_id: str, + recv_timeout: float = 120.0, + session_client_extensions: dict[str, Any] | None = None, + ) -> None: + self.server_url = server_url.strip() + self.auth_token = auth_token.strip() + self.device_id = (device_id or "drone-001").strip() + self.recv_timeout = float(recv_timeout) + self._session_client_extensions: dict[str, Any] = dict( + session_client_extensions or {} + ) + self._transport_profile: str = TRANSPORT_PCM_ASR_UPLINK + self._ws: Any = None + self._session_id: str | None = None + self._lock = threading.Lock() + + @property + def connected(self) -> bool: + with self._lock: + return self._ws is not None + + def close(self) -> None: + with self._lock: + self._close_nolock() + + def _close_nolock(self) -> None: + if self._ws is None: + self._session_id = None + return + try: + if self._session_id: + try: + self._ws.send( + json.dumps( + { + "type": "session.end", + "proto_version": _CLOUD_PROTO, + "session_id": self._session_id, + }, + ensure_ascii=False, + ) + ) + except Exception: # noqa: BLE001 + pass + finally: + try: + self._ws.close() + except Exception: # noqa: BLE001 + pass + self._ws = None + self._session_id = None + + def connect(self) -> None: + """建立 WSS,发送 session.start,等待 session.ready。""" + with self._lock: + self._connect_nolock() + + def _connect_nolock(self) -> None: + import websocket # websocket-client + + self._close_nolock() + hdr = [f"Authorization: Bearer {self.auth_token}"] + try: + self._ws = websocket.create_connection( + self.server_url, + header=hdr, + timeout=self.recv_timeout, + ) + self._ws.settimeout(self.recv_timeout) + self._session_id = str(uuid.uuid4()) + client_payload = _merge_session_client( + self.device_id, + session_client_extensions=self._session_client_extensions, + ) + if self._session_client_extensions: + logger.info( + "session.start 已附加 client 扩展键: %s", + sorted(self._session_client_extensions.keys()), + ) + start = { + "type": "session.start", + "proto_version": _CLOUD_PROTO, + "transport_profile": self._transport_profile, + "session_id": self._session_id, + "auth_token": self.auth_token, + "client": client_payload, + } + self._ws.send(json.dumps(start, ensure_ascii=False)) + raw = self._ws.recv() + if isinstance(raw, bytes): + raise CloudVoiceError("session.ready 期望 JSON 文本帧,收到二进制") + data = json.loads(raw) + if data.get("type") != "session.ready": + raise CloudVoiceError( + f"期望 session.ready,收到: {data.get('type')!r}", + code="INVALID_MESSAGE", + ) + logger.info("云端会话已就绪 session_id=%s", self._session_id) + except Exception: + self._close_nolock() + raise + + def ensure_connected(self) -> None: + with self._lock: + if self._ws is None: + self._connect_nolock() + + def _execute_turn_nolock(self, t: str) -> dict[str, Any]: + """已持锁且 _ws 已连接:发送 turn.text 并收齐本轮帧。""" + import websocket # websocket-client + + ws = self._ws + if ws is None: + raise CloudVoiceError("WebSocket 未连接") + + turn_id = str(uuid.uuid4()) + turn_msg = { + "type": "turn.text", + "proto_version": _CLOUD_PROTO, + "transport_profile": self._transport_profile, + "turn_id": turn_id, + "text": t, + "is_final": True, + "source": "device_stt", + } + try: + ws.send(json.dumps(turn_msg, ensure_ascii=False)) + except Exception as e: + if _transient_ws_exc(e): + raise + raise CloudVoiceError(f"发送 turn.text 失败: {e}", code="INTERNAL") from e + logger.debug("→ turn.text turn_id=%s", turn_id) + + expecting_binary = False + _pending_tts_seq: int | None = None + pcm_entries: list[tuple[int | None, int, bytes]] = [] + _pcm_arrival = 0 + llm_stream_parts: list[str] = [] + dialog: dict[str, Any] | None = None + metrics: dict[str, Any] = {} + sample_rate_hz = 24000 + + while True: + try: + msg = ws.recv() + except websocket.WebSocketConnectionClosedException as e: + raise CloudVoiceError( + f"连接已断开: {e}", + code="DISCONNECTED", + retryable=True, + ) from e + except Exception as e: + if _transient_ws_exc(e): + raise + raise + + if isinstance(msg, bytes): + if expecting_binary: + expecting_binary = False + else: + logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理") + pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg)) + _pcm_arrival += 1 + _pending_tts_seq = None + continue + + if not isinstance(msg, str): + raise CloudVoiceError( + f"期望文本帧为 str,实际为 {type(msg).__name__}", + code="INVALID_MESSAGE", + ) + text_frame = msg.strip() + if not text_frame: + logger.debug("跳过空 WebSocket 文本帧") + continue + try: + data = json.loads(text_frame) + except json.JSONDecodeError as e: + head = text_frame[:200].replace("\n", "\\n") + raise CloudVoiceError( + f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}", + code="INVALID_MESSAGE", + ) from e + mtype = data.get("type") + + if mtype == "asr.partial": + logger.debug("← asr.partial(机端不参与状态跳转)") + continue + + if mtype == "llm.text_delta": + if data.get("turn_id") != turn_id: + logger.debug( + "llm.text_delta turn_id 与当前不一致,忽略 type=%s", + mtype, + ) + continue + raw_d = data.get("delta") + delta = "" if raw_d is None else str(raw_d) + llm_stream_parts.append(delta) + _print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in ( + "1", + "true", + "yes", + ) + if _print_stream: + print(delta, end="", flush=True) + if data.get("done"): + print(flush=True) + logger.debug( + "← llm.text_delta done=%s delta_chars=%s", + data.get("done"), + len(delta), + ) + continue + + if mtype == "tts_audio_chunk": + _pending_tts_seq = None + if data.get("turn_id") != turn_id: + logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制") + else: + try: + sample_rate_hz = int( + data.get("sample_rate_hz") or sample_rate_hz + ) + except (TypeError, ValueError): + pass + _s = data.get("seq") + try: + if _s is not None: + _pending_tts_seq = int(_s) + except (TypeError, ValueError): + _pending_tts_seq = None + if data.get("is_final"): + logger.debug("← tts_audio_chunk is_final=true seq=%s", _s) + expecting_binary = True + continue + + if mtype == "dialog_result": + if data.get("turn_id") != turn_id: + raise CloudVoiceError( + "dialog_result turn_id 不匹配", code="INVALID_MESSAGE" + ) + dialog = data + logger.info( + "← dialog_result routing=%s", data.get("routing") + ) + continue + + if mtype == "turn.complete": + if data.get("turn_id") != turn_id: + raise CloudVoiceError( + "turn.complete turn_id 不匹配", code="INVALID_MESSAGE" + ) + metrics = data.get("metrics") or {} + break + + if mtype == "error": + code = str(data.get("code") or "INTERNAL") + raise CloudVoiceError( + data.get("message") or code, + code=code, + retryable=bool(data.get("retryable")), + ) + + logger.debug("忽略服务端消息 type=%s", mtype) + + if dialog is None: + raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE") + + full_pcm = _merge_tts_pcm_chunks(pcm_entries) + pcm = ( + np.frombuffer(full_pcm, dtype=np.int16).copy() + if full_pcm + else np.array([], dtype=np.int16) + ) + if pcm.size > 0: + mx = int(np.max(np.abs(pcm))) + if mx == 0: + logger.warning( + "云端 TTS 已收齐二进制总长 %s 字节(≈%s 个 s16 采样),但全为 0x00," + "属于服务端发出的静音占位或未写入合成结果;机端无法通过重采样/扬声器修复。" + "请在服务端对同一次 synthesize 写 WAV 核对非零采样,并确认 WS 先发 tts_audio_chunk JSON、" + "再发 raw PCM 帧、且未把 JSON/base64 误当 binary 发出。", + len(full_pcm), + pcm.size, + ) + if os.environ.get("ROCKET_CLOUD_PCM_HEX", "").strip().lower() in ( + "1", + "true", + "yes", + ): + head = full_pcm[:64] + logger.warning( + "ROCKET_CLOUD_PCM_HEX: 前 %s 字节 hex=%s", + len(head), + head.hex(), + ) + + llm_stream_text = "".join(llm_stream_parts) + return { + "protocol": dialog.get("protocol"), + "routing": dialog.get("routing"), + "flight_intent": dialog.get("flight_intent"), + "confirm": dialog.get("confirm"), + "chat_reply": dialog.get("chat_reply"), + "user_input": dialog.get("user_input"), + "pcm": pcm, + "sample_rate_hz": sample_rate_hz, + "metrics": metrics, + "llm_stream_text": llm_stream_text, + } + + def _execute_turn_audio_nolock( + self, pcm_int16: np.ndarray, sample_rate_hz: int + ) -> dict[str, Any]: + """发送 turn.audio.start → 多条 turn.audio.chunk(pcm_base64 文本帧)→ turn.audio.end;禁止 binary 上发 PCM。""" + import websocket # websocket-client + + ws = self._ws + if ws is None: + raise CloudVoiceError("WebSocket 未连接") + + pcm_int16 = np.asarray(pcm_int16, dtype=np.int16).reshape(-1) + if pcm_int16.size == 0: + raise CloudVoiceError("turn.audio PCM 为空") + + pcm_mx = int(np.max(np.abs(pcm_int16))) + pcm_rms = float(np.sqrt(np.mean(pcm_int16.astype(np.float64) ** 2))) + dur_sec = float(pcm_int16.size) / max(1, int(sample_rate_hz)) + logger.info( + "turn.audio 上行: samples=%s sr_hz=%s dur≈%.2fs abs_max=%s rms=%.1f dtype=int16", + pcm_int16.size, + int(sample_rate_hz), + dur_sec, + pcm_mx, + pcm_rms, + ) + if pcm_mx == 0: + logger.warning( + "turn.audio 上行波形全零,云端 ASR 通常会判无有效语音(请查麦/切段/VAD 是否误交静音)" + ) + elif pcm_mx < 200: + logger.warning( + "turn.audio 上行幅值极小 abs_max=%s(仍发送);若云端反复无识别请调 AGC/VAD/麦增益", + pcm_mx, + ) + + turn_id = str(uuid.uuid4()) + start = { + "type": "turn.audio.start", + "proto_version": _CLOUD_PROTO, + "transport_profile": self._transport_profile, + "turn_id": turn_id, + "sample_rate_hz": int(sample_rate_hz), + "codec": "pcm_s16le", + "channels": 1, + } + raw = pcm_int16.tobytes() + try: + ws.send(json.dumps(start, ensure_ascii=False)) + try: + raw_chunk = int(os.environ.get("ROCKET_CLOUD_AUDIO_CHUNK_BYTES", "8192")) + except ValueError: + raw_chunk = 8192 + raw_chunk = max(2048, min(256 * 1024, raw_chunk)) + n_chunks = 0 + for i in range(0, len(raw), raw_chunk): + piece = raw[i : i + raw_chunk] + chunk_msg = { + "type": "turn.audio.chunk", + "proto_version": _CLOUD_PROTO, + "transport_profile": self._transport_profile, + "turn_id": turn_id, + "pcm_base64": base64.b64encode(piece).decode("ascii"), + } + ws.send(json.dumps(chunk_msg, ensure_ascii=False)) + n_chunks += 1 + end = { + "type": "turn.audio.end", + "proto_version": _CLOUD_PROTO, + "transport_profile": self._transport_profile, + "turn_id": turn_id, + } + ws.send(json.dumps(end, ensure_ascii=False)) + except Exception as e: + if _transient_ws_exc(e): + raise + raise CloudVoiceError(f"发送 turn.audio 失败: {e}", code="INTERNAL") from e + logger.debug( + "→ turn.audio start/%s chunk(s)/end turn_id=%s samples=%s", + n_chunks, + turn_id, + pcm_int16.size, + ) + + expecting_binary = False + _pending_tts_seq: int | None = None + pcm_entries: list[tuple[int | None, int, bytes]] = [] + _pcm_arrival = 0 + llm_stream_parts: list[str] = [] + dialog: dict[str, Any] | None = None + metrics: dict[str, Any] = {} + out_sr = 24000 + + while True: + try: + msg = ws.recv() + except websocket.WebSocketConnectionClosedException as e: + raise CloudVoiceError( + f"连接已断开: {e}", + code="DISCONNECTED", + retryable=True, + ) from e + except Exception as e: + if _transient_ws_exc(e): + raise + raise + + if isinstance(msg, bytes): + if expecting_binary: + expecting_binary = False + else: + logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理") + pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg)) + _pcm_arrival += 1 + _pending_tts_seq = None + continue + + if not isinstance(msg, str): + raise CloudVoiceError( + f"期望文本帧为 str,实际为 {type(msg).__name__}", + code="INVALID_MESSAGE", + ) + text_frame = msg.strip() + if not text_frame: + logger.debug("跳过空 WebSocket 文本帧") + continue + try: + data = json.loads(text_frame) + except json.JSONDecodeError as e: + head = text_frame[:200].replace("\n", "\\n") + raise CloudVoiceError( + f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}", + code="INVALID_MESSAGE", + ) from e + mtype = data.get("type") + + if mtype == "asr.partial": + logger.debug("← asr.partial(机端不参与状态跳转)") + continue + + if mtype == "llm.text_delta": + if data.get("turn_id") != turn_id: + logger.debug( + "llm.text_delta turn_id 与当前不一致,忽略 type=%s", + mtype, + ) + continue + raw_d = data.get("delta") + delta = "" if raw_d is None else str(raw_d) + llm_stream_parts.append(delta) + _print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in ( + "1", + "true", + "yes", + ) + if _print_stream: + print(delta, end="", flush=True) + if data.get("done"): + print(flush=True) + logger.debug( + "← llm.text_delta done=%s delta_chars=%s", + data.get("done"), + len(delta), + ) + continue + + if mtype == "tts_audio_chunk": + _pending_tts_seq = None + if data.get("turn_id") != turn_id: + logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制") + else: + try: + out_sr = int(data.get("sample_rate_hz") or out_sr) + except (TypeError, ValueError): + pass + _s = data.get("seq") + try: + if _s is not None: + _pending_tts_seq = int(_s) + except (TypeError, ValueError): + _pending_tts_seq = None + if data.get("is_final"): + logger.debug("← tts_audio_chunk is_final=true seq=%s", _s) + expecting_binary = True + continue + + if mtype == "dialog_result": + if data.get("turn_id") != turn_id: + raise CloudVoiceError( + "dialog_result turn_id 不匹配", code="INVALID_MESSAGE" + ) + dialog = data + logger.info( + "← dialog_result routing=%s", data.get("routing") + ) + continue + + if mtype == "turn.complete": + if data.get("turn_id") != turn_id: + raise CloudVoiceError( + "turn.complete turn_id 不匹配", code="INVALID_MESSAGE" + ) + metrics = data.get("metrics") or {} + break + + if mtype == "error": + code = str(data.get("code") or "INTERNAL") + raise CloudVoiceError( + data.get("message") or code, + code=code, + retryable=bool(data.get("retryable")), + ) + + logger.debug("忽略服务端消息 type=%s", mtype) + + if dialog is None: + raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE") + + full_pcm = _merge_tts_pcm_chunks(pcm_entries) + out_pcm = ( + np.frombuffer(full_pcm, dtype=np.int16).copy() + if full_pcm + else np.array([], dtype=np.int16) + ) + if out_pcm.size > 0: + mx = int(np.max(np.abs(out_pcm))) + if mx == 0: + logger.warning( + "云端 TTS 已收齐但全零采样,请核对服务端。", + ) + + llm_stream_text = "".join(llm_stream_parts) + return { + "protocol": dialog.get("protocol"), + "routing": dialog.get("routing"), + "flight_intent": dialog.get("flight_intent"), + "confirm": dialog.get("confirm"), + "chat_reply": dialog.get("chat_reply"), + "user_input": dialog.get("user_input"), + "pcm": out_pcm, + "sample_rate_hz": out_sr, + "metrics": metrics, + "llm_stream_text": llm_stream_text, + } + + def run_turn_audio( + self, pcm_int16: np.ndarray, sample_rate_hz: int + ) -> dict[str, Any]: + """上行一轮麦克风 PCM:chunk 均为含 pcm_base64 的文本 JSON;收齐 dialog_result + TTS + turn.complete。""" + try: + raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3")) + except ValueError: + raw_attempts = 3 + attempts = max(1, raw_attempts) + try: + delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35")) + except ValueError: + delay = 0.35 + delay = max(0.0, delay) + + for attempt in range(attempts): + with self._lock: + try: + if self._ws is None: + self._connect_nolock() + return self._execute_turn_audio_nolock(pcm_int16, sample_rate_hz) + except CloudVoiceError as e: + retry = bool(e.retryable) or e.code == "DISCONNECTED" + if retry and attempt < attempts - 1: + logger.warning( + "turn.audio 可恢复错误,重连重试 (%s/%s): %s", + attempt + 1, + attempts, + e, + ) + self._close_nolock() + if delay: + time.sleep(delay) + continue + raise + except Exception as e: + if _transient_ws_exc(e) and attempt < attempts - 1: + logger.warning( + "turn.audio WebSocket 瞬断,重连重试 (%s/%s): %s", + attempt + 1, + attempts, + e, + ) + self._close_nolock() + if delay: + time.sleep(delay) + continue + raise + + raise CloudVoiceError("run_turn_audio 未执行", code="INTERNAL") + + def _execute_tts_synthesize_nolock(self, text: str) -> dict[str, Any]: + """已持锁且 _ws 已连接:发送 tts.synthesize,仅收 tts_audio_chunk* 与 turn.complete(无 dialog_result)。""" + import websocket # websocket-client + + ws = self._ws + if ws is None: + raise CloudVoiceError("WebSocket 未连接") + + turn_id = str(uuid.uuid4()) + synth_msg = { + "type": "tts.synthesize", + "proto_version": _CLOUD_PROTO, + "transport_profile": self._transport_profile, + "turn_id": turn_id, + "text": text, + } + try: + ws.send(json.dumps(synth_msg, ensure_ascii=False)) + except Exception as e: + if _transient_ws_exc(e): + raise + raise CloudVoiceError(f"发送 tts.synthesize 失败: {e}", code="INTERNAL") from e + logger.debug("→ tts.synthesize turn_id=%s", turn_id) + + expecting_binary = False + _pending_tts_seq: int | None = None + pcm_entries: list[tuple[int | None, int, bytes]] = [] + _pcm_arrival = 0 + metrics: dict[str, Any] = {} + sample_rate_hz = 24000 + + while True: + try: + msg = ws.recv() + except websocket.WebSocketConnectionClosedException as e: + raise CloudVoiceError( + f"连接已断开: {e}", + code="DISCONNECTED", + retryable=True, + ) from e + except Exception as e: + if _transient_ws_exc(e): + raise + raise + + if isinstance(msg, bytes): + if expecting_binary: + expecting_binary = False + else: + logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理") + pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg)) + _pcm_arrival += 1 + _pending_tts_seq = None + continue + + if not isinstance(msg, str): + raise CloudVoiceError( + f"期望文本帧为 str,实际为 {type(msg).__name__}", + code="INVALID_MESSAGE", + ) + text_frame = msg.strip() + if not text_frame: + logger.debug("跳过空 WebSocket 文本帧") + continue + try: + data = json.loads(text_frame) + except json.JSONDecodeError as e: + head = text_frame[:200].replace("\n", "\\n") + raise CloudVoiceError( + f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}", + code="INVALID_MESSAGE", + ) from e + mtype = data.get("type") + + if mtype == "asr.partial": + logger.debug("← asr.partial(tts 轮次,忽略)") + continue + + if mtype == "llm.text_delta": + if data.get("turn_id") != turn_id: + logger.debug( + "llm.text_delta turn_id 与当前 tts 不一致,忽略", + ) + continue + + if mtype == "tts_audio_chunk": + _pending_tts_seq = None + if data.get("turn_id") != turn_id: + logger.warning( + "tts_audio_chunk turn_id 与 tts.synthesize 不一致,仍消费后续二进制", + ) + else: + try: + sample_rate_hz = int( + data.get("sample_rate_hz") or sample_rate_hz + ) + except (TypeError, ValueError): + pass + _s = data.get("seq") + try: + if _s is not None: + _pending_tts_seq = int(_s) + except (TypeError, ValueError): + _pending_tts_seq = None + if data.get("is_final"): + logger.debug("← tts_audio_chunk is_final=true seq=%s", _s) + expecting_binary = True + continue + + if mtype == "dialog_result": + logger.debug("tts.synthesize 收到 dialog_result(非预期),忽略") + continue + + if mtype == "turn.complete": + if data.get("turn_id") != turn_id: + raise CloudVoiceError( + "turn.complete turn_id 不匹配", code="INVALID_MESSAGE" + ) + metrics = data.get("metrics") or {} + break + + if mtype == "error": + code = str(data.get("code") or "INTERNAL") + raise CloudVoiceError( + data.get("message") or code, + code=code, + retryable=bool(data.get("retryable")), + ) + + logger.debug("忽略服务端消息 type=%s", mtype) + + full_pcm = _merge_tts_pcm_chunks(pcm_entries) + pcm = ( + np.frombuffer(full_pcm, dtype=np.int16).copy() + if full_pcm + else np.array([], dtype=np.int16) + ) + if pcm.size > 0: + mx = int(np.max(np.abs(pcm))) + if mx == 0: + logger.warning( + "tts.synthesize 收齐 PCM 但全零(服务端静音占位);总长 %s 字节", + len(full_pcm), + ) + + return { + "pcm": pcm, + "sample_rate_hz": sample_rate_hz, + "metrics": metrics, + } + + def run_tts_synthesize(self, text: str) -> dict[str, Any]: + """ + 发送 tts.synthesize,收齐 TTS 块与 turn.complete(无 dialog_result)。 + 与 run_turn 共用连接,互斥由服务端排队;重试策略同 ROCKET_CLOUD_TURN_RETRIES。 + """ + t = (text or "").strip() + if not t: + raise CloudVoiceError("tts.synthesize text 不能为空") + + try: + raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3")) + except ValueError: + raw_attempts = 3 + attempts = max(1, raw_attempts) + try: + delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35")) + except ValueError: + delay = 0.35 + delay = max(0.0, delay) + + for attempt in range(attempts): + with self._lock: + try: + if self._ws is None: + self._connect_nolock() + return self._execute_tts_synthesize_nolock(t) + except CloudVoiceError as e: + retry = bool(e.retryable) or e.code == "DISCONNECTED" + if retry and attempt < attempts - 1: + logger.warning( + "tts.synthesize 可恢复错误,将重连并重试 (%s/%s): %s", + attempt + 1, + attempts, + e, + ) + self._close_nolock() + if delay: + time.sleep(delay) + continue + raise + except Exception as e: + if _transient_ws_exc(e) and attempt < attempts - 1: + logger.warning( + "tts.synthesize WebSocket 瞬断,重连并重试 (%s/%s): %s", + attempt + 1, + attempts, + e, + ) + self._close_nolock() + if delay: + time.sleep(delay) + continue + raise + + raise CloudVoiceError("run_tts_synthesize 未执行", code="INTERNAL") + + def run_turn(self, text: str) -> dict[str, Any]: + """ + 发送一轮用户文本,收齐 dialog_result、TTS 块、turn.complete。 + + 支持流式下行:可先于 dialog_result 收到 tts_audio_chunk+PCM 与 llm.text_delta; + 飞控与最终文案仍以 dialog_result 为准。 + + 若中间因对端已关 TCP、ping/pong Broken pipe 等断开,会自动关连接、 + 重连 session 并重发本轮(次数由 ROCKET_CLOUD_TURN_RETRIES 控制,默认 3)。 + + Returns: + dict: routing, flight_intent, chat_reply, user_input, pcm, sample_rate_hz, + metrics, llm_stream_text(llm.text_delta 拼接,可选调试/UI) + """ + t = (text or "").strip() + if not t: + raise CloudVoiceError("turn.text 不能为空") + + try: + raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3")) + except ValueError: + raw_attempts = 3 + attempts = max(1, raw_attempts) + try: + delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35")) + except ValueError: + delay = 0.35 + delay = max(0.0, delay) + + for attempt in range(attempts): + with self._lock: + try: + if self._ws is None: + self._connect_nolock() + return self._execute_turn_nolock(t) + except CloudVoiceError as e: + retry = bool(e.retryable) or e.code == "DISCONNECTED" + if retry and attempt < attempts - 1: + logger.warning( + "云端回合可恢复错误,将重连并重试 (%s/%s): %s", + attempt + 1, + attempts, + e, + ) + self._close_nolock() + if delay: + time.sleep(delay) + continue + raise + except Exception as e: + if _transient_ws_exc(e) and attempt < attempts - 1: + logger.warning( + "云端 WebSocket 瞬断(如对端先关、PONG 写失败)," + "重连并重发 turn (%s/%s): %s", + attempt + 1, + attempts, + e, + ) + self._close_nolock() + if delay: + time.sleep(delay) + continue + raise + + raise CloudVoiceError("run_turn 未执行", code="INTERNAL") diff --git a/voice_drone/core/command.py b/voice_drone/core/command.py new file mode 100644 index 0000000..613b86a --- /dev/null +++ b/voice_drone/core/command.py @@ -0,0 +1,205 @@ +from pydantic import BaseModel, Field +from typing import Dict, Any, Optional, Literal +from datetime import datetime +from voice_drone.core.configuration import ( + TAKEOFF_CONFIG, + LAND_CONFIG, + FOLLOW_CONFIG, + FORWARD_CONFIG, + BACKWARD_CONFIG, + LEFT_CONFIG, + RIGHT_CONFIG, + UP_CONFIG, + DOWN_CONFIG, + HOVER_CONFIG, + RETURN_HOME_CONFIG, +) +import warnings +warnings.filterwarnings("ignore") + + +class CommandParams(BaseModel): + """ + 命令参数 + """ + distance: Optional[float] = Field( + None, + description="飞行距离,单位:米(m),必须大于等于0(land/hover 可以为0)", + ge=0 + ) + speed: Optional[float] = Field( + None, + description="飞行速度,单位:米每秒(m/s),必须大于等于0(land/hover 可以为0)", + ge=0 + ) + duration: Optional[float] = Field( + None, + description="飞行持续时间,单位:秒(s),必须大于0", + gt=0 + ) + +class Command(BaseModel): + """ + 无人机控制命令 + """ + command: Literal[ + "takeoff", + "follow", + "forward", + "backward", + "left", + "right", + "up", + "down", + "hover", + "land", + "return_home", + ] = Field( + ..., + description="无人机控制动作: takeoff(起飞), follow(跟随), forward(向前), backward(向后), left(向左), right(向右), up(向上), down(向下), hover(悬停), land(降落), return_home(返航)", + ) + params: CommandParams = Field(..., description="命令参数") + timestamp: str = Field(..., description="命令生成时间戳,ISO 8601 格式(如:2024-01-01T12:00:00.000Z)") + sequence_id: int = Field(..., description="命令序列号,用于保证命令顺序和去重") + + # 命令配置映射字典 + _CONFIG_MAP = { + "takeoff": TAKEOFF_CONFIG, + "follow": FOLLOW_CONFIG, + "land": LAND_CONFIG, + "forward": FORWARD_CONFIG, + "backward": BACKWARD_CONFIG, + "left": LEFT_CONFIG, + "right": RIGHT_CONFIG, + "up": UP_CONFIG, + "down": DOWN_CONFIG, + "hover": HOVER_CONFIG, + "return_home": RETURN_HOME_CONFIG, + } + + # 创建命令 + @classmethod + def create( + cls, + command: str, + sequence_id: int, + distance: Optional[float] = None, + speed: Optional[float] = None, + duration: Optional[float] = None + ) -> "Command": + return cls( + command=command, + params=CommandParams(distance=distance, speed=speed, duration=duration), + timestamp=datetime.utcnow().isoformat() + "Z", + sequence_id=sequence_id, + ) + + def _get_default_config(self): + """获取当前命令的默认配置""" + return self._CONFIG_MAP.get(self.command) + + # 填充默认值 + def fill_defaults(self) -> None: + """填充缺失的参数值""" + # 如果所有参数都已提供,直接返回 + if (self.params.distance is not None and + self.params.speed is not None and + self.params.duration is not None): + return + + # 如果有缺失的参数,调用智能填充方法 + self._fill_smart_params() + + def _fill_smart_params(self): + """智能填充缺失的参数值""" + default = self._get_default_config() + if default is None: + # 若命令未知,则直接返回不填充 + return + + d = self.params.distance + s = self.params.speed + t = self.params.duration + + # 统计 None 个数 + none_cnt = sum(x is None for x in [d, s, t]) + + # 三个都为None,直接填默认值 + if none_cnt == 3: + self.params.distance = default["distance"] + self.params.speed = default["speed"] + self.params.duration = default["duration"] + return + + # 只有一个参数有值的情况 + if none_cnt == 2: + if s is not None and d is None and t is None: + # 仅速度:使用默认持续时间,计算距离 + self.params.duration = default["duration"] + self.params.distance = s * self.params.duration + return + + if t is not None and d is None and s is None: + # 仅持续时间:使用默认速度,计算距离 + self.params.speed = default["speed"] + self.params.distance = self.params.speed * t + return + + if d is not None and s is None and t is None: + # 仅距离:使用默认速度,计算持续时间 + self.params.speed = default["speed"] + # 防止除以0 + if self.params.speed == 0: + self.params.duration = 0 + else: + self.params.duration = d / self.params.speed + return + + # 两个参数有值,一个None,自动计算缺失的参数 + if none_cnt == 1: + if d is None and s is not None and t is not None: + # 缺失距离:distance = speed * duration + self.params.distance = s * t + return + + if s is None and d is not None and t is not None: + # 缺失速度:speed = distance / duration + if t == 0: + self.params.speed = 0 + else: + self.params.speed = d / t + return + + if t is None and d is not None and s is not None: + # 缺失持续时间:duration = distance / speed + if s == 0: + self.params.duration = 0 + else: + self.params.duration = d / s + return + + + # 转换为字典 + def to_dict(self) -> dict: + + result = { + "command": self.command, + "params": {}, + "timestamp": self.timestamp, + "sequence_id": self.sequence_id + } + + if self.params.distance is None or self.params.speed is None or self.params.duration is None: + self.fill_defaults() + + result["params"]["distance"] = self.params.distance + result["params"]["speed"] = self.params.speed + result["params"]["duration"] = self.params.duration + + return result + + + +if __name__ == "__main__": + command = Command.create("takeoff", 1, speed=2) + print(command.to_dict()) \ No newline at end of file diff --git a/voice_drone/core/configuration.py b/voice_drone/core/configuration.py new file mode 100644 index 0000000..969e134 --- /dev/null +++ b/voice_drone/core/configuration.py @@ -0,0 +1,209 @@ +import os +from pathlib import Path + +from voice_drone.tools.config_loader import load_config +from voice_drone.logging_ import get_logger + +_cfg_log = get_logger("voice_drone.configuration") + +# voice_drone/core/configuration.py -> 工程根目录 voice_drone_assistant 为 parents[2] +_PROJECT_ROOT = Path(__file__).resolve().parents[2] + + +def _abs_config_path(relative: str) -> str: + p = Path(relative) + if p.is_absolute(): + return str(p) + return str(_PROJECT_ROOT / p) + + +# 系统配置加载器 +class SystemConfigLoader: + def __init__(self, config_path="voice_drone/config/system.yaml"): + self.config_path = _abs_config_path(config_path) + self.config = load_config(self.config_path) + + # 获取音频配置 + def get_audio_config(self): + return self.config["audio"] + + # 获取语音活动检测配置 + def get_vad_config(self): + return self.config["vad"] + + # 获取socket配置 + def get_socket_server_config(self): + return self.config["socket_server"] + + # 获取日志配置 + def get_logging_config(self): + return self.config["logging"] + + # 获取STT配置 + def get_stt_config(self): + return self.config["stt"] + + # 获取TTS配置 + def get_tts_config(self): + """ + 获取文本转语音(TTS)配置 + + 结构示例: + tts: + model_dir: "src/models/Kokoro-82M-v1.1-zh-ONNX" + model_name: "model_q4.onnx" # 可选: model.onnx, model_fp16.onnx, model_quantized.onnx 等 + voice: "zf_001" # 语音风格文件名(不含扩展名) + speed: 1.0 # 语速系数 + sample_rate: 24000 # 输出采样率 + """ + return self.config.get("tts", {}) + + # 获取文本预处理配置 + def get_text_preprocessor_config(self): + return self.config.get("text_preprocessor", {}) + + # 获取识别器流程配置 + def get_recognizer_config(self): + return self.config.get("recognizer", {}) + + # 云端语音(WebSocket / pcm_asr_uplink 会话) + def get_cloud_voice_config(self): + return self.config.get("cloud_voice", {}) + + # 主程序 TakeoffPrintRecognizer(main_app) + def get_assistant_config(self): + return self.config.get("assistant", {}) + +# 命令配置加载器 +class CommandConfigLoader: + def __init__(self, config_path="voice_drone/config/command_.yaml"): + self.config_path = _abs_config_path(config_path) + self.config = load_config(self.config_path)["control_params"] + + def get_takeoff_config(self): + return self.config["takeoff"] + def get_land_config(self): + return self.config["land"] + def get_follow_config(self): + return self.config["follow"] + def get_forward_config(self): + return self.config["forward"] + def get_backward_config(self): + return self.config["backward"] + def get_left_config(self): + return self.config["left"] + def get_right_config(self): + return self.config["right"] + def get_up_config(self): + return self.config["up"] + def get_down_config(self): + return self.config["down"] + def get_hover_config(self): + return self.config["hover"] + + def get_return_home_config(self): + return self.config["return_home"] + +# 关键词配置加载器 +class KeywordsConfigLoader: + def __init__(self, config_path="voice_drone/config/keywords.yaml"): + self.config_path = _abs_config_path(config_path) + self.config = load_config(self.config_path)["keywords"] + + def get_keywords(self): + return self.config + +# 唤醒词配置加载器 +class WakeWordConfigLoader: + def __init__(self, config_path="voice_drone/config/wake_word.yaml"): + self.config_path = _abs_config_path(config_path) + self.config = load_config(self.config_path)["wake_word"] + + def get_primary(self): + return self.config.get("primary", "") + + def get_variants(self): + return self.config.get("variants", []) + + def get_matching_config(self): + return self.config.get("matching", {}) + + +# 系统配置常量 +system_config = SystemConfigLoader() +SYSTEM_AUDIO_CONFIG = system_config.get_audio_config() +SYSTEM_VAD_CONFIG = system_config.get_vad_config() +SYSTEM_SOCKET_SERVER_CONFIG = system_config.get_socket_server_config() +SYSTEM_LOGGING_CONFIG = system_config.get_logging_config() +SYSTEM_STT_CONFIG = system_config.get_stt_config() +SYSTEM_TTS_CONFIG = system_config.get_tts_config() +SYSTEM_TEXT_PREPROCESSOR_CONFIG = system_config.get_text_preprocessor_config() +SYSTEM_RECOGNIZER_CONFIG = system_config.get_recognizer_config() +SYSTEM_CLOUD_VOICE_CONFIG = system_config.get_cloud_voice_config() +SYSTEM_ASSISTANT_CONFIG = system_config.get_assistant_config() + + +def load_cloud_voice_px4_context() -> dict: + """ + 加载合并到云端 session.start.client 的 PX4/MAV 扩展字段。 + 路径:环境变量 ROCKET_CLOUD_PX4_CONTEXT_FILE,否则 cloud_voice.px4_context_file(相对工程根)。 + """ + cv = SYSTEM_CLOUD_VOICE_CONFIG if isinstance(SYSTEM_CLOUD_VOICE_CONFIG, dict) else {} + raw = (os.environ.get("ROCKET_CLOUD_PX4_CONTEXT_FILE") or "").strip() + rel = raw or str(cv.get("px4_context_file") or "").strip() + if not rel: + return {} + p = Path(rel) + if not p.is_absolute(): + p = _PROJECT_ROOT / rel + if not p.is_file(): + _cfg_log.warning("cloud_voice PX4 上下文文件不存在,已跳过: %s", p) + return {} + try: + data = load_config(str(p)) + except Exception as e: # noqa: BLE001 + _cfg_log.warning("读取 PX4 上下文 YAML 失败: %s — %s", p, e) + return {} + if not isinstance(data, dict): + return {} + return data + + +SYSTEM_CLOUD_VOICE_PX4_CONTEXT = load_cloud_voice_px4_context() + +# 命令配置常量 +command_config = CommandConfigLoader() +TAKEOFF_CONFIG = command_config.get_takeoff_config() +LAND_CONFIG = command_config.get_land_config() +FOLLOW_CONFIG = command_config.get_follow_config() +FORWARD_CONFIG = command_config.get_forward_config() +BACKWARD_CONFIG = command_config.get_backward_config() +LEFT_CONFIG = command_config.get_left_config() +RIGHT_CONFIG = command_config.get_right_config() +UP_CONFIG = command_config.get_up_config() +DOWN_CONFIG = command_config.get_down_config() +HOVER_CONFIG = command_config.get_hover_config() +RETURN_HOME_CONFIG = command_config.get_return_home_config() + +# 关键词配置常量 +keywords_config = KeywordsConfigLoader() +KEYWORDS_CONFIG = keywords_config.get_keywords() + +# 唤醒词配置常量 +wake_word_config = WakeWordConfigLoader() +WAKE_WORD_PRIMARY = wake_word_config.get_primary() +WAKE_WORD_VARIANTS = wake_word_config.get_variants() +WAKE_WORD_MATCHING_CONFIG = wake_word_config.get_matching_config() + + +if __name__ == "__main__": + print(TAKEOFF_CONFIG) + print(LAND_CONFIG) + print(FORWARD_CONFIG) + print(BACKWARD_CONFIG) + print(LEFT_CONFIG) + print(RIGHT_CONFIG) + print(UP_CONFIG) + print(DOWN_CONFIG) + print(HOVER_CONFIG) + print(KEYWORDS_CONFIG) \ No newline at end of file diff --git a/voice_drone/core/flight_intent.py b/voice_drone/core/flight_intent.py new file mode 100644 index 0000000..dbac4d3 --- /dev/null +++ b/voice_drone/core/flight_intent.py @@ -0,0 +1,338 @@ +""" +flight_intent v1 校验与辅助(对齐 docs/FLIGHT_INTENT_SCHEMA_v1.md)。 + +兼容 Pydantic v2(field_validator / ConfigDict)。 + +互操作:部分云端会把「悬停 N 秒」写成 hover.args.duration;规范建议用 hover + wait.seconds。 +解析时会折叠为 hover(无 duration)+ wait(seconds),与机端执行一致。 +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Literal, Optional, Tuple, Union + +from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator + +_COORD_CAP = 10_000.0 + + +def _check_coord(name: str, v: Optional[float]) -> Optional[float]: + if v is None: + return None + if not isinstance(v, (int, float)) or not math.isfinite(float(v)): + raise ValueError(f"{name} must be a finite number") + fv = float(v) + if abs(fv) > _COORD_CAP: + raise ValueError(f"{name} out of range (|.| <= {_COORD_CAP})") + return fv + + +# --- args ----------------------------------------------------------------- + + +class TakeoffArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + + relative_altitude_m: Optional[float] = None + + @field_validator("relative_altitude_m") + @classmethod + def _alt(cls, v: Optional[float]) -> Optional[float]: + if v is not None and v <= 0: + raise ValueError("relative_altitude_m must be > 0 when set") + return v + + +class EmptyArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class HoverHoldArgs(BaseModel): + """hover / hold:规范仅 {}; 为兼容云端可带 duration(秒),解析后展开为 wait。""" + + model_config = ConfigDict(extra="forbid") + + duration: Optional[float] = None + + @field_validator("duration") + @classmethod + def _dur(cls, v: Optional[float]) -> Optional[float]: + if v is None: + return None + fv = float(v) + if not (0 < fv <= 3600): + raise ValueError("duration must satisfy 0 < duration <= 3600 when set") + return fv + + +class WaitArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + + seconds: float + + @field_validator("seconds") + @classmethod + def _rng(cls, v: float) -> float: + if not (0 < v <= 3600): + raise ValueError("seconds must satisfy 0 < seconds <= 3600") + return v + + +class GotoArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + + frame: str + x: Optional[float] = None + y: Optional[float] = None + z: Optional[float] = None + + @field_validator("frame") + @classmethod + def _frame(cls, v: str) -> str: + if v not in ("local_ned", "body_ned"): + raise ValueError('frame must be "local_ned" or "body_ned"') + return v + + @field_validator("x", "y", "z") + @classmethod + def _coord(cls, v: Optional[float], info: ValidationInfo) -> Optional[float]: + return _check_coord(str(info.field_name), v) + + +# --- actions -------------------------------------------------------------- + + +class ActionTakeoff(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["takeoff"] = "takeoff" + args: TakeoffArgs + + +class ActionLand(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["land"] = "land" + args: EmptyArgs + + +class ActionReturnHome(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["return_home"] = "return_home" + args: EmptyArgs + + +class ActionHover(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["hover"] = "hover" + args: HoverHoldArgs + + +class ActionHold(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["hold"] = "hold" + args: HoverHoldArgs + + +class ActionGoto(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["goto"] = "goto" + args: GotoArgs + + +class ActionWait(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal["wait"] = "wait" + args: WaitArgs + + +FlightAction = Union[ + ActionTakeoff, + ActionLand, + ActionReturnHome, + ActionHover, + ActionHold, + ActionGoto, + ActionWait, +] + + +class FlightIntentPayload(BaseModel): + model_config = ConfigDict(extra="forbid") + + is_flight_intent: bool + version: int + actions: List[Any] + summary: str + trace_id: Optional[str] = None + + @field_validator("is_flight_intent") + @classmethod + def _flag(cls, v: bool) -> bool: + if v is not True: + raise ValueError("is_flight_intent must be true") + return v + + @field_validator("version") + @classmethod + def _ver(cls, v: int) -> int: + if v != 1: + raise ValueError("version must be 1") + return v + + @field_validator("summary") + @classmethod + def _sum(cls, v: str) -> str: + if not (isinstance(v, str) and v.strip()): + raise ValueError("summary must be non-empty") + return v + + @field_validator("trace_id") + @classmethod + def _tid(cls, v: Optional[str]) -> Optional[str]: + if v is not None and len(v) > 128: + raise ValueError("trace_id length must be <= 128") + return v + + @field_validator("actions") + @classmethod + def _actions_nonempty(cls, v: List[Any]) -> List[Any]: + if not isinstance(v, list) or len(v) == 0: + raise ValueError("actions must be a non-empty array") + return v + + +@dataclass +class ValidatedFlightIntent: + summary: str + trace_id: Optional[str] + actions: List[FlightAction] + + +def _parse_one_action(raw: dict) -> FlightAction: + t = raw.get("type") + if not isinstance(t, str): + raise ValueError("action.type must be a string") + args = raw.get("args") + if not isinstance(args, dict): + raise ValueError("action.args must be an object") + if t == "takeoff": + return ActionTakeoff(args=TakeoffArgs.model_validate(args)) + if t == "land": + return ActionLand(args=EmptyArgs.model_validate(args)) + if t == "return_home": + return ActionReturnHome(args=EmptyArgs.model_validate(args)) + if t == "hover": + return ActionHover(args=HoverHoldArgs.model_validate(args)) + if t == "hold": + return ActionHold(args=HoverHoldArgs.model_validate(args)) + if t == "goto": + return ActionGoto(args=GotoArgs.model_validate(args)) + if t == "wait": + return ActionWait(args=WaitArgs.model_validate(args)) + raise ValueError(f"unknown action.type: {t!r}") + + +def _expand_hover_duration(actions: List[FlightAction]) -> List[FlightAction]: + """将 hover/hold 上附带的 duration 转为标准 wait(seconds)。""" + out: List[FlightAction] = [] + for a in actions: + if isinstance(a, ActionHover): + d = a.args.duration + if d is not None: + out.append(ActionHover(args=HoverHoldArgs())) + out.append(ActionWait(args=WaitArgs(seconds=float(d)))) + else: + out.append(a) + elif isinstance(a, ActionHold): + d = a.args.duration + if d is not None: + out.append(ActionHold(args=HoverHoldArgs())) + out.append(ActionWait(args=WaitArgs(seconds=float(d)))) + else: + out.append(a) + else: + out.append(a) + return out + + +def parse_flight_intent_dict(data: dict) -> Tuple[Optional[ValidatedFlightIntent], List[str]]: + """ + L1–L3 校验。成功返回 (ValidatedFlightIntent, []);失败返回 (None, [错误信息, ...])。 + """ + errors: List[str] = [] + try: + top = FlightIntentPayload.model_validate(data) + except Exception as e: # noqa: BLE001 + return None, [str(e)] + + parsed_actions: List[FlightAction] = [] + for i, item in enumerate(top.actions): + if not isinstance(item, dict): + errors.append(f"actions[{i}] must be an object") + continue + try: + parsed_actions.append(_parse_one_action(item)) + except Exception as e: # noqa: BLE001 + errors.append(f"actions[{i}]: {e}") + + if errors: + return None, errors + + parsed_actions = _expand_hover_duration(parsed_actions) + + if isinstance(parsed_actions[0], ActionWait): + return None, ["first action must not be wait (nothing to control yet)"] + + return ( + ValidatedFlightIntent( + summary=top.summary.strip(), + trace_id=top.trace_id, + actions=parsed_actions, + ), + [], + ) + + +def goto_action_to_command(action: ActionGoto, sequence_id: int) -> Tuple[Optional[Any], Optional[str]]: + """ + 将单轴 goto 映射为现有 Socket Command。 + 返回 (Command | None, error_reason | None)。 + """ + from voice_drone.core.command import Command + + a = action.args + coords = [ + ("x", a.x), + ("y", a.y), + ("z", a.z), + ] + active = [(name, val) for name, val in coords if val is not None and val != 0] + if len(active) == 0: + return None, "goto: all axes omit or zero (no-op)" + if len(active) > 1: + return ( + None, + f"goto: multi-axis ({', '.join(n for n, _ in active)}) not sent via Socket " + "(use bridge or decompose)", + ) + + name, val = active[0] + dist = abs(float(val)) + body_map = { + "x": ("forward", "backward"), + "y": ("right", "left"), + "z": ("down", "up"), + } + pos, neg = body_map[name] + cmd_name = pos if val > 0 else neg + cmd = Command.create(cmd_name, sequence_id, distance=dist) + cmd.fill_defaults() + return cmd, None diff --git a/voice_drone/core/mic_device_select.py b/voice_drone/core/mic_device_select.py new file mode 100644 index 0000000..4f844ae --- /dev/null +++ b/voice_drone/core/mic_device_select.py @@ -0,0 +1,186 @@ +"""启动时列出 arecord -l 与 PyAudio 输入设备,并把 ALSA card/device 映射到 PyAudio 索引供交互选择。""" + +from __future__ import annotations + +import re +import subprocess +from typing import List, Optional, Tuple, Any + +from voice_drone.logging_ import get_logger + +logger = get_logger("mic_device_select") + + +def run_arecord_l() -> str: + try: + r = subprocess.run( + ["arecord", "-l"], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + out = (r.stdout or "").rstrip() + err = (r.stderr or "").strip() + body = out + (f"\n{err}" if err else "") + return body.strip() if body.strip() else "(arecord 无输出)" + except FileNotFoundError: + return "(未找到 arecord,可安装 alsa-utils)" + except Exception as e: # noqa: BLE001 + return f"(执行 arecord -l 失败: {e})" + + +def parse_arecord_capture_lines(text: str) -> List[Tuple[int, int, str]]: + rows: List[Tuple[int, int, str]] = [] + for line in text.splitlines(): + m = re.search(r"card\s+(\d+):.+?,\s*device\s+(\d+):", line, re.IGNORECASE) + if m: + rows.append((int(m.group(1)), int(m.group(2)), line.strip())) + return rows + + +def _pyaudio_input_devices() -> List[Tuple[int, Any]]: + from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio + + fix_ld_path_for_portaudio() + import pyaudio + + pa = pyaudio.PyAudio() + out: List[Tuple[int, Any]] = [] + try: + for i in range(pa.get_device_count()): + try: + inf = pa.get_device_info_by_index(i) + if int(inf.get("maxInputChannels", 0)) <= 0: + continue + out.append((i, inf)) + except Exception: + continue + return out + finally: + pa.terminate() + + +def match_alsa_hw_to_pyaudio_index( + card: int, + dev: int, + pa_items: List[Tuple[int, Any]], +) -> Optional[int]: + want1 = f"(hw:{card},{dev})" + want2 = f"(hw:{card}, {dev})" + for idx, inf in pa_items: + name = str(inf.get("name", "")) + if want1 in name or want2 in name: + return idx + return None + + +def print_mic_device_menu() -> List[int]: + """ + 打印 arecord + PyAudio + 映射表。 + 返回本菜单中列出的 PyAudio 索引列表(顺序与 [1]、[2]… 一致)。 + """ + alsa_text = run_arecord_l() + pa_items = _pyaudio_input_devices() + + print("\n" + "=" * 72, flush=True) + print("录音设备(先 ALSA,再 PortAudio;请记下要用的 PyAudio 索引)", flush=True) + print("=" * 72, flush=True) + print("\n--- arecord -l(系统硬件视角)---\n", flush=True) + print(alsa_text, flush=True) + + print("\n--- PyAudio 可录音设备(maxInputChannels > 0)---\n", flush=True) + ordered_indices: List[int] = [] + if not pa_items: + print("(无)\n", flush=True) + for rank, (idx, inf) in enumerate(pa_items, start=1): + ordered_indices.append(idx) + mic = int(inf.get("maxInputChannels", 0)) + outc = int(inf.get("maxOutputChannels", 0)) + name = str(inf.get("name", "?")) + print( + f" [{rank}] PyAudio_index={idx} in={mic} out={outc} {name}", + flush=True, + ) + + alsa_rows = parse_arecord_capture_lines(alsa_text) + print( + "\n--- 映射:arecord 的 card / device → PyAudio 索引(匹配设备名中的 hw:X,Y)---\n", + flush=True, + ) + if not alsa_rows: + print(" (未解析到 card/device 行,请直接用上一表的 PyAudio_index)", flush=True) + for card, dev, line in alsa_rows: + pidx = match_alsa_hw_to_pyaudio_index(card, dev, pa_items) + short = line if len(line) <= 76 else line[:73] + "..." + if pidx is not None: + print( + f" card {card}, device {dev} → PyAudio 索引 {pidx}\n {short}", + flush=True, + ) + else: + print( + f" card {card}, device {dev} → (无 in>0 设备名含 hw:{card},{dev})\n {short}", + flush=True, + ) + + print( + "\n说明:程序只会用「一个」PyAudio 索引打开麦克风;" + "HDMI 等若 in=0 不会出现在可录音列表。\n" + + "=" * 72 + + "\n", + flush=True, + ) + return ordered_indices + + +def prompt_for_input_device_index() -> int: + """交互式选择,返回写入 audio.input_device_index 的 PyAudio 索引。""" + ordered = print_mic_device_menu() + if not ordered: + print("错误:没有发现可录音的 PyAudio 设备。", flush=True) + raise SystemExit(2) + + valid_set = set(ordered) + print( + "请输入菜单序号 [1-" + f"{len(ordered)}](推荐),或直接输入 PyAudio_index 数字;q 退出。", + flush=True, + ) + while True: + try: + raw = input("录音设备> ").strip() + except EOFError: + raise SystemExit(1) from None + if not raw: + continue + if raw.lower() in ("q", "quit", "exit"): + raise SystemExit(0) + if not raw.isdigit(): + print("请输入正整数或 q。", flush=True) + continue + n = int(raw) + if 1 <= n <= len(ordered): + chosen = ordered[n - 1] + print(f"已选择:菜单 [{n}] → PyAudio 索引 {chosen}\n", flush=True) + logger.info("交互选择录音设备 PyAudio index=%s", chosen) + return chosen + if n in valid_set: + print(f"已选择:PyAudio 索引 {n}\n", flush=True) + logger.info("交互选择录音设备 PyAudio index=%s", n) + return n + print( + f"无效:{n} 不在可选列表。可选序号为 1~{len(ordered)}," + f"或索引之一 {sorted(valid_set)}。", + flush=True, + ) + + +def apply_input_device_index_only(index: int) -> None: + """写入运行时配置:仅用索引选设备,其余 yaml 中的 hw/名称匹配不再参与 logic。""" + from voice_drone.core.configuration import SYSTEM_AUDIO_CONFIG + + SYSTEM_AUDIO_CONFIG["input_device_index"] = int(index) + SYSTEM_AUDIO_CONFIG["input_hw_card_device"] = None + SYSTEM_AUDIO_CONFIG["input_device_name_match"] = None + SYSTEM_AUDIO_CONFIG["input_strict_selection"] = False diff --git a/voice_drone/core/portaudio_env.py b/voice_drone/core/portaudio_env.py new file mode 100644 index 0000000..50e6b91 --- /dev/null +++ b/voice_drone/core/portaudio_env.py @@ -0,0 +1,60 @@ +"""PortAudio/PyAudio 启动前调整动态库搜索路径。 + +conda 环境下 `.../envs/xxx/lib` 里的 libasound 会到同前缀的 alsa-lib 子目录找插件, +该目录常缺 libasound_module_*.so,日志里刷屏且采集电平可能异常。 + +处理:1) 去掉仅含插件目录的路径;2) 把系统 /usr/lib/ 插到 LD_LIBRARY_PATH 最前, +让动态链接优先用系统的 libasound。""" +from __future__ import annotations + +import os +import platform + + +def strip_conda_alsa_from_ld_library_path() -> None: + ld = os.environ.get("LD_LIBRARY_PATH", "") + if not ld: + return + parts: list[str] = [] + for p in ld.split(":"): + if not p: + continue + pl = p.lower() + if "conda" in pl or "miniconda" in pl or "mamba" in pl: + if "alsa-lib" in pl or "alsa_lib" in pl: + continue + parts.append(p) + if parts: + os.environ["LD_LIBRARY_PATH"] = ":".join(parts) + else: + os.environ.pop("LD_LIBRARY_PATH", None) + + +def prepend_system_lib_dirs_for_alsa() -> None: + """Linux:把系统 lib 目录放在 LD_LIBRARY_PATH 最前面。""" + if platform.system() != "Linux": + return + triplet = { + "aarch64": ("/usr/lib/aarch64-linux-gnu", "/lib/aarch64-linux-gnu"), + "x86_64": ("/usr/lib/x86_64-linux-gnu", "/lib/x86_64-linux-gnu"), + "amd64": ("/usr/lib/x86_64-linux-gnu", "/lib/x86_64-linux-gnu"), + }.get(platform.machine().lower()) + if not triplet: + return + prepend: list[str] = [] + for d in triplet: + if os.path.isdir(d): + prepend.append(d) + if not prepend: + return + rest = [p for p in os.environ.get("LD_LIBRARY_PATH", "").split(":") if p] + out: list[str] = [] + for p in prepend + rest: + if p not in out: + out.append(p) + os.environ["LD_LIBRARY_PATH"] = ":".join(out) + + +def fix_ld_path_for_portaudio() -> None: + prepend_system_lib_dirs_for_alsa() + strip_conda_alsa_from_ld_library_path() diff --git a/voice_drone/core/qwen_intent_chat.py b/voice_drone/core/qwen_intent_chat.py new file mode 100644 index 0000000..740837d --- /dev/null +++ b/voice_drone/core/qwen_intent_chat.py @@ -0,0 +1,115 @@ +"""与 scripts/qwen_flight_intent_sim.py 对齐:飞控意图 JSON vs 闲聊,供语音主程序内调用。""" + +from __future__ import annotations + +import json +import os +import re +from pathlib import Path +from typing import Any, Optional, Tuple + +# 与 qwen_flight_intent_sim._SYSTEM 保持一致 +FLIGHT_INTENT_CHAT_SYSTEM = """你是无人机飞控意图助手,只做两件事(必须二选一): + +【规则 A — 飞控相关】当用户话里包含对无人机的飞行任务、航线、起降、返航、悬停、等待、速度高度、坐标点、offboard、PX4/MAVROS 等操作意图时: +只输出一行 JSON,且不要有任何其它字符、不要 Markdown、不要代码块。 +JSON Schema 含义(见仓库 docs/FLIGHT_INTENT_SCHEMA_v1.md): +{ + "is_flight_intent": true, + "version": 1, + "actions": [ // 按时间顺序排列 + {"type": "takeoff", "args": {}}, + {"type": "takeoff", "args": {"relative_altitude_m": number}}, + {"type": "goto", "args": {"frame": "local_ned"|"body_ned", "x": number|null, "y": number|null, "z": number|null}}, + {"type": "land" | "return_home" | "hover" | "hold", "args": {}}, + {"type": "wait", "args": {"seconds": number}} + ], + "summary": "一句话概括", + "trace_id": "可选,简短追踪ID" +} +- 停多久、延迟多久必须用 wait,例如「悬停 3 秒再降落」应为 hover → wait(3) → land;不要把秒数写进 summary 代替 wait。 +- 坐标缺省 frame 时用 "local_ned";无法确定的数字可省略字段或用 null。 +- 返程/返航映射为 {"type":"return_home","args":{}}。 +- 仅允许小写 type;args 只含规范允许键,禁止多余键。 + +【规则 B — 非飞控】若只是日常聊天、与无人机任务无关:用正常的自然中文回复,不要输出 JSON,不要用花括号开头。""" + + +def _strip_fenced_json(text: str) -> str: + text = text.strip() + m = re.match(r"^```(?:json)?\s*\n?(.*)\n?```\s*$", text, re.DOTALL | re.IGNORECASE) + if m: + return m.group(1).strip() + return text + + +def _first_balanced_json_object(text: str) -> Optional[str]: + t = _strip_fenced_json(text) + start = t.find("{") + if start < 0: + return None + depth = 0 + for i in range(start, len(t)): + if t[i] == "{": + depth += 1 + elif t[i] == "}": + depth -= 1 + if depth == 0: + return t[start : i + 1] + return None + + +def parse_flight_intent_reply(raw: str) -> Tuple[str, Optional[dict[str, Any]]]: + """返回 (模式标签, 若为飞控则 dict 否则 None)。""" + chunk = _first_balanced_json_object(raw) + if chunk: + try: + obj = json.loads(chunk) + except json.JSONDecodeError: + return "闲聊", None + if isinstance(obj, dict) and obj.get("is_flight_intent") is True: + return "飞控意图JSON", obj + return "闲聊", None + + +def default_qwen_gguf_path(project_root: Path) -> Path: + """子工程优先本目录 cache/;不存在时回退到上级仓库(同级 rocket_drone_audio/cache/)。""" + name = "qwen2.5-1.5b-instruct-q4_k_m.gguf" + primary = project_root / "cache" / "qwen25-1.5b-gguf" / name + if primary.is_file(): + return primary + legacy = project_root.parent / "cache" / "qwen25-1.5b-gguf" / name + if legacy.is_file(): + return legacy + return primary + + +def load_llama_qwen( + model_path: Path, + n_ctx: int = 4096, +): + """ + llama-cpp-python 封装。可选环境变量(详见 rocket_drone_audio 文件头): + ROCKET_LLM_N_THREADS、ROCKET_LLM_N_GPU_LAYERS、ROCKET_LLM_N_BATCH。 + """ + if not model_path.is_file(): + return None + try: + from llama_cpp import Llama + except ImportError: + return None + opts: dict = { + "model_path": str(model_path), + "n_ctx": int(n_ctx), + "verbose": False, + } + nt = os.environ.get("ROCKET_LLM_N_THREADS", "").strip() + if nt.isdigit() or (nt.startswith("-") and nt[1:].isdigit()): + opts["n_threads"] = max(1, int(nt)) + ng = os.environ.get("ROCKET_LLM_N_GPU_LAYERS", "").strip() + if ng.isdigit(): + opts["n_gpu_layers"] = int(ng) + nb = os.environ.get("ROCKET_LLM_N_BATCH", "").strip() + if nb.isdigit(): + opts["n_batch"] = max(1, int(nb)) + return Llama(**opts) diff --git a/voice_drone/core/recognizer.py b/voice_drone/core/recognizer.py new file mode 100644 index 0000000..6f1536b --- /dev/null +++ b/voice_drone/core/recognizer.py @@ -0,0 +1,969 @@ +""" +高性能实时语音识别与命令生成系统 + +整合所有模块,实现从语音检测到命令发送的完整流程: +1. 音频采集(高性能模式) +2. 音频预处理(降噪+AGC) +3. VAD语音活动检测 +4. STT语音识别 +5. 文本预处理(纠错+参数提取) +6. 命令生成 +7. Socket发送 + +性能优化: +- 多线程异步处理 +- 非阻塞音频采集 +- LRU缓存优化 +- 低延迟设计 +""" + +import math +import numpy as np +import os +import random +import threading +import queue +import time +from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from voice_drone.core.audio import AudioCapture, AudioPreprocessor +from voice_drone.core.vad import VAD +from voice_drone.core.stt import STT +from voice_drone.core.text_preprocessor import TextPreprocessor, get_preprocessor +from voice_drone.core.command import Command +from voice_drone.core.scoket_client import SocketClient +from voice_drone.core.configuration import ( + SYSTEM_AUDIO_CONFIG, + SYSTEM_RECOGNIZER_CONFIG, + SYSTEM_SOCKET_SERVER_CONFIG, +) +from voice_drone.core.tts_ack_cache import ( + compute_ack_pcm_fingerprint, + load_cached_phrases, + persist_phrases, +) +from voice_drone.core.wake_word import WakeWordDetector, get_wake_word_detector +from voice_drone.logging_ import get_logger + +if TYPE_CHECKING: + from voice_drone.core.tts import KokoroOnnxTTS + +logger = get_logger("recognizer") + + +class VoiceCommandRecognizer: + """ + 高性能实时语音命令识别器 + + 完整的语音转命令系统,包括: + - 音频采集和预处理 + - 语音活动检测 + - 语音识别 + - 文本预处理和参数提取 + - 命令生成 + - Socket发送 + """ + + def __init__(self, auto_connect_socket: bool = True): + """ + 初始化语音命令识别器 + + Args: + auto_connect_socket: 是否自动连接Socket服务器 + """ + logger.info("初始化语音命令识别系统...") + + # 初始化各模块 + self.audio_capture = AudioCapture() + self.audio_preprocessor = AudioPreprocessor() + self.vad = VAD() + self.stt = STT() + self.text_preprocessor = get_preprocessor() # 使用全局单例 + self.wake_word_detector = get_wake_word_detector() # 使用全局单例 + + # Socket客户端 + self.socket_client = SocketClient(SYSTEM_SOCKET_SERVER_CONFIG) + self.auto_connect_socket = auto_connect_socket + if self.auto_connect_socket: + if not self.socket_client.connect(): + logger.warning("Socket连接失败,将在发送命令时重试") + + # 语音段缓冲区 + self.speech_buffer: list = [] # 存储语音音频块 + self.speech_buffer_lock = threading.Lock() + + # 预缓冲区:保存语音检测前一小段音频,避免丢失开头 + # 例如:pre_speech_max_seconds = 0.8 表示保留最近约 0.8 秒音频 + self.pre_speech_buffer: list = [] # 存储最近的静音/背景音块 + # 从系统配置读取(确保类型正确:YAML 可能把数值当字符串) + self.pre_speech_max_seconds: float = float( + SYSTEM_RECOGNIZER_CONFIG.get("pre_speech_max_seconds", 0.8) + ) + self.pre_speech_max_chunks: Optional[int] = None # 根据采样率和chunk大小动态计算 + + # 命令发送成功后的 TTS 反馈(懒加载 Kokoro,避免拖慢启动) + self.ack_tts_enabled = bool(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_enabled", True)) + self.ack_tts_text = str(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_text", "好的收到")).strip() + self.ack_tts_phrases: Dict[str, List[str]] = self._normalize_ack_tts_phrases( + SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_phrases") + ) + # True:仅 ack_tts_phrases 中出现的命令会播报,且每次随机一句;False:全局 ack_tts_text(所有成功命令同一应答) + self._ack_mode_phrases: bool = bool(self.ack_tts_phrases) + self.ack_tts_prewarm = bool(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_prewarm", True)) + self.ack_tts_prewarm_blocking = bool( + SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_prewarm_blocking", True) + ) + self.ack_pause_mic_for_playback = bool( + SYSTEM_RECOGNIZER_CONFIG.get("ack_pause_mic_for_playback", True) + ) + self.ack_tts_disk_cache = bool( + SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_disk_cache", True) + ) + self._tts_engine: Optional["KokoroOnnxTTS"] = None + # 阻塞预加载时缓存波形:全局单句 _tts_ack_pcm,或按命令随机模式下的 _tts_phrase_pcm_cache(每句一条) + self._tts_ack_pcm: Optional[Tuple[np.ndarray, int]] = None + self._tts_phrase_pcm_cache: Dict[str, Tuple[np.ndarray, int]] = {} + self._tts_lock = threading.Lock() + # 命令线程只入队,主线程 process_audio_stream 中统一播放(避免 Windows 下后台线程 sd.play 无声) + self._ack_playback_queue: queue.Queue = queue.Queue(maxsize=8) + + # STT识别线程和队列 + self.stt_queue = queue.Queue(maxsize=5) # STT识别队列 + self.stt_thread: Optional[threading.Thread] = None + + # 命令处理线程和队列 + self.command_queue = queue.Queue(maxsize=10) # 命令处理队列 + self.command_thread: Optional[threading.Thread] = None + + # 运行状态 + self.running = False + + # 命令序列号(用于去重和顺序保证) + self.sequence_id = 0 + self.sequence_lock = threading.Lock() + + logger.info( + f"应答TTS配置: enabled={self.ack_tts_enabled}, " + f"mode={'按命令随机短语' if self._ack_mode_phrases else '全局固定文案'}, " + f"prewarm_blocking={self.ack_tts_prewarm_blocking}, " + f"pause_mic={self.ack_pause_mic_for_playback}, " + f"disk_cache={self.ack_tts_disk_cache}" + ) + if self._ack_mode_phrases: + logger.info(f" 仅播报命令: {list(self.ack_tts_phrases.keys())}") + + # VAD 后端:silero(默认)或 energy(按块 RMS,Silero 在部分板载麦上长期无段时使用) + _ev_env = os.environ.get("ROCKET_ENERGY_VAD", "").lower() in ( + "1", + "true", + "yes", + ) + _yaml_backend = str( + SYSTEM_RECOGNIZER_CONFIG.get("vad_backend", "silero") + ).lower() + self._use_energy_vad: bool = _ev_env or _yaml_backend == "energy" + self._energy_rms_high: float = float( + SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_rms_high", 280) + ) + self._energy_rms_low: float = float( + SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_rms_low", 150) + ) + self._energy_start_chunks: int = int( + SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_start_chunks", 4) + ) + self._energy_end_chunks: int = int( + SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_end_chunks", 15) + ) + # 高噪底/AGC 下 RMS 几乎不低于 energy_vad_rms_low 时,用「相对本段峰值」辅助判停 + self._energy_end_peak_ratio: float = float( + SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_end_peak_ratio", 0.88) + ) + # 说话过程中对 utt 峰值每块乘衰减再与当前 rms 取 max,避免前几个字特响导致后半句一直被判「相对衰减」而误切段 + self._energy_utt_peak_decay: float = float( + SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_utt_peak_decay", 0.988) + ) + self._energy_utt_peak_decay = max(0.95, min(0.9999, self._energy_utt_peak_decay)) + self._ev_speaking: bool = False + self._ev_high_run: int = 0 + self._ev_low_run: int = 0 + self._ev_rms_peak: float = 0.0 + self._ev_last_diag_time: float = 0.0 + self._ev_utt_peak: float = 0.0 + # 可选:能量 VAD 刚进入「正在说话」时回调(用于机端 PROMPT_LISTEN 计时清零等) + self._vad_speech_start_hook: Optional[Callable[[], None]] = None + + _trail_raw = SYSTEM_RECOGNIZER_CONFIG.get("trailing_silence_seconds") + if _trail_raw is not None: + _trail = float(_trail_raw) + if _trail > 0: + fs = int(SYSTEM_AUDIO_CONFIG.get("frame_size", 1024)) + sr = int(SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000)) + if fs > 0 and sr > 0: + n_end = max(1, int(math.ceil(_trail * sr / fs))) + self._energy_end_chunks = n_end + self.vad.silence_end_frames = n_end + logger.info( + "VAD 句尾切段:trailing_silence_seconds=%.2f → 连续静音块数=%d " + "(每块≈%.0fms,Silero 与 energy 共用)", + _trail, + n_end, + 1000.0 * fs / sr, + ) + + if self._use_energy_vad: + logger.info( + "VAD 后端: energy(RMS)" + f" high={self._energy_rms_high} low={self._energy_rms_low} " + f"start_chunks={self._energy_start_chunks} end_chunks={self._energy_end_chunks}" + f" end_peak_ratio={self._energy_end_peak_ratio}" + f" utt_peak_decay={self._energy_utt_peak_decay}" + ) + + logger.info("语音命令识别系统初始化完成") + + @staticmethod + def _normalize_ack_tts_phrases(raw) -> Dict[str, List[str]]: + """YAML: ack_tts_phrases: { takeoff: [\"...\", ...], ... }""" + result: Dict[str, List[str]] = {} + if not isinstance(raw, dict): + return result + for k, v in raw.items(): + key = str(k).strip() + if not key: + continue + if isinstance(v, list): + phrases = [str(x).strip() for x in v if str(x).strip()] + elif isinstance(v, str) and v.strip(): + phrases = [v.strip()] + else: + phrases = [] + if phrases: + result[key] = phrases + return result + + def _has_ack_tts_content(self) -> bool: + if self._ack_mode_phrases: + return any(bool(v) for v in self.ack_tts_phrases.values()) + return bool(self.ack_tts_text) + + def _pick_ack_phrase(self, command_name: str) -> Optional[str]: + if self._ack_mode_phrases: + phrases = self.ack_tts_phrases.get(command_name) + if not phrases: + return None + return random.choice(phrases) + return self.ack_tts_text or None + + def _get_cached_pcm_for_phrase(self, phrase: str) -> Optional[Tuple[np.ndarray, int]]: + """若启动阶段已预合成该句,则返回缓存,播报时不再跑 ONNX(低延迟)。""" + if self._ack_mode_phrases: + return self._tts_phrase_pcm_cache.get(phrase) + if self._tts_ack_pcm is not None: + return self._tts_ack_pcm + return None + + def _ensure_tts_engine(self) -> "KokoroOnnxTTS": + """懒加载 Kokoro(双检锁,避免多线程重复加载)。""" + from voice_drone.core.tts import KokoroOnnxTTS + + if self._tts_engine is not None: + return self._tts_engine + with self._tts_lock: + if self._tts_engine is None: + logger.info("TTS: 正在加载 Kokoro 模型(首次约需十余秒)…") + self._tts_engine = KokoroOnnxTTS() + logger.info("TTS: Kokoro 模型加载完成") + assert self._tts_engine is not None + return self._tts_engine + + def _enqueue_ack_playback(self, command_name: str) -> None: + """ + 命令已成功发出后,将待播音频交给主线程队列。 + + 不在此线程直接调用 sounddevice:Windows 上后台线程常出现播放完全无声。 + """ + if not self.ack_tts_enabled: + return + phrase = self._pick_ack_phrase(command_name) + if not phrase: + return + try: + cached = self._get_cached_pcm_for_phrase(phrase) + if cached is not None: + audio, sr = cached + self._ack_playback_queue.put(("pcm", audio.copy(), sr), block=False) + logger.info( + f"命令已发送,已排队语音应答(主线程播放,预缓存): {phrase!r}" + ) + print(f"[TTS] 已排队语音应答(主线程播放,预缓存): {phrase!r}", flush=True) + else: + self._ack_playback_queue.put(("synth", phrase), block=False) + logger.info( + f"命令已发送,已排队语音应答(主线程合成+播放,无缓存,可能有数秒延迟): {phrase!r}" + ) + print( + f"[TTS] 已排队语音应答(主线程合成+播放,无缓存): {phrase!r}", + flush=True, + ) + except queue.Full: + logger.warning("应答语音播放队列已满,跳过本次") + + def _before_audio_iteration(self) -> None: + """主循环每轮开头(主线程):子类可扩展以播放其它排队 TTS。""" + self._drain_ack_playback_queue() + + def _drain_ack_playback_queue(self, recover_mic: bool = True) -> None: + """在主线程中播放队列中的应答(与麦克风采集同进程、同主循环线程)。 + + Args: + recover_mic: 播完后是否恢复麦克风;退出 shutdown 时应为 False,避免与 stop() 中关流冲突。 + """ + from voice_drone.core.tts import play_tts_audio, speak_text + + items: list = [] + while True: + try: + items.append(self._ack_playback_queue.get_nowait()) + except queue.Empty: + break + if not items: + return + + mic_stopped = False + if self.ack_pause_mic_for_playback: + try: + logger.info( + "TTS: 已暂停麦克风采集以便扬声器播放(避免 Windows 下输入/输出同时开无声)" + ) + self.audio_capture.stop_stream() + mic_stopped = True + except Exception as e: + logger.warning(f"暂停麦克风失败,将尝试直接播放: {e}") + + try: + for item in items: + try: + kind = item[0] + if kind == "pcm": + _, audio, sr = item + logger.info("TTS: 主线程播放应答(预缓存波形)") + play_tts_audio(audio, sr) + logger.info("TTS: 播放完成") + elif kind == "synth": + logger.info("TTS: 主线程合成并播放应答(无预缓存)") + tts = self._ensure_tts_engine() + text = item[1] if len(item) >= 2 else (self.ack_tts_text or "") + speak_text(text, tts=tts) + except Exception as e: + logger.warning(f"应答语音播放失败: {e}", exc_info=True) + finally: + if mic_stopped and recover_mic: + try: + self.audio_capture.start_stream() + try: + self.audio_preprocessor.reset() + except Exception as e: # noqa: BLE001 + logger.debug("audio_preprocessor.reset: %s", e) + # TTS 暂停期间若未凑齐「尾静音」帧,VAD 会一直保持 is_speaking=True; + # 恢复后 detect_speech_start 会直接放弃,表现为「能恢复采集但再也不识别」。 + self.vad.reset() + with self.speech_buffer_lock: + self.speech_buffer.clear() + self.pre_speech_buffer.clear() + logger.info("TTS: 麦克风采集已恢复(已重置 VAD 与语音缓冲)") + except Exception as e: + logger.error(f"麦克风采集恢复失败,请重启程序: {e}", exc_info=True) + + def _prewarm_tts_async(self) -> None: + """后台预加载 TTS(仅当未使用阻塞预加载时)。""" + if not self.ack_tts_enabled or not self._has_ack_tts_content() or not self.ack_tts_prewarm: + return + + def _run() -> None: + try: + self._ensure_tts_engine() + if self._ack_mode_phrases: + logger.warning( + "TTS: 当前为「按命令随机短语」且未使用阻塞预加载," + "各句首次播报可能仍有数秒延迟;若需低延迟请将 ack_tts_prewarm_blocking 设为 true。" + ) + except Exception as e: + logger.warning(f"TTS 预加载失败(将在首次播报时重试): {e}", exc_info=True) + + threading.Thread(target=_run, daemon=True, name="tts-prewarm").start() + + def _prewarm_tts_blocking(self) -> None: + """启动时准备应答 PCM:优先读磁盘缓存(文案与 TTS 配置未变则跳过合成);必要时加载 Kokoro 并合成。""" + if not self.ack_tts_enabled or not self._has_ack_tts_content() or not self.ack_tts_prewarm: + return + use_disk = self.ack_tts_disk_cache + logger.info("TTS: 正在准备语音反馈(磁盘缓存 / 合成)…") + print("正在加载语音反馈…") + try: + if self._ack_mode_phrases: + self._tts_phrase_pcm_cache.clear() + seen: set = set() + unique: List[str] = [] + for lst in self.ack_tts_phrases.values(): + for t in lst: + p = str(t).strip() + if p and p not in seen: + seen.add(p) + unique.append(p) + if not unique: + return + + fingerprint = compute_ack_pcm_fingerprint(unique, mode_phrases=True) + missing = list(unique) + if use_disk: + loaded, missing = load_cached_phrases(unique, fingerprint) + for ph, pcm in loaded.items(): + self._tts_phrase_pcm_cache[ph] = pcm + + if not missing: + self._tts_ack_pcm = None + logger.info( + "TTS: 已从磁盘加载全部应答波形(%d 句),跳过 Kokoro 加载与合成", + len(unique), + ) + print("语音反馈已就绪(本地缓存),可以开始说话下指令。") + return + + self._ensure_tts_engine() + assert self._tts_engine is not None + need = [p for p in unique if p not in self._tts_phrase_pcm_cache] + for j, phrase in enumerate(need, start=1): + logger.info( + f"TTS: 合成应答句 {j}/{len(need)}: {phrase!r}" + ) + audio, sr = self._tts_engine.synthesize(phrase) + self._tts_phrase_pcm_cache[phrase] = (audio, sr) + self._tts_ack_pcm = None + if use_disk: + persist_phrases(fingerprint, dict(self._tts_phrase_pcm_cache)) + logger.info( + "TTS: 语音反馈已就绪(随机应答已缓存,播报低延迟)" + ) + print("语音反馈引擎已就绪,可以开始说话下指令。") + else: + text = (self.ack_tts_text or "").strip() + if not text: + return + fingerprint = compute_ack_pcm_fingerprint( + [], global_text=text, mode_phrases=False + ) + missing = [text] + if use_disk: + loaded, missing = load_cached_phrases([text], fingerprint) + if text in loaded: + self._tts_ack_pcm = loaded[text] + + if not missing: + logger.info( + "TTS: 已从磁盘加载全局应答波形,跳过 Kokoro 加载与合成" + ) + print("语音反馈已就绪(本地缓存),可以开始说话下指令。") + return + + self._ensure_tts_engine() + assert self._tts_engine is not None + audio, sr = self._tts_engine.synthesize(text) + self._tts_ack_pcm = (audio, sr) + if use_disk: + persist_phrases(fingerprint, {text: self._tts_ack_pcm}) + logger.info( + "TTS: 语音反馈引擎已就绪;已缓存应答语音,命令成功后将快速播报" + ) + print("语音反馈引擎已就绪,可以开始说话下指令。") + except Exception as e: + logger.warning( + f"TTS: 启动阶段预加载失败,命令成功后可能延迟或无语音反馈: {e}", + exc_info=True, + ) + + @staticmethod + def _init_sounddevice_output_probe() -> None: + """在主线程探测默认输出设备;应答播报必须在主线程调用 sd.play。""" + try: + from voice_drone.core.tts import log_sounddevice_output_devices + + log_sounddevice_output_devices() + import sounddevice as sd # type: ignore + + from voice_drone.core.tts import _sounddevice_default_output_index + + out_idx = _sounddevice_default_output_index() + if out_idx is not None and int(out_idx) >= 0: + info = sd.query_devices(int(out_idx)) + logger.info( + f"sounddevice 默认输出设备: {info.get('name', '?')} (index={out_idx})" + ) + sd.check_output_settings(samplerate=24000, channels=1, dtype="float32") + # 预解析 tts.output_device,启动日志中可见实际用于播放的设备 + from voice_drone.core.tts import get_playback_output_device_id + + get_playback_output_device_id() + except Exception as e: + logger.warning(f"sounddevice 输出设备探测失败,可能导致无法播音: {e}") + + def _get_next_sequence_id(self) -> int: + """获取下一个命令序列号""" + with self.sequence_lock: + self.sequence_id += 1 + return self.sequence_id + + @staticmethod + def _int16_chunk_rms(chunk: np.ndarray) -> float: + if chunk.size == 0: + return 0.0 + return float(np.sqrt(np.mean(chunk.astype(np.float64) ** 2))) + + def _submit_concatenated_speech_to_stt(self) -> None: + """在持有 speech_buffer_lock 时调用:合并 speech_buffer 并送 STT,然后清空。""" + if len(self.speech_buffer) == 0: + return + speech_audio = np.concatenate(self.speech_buffer) + self.speech_buffer.clear() + min_samples = int(self.audio_capture.sample_rate * 0.5) + if len(speech_audio) >= min_samples: + try: + self.stt_queue.put(speech_audio.copy(), block=False) + logger.debug( + f"提交语音段到STT队列,长度: {len(speech_audio)} 采样点" + ) + if os.environ.get("ROCKET_PRINT_VAD", "").lower() in ( + "1", + "true", + "yes", + ): + print( + f"[VAD] 已送 STT,{len(speech_audio)} 采样点(≈{len(speech_audio) / float(self.audio_capture.sample_rate):.2f}s)", + flush=True, + ) + except queue.Full: + logger.warning("STT队列已满,跳过本次识别") + elif os.environ.get("ROCKET_PRINT_VAD", "").lower() in ( + "1", + "true", + "yes", + ): + print( + f"[VAD] 语音段太短已丢弃({len(speech_audio)} < {min_samples} 采样)", + flush=True, + ) + + def _energy_vad_on_chunk(self, processed_chunk: np.ndarray) -> None: + rms = self._int16_chunk_rms(processed_chunk) + _vad_diag = os.environ.get("ROCKET_PRINT_VAD", "").lower() in ( + "1", + "true", + "yes", + ) + if _vad_diag: + self._ev_rms_peak = max(self._ev_rms_peak, rms) + now = time.monotonic() + if now - self._ev_last_diag_time >= 3.0: + print( + f"[VAD] energy 诊断:近 3s 块 RMS 峰值≈{self._ev_rms_peak:.0f} " + f"(high={self._energy_rms_high} low={self._energy_rms_low})", + flush=True, + ) + self._ev_rms_peak = 0.0 + self._ev_last_diag_time = now + + if not self._ev_speaking: + if rms >= self._energy_rms_high: + self._ev_high_run += 1 + else: + self._ev_high_run = 0 + if self._ev_high_run >= self._energy_start_chunks: + self._ev_speaking = True + self._ev_high_run = 0 + self._ev_low_run = 0 + self._ev_utt_peak = rms + hook = self._vad_speech_start_hook + if hook is not None: + try: + hook() + except Exception as e: # noqa: BLE001 + logger.debug("vad_speech_start_hook: %s", e, exc_info=True) + with self.speech_buffer_lock: + if self.pre_speech_buffer: + self.speech_buffer = list(self.pre_speech_buffer) + else: + self.speech_buffer.clear() + self.speech_buffer.append(processed_chunk) + logger.debug( + "energy VAD: 开始收集语音段(含预缓冲约 %.2f s)", + self.pre_speech_max_seconds, + ) + return + + with self.speech_buffer_lock: + self.speech_buffer.append(processed_chunk) + + self._ev_utt_peak = max(rms, self._ev_utt_peak * self._energy_utt_peak_decay) + below_abs = rms <= self._energy_rms_low + below_rel = ( + self._energy_end_peak_ratio > 0 + and self._ev_utt_peak >= self._energy_rms_high + and rms <= self._ev_utt_peak * self._energy_end_peak_ratio + ) + if below_abs or below_rel: + self._ev_low_run += 1 + else: + self._ev_low_run = 0 + + if self._ev_low_run >= self._energy_end_chunks: + self._ev_speaking = False + self._ev_low_run = 0 + self._ev_utt_peak = 0.0 + with self.speech_buffer_lock: + self._submit_concatenated_speech_to_stt() + self._reset_agc_after_utterance_end() + logger.debug("energy VAD: 语音段结束,已提交") + + def _reset_agc_after_utterance_end(self) -> None: + """VAD 句尾:清 AGC 滑窗,避免巨响后 RMS 卡死。""" + try: + self.audio_preprocessor.reset_agc_state() + except AttributeError: + pass + + def discard_pending_stt_segments(self) -> int: + """丢弃尚未被 STT 线程取走的整句,避免唤醒/播 TTS 关麦后仍识别旧段。""" + n = 0 + while True: + try: + self.stt_queue.get_nowait() + self.stt_queue.task_done() + n += 1 + except queue.Empty: + break + if n: + logger.info( + "已丢弃 %s 条待 STT 的语音段(流程切换,避免与播 TTS 重叠)", + n, + ) + return n + + def _stt_worker_thread(self): + """STT识别工作线程(异步处理,不阻塞主流程)""" + logger.info("STT识别线程已启动") + while self.running: + try: + audio_data = self.stt_queue.get(timeout=0.1) + except queue.Empty: + continue + except Exception as e: + logger.error(f"STT工作线程错误: {e}", exc_info=True) + continue + + try: + if audio_data is None: + break + + try: + text = self.stt.invoke_numpy(audio_data) + + if os.environ.get("ROCKET_PRINT_STT", "").lower() in ( + "1", + "true", + "yes", + ): + print( + f"[STT] {text!r}" + if (text and text.strip()) + else "[STT] <空或不识别>", + flush=True, + ) + + if text and text.strip(): + logger.info(f"🎤 STT识别结果: {text}") + + try: + self.command_queue.put(text, block=False) + logger.debug(f"文本已提交到命令处理队列: {text}") + except queue.Full: + logger.warning("命令处理队列已满,跳过本次识别结果") + + except Exception as e: + logger.error(f"STT识别失败: {e}", exc_info=True) + finally: + self.stt_queue.task_done() + + logger.info("STT识别线程已停止") + + def _command_worker_thread(self): + """命令处理工作线程(文本预处理+命令生成+Socket发送)""" + logger.info("命令处理线程已启动") + while self.running: + try: + text = self.command_queue.get(timeout=0.1) + except queue.Empty: + continue + except Exception as e: + logger.error(f"命令处理线程错误: {e}", exc_info=True) + continue + + try: + if text is None: + break + + try: + # 1. 检测唤醒词 + is_wake, matched_wake_word = self.wake_word_detector.detect(text) + + if not is_wake: + logger.debug(f"未检测到唤醒词,忽略文本: {text}") + continue + + logger.info(f"🔔 检测到唤醒词: {matched_wake_word}") + + # 2. 提取命令文本(移除唤醒词) + command_text = self.wake_word_detector.extract_command_text(text) + if not command_text or not command_text.strip(): + logger.warning(f"唤醒词后无命令内容: {text}") + continue + + logger.debug(f"提取的命令文本: {command_text}") + + # 3. 文本预处理(快速模式,不进行分词) + normalized_text, params = self.text_preprocessor.preprocess_fast(command_text) + + logger.debug(f"文本预处理结果:") + logger.debug(f" 规范化文本: {normalized_text}") + logger.debug(f" 命令关键词: {params.command_keyword}") + logger.debug(f" 距离: {params.distance} 米") + logger.debug(f" 速度: {params.speed} 米/秒") + logger.debug(f" 时间: {params.duration} 秒") + + # 4. 检查是否识别到命令关键词 + if not params.command_keyword: + logger.warning(f"未识别到有效命令关键词: {normalized_text}") + continue + + # 5. 生成命令 + sequence_id = self._get_next_sequence_id() + command = Command.create( + command=params.command_keyword, + sequence_id=sequence_id, + distance=params.distance, + speed=params.speed, + duration=params.duration + ) + + logger.info(f"📝 生成命令: {command.command}") + logger.debug(f"命令详情: {command.to_dict()}") + + # 6. 发送命令到Socket服务器 + if self.socket_client.send_command_with_retry(command): + logger.info(f"✅ 命令已发送: {command.command} (序列号: {sequence_id})") + self._enqueue_ack_playback(command.command) + else: + logger.warning( + "命令未送达(已达 max_retries): %s (序列号: %s)", + command.command, + sequence_id, + ) + + except Exception as e: + logger.error(f"命令处理失败: {e}", exc_info=True) + + finally: + self.command_queue.task_done() + + logger.info("命令处理线程已停止") + + def start(self): + """启动语音命令识别系统""" + if self.running: + logger.warning("语音命令识别系统已在运行") + return + + # 先完成阻塞式 TTS 预加载,再开麦与识别线程,避免用户在预加载期间下指令导致无波形缓存、播报延迟 + print("[TTS] 探测扬声器并预加载应答语音(可能需十余秒,请勿说话)…", flush=True) + self._init_sounddevice_output_probe() + if self.ack_tts_enabled and self._has_ack_tts_content() and self.ack_tts_prewarm: + if self.ack_tts_prewarm_blocking: + self._prewarm_tts_blocking() + else: + print( + "[TTS] 已跳过启动预加载(ack_tts_enabled/应答文案/ack_tts_prewarm)", + flush=True, + ) + + self.running = True + + # 启动STT识别线程 + self.stt_thread = threading.Thread(target=self._stt_worker_thread, daemon=True) + self.stt_thread.start() + + # 启动命令处理线程 + self.command_thread = threading.Thread(target=self._command_worker_thread, daemon=True) + self.command_thread.start() + + # 启动音频采集 + self.audio_capture.start_stream() + + if self.ack_tts_enabled and self._has_ack_tts_content() and self.ack_tts_prewarm: + if not self.ack_tts_prewarm_blocking: + self._prewarm_tts_async() + + logger.info("语音命令识别系统已启动") + print("\n" + "=" * 70) + print("🎙️ 高性能实时语音命令识别系统已启动") + print("=" * 70) + print("💡 功能说明:") + print(" - 系统会自动检测语音并识别") + print(f" - 🔔 唤醒词: {self.wake_word_detector.primary}") + print(" - 只有包含唤醒词的语音才会被处理") + print(" - 识别结果会自动转换为无人机控制命令") + print(" - 命令会自动发送到Socket服务器") + print(" - 按 Ctrl+C 退出") + print("=" * 70 + "\n") + + def stop(self): + """停止语音命令识别系统""" + if not self.running: + return + + self.running = False + + # 先通知工作线程结束,再播放尚未 drain 的应答(避免 Ctrl+C 时主循环未跑下一轮导致无声) + if self.stt_thread is not None: + self.stt_queue.put(None) + if self.command_thread is not None: + self.command_queue.put(None) + if self.stt_thread is not None: + self.stt_thread.join(timeout=2.0) + if self.command_thread is not None: + self.command_thread.join(timeout=2.0) + + if self.ack_tts_enabled: + try: + self._drain_ack_playback_queue(recover_mic=False) + except Exception as e: + logger.warning(f"退出前播放应答失败: {e}", exc_info=True) + + self.audio_capture.stop_stream() + + # 断开Socket连接 + if self.socket_client.connected: + self.socket_client.disconnect() + + logger.info("语音命令识别系统已停止") + print("\n语音命令识别系统已停止") + + def process_audio_stream(self): + """ + 处理音频流(主循环) + + 高性能实时处理流程: + 1. 采集音频块(非阻塞) + 2. 预处理(降噪+AGC) + 3. VAD检测语音开始/结束 + 4. 收集语音段 + 5. 异步STT识别(不阻塞主流程) + """ + try: + while self.running: + # 0. 主线程播放命令应答(必须在采集循环线程中执行 sd.play,见 tts.play_tts_audio 说明) + self._before_audio_iteration() + + # 1. 采集音频块(非阻塞,高性能模式) + chunk = self.audio_capture.read_chunk_numpy(timeout=0.1) + if chunk is None: + continue + + # 2. 音频预处理(降噪+AGC) + processed_chunk = self.audio_preprocessor.process(chunk) + + # 初始化预缓冲区的最大块数(只需计算一次) + if self.pre_speech_max_chunks is None: + # 每个chunk包含的采样点数 + samples_per_chunk = processed_chunk.shape[0] + if samples_per_chunk > 0: + # 0.8 秒需要的chunk数量 = 预缓冲秒数 * 采样率 / 每块采样数 + chunks = int( + self.pre_speech_max_seconds * self.audio_capture.sample_rate + / samples_per_chunk + ) + # 至少保留 1 块,避免被算成 0 + self.pre_speech_max_chunks = max(chunks, 1) + else: + self.pre_speech_max_chunks = 1 + + # 将当前块加入预缓冲区(环形缓冲) + # 注意:预缓冲区保存的是“最近的一段音频”,无论当下是否在说话 + self.pre_speech_buffer.append(processed_chunk) + if ( + self.pre_speech_max_chunks is not None + and len(self.pre_speech_buffer) > self.pre_speech_max_chunks + ): + # 超出最大长度时,丢弃最早的块 + self.pre_speech_buffer.pop(0) + + # 3. VAD:Silero 或能量(RMS)分段 + if self._use_energy_vad: + self._energy_vad_on_chunk(processed_chunk) + else: + chunk_bytes = processed_chunk.tobytes() + + if self.vad.detect_speech_start(chunk_bytes): + hook = self._vad_speech_start_hook + if hook is not None: + try: + hook() + except Exception as e: # noqa: BLE001 + logger.debug( + "vad_speech_start_hook: %s", e, exc_info=True + ) + with self.speech_buffer_lock: + if self.pre_speech_buffer: + self.speech_buffer = list(self.pre_speech_buffer) + else: + self.speech_buffer.clear() + self.speech_buffer.append(processed_chunk) + logger.debug( + "检测到语音开始,使用预缓冲音频(约 %.2f 秒)作为前缀,开始收集语音段", + self.pre_speech_max_seconds, + ) + + elif self.vad.is_speaking: + with self.speech_buffer_lock: + self.speech_buffer.append(processed_chunk) + + if self.vad.detect_speech_end(chunk_bytes): + with self.speech_buffer_lock: + self._submit_concatenated_speech_to_stt() + self._reset_agc_after_utterance_end() + logger.debug("检测到语音结束,提交识别") + + hook = getattr(self, "_after_processed_audio_chunk", None) + if hook is not None: + try: + hook(processed_chunk) + except Exception as e: # noqa: BLE001 + logger.debug( + "after_processed_audio_chunk: %s", e, exc_info=True + ) + + except KeyboardInterrupt: + logger.info("用户中断") + except Exception as e: + logger.error(f"处理音频流时发生错误: {e}", exc_info=True) + raise + + def run(self): + """运行语音命令识别系统(完整流程)""" + try: + self.start() + self.process_audio_stream() + finally: + self.stop() + + +if __name__ == "__main__": + # 测试代码 + recognizer = VoiceCommandRecognizer() + recognizer.run() diff --git a/voice_drone/core/rule.py b/voice_drone/core/rule.py new file mode 100644 index 0000000..e69de29 diff --git a/voice_drone/core/scoket_client.py b/voice_drone/core/scoket_client.py new file mode 100644 index 0000000..fd23259 --- /dev/null +++ b/voice_drone/core/scoket_client.py @@ -0,0 +1,239 @@ +""" +Socket客户端 +""" +import socket +import json +import time +from voice_drone.core.configuration import SYSTEM_SOCKET_SERVER_CONFIG +from voice_drone.core.command import Command +from voice_drone.logging_ import get_logger + +logger = get_logger("socket.client") + + +class SocketClient: + # 初始化Socket客户端 + def __init__(self, config: dict): + self.host = config.get("host") + self.port = config.get("port") + self.connect_timeout = config.get("connect_timeout") + self.send_timeout = config.get("send_timeout") + self.reconnect_interval = float(config.get("reconnect_interval") or 3.0) + # max_retries:-1 表示断线后持续重连并发送,直到成功(不视为致命错误) + _mr = config.get("max_retries", -1) + try: + self.max_reconnect_attempts = int(_mr) + except (TypeError, ValueError): + self.max_reconnect_attempts = -1 + + self.sock = None + self.connected = False + + # 连接到socket服务器 + def connect(self) -> bool: + if self.connected and self.sock is not None: + return True + try: + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.settimeout(self.connect_timeout) + self.sock.connect((self.host, self.port)) + self.sock.settimeout(self.send_timeout) + self.connected = True + print( + f"[SocketClient] 连接成功: {self.host}:{self.port}", + flush=True, + ) + logger.info("Socket 已连接 %s:%s", self.host, self.port) + return True + + except socket.timeout: + logger.warning( + "Socket 连接超时 host=%s port=%s timeout=%s", + self.host, + self.port, + self.connect_timeout, + ) + print( + "[SocketClient] connect: 连接超时 " + f"(host={self.host!r}, port={self.port!r}, timeout={self.connect_timeout!r})", + flush=True, + ) + self._cleanup() + return False + except ConnectionRefusedError as e: + logger.warning("Socket 连接被拒绝: %s", e) + print( + f"[SocketClient] connect: 连接被拒绝: {e!r}", + flush=True, + ) + self._cleanup() + return False + except OSError as e: + logger.warning("Socket connect OSError (%s): %s", type(e).__name__, e) + print( + f"[SocketClient] connect: OSError ({type(e).__name__}): {e!r}", + flush=True, + ) + self._cleanup() + return False + except Exception as e: + print( + f"[SocketClient] connect: 未预期异常 ({type(e).__name__}): {e!r}", + flush=True, + ) + self._cleanup() + return False + + # 断开与socket服务器的连接 + def disconnect(self) -> None: + self._cleanup() + + # 清理资源 + def _cleanup(self) -> None: + if self.sock is not None: + try: + self.sock.close() + except Exception: + pass + self.sock = None + self.connected = False + + # 确保连接已建立 + def _ensure_connected(self) -> bool: + if self.connected and self.sock is not None: + return True + + return self.connect() + + # 发送命令 + def send_command(self, command) -> bool: + print("[SocketClient] 正在发送命令…", flush=True) + + if not self._ensure_connected(): + logger.warning( + "Socket 未连接且 connect 失败,跳过本次发送 host=%s port=%s", + self.host, + self.port, + ) + print( + "[SocketClient] 未连接或 connect 失败,跳过发送", + flush=True, + ) + return False + + try: + command_dict = command.to_dict() + json_str = json.dumps(command_dict, ensure_ascii=False) + + # 添加换行符(根据 JSON格式说明.md,命令以换行符分隔) + message = json_str + "\n" + + # 发送数据 + self.sock.sendall(message.encode("utf-8")) + print("[SocketClient] sendall 成功", flush=True) + return True + + except socket.timeout: + logger.warning("Socket send 超时,将断开以便重连") + print("[SocketClient] send_command: socket 超时", flush=True) + self._cleanup() + return False + except ConnectionResetError as e: + logger.warning("Socket 连接被重置(将重连): %s", e) + print( + f"[SocketClient] send_command: 连接被重置 ({type(e).__name__}): {e!r}", + flush=True, + ) + self._cleanup() + return False + except BrokenPipeError as e: + logger.warning("Socket 管道破裂(将重连): %s", e) + print( + f"[SocketClient] send_command: 管道破裂 ({type(e).__name__}): {e!r}", + flush=True, + ) + self._cleanup() + return False + except OSError as e: + # 断网、对端关闭等:可恢复,不当作未捕获致命错误 + logger.warning("Socket send OSError (%s): %s(将重连)", type(e).__name__, e) + print( + f"[SocketClient] send_command: OSError ({type(e).__name__}): {e!r}", + flush=True, + ) + self._cleanup() + return False + except Exception as e: + logger.warning( + "Socket send 异常 (%s): %s(将重连)", type(e).__name__, e + ) + print( + f"[SocketClient] send_command: 异常 ({type(e).__name__}): {e!r}", + flush=True, + ) + self._cleanup() + return False + + # 发送命令并重试 + def send_command_with_retry(self, command) -> bool: + """失败后清理连接并按 reconnect_interval 重试;max_retries=-1 时直到发送成功。""" + unlimited = self.max_reconnect_attempts < 0 + cap = max(1, self.max_reconnect_attempts) if not unlimited else None + attempt = 0 + while True: + attempt += 1 + self._cleanup() + if self.send_command(command): + if attempt > 1: + print( + f"[SocketClient] 重试后发送成功(第 {attempt} 次)", + flush=True, + ) + logger.info("Socket 重连后命令已发送(第 %s 次尝试)", attempt) + return True + + if not unlimited and cap is not None and attempt >= cap: + logger.warning( + "Socket 已达 max_retries=%s,本次命令未送达,稍后可再试", + self.max_reconnect_attempts, + ) + print( + "[SocketClient] 已达最大重试次数,本次命令未送达(可稍后重试)", + flush=True, + ) + return False + + # 无限重试时每 10 次打一条日志,避免刷屏 + if unlimited and attempt % 10 == 1: + logger.warning( + "Socket 发送失败,%ss 后第 %s 次重连重试…", + self.reconnect_interval, + attempt, + ) + print( + f"[SocketClient] 发送失败,{self.reconnect_interval}s 后重试 " + f"(第 {attempt} 次)", + flush=True, + ) + time.sleep(self.reconnect_interval) + + # 上下文管理器入口 + def __enter__(self): + self.connect() + return self + + # 上下文管理器出口 + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() + + +if __name__ == "__main__": + from voice_drone.core.configuration import SYSTEM_SOCKET_SERVER_CONFIG + from voice_drone.core.command import Command + + config = SYSTEM_SOCKET_SERVER_CONFIG + client = SocketClient(config) + client.connect() + command = Command.create("takeoff", 1) + client.send_command(command) + client.disconnect() \ No newline at end of file diff --git a/voice_drone/core/streaming_llm_tts.py b/voice_drone/core/streaming_llm_tts.py new file mode 100644 index 0000000..1b9a757 --- /dev/null +++ b/voice_drone/core/streaming_llm_tts.py @@ -0,0 +1,46 @@ +"""流式闲聊:按句切分文本入队 TTS;飞控 JSON 路径由调用方整块推理后再播。""" + +from __future__ import annotations + +# 句末:切段送合成 +_SENTENCE_END = frozenset("。!?;\n") +# 过长且无句末时,优先在以下标点处断开 +_SOFT_BREAK = ",、," + + +def take_completed_sentences(buffer: str) -> tuple[list[str], str]: + """从 buffer 开头取出所有「以句末标点结尾」的完整小段。""" + segments: list[str] = [] + i = 0 + n = len(buffer) + while i < n: + j = i + while j < n and buffer[j] not in _SENTENCE_END: + j += 1 + if j >= n: + break + raw = buffer[i : j + 1].strip() + if raw: + segments.append(raw) + i = j + 1 + return segments, buffer[i:] + + +def force_soft_split(remainder: str, max_chars: int) -> tuple[list[str], str]: + """remainder 长度 >= max_chars 且无句末时,强制切下第一段。""" + if max_chars <= 0 or len(remainder) < max_chars: + return [], remainder + window = remainder[:max_chars] + cut = -1 + for sep in _SOFT_BREAK: + p = window.rfind(sep) + if p > cut: + cut = p + if cut <= 0: + cut = max_chars + first = remainder[: cut + 1].strip() + rest = remainder[cut + 1 :] + out: list[str] = [] + if first: + out.append(first) + return out, rest diff --git a/voice_drone/core/stt.py b/voice_drone/core/stt.py new file mode 100644 index 0000000..a673725 --- /dev/null +++ b/voice_drone/core/stt.py @@ -0,0 +1,494 @@ +""" +语音识别(Speech-to-Text)类 - 纯 ONNX Runtime 极致性能推理 +针对 RK3588 等 ARM 设备进行了深度优化,完全移除 FunASR 依赖。 +前处理(fbank + CMVN + LFR)与解码均手写实现。 +""" +import platform +import os +import multiprocessing + +import numpy as np +from pathlib import Path +from typing import List, Dict, Any, Optional + +import onnx +import onnxruntime as ort +from voice_drone.logging_ import get_logger +from voice_drone.tools.wrapper import time_cost +import scipy.special +from voice_drone.core.configuration import SYSTEM_STT_CONFIG, SYSTEM_AUDIO_CONFIG + +# voice_drone/core/stt.py -> 工程根(含 voice_drone_assistant 与本仓库根两种布局) +_STT_PROJECT_ROOT = Path(__file__).resolve().parents[2] + + +def _stt_path_candidates(path: Path) -> List[Path]: + """相对配置路径的候选绝对路径:优先工程目录,其次嵌套在上一级仓库时的 src/models/。""" + if path.is_absolute(): + return [path] + out: List[Path] = [_STT_PROJECT_ROOT / path] + if path.parts and path.parts[0] == "models": + out.append(_STT_PROJECT_ROOT.parent / "src" / path) + return out + + +class STT: + """ + 语音识别(Speech-to-Text)类 + 使用 ONNX Runtime 进行最优性能推理 + 针对 RK3588 等 ARM 设备进行了深度优化 + """ + + def __init__(self): + """ + 初始化 STT 模型 + """ + stt_conf = SYSTEM_STT_CONFIG + self.logger = get_logger("stt.onnx") + + # 从配置读取参数 + self.model_dir = stt_conf.get("model_dir") + self.model_path = stt_conf.get("model_path") + self.prefer_int8 = stt_conf.get("prefer_int8", True) + _wf = stt_conf.get("warmup_file") + self.warmup_file: Optional[str] = None + if _wf: + wf_path = Path(_wf) + if wf_path.is_absolute() and wf_path.is_file(): + self.warmup_file = str(wf_path) + else: + for c in _stt_path_candidates(wf_path): + if c.is_file(): + self.warmup_file = str(c) + break + + # 音频预处理参数(确保数值类型正确) + self.sample_rate = int(stt_conf.get("sample_rate", SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000))) + self.n_mels = int(stt_conf.get("n_mels", 80)) + self.frame_length_ms = float(stt_conf.get("frame_length_ms", 25)) + self.frame_shift_ms = float(stt_conf.get("frame_shift_ms", 10)) + self.log_eps = float(stt_conf.get("log_eps", 1e-10)) + + # ARM 优化配置 + arm_conf = stt_conf.get("arm_optimization", {}) + self.arm_enabled = arm_conf.get("enabled", True) + self.arm_max_threads = arm_conf.get("max_threads", 4) + + # CTC 解码配置 + ctc_conf = stt_conf.get("ctc_decode", {}) + self.blank_id = ctc_conf.get("blank_id", 0) + + # 语言和文本规范化配置(默认值) + lang_conf = stt_conf.get("language", {}) + text_norm_conf = stt_conf.get("text_norm", {}) + self.lang_zh_default = lang_conf.get("zh_id", 3) + self.with_itn_default = text_norm_conf.get("with_itn_id", 14) + self.without_itn_default = text_norm_conf.get("without_itn_id", 15) + + # 后处理配置 + postprocess_conf = stt_conf.get("postprocess", {}) + self.special_tokens = postprocess_conf.get("special_tokens", [ + "<|zh|>", "<|NEUTRAL|>", "<|Speech|>", "<|woitn|>", "<|withitn|>" + ]) + + # 检测是否为 RK3588 或 ARM 设备 + ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64') + RK3588 = 'rk3588' in platform.platform().lower() or os.path.exists('/proc/device-tree/compatible') + + # ARM 设备性能优化配置 + if self.arm_enabled and (ARM or RK3588): + cpu_count = multiprocessing.cpu_count() + optimal_threads = min(self.arm_max_threads, cpu_count) + + # 设置 OpenMP 线程数 + os.environ['OMP_NUM_THREADS'] = str(optimal_threads) + os.environ['MKL_NUM_THREADS'] = str(optimal_threads) + os.environ['KMP_AFFINITY'] = 'granularity=fine,compact,1,0' + os.environ['OMP_DYNAMIC'] = 'FALSE' + os.environ['MKL_DYNAMIC'] = 'FALSE' + + self.logger.info("ARM/RK3588 优化已启用") + self.logger.info(f" CPU 核心数: {cpu_count}") + self.logger.info(f" 优化线程数: {optimal_threads}") + + # 确定模型路径 + onnx_model_path = self._resolve_model_path() + + # 保存模型目录路径(用于加载 tokens.txt) + self.onnx_model_dir = onnx_model_path.parent + + self.logger.info(f"加载 ONNX 模型: {onnx_model_path}") + self._load_onnx_model(str(onnx_model_path)) + + # 模型预热 + if self.warmup_file and os.path.exists(self.warmup_file): + try: + self.logger.info(f"正在预热模型(使用: {self.warmup_file})...") + _ = self.invoke(self.warmup_file) + self.logger.info("模型预热完成") + except Exception as e: + self.logger.warning(f"预热失败(可忽略): {e}") + elif self.warmup_file: + self.logger.warning(f"预热文件不存在: {self.warmup_file},跳过预热步骤") + + def _resolve_existing_model_file(self, raw: Optional[str]) -> Optional[Path]: + if not raw: + return None + p = Path(raw) + for c in _stt_path_candidates(p): + if c.is_file(): + return c + return None + + def _resolve_existing_model_dir(self, raw: Optional[str]) -> Optional[Path]: + if not raw: + return None + p = Path(raw) + for c in _stt_path_candidates(p): + if c.is_dir(): + return c + return None + + def _resolve_model_path(self) -> Path: + """ + 解析模型路径 + + Returns: + 模型文件路径 + """ + if self.model_path: + hit = self._resolve_existing_model_file(self.model_path) + if hit is not None: + return hit + + if not self.model_dir: + raise ValueError("配置中必须指定 model_path 或 model_dir") + + model_dir = self._resolve_existing_model_dir(self.model_dir) + if model_dir is None: + tried = ", ".join(str(x) for x in _stt_path_candidates(Path(self.model_dir))) + raise FileNotFoundError( + f"ONNX 模型目录不存在。config model_dir={self.model_dir!r},已尝试: {tried}。" + f"请将 SenseVoice 放入 {_STT_PROJECT_ROOT / 'models'},或 ln -s ../src/models " + f"{_STT_PROJECT_ROOT / 'models'}(见 models/README.txt)。" + ) + + # 优先使用 INT8 量化模型(如果启用) + if self.prefer_int8: + int8_path = model_dir / "model.int8.onnx" + if int8_path.exists(): + return int8_path + + # 回退到普通模型 + onnx_path = model_dir / "model.onnx" + if onnx_path.exists(): + return onnx_path + + raise FileNotFoundError(f"ONNX 模型文件不存在: 在 {model_dir} 中未找到 model.int8.onnx 或 model.onnx") + + def _load_onnx_model(self, onnx_model_path: str): + """加载 ONNX 模型""" + # 创建 ONNX Runtime 会话选项 + sess_options = ort.SessionOptions() + + # ARM 设备优化 + ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64') + if self.arm_enabled and ARM: + cpu_count = multiprocessing.cpu_count() + optimal_threads = min(self.arm_max_threads, cpu_count) + sess_options.intra_op_num_threads = optimal_threads + sess_options.inter_op_num_threads = optimal_threads + + # 启用所有图优化(最优性能) + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + # 创建推理会话 + self.onnx_session = ort.InferenceSession( + onnx_model_path, + sess_options=sess_options, + providers=['CPUExecutionProvider'] + ) + + # 获取模型元数据 + onnx_model = onnx.load(onnx_model_path) + self.model_metadata = {prop.key: prop.value for prop in onnx_model.metadata_props} + + # 解析元数据 + self.lfr_window_size = int(self.model_metadata.get('lfr_window_size', 7)) + self.lfr_window_shift = int(self.model_metadata.get('lfr_window_shift', 6)) + self.vocab_size = int(self.model_metadata.get('vocab_size', 25055)) + + # 解析 CMVN 参数 + neg_mean_str = self.model_metadata.get('neg_mean', '') + inv_stddev_str = self.model_metadata.get('inv_stddev', '') + self.neg_mean = np.array([float(x) for x in neg_mean_str.split(',')]) if neg_mean_str else None + self.inv_stddev = np.array([float(x) for x in inv_stddev_str.split(',')]) if inv_stddev_str else None + + # 语言和文本规范化 ID(从元数据获取,如果没有则使用配置默认值) + self.lang_zh = int(self.model_metadata.get('lang_zh', self.lang_zh_default)) + self.with_itn = int(self.model_metadata.get('with_itn', self.with_itn_default)) + self.without_itn = int(self.model_metadata.get('without_itn', self.without_itn_default)) + + self.logger.info("ONNX 模型加载完成") + self.logger.info(f" LFR窗口大小: {self.lfr_window_size}") + self.logger.info(f" LFR窗口偏移: {self.lfr_window_shift}") + self.logger.info(f" 词汇表大小: {self.vocab_size}") + + def _load_tokens(self): + """加载 tokens 映射""" + tokens_file = self.onnx_model_dir / "tokens.txt" + if tokens_file.exists(): + self.tokens = {} + with open(tokens_file, 'r', encoding='utf-8') as f: + for line in f: + parts = line.strip().split() + if len(parts) >= 2: + token = parts[0] + token_id = int(parts[-1]) + self.tokens[token_id] = token + return True + return False + + def _preprocess_audio_array(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> tuple: + """ + 预处理音频数组:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现) + + 支持实时音频流处理(numpy数组输入) + + 流程: + 1. 输入 16k 单声道 numpy 数组(int16 或 float32) + 2. 计算 80 维 log-mel fbank + 3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev) + 4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征 + + Args: + audio_array: 音频数据(numpy array,int16 或 float32) + sample_rate: 采样率,None时使用配置值 + + Returns: + (features, lengths): 特征和长度 + """ + import librosa + + sr = sample_rate if sample_rate is not None else self.sample_rate + + # 1. 转换为float32格式(如果输入是int16) + if audio_array.dtype == np.int16: + audio = audio_array.astype(np.float32) / 32768.0 + else: + audio = audio_array.astype(np.float32) + + # 确保是单声道 + if len(audio.shape) > 1: + audio = np.mean(audio, axis=1) + + if audio.size == 0: + raise ValueError("音频数组为空") + + # 2. 计算 fbank 特征 + n_fft = int(self.frame_length_ms / 1000.0 * sr) + hop_length = int(self.frame_shift_ms / 1000.0 * sr) + + mel_spec = librosa.feature.melspectrogram( + y=audio, + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + n_mels=self.n_mels, + window="hann", + center=True, + power=1.0, # 线性能量 + ) + + # log-mel + log_mel = np.log(np.maximum(mel_spec, self.log_eps)).T # (T, n_mels) + + # 3. CMVN:使用 ONNX 元数据中的 neg_mean / inv_stddev + if self.neg_mean is not None and self.inv_stddev is not None: + if self.neg_mean.shape[0] == log_mel.shape[1]: + log_mel = (log_mel + self.neg_mean) * self.inv_stddev + + # 4. LFR:按窗口 lfr_window_size 堆叠,步长 lfr_window_shift + T, D = log_mel.shape + m = self.lfr_window_size + n = self.lfr_window_shift + + if T < m: + # 帧数不够,补到 m 帧 + pad = np.tile(log_mel[-1], (m - T, 1)) + log_mel = np.vstack([log_mel, pad]) + T = m + + # 计算 LFR 后的帧数 + T_lfr = 1 + (T - m) // n + lfr_feats = [] + for i in range(T_lfr): + start = i * n + end = start + m + chunk = log_mel[start:end, :] # (m, D) + lfr_feats.append(chunk.reshape(-1)) # 展平为 560 维 + + lfr_feats = np.stack(lfr_feats, axis=0) # (T_lfr, m*D=560) + + # 增加 batch 维度: (1, T_lfr, 560) + lfr_feats = lfr_feats[np.newaxis, :, :].astype(np.float32) + lengths = np.array([lfr_feats.shape[1]], dtype=np.int32) + + return lfr_feats, lengths + + def _preprocess_audio(self, audio_path: str) -> tuple: + """ + 预处理音频:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现) + + 流程: + 1. 读入 16k 单声道 wav + 2. 计算 80 维 log-mel fbank + 3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev) + 4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征 + """ + import librosa + + # 1. 读入音频 + audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True) + if audio.size == 0: + raise ValueError(f"音频为空: {audio_path}") + + # 使用_preprocess_audio_array处理 + return self._preprocess_audio_array(audio, sr) + + def _ctc_decode(self, logits: np.ndarray, length: np.ndarray) -> str: + """ + CTC 解码:将 logits 转换为文本 + + Args: + logits: CTC logits,形状为 (N, T, vocab_size) + length: 序列长度,形状为 (N,) + + Returns: + 解码后的文本 + """ + # 加载 tokens(如果还没加载) + if not hasattr(self, 'tokens') or len(self.tokens) == 0: + self._load_tokens() + + # Greedy CTC 解码 + # 应用 softmax 获取概率 + probs = scipy.special.softmax(logits[0][:length[0]], axis=-1) + + # 获取每个时间步的最大概率 token + token_ids = np.argmax(probs, axis=-1) + + # CTC 解码:移除空白和重复 + prev_token = -1 + decoded_tokens = [] + + for token_id in token_ids: + if token_id != self.blank_id and token_id != prev_token: + decoded_tokens.append(token_id) + prev_token = token_id + + # Token ID 转文本 + text_parts = [] + for token_id in decoded_tokens: + if token_id in self.tokens: + token = self.tokens[token_id] + # 处理 SentencePiece 标记 + if token.startswith('▁'): + if text_parts: # 如果不是第一个token,添加空格 + text_parts.append(' ') + text_parts.append(token[1:]) + elif not token.startswith('<|'): # 忽略特殊标记 + text_parts.append(token) + + text = ''.join(text_parts) + + # 后处理:移除残留的特殊标记 + for special in self.special_tokens: + text = text.replace(special, '') + + return text.strip() + + @time_cost("STT-语音识别推理耗时") + def invoke(self, audio_path: str) -> List[Dict[str, Any]]: + """ + 执行语音识别推理(从文件) + + Args: + audio_path: 音频文件路径 + + Returns: + 识别结果列表,格式: [{"text": "识别文本"}] + """ + # 预处理音频 + features, features_length = self._preprocess_audio(audio_path) + + # 执行推理 + text = self._inference(features, features_length) + + return [{"text": text}] + + def invoke_numpy(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> str: + """ + 执行语音识别推理(从numpy数组,实时处理) + + Args: + audio_array: 音频数据(numpy array,int16 或 float32) + sample_rate: 采样率,None时使用配置值 + + Returns: + 识别文本 + """ + # 预处理音频数组 + features, features_length = self._preprocess_audio_array(audio_array, sample_rate) + + # 执行推理 + text = self._inference(features, features_length) + + return text + + def _inference(self, features: np.ndarray, features_length: np.ndarray) -> str: + """ + 执行ONNX推理(内部方法) + + Args: + features: 特征数组 + features_length: 特征长度 + + Returns: + 识别文本 + """ + # 准备 ONNX 模型输入 + N, T, C = features.shape + + # 语言ID + language = np.array([self.lang_zh], dtype=np.int32) + + # 文本规范化 + text_norm = np.array([self.with_itn], dtype=np.int32) + + # ONNX 推理 + inputs = { + 'x': features.astype(np.float32), + 'x_length': features_length.astype(np.int32), + 'language': language, + 'text_norm': text_norm + } + + outputs = self.onnx_session.run(None, inputs) + logits = outputs[0] # 形状: (N, T, vocab_size) + + # CTC 解码 + text = self._ctc_decode(logits, features_length) + + return text + + +if __name__ == "__main__": + # 使用 ONNX 模型进行推理 + import os + stt = STT() + + for i in range(10): + result = stt.invoke("/home/lktx/projects/audio_controll_drone_without_llm/test/测试音频.wav") + # result = stt.invoke_numpy(np.random.rand(16000), 16000) + print(f"第{i+1}次识别结果: {result}") \ No newline at end of file diff --git a/voice_drone/core/text_preprocessor.py b/voice_drone/core/text_preprocessor.py new file mode 100644 index 0000000..cdeb107 --- /dev/null +++ b/voice_drone/core/text_preprocessor.py @@ -0,0 +1,716 @@ +""" +文本预处理模块 - 高性能实时语音转命令文本处理 + +本模块主要用于对语音识别输出的文本进行清洗、纠错、简繁转换、分词和参数提取, +便于后续命令意图分析和参数解析。 + +主要功能: +1. 文本清理:去除杂音、特殊字符、多余空格 +2. 纠错:同音字纠正、常见错误修正 +3. 简繁转换:统一文本格式(繁体转简体) +4. 分词:使用jieba分词,便于关键词匹配 +5. 数字提取:提取距离(米)、速度(米/秒)、时间(秒) +6. 关键词识别:识别命令关键词(起飞、降落、前进等) + +性能优化: +- LRU缓存常用处理结果(分词、中文数字解析、完整预处理) +- 预编译正则表达式 +- 优化字符串操作(使用正则表达式批量替换) +- 延迟加载可选依赖 +- 缓存关键词排序结果 +""" + +import re +from typing import Dict, Optional, List, Tuple, Set +from functools import lru_cache +from dataclasses import dataclass +from voice_drone.logging_ import get_logger +from voice_drone.core.configuration import KEYWORDS_CONFIG, SYSTEM_TEXT_PREPROCESSOR_CONFIG +import warnings +warnings.filterwarnings("ignore") + +logger = get_logger("text_preprocessor") + +# 延迟加载可选依赖 +try: + from opencc import OpenCC + OPENCC_AVAILABLE = True +except ImportError: + OPENCC_AVAILABLE = False + logger.warning("opencc 未安装,将跳过简繁转换功能") + +try: + import jieba + JIEBA_AVAILABLE = True + # 初始化jieba,加载词典 + jieba.initialize() +except ImportError: + JIEBA_AVAILABLE = False + logger.warning("jieba 未安装,将跳过分词功能") + +try: + from pypinyin import lazy_pinyin, Style + PYPINYIN_AVAILABLE = True +except ImportError: + PYPINYIN_AVAILABLE = False + logger.warning("pypinyin 未安装,将跳过拼音相关功能") + + +@dataclass +class ExtractedParams: + """提取的参数信息""" + distance: Optional[float] = None # 距离(米) + speed: Optional[float] = None # 速度(米/秒) + duration: Optional[float] = None # 时间(秒) + command_keyword: Optional[str] = None # 识别的命令关键词 + + +@dataclass +class PreprocessedText: + """预处理后的文本结果""" + cleaned_text: str # 清理后的文本 + normalized_text: str # 规范化后的文本(简繁转换后) + words: List[str] # 分词结果 + params: ExtractedParams # 提取的参数 + original_text: str # 原始文本 + + +class TextPreprocessor: + """ + 高性能文本预处理器 + + 针对实时语音转命令场景优化,支持: + - 文本清理和规范化 + - 同音字纠错 + - 简繁转换 + - 分词 + - 数字和单位提取 + - 命令关键词识别 + """ + + def __init__(self, + enable_traditional_to_simplified: Optional[bool] = None, + enable_segmentation: Optional[bool] = None, + enable_correction: Optional[bool] = None, + enable_number_extraction: Optional[bool] = None, + enable_keyword_detection: Optional[bool] = None, + lru_cache_size: Optional[int] = None): + """ + 初始化文本预处理器 + + Args: + enable_traditional_to_simplified: 是否启用繁简转换(None时从配置读取) + enable_segmentation: 是否启用分词(None时从配置读取) + enable_correction: 是否启用纠错(None时从配置读取) + enable_number_extraction: 是否启用数字提取(None时从配置读取) + enable_keyword_detection: 是否启用关键词检测(None时从配置读取) + lru_cache_size: LRU缓存大小(None时从配置读取) + """ + # 从配置读取参数(如果未提供) + config = SYSTEM_TEXT_PREPROCESSOR_CONFIG or {} + + self.enable_traditional_to_simplified = ( + enable_traditional_to_simplified + if enable_traditional_to_simplified is not None + else config.get("enable_traditional_to_simplified", True) + ) and OPENCC_AVAILABLE + + self.enable_segmentation = ( + enable_segmentation + if enable_segmentation is not None + else config.get("enable_segmentation", True) + ) and JIEBA_AVAILABLE + + self.enable_correction = ( + enable_correction + if enable_correction is not None + else config.get("enable_correction", True) + ) + + self.enable_number_extraction = ( + enable_number_extraction + if enable_number_extraction is not None + else config.get("enable_number_extraction", True) + ) + + self.enable_keyword_detection = ( + enable_keyword_detection + if enable_keyword_detection is not None + else config.get("enable_keyword_detection", True) + ) + + cache_size = ( + lru_cache_size + if lru_cache_size is not None + else config.get("lru_cache_size", 512) + ) + + # 初始化OpenCC(如果可用) + if self.enable_traditional_to_simplified: + self.opencc = OpenCC('t2s') # 繁体转简体 + else: + self.opencc = None + + # 加载关键词映射(命令关键词 -> 命令类型) + self._load_keyword_mapping() + + # 预编译正则表达式(性能优化) + self._compile_regex_patterns() + + # 加载纠错字典 + self._load_correction_dict() + + # 设置LRU缓存大小 + self._cache_size = cache_size + + # 创建缓存装饰器(用于分词、中文数字解析、完整预处理) + self._segment_text_cached = lru_cache(maxsize=cache_size)(self._segment_text_impl) + self._parse_chinese_number_cached = lru_cache(maxsize=128)(self._parse_chinese_number_impl) + self._preprocess_cached = lru_cache(maxsize=cache_size)(self._preprocess_impl) + self._preprocess_fast_cached = lru_cache(maxsize=cache_size)(self._preprocess_fast_impl) + + logger.info(f"文本预处理器初始化完成") + logger.info(f" 繁简转换: {'启用' if self.enable_traditional_to_simplified else '禁用'}") + logger.info(f" 分词: {'启用' if self.enable_segmentation else '禁用'}") + logger.info(f" 纠错: {'启用' if self.enable_correction else '禁用'}") + logger.info(f" 数字提取: {'启用' if self.enable_number_extraction else '禁用'}") + logger.info(f" 关键词检测: {'启用' if self.enable_keyword_detection else '禁用'}") + + def _load_keyword_mapping(self): + """加载关键词映射表(命令关键词 -> 命令类型)""" + self.keyword_to_command: Dict[str, str] = {} + + if KEYWORDS_CONFIG: + for command_type, keywords in KEYWORDS_CONFIG.items(): + if isinstance(keywords, list): + for keyword in keywords: + self.keyword_to_command[keyword] = command_type + elif isinstance(keywords, str): + self.keyword_to_command[keywords] = command_type + + # 预计算排序结果(按长度降序,优先匹配长关键词) + self.sorted_keywords = sorted( + self.keyword_to_command.keys(), + key=len, + reverse=True + ) + + logger.debug(f"加载了 {len(self.keyword_to_command)} 个关键词映射") + + def _compile_regex_patterns(self): + """预编译正则表达式(性能优化)""" + # 清理文本:去除特殊字符、多余空格 + self.pattern_clean_special = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9\s米每秒秒分小时\.]') + self.pattern_clean_spaces = re.compile(r'\s+') + + # 数字提取模式 + # 距离:数字 + (米|m|M|公尺)(排除速度单位) + self.pattern_distance = re.compile( + r'(\d+\.?\d*)\s*(?:米|m|M|公尺|meter|meters)(?!\s*[/每]?\s*秒)', + re.IGNORECASE + ) + + # 速度:数字 + (米每秒|m/s|米/秒|米秒|mps|MPS)(优先匹配) + # 支持"三米每秒"、"5米/秒"、"2.5米每秒"等格式 + self.pattern_speed = re.compile( + r'(?:速度\s*[::]?\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:米\s*[/每]?\s*秒|m\s*/\s*s|mps|MPS)', + re.IGNORECASE + ) + + # 时间:数字 + (秒|s|S|分钟|分|min|小时|时|h|H) + # 支持"持续10秒"、"5分钟"等格式 + self.pattern_duration = re.compile( + r'(?:持续\s*|持续\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:秒|s|S|分钟|分|min|小时|时|h|H)', + re.IGNORECASE + ) + + # 中文数字映射(用于识别"十米"、"五秒"等) + self.chinese_numbers = { + '零': 0, '一': 1, '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10, + '壹': 1, '贰': 2, '叁': 3, '肆': 4, '伍': 5, + '陆': 6, '柒': 7, '捌': 8, '玖': 9, '拾': 10, + '百': 100, '千': 1000, '万': 10000 + } + + # 中文数字模式(如"十米"、"五秒"、"二十米"、"三米每秒") + # 支持"二十"、"三十"等复合数字 + self.pattern_chinese_number = re.compile( + r'([零一二三四五六七八九十壹贰叁肆伍陆柒捌玖拾百千万]+)\s*(?:米|秒|分|小时|米\s*[/每]?\s*秒)' + ) + + def _load_correction_dict(self): + """加载纠错字典(同音字、常见错误)并编译正则表达式""" + # 无人机控制相关的常见同音字/错误字映射(只保留实际需要纠错的) + correction_pairs = [ + # 动作相关(同音字纠错) + ('起非', '起飞'), + ('降洛', '降落'), + ('悬廷', '悬停'), + ('停只', '停止'), + ] + + # 构建正则表达式模式(按长度降序,优先匹配长模式) + if correction_pairs: + # 按长度降序排序,优先匹配长模式 + sorted_pairs = sorted(correction_pairs, key=lambda x: len(x[0]), reverse=True) + + # 构建替换映射字典(用于快速查找) + self.correction_replacements = {wrong: correct for wrong, correct in sorted_pairs} + + # 编译单一正则表达式 + patterns = [re.escape(wrong) for wrong, _ in sorted_pairs] + self.correction_pattern = re.compile('|'.join(patterns)) + else: + self.correction_pattern = None + self.correction_replacements = {} + + logger.debug(f"加载了 {len(correction_pairs)} 个纠错规则") + + def clean_text(self, text: str) -> str: + """ + 清理文本:去除特殊字符、多余空格 + + Args: + text: 原始文本 + + Returns: + 清理后的文本 + """ + if not text: + return "" + + # 去除特殊字符(保留中文、英文、数字、空格、常用标点) + text = self.pattern_clean_special.sub('', text) + + # 统一空格(多个空格合并为一个) + text = self.pattern_clean_spaces.sub(' ', text) + + # 去除首尾空格 + text = text.strip() + + return text + + def correct_text(self, text: str) -> str: + """ + 纠错:同音字、常见错误修正(使用正则表达式优化) + + Args: + text: 待纠错文本 + + Returns: + 纠错后的文本 + """ + if not self.enable_correction or not text or not self.correction_pattern: + return text + + # 使用正则表达式一次性替换所有模式(性能优化) + def replacer(match): + matched = match.group(0) + # 直接从字典中查找替换(O(1)查找) + return self.correction_replacements.get(matched, matched) + + return self.correction_pattern.sub(replacer, text) + + def traditional_to_simplified(self, text: str) -> str: + """ + 繁体转简体 + + Args: + text: 待转换文本 + + Returns: + 转换后的文本 + """ + if not self.enable_traditional_to_simplified or not self.opencc or not text: + return text + + try: + return self.opencc.convert(text) + except Exception as e: + logger.warning(f"繁简转换失败: {e}") + return text + + def _segment_text_impl(self, text: str) -> List[str]: + """ + 分词实现(内部方法,不带缓存) + + Args: + text: 待分词文本 + + Returns: + 分词结果列表 + """ + if not self.enable_segmentation or not text: + return [text] if text else [] + + try: + words = list(jieba.cut(text, cut_all=False)) + # 过滤空字符串 + words = [w.strip() for w in words if w.strip()] + return words + except Exception as e: + logger.warning(f"分词失败: {e}") + return [text] if text else [] + + def segment_text(self, text: str) -> List[str]: + """ + 分词(带缓存) + + Args: + text: 待分词文本 + + Returns: + 分词结果列表 + """ + return self._segment_text_cached(text) + + def extract_numbers(self, text: str) -> ExtractedParams: + """ + 提取数字和单位(距离、速度、时间) + + Args: + text: 待提取文本 + + Returns: + ExtractedParams对象,包含提取的参数 + """ + params = ExtractedParams() + + if not self.enable_number_extraction or not text: + return params + + # 优先提取速度(避免被误识别为距离) + speed_match = self.pattern_speed.search(text) + if speed_match: + try: + speed_str = speed_match.group(1) + # 尝试解析中文数字 + if speed_str.isdigit() or '.' in speed_str: + params.speed = float(speed_str) + else: + # 中文数字(使用缓存) + chinese_speed = self._parse_chinese_number(speed_str) + if chinese_speed is not None: + params.speed = float(chinese_speed) + except (ValueError, AttributeError): + pass + + # 提取距离(米,排除速度单位) + distance_match = self.pattern_distance.search(text) + if distance_match: + try: + params.distance = float(distance_match.group(1)) + except ValueError: + pass + + # 提取时间(秒) + duration_matches = self.pattern_duration.finditer(text) # 查找所有匹配 + for duration_match in duration_matches: + try: + duration_str = duration_match.group(1) + duration_unit = duration_match.group(2).lower() if len(duration_match.groups()) > 1 else '秒' + + # 解析数字(支持中文数字) + if duration_str.isdigit() or '.' in duration_str: + duration_value = float(duration_str) + else: + # 中文数字 + chinese_duration = self._parse_chinese_number(duration_str) + if chinese_duration is None: + continue + duration_value = float(chinese_duration) + + # 转换为秒 + if '分' in duration_unit or 'min' in duration_unit: + params.duration = duration_value * 60 + break # 取第一个匹配 + elif '小时' in duration_unit or 'h' in duration_unit: + params.duration = duration_value * 3600 + break + else: # 秒 + params.duration = duration_value + break + except (ValueError, IndexError, AttributeError): + continue + + # 尝试提取中文数字(如"十米"、"五秒"、"二十米"、"三米每秒") + chinese_matches = self.pattern_chinese_number.finditer(text) + for chinese_match in chinese_matches: + try: + chinese_num_str = chinese_match.group(1) + full_match = chinese_match.group(0) + + # 解析中文数字(使用缓存) + num_value = self._parse_chinese_number(chinese_num_str) + + if num_value is not None: + # 判断单位类型 + if '米每秒' in full_match or '米/秒' in full_match or '米每' in full_match: + # 速度单位 + if params.speed is None: + params.speed = float(num_value) + elif '米' in full_match and '秒' not in full_match: + # 距离单位(不包含"秒") + if params.distance is None: + params.distance = float(num_value) + elif '秒' in full_match and '米' not in full_match: + # 时间单位(不包含"米") + if params.duration is None: + params.duration = float(num_value) + elif '分' in full_match and '米' not in full_match: + # 时间单位(分钟) + if params.duration is None: + params.duration = float(num_value) * 60 + except (ValueError, IndexError, AttributeError): + continue + + return params + + def _parse_chinese_number_impl(self, chinese_num: str) -> Optional[int]: + """ + 解析中文数字实现(内部方法,不带缓存) + + Args: + chinese_num: 中文数字字符串 + + Returns: + 对应的阿拉伯数字,解析失败返回None + """ + if not chinese_num: + return None + + # 单个数字 + if chinese_num in self.chinese_numbers: + return self.chinese_numbers[chinese_num] + + # "十" -> 10 + if chinese_num == '十' or chinese_num == '拾': + return 10 + + # "十一" -> 11, "十二" -> 12, ... + if chinese_num.startswith('十') or chinese_num.startswith('拾'): + rest = chinese_num[1:] + if rest in self.chinese_numbers: + return 10 + self.chinese_numbers[rest] + + # "二十" -> 20, "三十" -> 30, ... + if chinese_num.endswith('十') or chinese_num.endswith('拾'): + prefix = chinese_num[:-1] + if prefix in self.chinese_numbers: + return self.chinese_numbers[prefix] * 10 + + # "二十五" -> 25, "三十五" -> 35, ... + if '十' in chinese_num or '拾' in chinese_num: + parts = chinese_num.replace('拾', '十').split('十') + if len(parts) == 2: + tens_part = parts[0] if parts[0] else '一' # "十五" -> parts[0]为空 + ones_part = parts[1] if parts[1] else '' + + tens = self.chinese_numbers.get(tens_part, 1) if tens_part else 1 + ones = self.chinese_numbers.get(ones_part, 0) if ones_part else 0 + + return tens * 10 + ones + + return None + + def _parse_chinese_number(self, chinese_num: str) -> Optional[int]: + """ + 解析中文数字(支持"十"、"二十"、"三十"、"五"等,带缓存) + + Args: + chinese_num: 中文数字字符串 + + Returns: + 对应的阿拉伯数字,解析失败返回None + """ + return self._parse_chinese_number_cached(chinese_num) + + def detect_keyword(self, text: str, words: Optional[List[str]] = None) -> Optional[str]: + """ + 检测命令关键词(使用缓存的排序结果) + + Args: + text: 待检测文本 + words: 分词结果(如果已分词,可传入以提高性能) + + Returns: + 检测到的命令类型(如"takeoff"、"forward"等),未检测到返回None + """ + if not self.enable_keyword_detection or not text: + return None + + # 如果已分词,优先使用分词结果匹配 + if words: + for word in words: + if word in self.keyword_to_command: + return self.keyword_to_command[word] + + # 使用缓存的排序结果(按长度降序,优先匹配长关键词) + for keyword in self.sorted_keywords: + if keyword in text: + return self.keyword_to_command[keyword] + + return None + + def preprocess(self, text: str) -> PreprocessedText: + """ + 完整的文本预处理流程(带缓存) + + Args: + text: 原始文本 + + Returns: + PreprocessedText对象,包含所有预处理结果 + """ + return self._preprocess_cached(text) + + def _preprocess_impl(self, text: str) -> PreprocessedText: + """ + 完整的文本预处理流程实现(内部方法,不带缓存) + + Args: + text: 原始文本 + + Returns: + PreprocessedText对象,包含所有预处理结果 + """ + if not text: + return PreprocessedText( + cleaned_text="", + normalized_text="", + words=[], + params=ExtractedParams(), + original_text=text + ) + + original_text = text + + # 1. 清理文本 + cleaned_text = self.clean_text(text) + + # 2. 纠错 + corrected_text = self.correct_text(cleaned_text) + + # 3. 繁简转换 + normalized_text = self.traditional_to_simplified(corrected_text) + + # 4. 分词 + words = self.segment_text(normalized_text) + + # 5. 提取数字和单位 + params = self.extract_numbers(normalized_text) + + # 6. 检测关键词 + command_keyword = self.detect_keyword(normalized_text, words) + params.command_keyword = command_keyword + + return PreprocessedText( + cleaned_text=cleaned_text, + normalized_text=normalized_text, + words=words, + params=params, + original_text=original_text + ) + + def preprocess_fast(self, text: str) -> Tuple[str, ExtractedParams]: + """ + 快速预处理(仅返回规范化文本和参数,不进行分词,带缓存) + + 适用于实时场景,性能优先 + + Args: + text: 原始文本 + + Returns: + (规范化文本, 提取的参数) + """ + return self._preprocess_fast_cached(text) + + def _preprocess_fast_impl(self, text: str) -> Tuple[str, ExtractedParams]: + """ + 快速预处理实现(内部方法,不带缓存) + + Args: + text: 原始文本 + + Returns: + (规范化文本, 提取的参数) + """ + if not text: + return "", ExtractedParams() + + # 1. 清理 + cleaned = self.clean_text(text) + + # 2. 纠错 + corrected = self.correct_text(cleaned) + + # 3. 繁简转换 + normalized = self.traditional_to_simplified(corrected) + + # 4. 提取参数(不进行分词,提高性能) + params = self.extract_numbers(normalized) + + # 5. 检测关键词(在完整文本中搜索) + params.command_keyword = self.detect_keyword(normalized, words=None) + + return normalized, params + + +# 全局单例(可选,用于提高性能) +_global_preprocessor: Optional[TextPreprocessor] = None + + +def get_preprocessor() -> TextPreprocessor: + """获取全局预处理器实例(单例模式)""" + global _global_preprocessor + if _global_preprocessor is None: + _global_preprocessor = TextPreprocessor() + return _global_preprocessor + + +if __name__ == "__main__": + from command import Command + + # 测试代码 + preprocessor = TextPreprocessor() + + test_cases = [ + "现在起飞往前飞,飞10米,速度为5米每秒", + "向前飞二十米,速度三米每秒", + "立刻降落", + "悬停五秒", + "向右飛十米", # 繁体测试 + "往左飛,速度2.5米/秒,持續10秒", + ] + + print("=" * 60) + print("文本预处理器测试") + print("=" * 60) + + for i, test_text in enumerate(test_cases, 1): + print(f"\n测试 {i}: {test_text}") + print("-" * 60) + + # 完整预处理 + result = preprocessor.preprocess(test_text) + print(f"原始文本: {result.original_text}") + print(f"清理后: {result.cleaned_text}") + print(f"规范化: {result.normalized_text}") + print(f"分词: {result.words}") + print(f"提取参数:") + print(f" 距离: {result.params.distance} 米") + print(f" 速度: {result.params.speed} 米/秒") + print(f" 时间: {result.params.duration} 秒") + print(f" 命令关键词: {result.params.command_keyword}") + + + command = Command.create(result.params.command_keyword, 1, result.params.distance, result.params.speed, result.params.duration) + print(f"命令: {command.to_dict()}") + + # 快速预处理 + fast_text, fast_params = preprocessor.preprocess_fast(test_text) + print(f"\n快速预处理结果:") + print(f" 规范化文本: {fast_text}") + print(f" 命令关键词: {fast_params.command_keyword}") diff --git a/voice_drone/core/tts.py b/voice_drone/core/tts.py new file mode 100644 index 0000000..b08427c --- /dev/null +++ b/voice_drone/core/tts.py @@ -0,0 +1,695 @@ +""" +TTS(Text-to-Speech)模块 - 基于 Kokoro ONNX 的中文实时合成 + +使用 Kokoro-82M-v1.1-zh-ONNX 模型进行文本转语音合成: +1. 文本 -> (可选)使用 misaki[zh] 做 G2P,得到音素串 +2. 音素字符 -> 根据 tokenizer vocab 映射为 token id 序列 +3. 通过 ONNX Runtime 推理生成 24kHz 单声道语音 + +说明: +- 主要依赖: onnxruntime + numpy +- 如果已安装 misaki[zh] (推荐),效果更好: + pip install "misaki[zh]" cn2an pypinyin jieba +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import onnxruntime as ort + +# 仅保留 ERROR,避免加载 Kokoro 时大量 ConstantFolding/Reciprocal 警告刷屏(不影响推理结果) +try: + ort.set_default_logger_severity(3) +except Exception: + pass + +from voice_drone.core.configuration import SYSTEM_TTS_CONFIG +from voice_drone.logging_ import get_logger + +# voice_drone/core/tts.py -> voice_drone_assistant 根目录 +_PROJECT_ROOT = Path(__file__).resolve().parents[2] + +logger = get_logger("tts.kokoro_onnx") + + +def _tts_model_dir_candidates(rel: Path) -> List[Path]: + if rel.is_absolute(): + return [rel] + out: List[Path] = [_PROJECT_ROOT / rel] + if rel.parts and rel.parts[0] == "models": + out.append(_PROJECT_ROOT.parent / "src" / rel) + return out + + +def _resolve_kokoro_model_dir(raw: str | Path) -> Path: + """含 tokenizer.json 的目录;支持子工程 models/ 缺失时回退到上级仓库 src/models/。""" + p = Path(raw) + for c in _tts_model_dir_candidates(p): + if (c / "tokenizer.json").is_file(): + return c.resolve() + for c in _tts_model_dir_candidates(p): + if c.is_dir(): + logger.warning( + "Kokoro 目录存在但未找到 tokenizer.json: %s(将仍使用该路径,后续可能报错)", + c, + ) + return c.resolve() + return (_PROJECT_ROOT / p).resolve() + + +class KokoroOnnxTTS: + """ + Kokoro 中文 ONNX 文本转语音封装 + + 基本用法: + tts = KokoroOnnxTTS() + audio, sr = tts.synthesize("你好,世界") + + 返回: + audio: np.ndarray[float32] 形状为 (N,), 范围约 [-1, 1] + sr: int 采样率(默认 24000) + """ + + def __init__(self, config: Optional[dict] = None) -> None: + # 读取系统 TTS 配置 + self.config = config or SYSTEM_TTS_CONFIG or {} + + # 模型根目录(包含 onnx/、tokenizer.json、voices/) + _raw_dir = self.config.get( + "model_dir", "models/Kokoro-82M-v1.1-zh-ONNX" + ) + model_dir = _resolve_kokoro_model_dir(_raw_dir) + self.model_dir = model_dir + + # ONNX 模型文件名(位于 model_dir/onnx 下;若 onnx/ 下没有可改配置为根目录文件名) + self.model_name = self.config.get("model_name", "model_q4.onnx") + self.onnx_path = model_dir / "onnx" / self.model_name + if not self.onnx_path.is_file(): + alt = model_dir / self.model_name + if alt.is_file(): + self.onnx_path = alt + + # 语音风格(voices 子目录下的 *.bin, 这里不含扩展名) + self.voice_name = self.config.get("voice", "zf_001") + self.voice_path = model_dir / "voices" / f"{self.voice_name}.bin" + + # 语速与输出采样率 + self.speed = float(self.config.get("speed", 1.0)) + self.sample_rate = int(self.config.get("sample_rate", 24000)) + + # tokenizer.json 路径(本地随 ONNX 模型一起提供) + self.tokenizer_path = model_dir / "tokenizer.json" + + # 初始化组件 + self._session: Optional[ort.InferenceSession] = None + self._vocab: Optional[dict] = None + self._voices: Optional[np.ndarray] = None + self._g2p = None # misaki[zh] G2P, 如不可用则退化为直接使用原始文本 + + self._load_all() + + # ------------------------------------------------------------------ # + # 对外主接口 + # ------------------------------------------------------------------ # + def synthesize(self, text: str) -> Tuple[np.ndarray, int]: + """ + 文本转语音 + + Args: + text: 输入文本(推荐为简体中文) + + Returns: + (audio, sample_rate) + """ + if not text or not text.strip(): + raise ValueError("TTS 输入文本不能为空") + + phonemes = self._text_to_phonemes(text) + token_ids = self._phonemes_to_token_ids(phonemes) + + if len(token_ids) == 0: + raise ValueError(f"TTS: 文本在当前 vocab 下无法映射到任何 token, text={text!r}") + + # 按 Kokoro-ONNX 官方示例约定: + # - token 序列长度 <= 510 + # - 前后各添加 pad token 0 + if len(token_ids) > 510: + logger.warning(f"TTS: token 长度 {len(token_ids)} > 510, 将被截断为 510") + token_ids = token_ids[:510] + + tokens = np.array([[0, *token_ids, 0]], dtype=np.int64) # shape: (1, <=512) + + # 根据 token 数量选择 style 向量 + assert self._voices is not None, "TTS: voices 尚未初始化" + voices = self._voices # shape: (N, 1, 256) + idx = min(len(token_ids), voices.shape[0] - 1) + style = voices[idx] # shape: (1, 256) + + speed = np.array([self.speed], dtype=np.float32) + + # ONNX 输入名约定: input_ids, style, speed + assert self._session is not None, "TTS: ONNX Session 尚未初始化" + session = self._session + audio = session.run( + None, + { + "input_ids": tokens, + "style": style, + "speed": speed, + }, + )[0] + + # 兼容不同导出形状: + # - 标准 Kokoro ONNX: (1, N) + # - 也有可能是 (N,) + audio = audio.astype(np.float32) + if audio.ndim == 2 and audio.shape[0] == 1: + audio = audio[0] + elif audio.ndim > 2: + # 极端情况: 压缩多余维度 + audio = np.squeeze(audio) + + return audio, self.sample_rate + + def synthesize_to_file(self, text: str, wav_path: str) -> str: + """ + 文本合成并保存为 wav 文件(16-bit PCM) + 需要依赖 scipy, 可选: + pip install scipy + """ + try: + import scipy.io.wavfile as wavfile # type: ignore + except Exception as e: # pragma: no cover - 仅在未安装时触发 + raise RuntimeError("保存到 wav 需要安装 scipy, 请先执行: pip install scipy") from e + + audio, sr = self.synthesize(text) + # 简单归一化并转为 int16 + max_val = float(np.max(np.abs(audio)) or 1.0) + audio_int16 = np.clip(audio / max_val, -1.0, 1.0) + audio_int16 = (audio_int16 * 32767.0).astype(np.int16) + + # 某些 SciPy 版本对一维/零维数组支持不统一,这里显式加上通道维度 + if audio_int16.ndim == 0: + audio_to_save = audio_int16.reshape(-1, 1) # 标量 -> (1,1) + elif audio_int16.ndim == 1: + audio_to_save = audio_int16.reshape(-1, 1) # (N,) -> (N,1) 单声道 + else: + audio_to_save = audio_int16 + + wavfile.write(wav_path, sr, audio_to_save) + return wav_path + + # ------------------------------------------------------------------ # + # 内部初始化 + # ------------------------------------------------------------------ # + def _load_all(self) -> None: + self._load_tokenizer_vocab() + self._load_voices() + self._load_onnx_session() + self._init_g2p() + + def _load_tokenizer_vocab(self) -> None: + """ + 从本地 tokenizer.json 载入 vocab 映射: + token(str) -> id(int) + """ + if not self.tokenizer_path.exists(): + raise FileNotFoundError(f"TTS: 未找到 tokenizer.json: {self.tokenizer_path}") + + with open(self.tokenizer_path, "r", encoding="utf-8") as f: + data = json.load(f) + + model = data.get("model") or {} + vocab = model.get("vocab") + if not isinstance(vocab, dict): + raise ValueError("TTS: tokenizer.json 格式不正确, 未找到 model.vocab 字段") + + # 保存为: 字符 -> id + self._vocab = {k: int(v) for k, v in vocab.items()} + logger.info(f"TTS: tokenizer vocab 加载完成, 词表大小: {len(self._vocab)}") + + def _load_voices(self) -> None: + """ + 载入语音风格向量(voices/*.bin) + """ + if not self.voice_path.exists(): + raise FileNotFoundError(f"TTS: 未找到语音风格文件: {self.voice_path}") + + voices = np.fromfile(self.voice_path, dtype=np.float32) + try: + voices = voices.reshape(-1, 1, 256) + except ValueError as e: + raise ValueError( + f"TTS: 语音风格文件形状不符合预期, 无法 reshape 为 (-1,1,256): {self.voice_path}" + ) from e + + self._voices = voices + logger.info( + f"TTS: 语音风格文件加载完成: {self.voice_name}, 可用 style 数量: {voices.shape[0]}" + ) + + def _load_onnx_session(self) -> None: + """ + 创建 ONNX Runtime 推理会话 + """ + if not self.onnx_path.exists(): + raise FileNotFoundError(f"TTS: 未找到 ONNX 模型文件: {self.onnx_path}") + + sess_options = ort.SessionOptions() + # 启用所有图优化 + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + # RK3588 等多核 CPU:可用环境变量固定 ORT 线程,避免过小/过大(默认 0 表示交给 ORT 自动) + _ti = os.environ.get("ROCKET_TTS_ORT_INTRA_OP_THREADS", "").strip() + if _ti.isdigit() and int(_ti) > 0: + sess_options.intra_op_num_threads = int(_ti) + _te = os.environ.get("ROCKET_TTS_ORT_INTER_OP_THREADS", "").strip() + if _te.isdigit() and int(_te) > 0: + sess_options.inter_op_num_threads = int(_te) + + # 简单的 CPU 推理(如需 GPU, 可在此扩展 providers) + self._session = ort.InferenceSession( + str(self.onnx_path), + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + + logger.info(f"TTS: Kokoro ONNX 模型加载完成: {self.onnx_path}") + + def _init_g2p(self) -> None: + """ + 初始化中文 G2P (基于 misaki[zh])。 + 如果环境中未安装 misaki, 则退化为直接使用原始文本字符做映射。 + """ + try: + from misaki import zh # type: ignore + + # 兼容不同版本的 misaki: + # - 新版: ZHG2P(version=...) 可用 + # - 旧版: ZHG2P() 不接受参数 + try: + self._g2p = zh.ZHG2P(version="1.1") # type: ignore[call-arg] + except TypeError: + self._g2p = zh.ZHG2P() # type: ignore[call-arg] + + logger.info("TTS: 已启用 misaki[zh] G2P, 将使用音素级别映射") + except Exception as e: + self._g2p = None + logger.warning( + "TTS: 未安装或无法导入 misaki[zh], 将直接基于原始文本字符做 token 映射, " + "合成效果可能较差。建议执行: pip install \"misaki[zh]\" cn2an pypinyin jieba" + ) + logger.debug(f"TTS: G2P 初始化失败原因: {e!r}") + + # ------------------------------------------------------------------ # + # 文本 -> 音素 / token + # ------------------------------------------------------------------ # + def _text_to_phonemes(self, text: str) -> str: + """ + 文本 -> 音素串 + + - 若 misaki[zh] 可用, 则使用 ZHG2P(version='1.1') 得到音素串 + - 否则, 直接返回原始文本(后续按字符映射) + """ + if self._g2p is None: + return text.strip() + + # 兼容不同版本的 misaki: + # - 有的返回 (phonemes, tokens) + # - 有的只返回 phonemes 字符串 + result = self._g2p(text) + if isinstance(result, tuple) or isinstance(result, list): + ps = result[0] + else: + ps = result + + if not ps: + # 回退: 如果 G2P 返回空, 使用原始文本 + logger.warning("TTS: G2P 结果为空, 回退为原始文本") + return text.strip() + return ps + + def _phonemes_to_token_ids(self, phonemes: str) -> List[int]: + """ + 将音素串映射为 token id 序列 + + 直接按字符级别查表: + - 每个字符在 vocab 中有唯一 id + - 空格本身也是一个 token (id=16) + """ + assert self._vocab is not None, "TTS: vocab 尚未初始化" + vocab = self._vocab + + token_ids: List[int] = [] + unknown_chars = set() + + for ch in phonemes: + if ch == "\n": + continue + tid = vocab.get(ch) + if tid is None: + unknown_chars.add(ch) + continue + token_ids.append(int(tid)) + + if unknown_chars: + logger.debug(f"TTS: 存在无法映射到 vocab 的字符: {unknown_chars}") + + return token_ids + + +def _resolve_output_device_id(raw: object) -> Optional[int]: + """ + 将配置中的 output_device 解析为 sounddevice 设备索引。 + None / 空:返回 None,表示使用 sd 默认输出。 + """ + import sounddevice as sd # type: ignore + + if raw is None: + return None + if isinstance(raw, bool): + return None + if isinstance(raw, int): + return raw if raw >= 0 else None + s = str(raw).strip() + if not s or s.lower() in ("null", "none", "default", ""): + return None + if s.isdigit(): + return int(s) + needle = s.lower() + devices = sd.query_devices() + matches: List[int] = [] + for i, d in enumerate(devices): + if int(d.get("max_output_channels", 0) or 0) <= 0: + continue + name = (d.get("name") or "").lower() + if needle in name: + matches.append(i) + if not matches: + logger.warning( + f"TTS: 未找到名称包含 {s!r} 的输出设备,将使用系统默认输出。" + "请检查 system.yaml 中 tts.output_device 或查看启动日志中的设备列表。" + ) + return None + if len(matches) > 1: + logger.info( + f"TTS: 名称 {s!r} 匹配到多个输出设备索引 {matches},使用第一个 {matches[0]}" + ) + return matches[0] + + +_playback_dev_cache: Optional[int] = None +_playback_dev_cache_key: Optional[str] = None +_playback_dev_cache_ready: bool = False + + +def get_playback_output_device_id() -> Optional[int]: + """从 SYSTEM_TTS_CONFIG 解析并缓存播放设备索引(None=默认输出)。""" + global _playback_dev_cache, _playback_dev_cache_key, _playback_dev_cache_ready + + cfg = SYSTEM_TTS_CONFIG or {} + raw = cfg.get("output_device") + key = repr(raw) + if _playback_dev_cache_ready and _playback_dev_cache_key == key: + return _playback_dev_cache + dev_id = _resolve_output_device_id(raw) + _playback_dev_cache = dev_id + _playback_dev_cache_key = key + _playback_dev_cache_ready = True + if dev_id is not None: + import sounddevice as sd # type: ignore + + info = sd.query_devices(dev_id) + logger.info( + f"TTS: 播放将使用输出设备 index={dev_id} name={info.get('name', '?')!r}" + ) + else: + logger.info("TTS: 播放使用系统默认输出设备(未指定或无法匹配 tts.output_device)") + return dev_id + + +def _sounddevice_default_output_index(): + """sounddevice 0.5+ 的 default.device 可能是 _InputOutputPair,需取 [1] 为输出索引。""" + import sounddevice as sd # type: ignore + + default = sd.default.device + if isinstance(default, (list, tuple)): + return int(default[1]) + if hasattr(default, "__getitem__"): + try: + return int(default[1]) + except (IndexError, TypeError, ValueError): + pass + try: + return int(default) + except (TypeError, ValueError): + return None + + +def log_sounddevice_output_devices() -> None: + """列出所有可用输出设备及当前默认输出,便于配置 tts.output_device。""" + try: + import sounddevice as sd # type: ignore + + out_idx = _sounddevice_default_output_index() + logger.info("sounddevice 输出设备列表(用于配置 tts.output_device 索引或名称子串):") + for i, d in enumerate(sd.query_devices()): + if int(d.get("max_output_channels", 0) or 0) <= 0: + continue + mark = " <- 当前默认输出" if out_idx is not None and int(out_idx) == i else "" + logger.info(f" [{i}] {d.get('name', '?')}{mark}") + except Exception as e: + logger.warning(f"无法枚举 sounddevice 输出设备: {e}") + + +def _resample_playback_audio(audio: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray: + """将波形重采样到设备采样率;优先 librosa(kaiser_best),失败则回退 scipy 多相。""" + from math import gcd + + from scipy.signal import resample # type: ignore + from scipy.signal import resample_poly # type: ignore + + if abs(sr_in - sr_out) < 1: + return np.asarray(audio, dtype=np.float32) + try: + import librosa # type: ignore + + return librosa.resample( + np.asarray(audio, dtype=np.float32), + orig_sr=sr_in, + target_sr=sr_out, + res_type="kaiser_best", + ).astype(np.float32) + except Exception as e: + logger.debug(f"TTS: librosa 重采样不可用,使用多相重采样: {e!r}") + try: + g = gcd(int(sr_in), int(sr_out)) + if g > 0: + up = int(sr_out) // g + down = int(sr_in) // g + return resample_poly(audio, up, down).astype(np.float32) + except Exception as e2: + logger.debug(f"TTS: resample_poly 失败,回退 FFT resample: {e2!r}") + num = max(1, int(len(audio) * float(sr_out) / float(sr_in))) + return resample(audio, num).astype(np.float32) + + +def _fade_playback_edges(audio: np.ndarray, sample_rate: int, fade_ms: float) -> np.ndarray: + """极短线性淡入淡出,减轻扬声器/驱动在段首段尾的爆音与杂音感。""" + if fade_ms <= 0 or audio.size < 16: + return audio + n = int(float(sample_rate) * fade_ms / 1000.0) + n = min(n, len(audio) // 4) + if n <= 0: + return audio + out = np.asarray(audio, dtype=np.float32, order="C").copy() + fade = np.linspace(0.0, 1.0, n, dtype=np.float32) + out[:n] *= fade + out[-n:] *= fade[::-1] + return out + + +def play_tts_audio( + audio: np.ndarray, + sample_rate: int, + *, + output_device: Optional[object] = None, +) -> None: + """ + 使用 sounddevice 播放单声道 float32 音频(阻塞至播放结束)。 + + 在 Windows 上 PortAudio/sounddevice 从非主线程调用时经常出现「无声音、无报错」, + 因此本项目中应答播报应在主线程(采集循环所在线程)调用本函数。 + + 另:多数 Realtek/WASAPI 设备对 24000Hz 播放会「完全无声」且不报错,需重采样到设备 + default_samplerate(常见 48000/44100),并用 OutputStream 写出。 + + Args: + output_device: 若指定,覆盖 system.yaml 的 tts.output_device(设备索引或名称子串)。 + """ + import sounddevice as sd # type: ignore + + cfg = SYSTEM_TTS_CONFIG or {} + force_native = bool(cfg.get("playback_resample_to_device_native", True)) + do_normalize = bool(cfg.get("playback_peak_normalize", True)) + gain = float(cfg.get("playback_gain", 1.0)) + if gain <= 0: + gain = 1.0 + fade_ms = float(cfg.get("playback_edge_fade_ms", 8.0)) + latency = cfg.get("playback_output_latency", "low") + if latency not in ("low", "medium", "high"): + latency = "low" + + audio = np.asarray(audio, dtype=np.float32).squeeze() + if audio.ndim > 1: + audio = np.squeeze(audio) + if audio.size == 0: + logger.warning("TTS: 播放跳过,音频长度为 0") + return + + if output_device is not None: + dev = _resolve_output_device_id(output_device) + else: + dev = get_playback_output_device_id() + if dev is None: + dev = _sounddevice_default_output_index() + if dev is None: + logger.warning("TTS: 无法解析输出设备索引,使用 sounddevice 默认输出") + else: + dev = int(dev) + + info = sd.query_devices(dev) if dev is not None else sd.query_devices(kind="output") + native_sr = int(float(info.get("default_samplerate", 48000))) + sr_out = int(sample_rate) + if force_native and native_sr > 0 and abs(native_sr - sr_out) > 1: + audio = _resample_playback_audio(audio, sr_out, native_sr) + sr_out = native_sr + logger.info( + f"TTS: 播放重采样 {sample_rate}Hz -> {sr_out}Hz(匹配设备 default_samplerate,避免 Windows 无声)" + ) + + peak_before = float(np.max(np.abs(audio))) + if do_normalize and peak_before > 1e-8 and peak_before > 0.95: + audio = (audio / peak_before * 0.92).astype(np.float32, copy=False) + + if gain != 1.0: + audio = (audio * np.float32(gain)).astype(np.float32, copy=False) + + audio = _fade_playback_edges(audio, sr_out, fade_ms) + + peak = float(np.max(np.abs(audio))) + rms = float(np.sqrt(np.mean(np.square(audio)))) + dname = info.get("name", "?") if isinstance(info, dict) else "?" + logger.info( + f"TTS: 播放 峰值={peak:.5f} RMS={rms:.5f} sr={sr_out}Hz 设备={dev!r} ({dname!r})" + ) + if peak < 1e-8: + logger.warning("TTS: 波形接近静音,请检查合成是否异常") + + audio = np.clip(audio, -1.0, 1.0).astype(np.float32, copy=False) + block = audio.reshape(-1, 1) + + try: + with sd.OutputStream( + device=dev, + channels=1, + samplerate=sr_out, + dtype="float32", + latency=latency, + ) as stream: + stream.write(block) + except Exception as e: + logger.warning(f"TTS: OutputStream 失败,回退 sd.play: {e}", exc_info=True) + sd.play(block, samplerate=sr_out, device=dev) + sd.wait() + + +def play_wav_path( + path: str | Path, + *, + output_device: Optional[object] = None, +) -> None: + """ + 播放 16-bit PCM WAV(单声道或立体声下混为单声道),走与 synthesize + play_tts_audio + 相同的 sounddevice 输出路径(含 ROCKET_TTS_DEVICE / yaml 设备解析)。 + """ + import wave + + p = Path(path) + with wave.open(str(p), "rb") as wf: + ch = int(wf.getnchannels()) + sw = int(wf.getsampwidth()) + sr = int(wf.getframerate()) + nframes = int(wf.getnframes()) + if sw != 2: + raise ValueError(f"仅支持 16-bit PCM: {p}") + raw = wf.readframes(nframes) + mono = np.frombuffer(raw, dtype=" None: + """ + 合成并播放一段语音;失败时仅打日志,不向外抛异常(适合命令成功后的反馈)。 + """ + if not text or not str(text).strip(): + return + try: + engine = tts or KokoroOnnxTTS() + t = str(text).strip() + logger.info(f"TTS: 开始合成并播放: {t!r}") + audio, sr = engine.synthesize(t) + play_tts_audio(audio, sr, output_device=output_device) + logger.info("TTS: 播放完成") + except Exception as e: + logger.warning(f"TTS 播放失败: {e}", exc_info=True) + + +__all__ = [ + "KokoroOnnxTTS", + "play_tts_audio", + "play_wav_path", + "speak_text", + "get_playback_output_device_id", + "log_sounddevice_output_devices", +] + +if __name__ == "__main__": + # 与主程序一致:使用 play_tts_audio(含重采样到设备 native 采样率) + tts = KokoroOnnxTTS() + + text = "任务执行完成,开始返航降落" + + print(f"正在合成语音: {text}") + audio, sr = tts.synthesize(text) + + print("正在播放(与主程序相同 play_tts_audio 路径)...") + try: + play_tts_audio(audio, sr) + print("播放结束。") + except Exception as e: + print(f"播放失败: {e}") + + # === 保存为 WAV 文件(可选)=== + try: + output_path = "任务执行完成,开始返航降落.wav" + tts.synthesize_to_file(text, output_path) + print(f"音频已保存至: {output_path}") + except RuntimeError as e: + print(f"保存失败(可能缺少 scipy): {e}") + diff --git a/voice_drone/core/tts_ack_cache.py b/voice_drone/core/tts_ack_cache.py new file mode 100644 index 0000000..1abce65 --- /dev/null +++ b/voice_drone/core/tts_ack_cache.py @@ -0,0 +1,152 @@ +""" +应答 TTS 波形磁盘缓存:文案与 TTS 配置未变时跳过逐条合成,加快启动。 + +缓存目录:项目根下 cache/ack_tts_pcm/ +""" + +from __future__ import annotations + +import hashlib +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from voice_drone.core.configuration import SYSTEM_TTS_CONFIG + +# 与 src/core/configuration.py 一致:src/core/tts_ack_cache.py -> parents[2] +_PROJECT_ROOT = Path(__file__).resolve().parents[2] + +ACK_PCM_CACHE_DIR = _PROJECT_ROOT / "cache" / "ack_tts_pcm" +MANIFEST_NAME = "manifest.json" +CACHE_FORMAT = 1 + + +def _tts_signature() -> dict: + tts = SYSTEM_TTS_CONFIG or {} + return { + "model_dir": str(tts.get("model_dir", "")), + "model_name": str(tts.get("model_name", "")), + "voice": str(tts.get("voice", "")), + "speed": round(float(tts.get("speed", 1.0)), 6), + "sample_rate": int(tts.get("sample_rate", 24000)), + } + + +def compute_ack_pcm_fingerprint( + unique_phrases: List[str], + *, + global_text: Optional[str] = None, + mode_phrases: bool = True, +) -> str: + """文案 + TTS 签名变化则指纹变,磁盘缓存失效。""" + payload = { + "cache_format": CACHE_FORMAT, + "tts": _tts_signature(), + "mode_phrases": mode_phrases, + } + if mode_phrases: + payload["phrases"] = sorted(unique_phrases) + else: + payload["global_text"] = (global_text or "").strip() + raw = json.dumps(payload, sort_keys=True, ensure_ascii=False) + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + +def _phrase_file_stem(fingerprint: str, phrase: str) -> str: + h = hashlib.sha256(fingerprint.encode("utf-8")) + h.update(b"\0") + h.update(phrase.encode("utf-8")) + return h.hexdigest()[:40] + + +def _load_one_npz(path: Path) -> Optional[Tuple[np.ndarray, int]]: + try: + z = np.load(path, allow_pickle=False) + audio = np.asarray(z["audio"], dtype=np.float32).squeeze() + sr = int(np.asarray(z["sr"]).reshape(-1)[0]) + if audio.size == 0 or sr <= 0: + return None + return (audio, sr) + except Exception: + return None + + +def load_cached_phrases( + unique_phrases: List[str], + fingerprint: str, +) -> Tuple[Dict[str, Tuple[np.ndarray, int]], List[str]]: + """ + 从磁盘加载与 fingerprint 匹配的缓存。 + + Returns: + (已加载的 phrase -> (audio, sr), 仍需合成的 phrase 列表) + """ + out: Dict[str, Tuple[np.ndarray, int]] = {} + if not unique_phrases: + return {}, [] + + cache_dir = ACK_PCM_CACHE_DIR + manifest_path = cache_dir / MANIFEST_NAME + if not manifest_path.is_file(): + return {}, list(unique_phrases) + + try: + with open(manifest_path, "r", encoding="utf-8") as f: + manifest = json.load(f) + except Exception: + return {}, list(unique_phrases) + + if int(manifest.get("format", 0)) != CACHE_FORMAT: + return {}, list(unique_phrases) + if manifest.get("fingerprint") != fingerprint: + return {}, list(unique_phrases) + + files: Dict[str, str] = manifest.get("files") or {} + missing: List[str] = [] + + for phrase in unique_phrases: + fname = files.get(phrase) + if not fname: + missing.append(phrase) + continue + path = cache_dir / fname + if not path.is_file(): + missing.append(phrase) + continue + loaded = _load_one_npz(path) + if loaded is None: + missing.append(phrase) + continue + out[phrase] = loaded + + return out, missing + + +def persist_phrases(fingerprint: str, phrase_pcm: Dict[str, Tuple[np.ndarray, int]]) -> None: + """写入/更新整包 manifest 与各句 npz(覆盖同名 manifest)。""" + if not phrase_pcm: + return + + cache_dir = ACK_PCM_CACHE_DIR + cache_dir.mkdir(parents=True, exist_ok=True) + + files: Dict[str, str] = {} + for phrase, (audio, sr) in phrase_pcm.items(): + stem = _phrase_file_stem(fingerprint, phrase) + fname = f"{stem}.npz" + path = cache_dir / fname + audio = np.asarray(audio, dtype=np.float32).squeeze() + np.savez_compressed(path, audio=audio, sr=np.array([int(sr)], dtype=np.int32)) + files[phrase] = fname + + manifest = { + "format": CACHE_FORMAT, + "fingerprint": fingerprint, + "files": files, + } + tmp = cache_dir / (MANIFEST_NAME + ".tmp") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(manifest, f, ensure_ascii=False, indent=0) + tmp.replace(cache_dir / MANIFEST_NAME) diff --git a/voice_drone/core/vad.py b/voice_drone/core/vad.py new file mode 100644 index 0000000..c99cc3a --- /dev/null +++ b/voice_drone/core/vad.py @@ -0,0 +1,429 @@ +""" +语音活动检测(VAD)模块 - 纯 ONNX Runtime 版 Silero VAD + +使用 Silero VAD 的 ONNX 模型检测语音活动,识别语音的开始和结束。 +不依赖 PyTorch/silero_vad 包,只依赖 onnxruntime + numpy。 +""" + +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import onnxruntime as ort +import multiprocessing +from voice_drone.core.configuration import ( + SYSTEM_AUDIO_CONFIG, + SYSTEM_RECOGNIZER_CONFIG, + SYSTEM_VAD_CONFIG, +) +from voice_drone.logging_ import get_logger +from voice_drone.tools.wrapper import time_cost +logger = get_logger("vad.silero_onnx") + + +class VAD: + """ + 语音活动检测器 + + 使用 ONNX Runtime + """ + + def __init__(self): + """ + 初始化 VAD 检测器 + + Args: + config: 可选的配置字典,用于覆盖默认配置 + """ + # ---- 从 system.yaml 的 vad 部分读取默认配置 ---- + # system.yaml: + # vad: + # threshold: 0.65 + # start_frame: 3 + # end_frame: 10 + # min_silence_duration_s: 0.5 + # max_silence_duration_s: 30 + # model_path: "src/models/silero_vad.onnx" + vad_conf = SYSTEM_VAD_CONFIG + + # 语音概率阈值(YAML 可能是字符串) + self.speech_threshold = float(vad_conf.get("threshold", 0.5)) + + # 连续多少帧检测到语音才认为“开始说话” + self.speech_start_frames = int(vad_conf.get("start_frame", 3)) + + # 连续多少帧检测到静音才认为“结束说话” + self.silence_end_frames = int(vad_conf.get("end_frame", 10)) + + # 可选: 最短/最长语音段时长(秒),可以在上层按需使用 + self.min_speech_duration = vad_conf.get("min_silence_duration_s") + self.max_speech_duration = vad_conf.get("max_silence_duration_s") + + # 采样率来自 audio 配置 + self.sample_rate = SYSTEM_AUDIO_CONFIG.get("sample_rate") + + # 与 recognizer 一致:能量 VAD 时不加载 Silero(避免无模型文件仍强加载) + _ev_env = os.environ.get("ROCKET_ENERGY_VAD", "").lower() in ( + "1", + "true", + "yes", + ) + _yaml_backend = str( + SYSTEM_RECOGNIZER_CONFIG.get("vad_backend", "silero") + ).lower() + if _ev_env or _yaml_backend == "energy": + self.onnx_session = None + self.vad_model_path = None + self.window_size = 512 if int(self.sample_rate or 16000) == 16000 else 256 + self.input_name = None + self.sr_input_name = None + self.state_input_name = None + self.output_name = None + self.state = None + self.speech_frame_count = 0 + self.silence_frame_count = 0 + self.is_speaking = False + logger.info( + "VAD:能量(RMS)分段模式,跳过 Silero ONNX(与 ROCKET_ENERGY_VAD / vad_backend 一致)" + ) + return + + # ---- 加载 Silero VAD ONNX 模型 ---- + raw_mp = SYSTEM_VAD_CONFIG.get("model_path") + if not raw_mp: + raise FileNotFoundError( + "vad.model_path 未配置。若只用能量 VAD,请在 system.yaml 中设 " + "recognizer.vad_backend: energy 并设置 ROCKET_ENERGY_VAD=1" + ) + mp = Path(raw_mp) + if not mp.is_absolute(): + mp = Path(__file__).resolve().parents[2] / mp + self.vad_model_path = str(mp) + if not mp.is_file(): + raise FileNotFoundError( + f"Silero VAD 模型不存在: {self.vad_model_path}。请下载 silero_vad.onnx 到该路径," + "或改用能量 VAD:recognizer.vad_backend: energy 且 ROCKET_ENERGY_VAD=1" + ) + + try: + logger.info(f"正在加载 Silero VAD ONNX 模型: {self.vad_model_path}") + sess_options = ort.SessionOptions() + cpu_count = multiprocessing.cpu_count() + optimal_threads = min(4, cpu_count) + sess_options.intra_op_num_threads = optimal_threads + sess_options.inter_op_num_threads = optimal_threads + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + + self.onnx_session = ort.InferenceSession( + str(mp), + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + + inputs = self.onnx_session.get_inputs() + outputs = self.onnx_session.get_outputs() + if not inputs: + raise RuntimeError("VAD ONNX 模型没有输入节点") + + # ---- 解析输入 / 输出 ---- + # 典型的 Silero VAD ONNX 会有: + # - 输入: audio(input), 采样率(sr), 状态(state 或 h/c) + # - 输出: 语音概率 + 可选的新状态 + self.input_name = None + self.sr_input_name: Optional[str] = None + self.state_input_name: Optional[str] = None + + for inp in inputs: + name = inp.name + if self.input_name is None: + # 优先匹配常见名称,否则退回第一个 + if name in ("input", "audio", "waveform"): + self.input_name = name + else: + self.input_name = name + if name == "sr": + self.sr_input_name = name + if name in ("state", "h", "c", "hidden"): + self.state_input_name = name + + # 如果依然没有确定 input_name,兜底使用第一个 + if self.input_name is None: + self.input_name = inputs[0].name + + self.output_name = outputs[0].name if outputs else None + + # 预分配状态向量(如果模型需要) + self.state: Optional[np.ndarray] = None + if self.state_input_name is not None: + state_inp = next(i for i in inputs if i.name == self.state_input_name) + # state 的 shape 通常是 [1, N] 或 [N], 这里用 0 初始化 + state_shape = [ + int(d) if isinstance(d, int) and d > 0 else 1 + for d in (state_inp.shape or [1]) + ] + self.state = np.zeros(state_shape, dtype=np.float32) + + # 从输入 shape 推断窗口大小: (batch, samples) 或 (samples,) + input_shape = inputs[0].shape + win_size = None + if isinstance(input_shape, (list, tuple)) and len(input_shape) >= 1: + last_dim = input_shape[-1] + if isinstance(last_dim, int): + win_size = last_dim + if win_size is None: + win_size = 512 if self.sample_rate == 16000 else 256 + self.window_size = int(win_size) + + logger.info( + f"Silero VAD ONNX 模型加载完成: 输入={self.input_name}, 输出={self.output_name}, " + f"window_size={self.window_size}, sample_rate={self.sample_rate}" + ) + except Exception as e: + logger.error(f"Silero VAD ONNX 模型加载失败: {e}") + raise RuntimeError( + f"无法加载 Silero VAD: {e}。若无需 Silero,请设 ROCKET_ENERGY_VAD=1 且 " + "recognizer.vad_backend: energy" + ) from e + + # State tracking + self.speech_frame_count = 0 # Consecutive speech frame count + self.silence_frame_count = 0 # Consecutive silence frame count + self.is_speaking = False # Currently speaking + + logger.info( + "VADDetector 初始化完成: " + f"speech_threshold={self.speech_threshold}, " + f"speech_start_frames={self.speech_start_frames}, " + f"silence_end_frames={self.silence_end_frames}, " + f"sample_rate={self.sample_rate}Hz" + ) + + # @time_cost("VAD-语音检测耗时") + def is_speech(self, audio_chunk: bytes) -> bool: + """ + 检测音频块是否包含语音 + + Args: + audio_chunk: 音频数据(bytes),必须是 16kHz, 16-bit, 单声道 PCM + + Returns: + True 表示检测到语音,False 表示静音 + """ + try: + if self.onnx_session is None: + return False + # 将 bytes 转换为 numpy array(int16),确保 little-endian 字节序 + audio_array = np.frombuffer(audio_chunk, dtype=" required_samples: + num_chunks = len(audio_float) // required_samples + speech_probs = [] + + for i in range(num_chunks): + start_idx = i * required_samples + end_idx = start_idx + required_samples + chunk = audio_float[start_idx:end_idx] + + # 模型通常期望输入形状为 (1, samples) + input_data = chunk[np.newaxis, :].astype(np.float32) + ort_inputs = {self.input_name: input_data} + + # 如果模型需要 sr,state 等附加输入,一并提供 + if getattr(self, "sr_input_name", None) is not None: + # Silero VAD 一般期望 int64 采样率 + ort_inputs[self.sr_input_name] = np.array( + [self.sample_rate], dtype=np.int64 + ) + if getattr(self, "state_input_name", None) is not None and self.state is not None: + ort_inputs[self.state_input_name] = self.state + + outputs = self.onnx_session.run(None, ort_inputs) + + # 如果模型返回新的 state,更新内部状态 + if ( + getattr(self, "state_input_name", None) is not None + and len(outputs) > 1 + ): + self.state = outputs[1] + + prob = float(outputs[0].reshape(-1)[0]) + speech_probs.append(prob) + + speech_prob = float(np.mean(speech_probs)) + else: + input_data = audio_float[:required_samples][np.newaxis, :].astype(np.float32) + ort_inputs = {self.input_name: input_data} + + if getattr(self, "sr_input_name", None) is not None: + ort_inputs[self.sr_input_name] = np.array( + [self.sample_rate], dtype=np.int64 + ) + if getattr(self, "state_input_name", None) is not None and self.state is not None: + ort_inputs[self.state_input_name] = self.state + + outputs = self.onnx_session.run(None, ort_inputs) + + if ( + getattr(self, "state_input_name", None) is not None + and len(outputs) > 1 + ): + self.state = outputs[1] + + speech_prob = float(outputs[0].reshape(-1)[0]) + + return speech_prob >= self.speech_threshold + except Exception as e: + logger.error(f"VAD detection failed: {e}") + return False + + def is_speech_numpy(self, audio_array: np.ndarray) -> bool: + """ + 检测音频数组是否包含语音 + + Args: + audio_array: 音频数据(numpy array,dtype=int16) + + Returns: + True 表示检测到语音,False 表示静音 + """ + # 转换为 bytes + audio_bytes = audio_array.tobytes() + return self.is_speech(audio_bytes) + + def detect_speech_start(self, audio_chunk: bytes) -> bool: + """ + 检测语音开始 + + 需要连续检测到多帧语音才认为语音开始 + + Args: + audio_chunk: 音频数据块 + + Returns: + True 表示检测到语音开始 + """ + if self.is_speaking: + return False + + if self.is_speech(audio_chunk): + self.speech_frame_count += 1 + self.silence_frame_count = 0 + + if self.speech_frame_count >= self.speech_start_frames: + self.is_speaking = True + self.speech_frame_count = 0 + logger.info("Speech start detected") + return True + else: + self.speech_frame_count = 0 + + return False + + def detect_speech_end(self, audio_chunk: bytes) -> bool: + """ + 检测语音结束 + + 需要连续检测到多帧静音才认为语音结束 + + Args: + audio_chunk: 音频数据块 + + Returns: + True 表示检测到语音结束 + """ + if not self.is_speaking: + return False + + if not self.is_speech(audio_chunk): + self.silence_frame_count += 1 + self.speech_frame_count = 0 + + if self.silence_frame_count >= self.silence_end_frames: + self.is_speaking = False + self.silence_frame_count = 0 + logger.info("Speech end detected") + return True + else: + self.silence_frame_count = 0 + + return False + + def reset(self) -> None: + """ + 重置检测器状态 + + 清除帧计数、是否在说话标记,以及 Silero 的 RNN 状态(长间隔后应清零,避免与后续音频错位)。 + """ + self.speech_frame_count = 0 + self.silence_frame_count = 0 + self.is_speaking = False + if self.state is not None: + self.state.fill(0) + logger.debug("VAD detector state reset") + + +if __name__ == "__main__": + """ + 使用测试音频按帧扫描,统计语音帧比例,更直观地验证 VAD 是否工作正常。 + """ + import wave + + vad = VAD() + audio_file = "test/测试音频.wav" + + # 1. 读取 wav + with wave.open(audio_file, "rb") as wf: + n_channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + framerate = wf.getframerate() + n_frames = wf.getnframes() + raw = wf.readframes(n_frames) + + # 2. 转成 int16 数组 + audio = np.frombuffer(raw, dtype=" 单声道 + if n_channels == 2: + audio = audio.reshape(-1, 2) + audio = audio.mean(axis=1).astype(np.int16) + + # 4. 重采样到 VAD 所需采样率(通常 16k) + target_sr = vad.sample_rate + if framerate != target_sr: + x_old = np.linspace(0, 1, num=len(audio), endpoint=False) + x_new = np.linspace(0, 1, num=int(len(audio) * target_sr / framerate), endpoint=False) + audio = np.interp(x_new, x_old, audio).astype(np.int16) + + print("wav info:", n_channels, "ch,", framerate, "Hz") + print("audio len (samples):", len(audio), " target_sr:", target_sr) + + # 5. 按 VAD 窗口大小逐帧扫描 + frame_samples = vad.window_size + frame_bytes = frame_samples * 2 # int16 -> 2 字节 + audio_bytes = audio.tobytes() + num_frames = len(audio_bytes) // frame_bytes + + speech_frames = 0 + for i in range(num_frames): + chunk = audio_bytes[i * frame_bytes : (i + 1) * frame_bytes] + if vad.is_speech(chunk): + speech_frames += 1 + + speech_ratio = speech_frames / num_frames if num_frames > 0 else 0.0 + print("total frames:", num_frames) + print("speech frames:", speech_frames) + print("speech ratio:", speech_ratio) + print("has_any_speech:", speech_frames > 0) \ No newline at end of file diff --git a/voice_drone/core/wake_word.py b/voice_drone/core/wake_word.py new file mode 100644 index 0000000..54963d2 --- /dev/null +++ b/voice_drone/core/wake_word.py @@ -0,0 +1,375 @@ +""" +唤醒词检测模块 - 高性能实时唤醒词识别 + +支持: +- 精确匹配 +- 模糊匹配(同音字、拼音) +- 部分匹配 +- 配置化变体映射 + +性能优化: +- 预编译正则表达式 +- LRU缓存匹配结果 +- 优化字符串操作 +""" + +import re +from typing import Optional, List, Tuple +from functools import lru_cache +from voice_drone.logging_ import get_logger +from voice_drone.core.configuration import ( + WAKE_WORD_PRIMARY, + WAKE_WORD_VARIANTS, + WAKE_WORD_MATCHING_CONFIG +) + +logger = get_logger("wake_word") + +# 延迟加载可选依赖 +try: + from pypinyin import lazy_pinyin, Style + PYPINYIN_AVAILABLE = True +except ImportError: + PYPINYIN_AVAILABLE = False + logger.warning("pypinyin 未安装,拼音匹配功能将受限") + + +class WakeWordDetector: + """ + 唤醒词检测器 + + 支持多种匹配模式: + - 精确匹配:完全匹配唤醒词 + - 模糊匹配:同音字、拼音变体 + - 部分匹配:只匹配部分唤醒词 + """ + + def __init__(self): + """初始化唤醒词检测器""" + logger.info("初始化唤醒词检测器...") + + # 从配置加载 + self.primary = WAKE_WORD_PRIMARY + self.variants = WAKE_WORD_VARIANTS or [] + self.matching_config = WAKE_WORD_MATCHING_CONFIG or {} + + # 匹配配置 + self.enable_fuzzy = self.matching_config.get("enable_fuzzy", True) + self.enable_partial = self.matching_config.get("enable_partial", True) + self.ignore_case = self.matching_config.get("ignore_case", True) + self.ignore_spaces = self.matching_config.get("ignore_spaces", True) + self.min_match_length = self.matching_config.get("min_match_length", 2) + self.similarity_threshold = self.matching_config.get("similarity_threshold", 0.7) + + # 构建匹配模式 + self._build_patterns() + + logger.info(f"唤醒词检测器初始化完成") + logger.info(f" 主唤醒词: {self.primary}") + logger.info(f" 变体数量: {len(self.variants)}") + logger.info(f" 模糊匹配: {'启用' if self.enable_fuzzy else '禁用'}") + logger.info(f" 部分匹配: {'启用' if self.enable_partial else '禁用'}") + + def _build_patterns(self): + """构建匹配模式(预编译正则表达式)""" + # 标准化所有变体(去除空格、转小写等) + self.normalized_variants = [] + + for variant in self.variants: + normalized = self._normalize_text(variant) + if normalized: + self.normalized_variants.append(normalized) + + # 去重 + self.normalized_variants = list(set(self.normalized_variants)) + + # 按长度降序排序(优先匹配长模式) + self.normalized_variants.sort(key=len, reverse=True) + + # 构建正则表达式模式 + patterns = [] + for variant in self.normalized_variants: + # 转义特殊字符 + escaped = re.escape(variant) + patterns.append(escaped) + + # 编译单一正则表达式 + if patterns: + self.pattern = re.compile('|'.join(patterns), re.IGNORECASE if self.ignore_case else 0) + else: + self.pattern = None + + logger.debug(f"构建了 {len(self.normalized_variants)} 个匹配模式") + + def _normalize_text(self, text: str) -> str: + """ + 标准化文本(用于匹配) + + Args: + text: 原始文本 + + Returns: + 标准化后的文本 + """ + if not text: + return "" + + normalized = text + + # 忽略大小写 + if self.ignore_case: + normalized = normalized.lower() + + # 忽略空格 + if self.ignore_spaces: + normalized = normalized.replace(' ', '').replace('\t', '').replace('\n', '') + + return normalized.strip() + + def _get_pinyin(self, text: str) -> str: + """ + 获取文本的拼音(用于拼音匹配) + + Args: + text: 中文文本 + + Returns: + 拼音字符串(小写,无空格) + """ + if not PYPINYIN_AVAILABLE: + return "" + + try: + pinyin_list = lazy_pinyin(text, style=Style.NORMAL) + return ''.join(pinyin_list).lower() + except Exception as e: + logger.debug(f"拼音转换失败: {e}") + return "" + + def _fuzzy_match(self, text: str, variant: str) -> bool: + """ + 模糊匹配(同音字、拼音) + + Args: + text: 待匹配文本 + variant: 变体文本 + + Returns: + 是否匹配 + """ + # 1. 精确匹配(已标准化) + normalized_text = self._normalize_text(text) + normalized_variant = self._normalize_text(variant) + + if normalized_text == normalized_variant: + return True + + # 2. 拼音匹配 + if PYPINYIN_AVAILABLE: + text_pinyin = self._get_pinyin(text) + variant_pinyin = self._get_pinyin(variant) + + if text_pinyin and variant_pinyin: + # 完全匹配拼音 + if text_pinyin == variant_pinyin: + return True + + # 部分匹配拼音(至少匹配一半) + if len(variant_pinyin) >= 2: + # 检查是否包含变体的拼音 + if variant_pinyin in text_pinyin or text_pinyin in variant_pinyin: + # 计算相似度 + similarity = min(len(variant_pinyin), len(text_pinyin)) / max(len(variant_pinyin), len(text_pinyin)) + if similarity >= self.similarity_threshold: + return True + + # 3. 字符级相似度匹配(简单实现) + if len(normalized_text) >= self.min_match_length and len(normalized_variant) >= self.min_match_length: + # 检查是否包含变体 + if normalized_variant in normalized_text or normalized_text in normalized_variant: + return True + + return False + + def _partial_match(self, text: str) -> bool: + """ + 部分匹配(只匹配部分唤醒词,如主词较长时取前半段;短词请在配置 variants 中列出) + + Args: + text: 待匹配文本 + + Returns: + 是否匹配 + """ + if not self.enable_partial: + return False + + normalized_text = self._normalize_text(text) + + # 检查是否包含主唤醒词的一部分 + if self.primary: + # 提取主唤醒词的前半部分(如四字词可拆成前两字) + primary_normalized = self._normalize_text(self.primary) + if len(primary_normalized) >= self.min_match_length * 2: + half_length = len(primary_normalized) // 2 + half_wake_word = primary_normalized[:half_length] + + if len(half_wake_word) >= self.min_match_length: + if half_wake_word in normalized_text: + return True + + return False + + @lru_cache(maxsize=256) + def detect(self, text: str) -> Tuple[bool, Optional[str]]: + """ + 检测文本中是否包含唤醒词 + + Args: + text: 待检测文本 + + Returns: + (是否匹配, 匹配到的唤醒词) + """ + if not text or not self.pattern: + return False, None + + normalized_text = self._normalize_text(text) + + # 1. 精确匹配(使用正则表达式) + if self.pattern: + match = self.pattern.search(normalized_text) + if match: + matched_text = match.group(0) + logger.debug(f"精确匹配到唤醒词: {matched_text}") + return True, matched_text + + # 2. 模糊匹配(同音字、拼音) + if self.enable_fuzzy: + for variant in self.normalized_variants: + if self._fuzzy_match(normalized_text, variant): + logger.debug(f"模糊匹配到唤醒词变体: {variant}") + return True, variant + + # 3. 部分匹配 + if self._partial_match(normalized_text): + logger.debug(f"部分匹配到唤醒词") + return True, self.primary[:len(self.primary)//2] if self.primary else None + + return False, None + + def extract_command_text(self, text: str) -> Optional[str]: + """ + 从文本中提取命令部分(移除唤醒词) + + Args: + text: 包含唤醒词的完整文本 + + Returns: + 提取的命令文本,如果未检测到唤醒词返回None + """ + is_wake, matched_wake_word = self.detect(text) + + if not is_wake: + return None + + # 标准化文本用于查找 + normalized_text = self._normalize_text(text) + normalized_wake = self._normalize_text(matched_wake_word) if matched_wake_word else "" + + if not normalized_wake or normalized_wake not in normalized_text: + return None + + # 找到唤醒词在标准化文本中的位置 + idx = normalized_text.find(normalized_wake) + if idx < 0: + return None + + # 方法1:尝试在原始文本中精确查找匹配的变体 + original_text = text + text_lower = original_text.lower() + + # 查找所有可能的变体在原始文本中的位置 + best_match_idx = -1 + best_match_length = 0 + + # 检查配置中的所有变体 + for variant in self.variants: + variant_normalized = self._normalize_text(variant) + if variant_normalized == normalized_wake: + # 这个变体匹配到了,尝试在原始文本中找到它 + variant_lower = variant.lower() + if variant_lower in text_lower: + variant_idx = text_lower.find(variant_lower) + if variant_idx >= 0: + # 选择最长的匹配(更准确) + if len(variant) > best_match_length: + best_match_idx = variant_idx + best_match_length = len(variant) + + # 如果找到了匹配的变体 + if best_match_idx >= 0: + command_start = best_match_idx + best_match_length + command_text = original_text[command_start:].strip() + # 移除开头的标点符号 + command_text = command_text.lstrip(',。、,.').strip() + return command_text if command_text else None + + # 方法2:回退方案 - 使用字符计数近似定位 + # 计算标准化文本中唤醒词结束位置对应的原始文本位置 + wake_end_in_normalized = idx + len(normalized_wake) + + # 计算原始文本中对应的字符位置 + char_count = 0 + for i, char in enumerate(original_text): + normalized_char = self._normalize_text(char) + if normalized_char: + if char_count >= wake_end_in_normalized: + command_text = original_text[i:].strip() + command_text = command_text.lstrip(',。、,.').strip() + return command_text if command_text else None + char_count += 1 + + return None + + +# 全局单例 +_global_detector: Optional[WakeWordDetector] = None + + +def get_wake_word_detector() -> WakeWordDetector: + """获取全局唤醒词检测器实例(单例模式)""" + global _global_detector + if _global_detector is None: + _global_detector = WakeWordDetector() + return _global_detector + + +if __name__ == "__main__": + # 测试代码 + detector = WakeWordDetector() + + test_cases = [ + ("无人机,现在起飞", True), + ("wu ren ji 现在起飞", True), + ("Wu Ren Ji 现在起飞", True), + ("五人机,现在起飞", True), + ("现在起飞", False), + ("无人,现在起飞", True), # 变体列表中的短说 + ("人机 前进", False), # 已移除单独「人机」变体,避免子串误唤醒 + ("无人机 前进", True), + ] + + print("=" * 60) + print("唤醒词检测测试") + print("=" * 60) + + for text, expected in test_cases: + is_wake, matched = detector.detect(text) + command_text = detector.extract_command_text(text) + status = "OK" if is_wake == expected else "FAIL" + print(f"{status} 文本: {text}") + print(f" 匹配: {is_wake} (期望: {expected})") + print(f" 匹配词: {matched}") + print(f" 提取命令: {command_text}") + print() diff --git a/voice_drone/core/任务执行完成,开始返航降落.wav b/voice_drone/core/任务执行完成,开始返航降落.wav new file mode 100644 index 0000000..1dd0d13 Binary files /dev/null and b/voice_drone/core/任务执行完成,开始返航降落.wav differ diff --git a/voice_drone/core/好的收到,开始起飞.wav b/voice_drone/core/好的收到,开始起飞.wav new file mode 100644 index 0000000..5e61583 Binary files /dev/null and b/voice_drone/core/好的收到,开始起飞.wav differ diff --git a/voice_drone/flight_bridge/__init__.py b/voice_drone/flight_bridge/__init__.py new file mode 100644 index 0000000..ad54987 --- /dev/null +++ b/voice_drone/flight_bridge/__init__.py @@ -0,0 +1 @@ +"""伴飞桥:将 flight_intent v1 译为 MAVROS/PX4 行为(ROS 1 Noetic 首版)。""" diff --git a/voice_drone/flight_bridge/ros1_mavros_executor.py b/voice_drone/flight_bridge/ros1_mavros_executor.py new file mode 100644 index 0000000..a61efe3 --- /dev/null +++ b/voice_drone/flight_bridge/ros1_mavros_executor.py @@ -0,0 +1,330 @@ +""" +ROS1 + MAVROS:按 ValidatedFlightIntent 顺序执行(Offboard 位姿 / AUTO.LAND / AUTO.RTL)。 + +需在已连接飞控的 MAVROS 环境中运行(与 src/px4_ctrl_offboard_demo.py 相同前提)。 +""" + +from __future__ import annotations + +import math +from typing import Optional, Tuple + +import rospy +from geometry_msgs.msg import PoseStamped +from mavros_msgs.msg import PositionTarget, State +from mavros_msgs.srv import CommandBool, SetMode + +from voice_drone.core.flight_intent import ( + ActionGoto, + ActionHold, + ActionHover, + ActionLand, + ActionReturnHome, + ActionTakeoff, + ActionWait, + FlightAction, + ValidatedFlightIntent, +) + + +def _yaw_from_quaternion(x: float, y: float, z: float, w: float) -> float: + siny_cosp = 2.0 * (w * z + x * y) + cosy_cosp = 1.0 - 2.0 * (y * y + z * z) + return math.atan2(siny_cosp, cosy_cosp) + + +class MavrosFlightExecutor: + """单次连接 MAVROS;对外提供 execute(intent)。""" + + def __init__(self) -> None: + rospy.init_node("flight_intent_mavros_bridge", anonymous=True) + + self.state = State() + self.pose = PoseStamped() + self.has_pose = False + + rospy.Subscriber("/mavros/state", State, self._state_cb, queue_size=10) + rospy.Subscriber( + "/mavros/local_position/pose", + PoseStamped, + self._pose_cb, + queue_size=10, + ) + self.sp_pub = rospy.Publisher( + "/mavros/setpoint_raw/local", + PositionTarget, + queue_size=20, + ) + + rospy.wait_for_service("/mavros/cmd/arming", timeout=60.0) + rospy.wait_for_service("/mavros/set_mode", timeout=60.0) + self._arming_name = "/mavros/cmd/arming" + self._set_mode_name = "/mavros/set_mode" + self.arm_srv = rospy.ServiceProxy(self._arming_name, CommandBool) + self.mode_srv = rospy.ServiceProxy(self._set_mode_name, SetMode) + + self.rate = rospy.Rate(20) + + self.default_takeoff_relative_m = rospy.get_param( + "~default_takeoff_relative_m", + 0.5, + ) + self.takeoff_timeout_sec = rospy.get_param("~takeoff_timeout_sec", 15.0) + self.goto_tol = rospy.get_param("~goto_position_tolerance", 0.15) + self.goto_timeout_sec = rospy.get_param("~goto_timeout_sec", 60.0) + self.land_timeout_sec = rospy.get_param("~land_timeout_sec", 45.0) + self.pre_stream_count = int(rospy.get_param("~offboard_pre_stream_count", 80)) + + def _state_cb(self, msg: State) -> None: + self.state = msg + + def _pose_cb(self, msg: PoseStamped) -> None: + self.pose = msg + self.has_pose = True + + def _wait_connected_and_pose(self) -> None: + rospy.loginfo("等待 FCU 与本地位置 …") + while not rospy.is_shutdown() and not self.state.connected: + self.rate.sleep() + while not rospy.is_shutdown() and not self.has_pose: + self.rate.sleep() + + @staticmethod + def _position_sp(x: float, y: float, z: float) -> PositionTarget: + sp = PositionTarget() + sp.header.stamp = rospy.Time.now() + sp.coordinate_frame = PositionTarget.FRAME_LOCAL_NED + sp.type_mask = ( + PositionTarget.IGNORE_VX + | PositionTarget.IGNORE_VY + | PositionTarget.IGNORE_VZ + | PositionTarget.IGNORE_AFX + | PositionTarget.IGNORE_AFY + | PositionTarget.IGNORE_AFZ + | PositionTarget.IGNORE_YAW + | PositionTarget.IGNORE_YAW_RATE + ) + sp.position.x = x + sp.position.y = y + sp.position.z = z + return sp + + def _publish_sp(self, sp: PositionTarget) -> None: + sp.header.stamp = rospy.Time.now() + self.sp_pub.publish(sp) + + def _stream_init_setpoint(self, x: float, y: float, z: float) -> None: + sp = self._position_sp(x, y, z) + for _ in range(self.pre_stream_count): + if rospy.is_shutdown(): + return + self._publish_sp(sp) + self.rate.sleep() + + def _refresh_mode_arm_proxies(self) -> None: + """MAVROS 重启或链路抖事后,旧 ServiceProxy 可能一直报 unavailable,需重建。""" + try: + rospy.wait_for_service(self._set_mode_name, timeout=2.0) + rospy.wait_for_service(self._arming_name, timeout=2.0) + except rospy.ROSException: + return + self.mode_srv = rospy.ServiceProxy(self._set_mode_name, SetMode) + self.arm_srv = rospy.ServiceProxy(self._arming_name, CommandBool) + + def _try_set_mode_arm_offboard(self) -> None: + try: + if self.state.mode != "OFFBOARD": + self.mode_srv(base_mode=0, custom_mode="OFFBOARD") + if not self.state.armed: + self.arm_srv(True) + except (rospy.ServiceException, rospy.ROSException) as exc: + rospy.logwarn_throttle(2.0, "set_mode/arm: %s", exc) + + def _current_xyz(self) -> Tuple[float, float, float]: + p = self.pose.pose.position + return (p.x, p.y, p.z) + + def _do_takeoff(self, step: ActionTakeoff) -> None: + alt = step.args.relative_altitude_m + dz = float(alt) if alt is not None else float(self.default_takeoff_relative_m) + self._wait_connected_and_pose() + x0, y0, z0 = self._current_xyz() + # 与 px4_ctrl_offboard_demo.py 一致:z_tgt = z0 + dz + z_tgt = z0 + dz + rospy.loginfo( + "takeoff: z0=%.2f Δ=%.2f m -> z_target=%.2f(与 demo 相同约定)", + z0, + dz, + z_tgt, + ) + + self._stream_init_setpoint(x0, y0, z0) + t0 = rospy.Time.now() + deadline = t0 + rospy.Duration(self.takeoff_timeout_sec) + while not rospy.is_shutdown() and rospy.Time.now() < deadline: + self._try_set_mode_arm_offboard() + sp = self._position_sp(x0, y0, z_tgt) + self._publish_sp(sp) + z_now = self.pose.pose.position.z + if ( + abs(z_now - z_tgt) < 0.08 + and self.state.mode == "OFFBOARD" + and self.state.armed + ): + rospy.loginfo("takeoff reached") + break + self.rate.sleep() + else: + rospy.logwarn("takeoff timeout,继续后续步骤") + + def _do_hold_position(self, _label: str = "hover") -> None: + x, y, z = self._current_xyz() + rospy.loginfo("%s: 保持当前点 (%.2f, %.2f, %.2f) NED", _label, x, y, z) + sp = self._position_sp(x, y, z) + t_end = rospy.Time.now() + rospy.Duration(1.0) + while not rospy.is_shutdown() and rospy.Time.now() < t_end: + self._try_set_mode_arm_offboard() + self._publish_sp(sp) + self.rate.sleep() + + def _do_wait(self, step: ActionWait) -> None: + sec = float(step.args.seconds) + rospy.loginfo("wait %.2f s", sec) + t_end = rospy.Time.now() + rospy.Duration(sec) + x, y, z = self._current_xyz() + sp = self._position_sp(x, y, z) + while not rospy.is_shutdown() and rospy.Time.now() < t_end: + # Offboard 下建议保持 stream + if self.state.mode == "OFFBOARD" and self.state.armed: + self._publish_sp(sp) + self.rate.sleep() + + def _ned_delta_from_goto(self, step: ActionGoto) -> Optional[Tuple[float, float, float]]: + a = step.args + dx = 0.0 if a.x is None else float(a.x) + dy = 0.0 if a.y is None else float(a.y) + dz = 0.0 if a.z is None else float(a.z) + if a.frame == "local_ned": + return (dx, dy, dz) + # body_ned: x前 y右 z下 → 转到 NED 水平增量 + q = self.pose.pose.orientation + yaw = _yaw_from_quaternion(q.x, q.y, q.z, q.w) + north = math.cos(yaw) * dx - math.sin(yaw) * dy + east = math.sin(yaw) * dx + math.cos(yaw) * dy + return (north, east, dz) + + def _do_goto(self, step: ActionGoto) -> None: + delta = self._ned_delta_from_goto(step) + if delta is None: + rospy.logwarn("goto: unsupported frame") + return + dn, de, dd = delta + if dn == 0.0 and de == 0.0 and dd == 0.0: + rospy.loginfo("goto: 零位移,跳过") + return + + x0, y0, z0 = self._current_xyz() + xt, yt, zt = x0 + dn, y0 + de, z0 + dd + rospy.loginfo( + "goto: (%.2f,%.2f,%.2f) -> (%.2f,%.2f,%.2f) NED", + x0, + y0, + z0, + xt, + yt, + zt, + ) + deadline = rospy.Time.now() + rospy.Duration(self.goto_timeout_sec) + while not rospy.is_shutdown() and rospy.Time.now() < deadline: + self._try_set_mode_arm_offboard() + sp = self._position_sp(xt, yt, zt) + self._publish_sp(sp) + px, py, pz = self._current_xyz() + err = math.sqrt( + (px - xt) ** 2 + (py - yt) ** 2 + (pz - zt) ** 2 + ) + if err < self.goto_tol and self.state.mode == "OFFBOARD": + rospy.loginfo("goto: reached (err=%.3f)", err) + break + self.rate.sleep() + else: + rospy.logwarn("goto: timeout") + + def _do_land(self) -> None: + rospy.loginfo("land: AUTO.LAND") + t0 = rospy.Time.now() + deadline = t0 + rospy.Duration(self.land_timeout_sec) + fails = 0 + while not rospy.is_shutdown() and rospy.Time.now() < deadline: + x, y, z = self._current_xyz() + # OFFBOARD 时若停发 setpoint,PX4 会很快退出该模式,MAVROS 侧 set_mode 可能长期不可用 + if self.state.armed and self.state.mode == "OFFBOARD": + self._publish_sp(self._position_sp(x, y, z)) + try: + if self.state.mode != "AUTO.LAND": + self.mode_srv(base_mode=0, custom_mode="AUTO.LAND") + except (rospy.ServiceException, rospy.ROSException) as exc: + fails += 1 + rospy.logwarn_throttle(2.0, "AUTO.LAND: %s", exc) + if fails == 1 or fails % 15 == 0: + self._refresh_mode_arm_proxies() + if not self.state.armed: + rospy.loginfo("land: disarmed") + return + self.rate.sleep() + rospy.logwarn("land: timeout") + + def _do_rtl(self) -> None: + rospy.loginfo("return_home: AUTO.RTL") + t0 = rospy.Time.now() + deadline = t0 + rospy.Duration(self.land_timeout_sec * 2) + fails = 0 + while not rospy.is_shutdown() and rospy.Time.now() < deadline: + x, y, z = self._current_xyz() + if self.state.armed and self.state.mode == "OFFBOARD": + self._publish_sp(self._position_sp(x, y, z)) + try: + if self.state.mode != "AUTO.RTL": + self.mode_srv(base_mode=0, custom_mode="AUTO.RTL") + except (rospy.ServiceException, rospy.ROSException) as exc: + fails += 1 + rospy.logwarn_throttle(2.0, "RTL: %s", exc) + if fails == 1 or fails % 15 == 0: + self._refresh_mode_arm_proxies() + if not self.state.armed: + rospy.loginfo("RTL Finished(已 disarm)") + return + self.rate.sleep() + rospy.logwarn("rtl: timeout") + + def execute(self, intent: ValidatedFlightIntent) -> None: + rospy.loginfo( + "执行 flight_intent:steps=%d summary=%s", + len(intent.actions), + intent.summary[:80], + ) + self._wait_connected_and_pose() + + for i, act in enumerate(intent.actions): + if rospy.is_shutdown(): + break + rospy.loginfo("--- step %d/%d: %s", i + 1, len(intent.actions), type(act).__name__) + self._dispatch(act) + + rospy.loginfo("flight_intent 序列结束") + + def _dispatch(self, act: FlightAction) -> None: + if isinstance(act, ActionTakeoff): + self._do_takeoff(act) + elif isinstance(act, (ActionHover, ActionHold)): + self._do_hold_position("hover" if isinstance(act, ActionHover) else "hold") + elif isinstance(act, ActionWait): + self._do_wait(act) + elif isinstance(act, ActionGoto): + self._do_goto(act) + elif isinstance(act, ActionLand): + self._do_land() + elif isinstance(act, ActionReturnHome): + self._do_rtl() + else: + rospy.logwarn("未支持的动作: %r", act) diff --git a/voice_drone/flight_bridge/ros1_node.py b/voice_drone/flight_bridge/ros1_node.py new file mode 100644 index 0000000..28e9ff2 --- /dev/null +++ b/voice_drone/flight_bridge/ros1_node.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +""" +ROS1 伴飞桥节点:订阅 JSON(std_msgs/String),校验 flight_intent v1 后调用 MAVROS 执行。 + +话题(默认): + 全局 /input(std_msgs/String),与 rostopic pub、语音端 ROCKET_FLIGHT_BRIDGE_TOPIC 一致。 + 可通过私有参数 ~input_topic 覆盖(须带前导 / 才是全局名)。 + +示例: + rostopic pub -1 /input std_msgs/String \\ + '{data: "{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"降\"}"}' + +前提是:已 roslaunch mavros px4.launch …,且 /mavros/state connected。 +""" + +from __future__ import annotations + +import json +import threading + +import rospy +from std_msgs.msg import String + +from voice_drone.core.flight_intent import parse_flight_intent_dict +from voice_drone.flight_bridge.ros1_mavros_executor import MavrosFlightExecutor + + +def _coerce_flight_intent_dict(raw: dict) -> dict: + """允许仅传 {actions, summary?},补全顶层字段。""" + if raw.get("is_flight_intent") is True and raw.get("version") == 1: + return raw + actions = raw.get("actions") + if isinstance(actions, list) and actions: + summary = str(raw.get("summary") or "bridge").strip() or "bridge" + return { + "is_flight_intent": True, + "version": 1, + "actions": actions, + "summary": summary, + } + raise ValueError("JSON 须为完整 flight_intent 或含 actions 数组") + + +class FlightIntentBridgeNode: + def __init__(self) -> None: + self._exec = MavrosFlightExecutor() + self._busy = threading.Lock() + # 默认用绝对名 /input:若用相对名 "input",anonymous 节点下会变成 /flight_intent_mavros_bridge_*/input,与 rostopic pub /input 不一致。 + topic = rospy.get_param("~input_topic", "/input") + self._sub = rospy.Subscriber( + topic, + String, + self._on_input, + queue_size=1, + ) + rospy.loginfo("flight_intent_bridge 就绪:订阅 %s", topic) + + def _on_input(self, msg: String) -> None: + data = (msg.data or "").strip() + if not data: + return + if not self._busy.acquire(blocking=False): + rospy.logwarn("上一段 flight_intent 仍在执行,忽略本条") + return + + def _run() -> None: + try: + try: + raw = json.loads(data) + except json.JSONDecodeError as e: + rospy.logerr("JSON 解析失败: %s", e) + return + if not isinstance(raw, dict): + rospy.logerr("顶层须为 JSON object") + return + raw = _coerce_flight_intent_dict(raw) + parsed, errors = parse_flight_intent_dict(raw) + if errors or parsed is None: + rospy.logerr("flight_intent 校验失败: %s", errors) + return + self._exec.execute(parsed) + finally: + self._busy.release() + + threading.Thread(target=_run, daemon=True, name="flight-intent-exec").start() + + +def main() -> None: + FlightIntentBridgeNode() + rospy.spin() + + +if __name__ == "__main__": + main() diff --git a/voice_drone/logging_/__init__.py b/voice_drone/logging_/__init__.py new file mode 100644 index 0000000..95dc916 --- /dev/null +++ b/voice_drone/logging_/__init__.py @@ -0,0 +1,14 @@ +""" +日志系统入口 + +提供 get_logger 接口,返回带颜色控制台输出的 logger。 + +注意:文件夹名已改为 logging_ 以避免与标准库的 logging 模块冲突。 +""" + +# 由于文件夹名已改为 logging_,不再与标准库的 logging 冲突 +# 直接导入标准库的 logging 模块 +import logging + +# 现在可以安全导入 color_logger +from .color_logger import get_logger diff --git a/voice_drone/logging_/color_logger.py b/voice_drone/logging_/color_logger.py new file mode 100644 index 0000000..26b6c8f --- /dev/null +++ b/voice_drone/logging_/color_logger.py @@ -0,0 +1,107 @@ +# 直接导入标准库的 logging(不再有命名冲突) +import logging + +from typing import Optional + +try: + # 读取系统日志配置: level / debug 等 + from voice_drone.core.configuration import SYSTEM_LOGGING_CONFIG +except Exception: + SYSTEM_LOGGING_CONFIG = {"level": "INFO", "debug": False} + + +class ColorFormatter(logging.Formatter): + """简单彩色日志格式化器.""" + + COLORS = { + "DEBUG": "\033[36m", # 青色 + "INFO": "\033[32m", # 绿色 + "WARNING": "\033[33m", # 黄色 + "ERROR": "\033[31m", # 红色 + "CRITICAL": "\033[41m", # 红底 + } + RESET = "\033[0m" + + def format(self, record: logging.LogRecord) -> str: + level = record.levelname + color = self.COLORS.get(level, "") + msg = super().format(record) + return f"{color}{msg}{self.RESET}" if color else msg + + +def _resolve_level(config_level: Optional[str], debug_flag: bool) -> int: + """根据配置字符串和 debug 标志,解析出 logging 等级.""" + if config_level is None: + config_level = "INFO" + level_str = str(config_level).upper() + + # 特殊值: 关闭日志 + if level_str in ("OFF", "NONE", "DISABLE", "DISABLED"): + return logging.CRITICAL + 1 # 实际等同于全关 + + if debug_flag: + return logging.DEBUG + + mapping = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "WARN": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + return mapping.get(level_str, logging.INFO) + + +def get_logger(name: str = "app") -> logging.Logger: + """ + 获取一个带颜色控制台输出的 logger。 + + - name: logger 名称,按模块/子系统区分即可,例如 "stt.onnx" + 日志级别与启用/禁用由配置文件 `system.yaml` 决定: + + ```yaml + logging: + level: "INFO" # DEBUG/INFO/WARNING/ERROR/CRITICAL/OFF + debug: false # true 时强制 DEBUG + ``` + + 使用示例: + + ```python + from voice_drone.logging_ import get_logger + + logger = get_logger("stt.onnx") + logger.info("模型加载完成") + logger.warning("预热失败,将继续执行") + logger.error("推理出错") + ``` + """ + logger = logging.getLogger(name) + if logger.handlers: + # 已经配置过,直接复用,但更新 level + level = _resolve_level( + SYSTEM_LOGGING_CONFIG.get("level") if isinstance(SYSTEM_LOGGING_CONFIG, dict) else None, + SYSTEM_LOGGING_CONFIG.get("debug", False) if isinstance(SYSTEM_LOGGING_CONFIG, dict) else False, + ) + logger.setLevel(level) + return logger + + # 解析配置的日志级别 + if isinstance(SYSTEM_LOGGING_CONFIG, dict): + level = _resolve_level(SYSTEM_LOGGING_CONFIG.get("level"), SYSTEM_LOGGING_CONFIG.get("debug", False)) + else: + level = logging.INFO + + logger.setLevel(level) + + handler = logging.StreamHandler() + fmt = "[%(asctime)s] [%(levelname)s] %(message)s" + datefmt = "%H:%M:%S" + handler.setFormatter(ColorFormatter(fmt=fmt, datefmt=datefmt)) + logger.addHandler(handler) + logger.propagate = False + + # 如果 level 被解析为 "关闭",则仍然返回 logger,但不会输出普通日志 + return logger + diff --git a/voice_drone/main_app.py b/voice_drone/main_app.py new file mode 100644 index 0000000..9ced2ca --- /dev/null +++ b/voice_drone/main_app.py @@ -0,0 +1,2271 @@ +# 实时检测语音:用「无人机」唤醒 → TTS「你好,我在呢」→ 收音一句指令(关麦)→ 大模型 Kokoro 播报答句 → 再仅听唤醒词。 +# 可选:assistant.local_keyword_takeoff_enabled 或 ROCKET_LOCAL_KEYWORD_TAKEOFF=1 时,「无人机 + keywords.yaml 里 takeoff 词」走本地 offboard + WAV(默认关闭)。 +# 其它指令走云端/本地 LLM → flight_intent 等(设 ROCKET_CLOUD_EXECUTE_FLIGHT=1 才执行机端序列)。 +# 环境变量:ROCKET_LLM_GGUF、ROCKET_LLM_MAX_TOKENS(默认 256)、ROCKET_LLM_CTX(默认 4096,可试 2048 省显存/略提速)、 +# ROCKET_LLM_N_THREADS(llama.cpp 线程数,如 RK3588 可试 6~8)、ROCKET_LLM_N_GPU_LAYERS(有 CUDA/Vulkan 时>0)、ROCKET_LLM_N_BATCH、 +# ROCKET_TTS_ORT_INTRA_OP_THREADS / ROCKET_TTS_ORT_INTER_OP_THREADS(Kokoro ONNXRuntime 线程), +# ROCKET_CHAT_IDLE_SEC(历史占位,每轮重置上下文)、ROCKET_TTS_DEVICE(同 qwen15b_chat --tts-device)、 +# ROCKET_INPUT_HW=2,0 对应 arecord -l 的 card,device;ROCKET_INPUT_DEVICE_INDEX、ROCKET_INPUT_DEVICE_NAME; +# 录音:默认交互列出 arecord -l + PyAudio 并选择;--input-index / ROCKET_INPUT_DEVICE_INDEX 跳过交互;--non-interactive 用 yaml 的 input_device_index(可为 null 自动探测)。 +# ROCKET_LLM_DISABLE=1 关闭对话。 +# ROCKET_LLM_STREAM=0 关闭流式输出(整段推理后再单次 TTS,便于对照调试)。 +# ROCKET_STREAM_TTS_CHUNK_CHARS 流式闲聊时、无句末标点则按此长度强制切段(默认 64,过小会听感碎)。 +# 云端语音(见 voice_drone_assistant/clientguide.md):ROCKET_CLOUD_VOICE=1 或 cloud_voice.enabled; +# ROCKET_CLOUD_WS_URL、ROCKET_CLOUD_AUTH_TOKEN、ROCKET_CLOUD_DEVICE_ID;ROCKET_CLOUD_FALLBACK_LOCAL=0 禁用本地回退。 +# 云端会话固定 pcm_asr_uplink(VAD 截句→turn.audio.*→Fun-ASR);同句快路径仍可用 turn.text。 +# 闲聊「无语音」超时:listen_silence_timeout_sec(默认 5):滴声后仅当 RMS Path: + raw = os.environ.get("ROCKET_WAKE_GREETING_WAV", "").strip() + return Path(raw).expanduser() if raw else _WAKE_GREETING_WAV + + +_CORE_DIR = _PROJECT_ROOT / "voice_drone" / "core" +_TAKEOFF_ACK_WAV = _CORE_DIR / "好的收到,开始起飞.wav" +_TAKEOFF_DONE_WAV = _CORE_DIR / "任务执行完成,开始返航降落.wav" +_OFFBOARD_SCRIPT = _PROJECT_ROOT / "scripts" / "run_px4_offboard_one_terminal.sh" + + +def _play_wav_blocking(path: Path) -> None: + """与 src/play_wav.py 相同:16-bit PCM 单文件 blocking 播放。""" + import pyaudio + + with wave.open(str(path), "rb") as wf: + ch = wf.getnchannels() + sw = wf.getsampwidth() + sr = wf.getframerate() + nframes = wf.getnframes() + if sw != 2: + raise ValueError(f"仅支持 16-bit PCM: {path}") + pcm = wf.readframes(nframes) + + p = pyaudio.PyAudio() + try: + fmt = p.get_format_from_width(sw) + chunk = 1024 + stream = p.open( + format=fmt, + channels=ch, + rate=sr, + output=True, + frames_per_buffer=chunk, + ) + stream.start_stream() + try: + step = chunk * sw * ch + for i in range(0, len(pcm), step): + stream.write(pcm[i : i + step]) + finally: + stream.stop_stream() + stream.close() + finally: + p.terminate() + + +def _synthesize_ready_beep( + sample_rate: int = 24000, + *, + duration_sec: float = 0.11, + frequency_hz: float = 988.0, + amplitude: float = 0.22, +) -> np.ndarray: + """正弦短鸣 + 淡入淡出,作唤醒后「可以说话」提示。""" + n = max(8, int(sample_rate * duration_sec)) + x = np.arange(n, dtype=np.float32) + w = np.sin(2.0 * np.pi * frequency_hz * x / float(sample_rate)).astype(np.float32) + fade = max(2, min(n // 3, int(0.006 * sample_rate))) + ramp = np.linspace(0.0, 1.0, fade, dtype=np.float32) + w[:fade] *= ramp + w[-fade:] *= ramp[::-1] + return np.clip(w * np.float32(amplitude), -1.0, 1.0) + + +def _terminate_process_group(proc: subprocess.Popen) -> None: + if proc.poll() is not None: + return + try: + os.killpg(proc.pid, signal.SIGTERM) + except ProcessLookupError: + return + except Exception as e: # noqa: BLE001 + logger.warning("SIGTERM offboard 进程组失败: %s", e) + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + try: + os.killpg(proc.pid, signal.SIGKILL) + except Exception as e: # noqa: BLE001 + logger.warning("SIGKILL offboard 进程组失败: %s", e) + + +class _WakeFlowPhase(enum.IntEnum): + IDLE = 0 + GREETING_WAIT = 1 + ONE_SHOT_LISTEN = 2 + LLM_BUSY = 3 + FLIGHT_CONFIRM_LISTEN = 4 + + +class TakeoffPrintRecognizer(VoiceCommandRecognizer): + """待机(IDLE)仅识别含唤醒词的句子;唤醒后多轮对话在 ONE_SHOT_LISTEN 等阶段不要求句内唤醒词。 + 云端会话为 pcm_asr_uplink:滴声后整句 PCM 上云 Fun-ASR;结束一轮回到 IDLE 再要唤醒词。""" + + def __init__(self, *, skip_model_preload: bool = False) -> None: + super().__init__(auto_connect_socket=False) + self.ack_tts_enabled = False + self._audio_play_lock = threading.Lock() + self._offboard_proc_lock = threading.Lock() + self._active_offboard_proc: subprocess.Popen | None = None + self._takeoff_side_task_busy = threading.Lock() + self._model_warm_lock = threading.Lock() + + # 流式闲聊会按句/块多次入队,队列过小易丢段 + self._llm_playback_queue: queue.Queue[str] = queue.Queue(maxsize=64) + self._chat_session_lock = threading.Lock() + self._chat_session_until: float = 0.0 + self._llm_messages: list = [] + self._llm = None + self._llm_tts_engine = None + self._llm_model_path = Path( + os.environ.get( + "ROCKET_LLM_GGUF", + str(default_qwen_gguf_path(_PROJECT_ROOT)), + ) + ) + self._chat_idle_sec = float(os.environ.get("ROCKET_CHAT_IDLE_SEC", "120")) + self._llm_max_tokens = int(os.environ.get("ROCKET_LLM_MAX_TOKENS", "256")) + self._llm_ctx = int(os.environ.get("ROCKET_LLM_CTX", "4096")) + self._llm_tts_max_chars = int(os.environ.get("ROCKET_LLM_TTS_MAX_CHARS", "800")) + self._llm_stream_enabled = os.environ.get( + "ROCKET_LLM_STREAM", "1" + ).lower() not in ("0", "false", "no") + self._stream_tts_chunk_chars = max( + 16, + int(os.environ.get("ROCKET_STREAM_TTS_CHUNK_CHARS", "64")), + ) + self._llm_disabled = os.environ.get("ROCKET_LLM_DISABLE", "").lower() in ( + "1", + "true", + "yes", + ) + _kw_raw = os.environ.get("ROCKET_LOCAL_KEYWORD_TAKEOFF", "").strip() + if _kw_raw: + self._local_keyword_takeoff_enabled = _kw_raw.lower() in ( + "1", + "true", + "yes", + ) + else: + _ac = ( + SYSTEM_ASSISTANT_CONFIG + if isinstance(SYSTEM_ASSISTANT_CONFIG, dict) + else {} + ) + self._local_keyword_takeoff_enabled = bool( + _ac.get("local_keyword_takeoff_enabled", False) + ) + self._skip_model_preload = skip_model_preload or os.environ.get( + "ROCKET_SKIP_MODEL_PRELOAD", "" + ).lower() in ("1", "true", "yes") + + cv = SYSTEM_CLOUD_VOICE_CONFIG if isinstance(SYSTEM_CLOUD_VOICE_CONFIG, dict) else {} + env_cloud = os.environ.get("ROCKET_CLOUD_VOICE", "").lower() in ( + "1", + "true", + "yes", + ) + self._cloud_voice_enabled = bool(env_cloud or cv.get("enabled")) + self._cloud_fallback_local = os.environ.get( + "ROCKET_CLOUD_FALLBACK_LOCAL", "" + ).lower() not in ("0", "false", "no") and bool( + cv.get("fallback_to_local", True) + ) + # 唤醒词仅在 IDLE 由命令线程强制;ONE_SHOT_LISTEN 整句直接上行或处理,不要求句内唤醒词。 + try: + self._listen_silence_timeout_sec = max( + 0.5, + float( + os.environ.get("ROCKET_PROMPT_LISTEN_TIMEOUT_SEC") + or cv.get("listen_silence_timeout_sec") + or 5.0 + ), + ) + except ValueError: + self._listen_silence_timeout_sec = 5.0 + try: + self._post_cue_mic_mute_ms = float( + os.environ.get("ROCKET_POST_CUE_MIC_MUTE_MS") + or cv.get("post_cue_mic_mute_ms") + or 200.0 + ) + except ValueError: + self._post_cue_mic_mute_ms = 200.0 + self._post_cue_mic_mute_ms = max(0.0, min(2000.0, self._post_cue_mic_mute_ms)) + try: + self._segment_cue_duration_ms = float( + os.environ.get("ROCKET_SEGMENT_CUE_DURATION_MS") + or cv.get("segment_cue_duration_ms") + or 120.0 + ) + except ValueError: + self._segment_cue_duration_ms = 120.0 + self._segment_cue_duration_ms = max(20.0, min(500.0, self._segment_cue_duration_ms)) + ws_url = (os.environ.get("ROCKET_CLOUD_WS_URL") or cv.get("server_url") or "").strip() + auth_tok = ( + os.environ.get("ROCKET_CLOUD_AUTH_TOKEN") or cv.get("auth_token") or "" + ).strip() + dev_id = ( + os.environ.get("ROCKET_CLOUD_DEVICE_ID") or cv.get("device_id") or "drone-001" + ).strip() + self._cloud_client = None + self._cloud_remote_tts_for_local = False + if self._cloud_voice_enabled: + if ws_url and auth_tok: + from voice_drone.core.cloud_voice_client import CloudVoiceClient + + self._cloud_client = CloudVoiceClient( + server_url=ws_url, + auth_token=auth_tok, + device_id=dev_id, + recv_timeout=float(cv.get("timeout") or 120), + session_client_extensions=dict(SYSTEM_CLOUD_VOICE_PX4_CONTEXT) + if SYSTEM_CLOUD_VOICE_PX4_CONTEXT + else None, + ) + _env_rt = os.environ.get("ROCKET_CLOUD_REMOTE_TTS", "").strip().lower() + if _env_rt in ("0", "false", "no"): + self._cloud_remote_tts_for_local = False + elif _env_rt in ("1", "true", "yes"): + self._cloud_remote_tts_for_local = True + else: + self._cloud_remote_tts_for_local = bool( + cv.get("remote_tts_for_local", True) + ) + print( + f"[云端] 已启用 WebSocket 对话: {ws_url} device_id={dev_id}", + flush=True, + ) + if self._cloud_remote_tts_for_local: + print( + "[云端] 本地文案播报将走 tts.synthesize(失败回退 Kokoro)。", + flush=True, + ) + print( + f"[云端] Fun-ASR 上行 turn.audio.*;仅待机时说唤醒词;" + f"滴声后累计静默 {self._listen_silence_timeout_sec:.1f}s(低于 yaml energy_vad_rms_low 才计);" + f"断句提示 {self._segment_cue_duration_ms:.0f}ms、消抖 {self._post_cue_mic_mute_ms:.0f}ms。", + flush=True, + ) + else: + logger.warning("cloud_voice 已启用但缺少 server_url/auth_token,将使用本地 LLM") + self._cloud_voice_enabled = False + + self._wake_flow_lock = threading.Lock() + self._wake_phase: int = int(_WakeFlowPhase.IDLE) + self._greeting_done = threading.Event() + self._playback_batch_is_greeting = False + self._pending_finish_wake_cycle_after_tts = False + self._pending_flight_confirm_after_tts = False + self._pending_flight_confirm: dict | None = None + self._flight_confirm_timer: threading.Timer | None = None + self._flight_confirm_timer_lock = threading.Lock() + self._staged_one_shot_after_greeting: str | None = None + self._mic_op_queue: queue.Queue[str] = queue.Queue(maxsize=8) + + # 默认仅 1 段在 STT 队列等待;可设 ROCKET_STT_QUEUE_MAX=2~8 允许少量排队 + _raw_sq = os.environ.get("ROCKET_STT_QUEUE_MAX", "1").strip() + try: + _stn = max(1, min(16, int(_raw_sq))) + except ValueError: + _stn = 1 + self.stt_queue = queue.Queue(maxsize=_stn) + + # PROMPT_LISTEN:v1 §4 为「RMS 低于阈值持续累计」,不是滴声后固定墙上时钟 5s + self._prompt_listen_watch_armed: bool = False + self._prompt_silence_accum_sec: float = 0.0 + self._segment_cue_done = threading.Event() + self._pending_chitchat_reprompt_after_tts = False + if self._cloud_client is not None: + self._vad_speech_start_hook = self._on_vad_speech_start_prompt_listen + self._after_processed_audio_chunk = self._tick_prompt_listen_silence_accum + + def _cancel_prompt_listen_timer(self) -> None: + """停止「滴声后静默监听」累计(飞控/结束唤醒/起 PCM 上行前等)。""" + self._prompt_listen_watch_armed = False + self._prompt_silence_accum_sec = 0.0 + + def _arm_prompt_listen_timeout(self) -> None: + """滴声后进 PROMPT_LISTEN:仅在麦克持续低于 energy_vad_rms_low 时累加,超时再播 MSG。""" + if self._cloud_client is None: + return + with self._wake_flow_lock: + if self._wake_phase != int(_WakeFlowPhase.ONE_SHOT_LISTEN): + return + self._prompt_silence_accum_sec = 0.0 + self._prompt_listen_watch_armed = True + logger.debug( + "PROMPT_LISTEN: 已启用 RMS 累计静默 %.1fs(低于 rms_low 才计时;说话或 rms≥low 清零)", + self._listen_silence_timeout_sec, + ) + + def _on_prompt_listen_timeout(self) -> None: + with self._wake_flow_lock: + if self._wake_phase != int(_WakeFlowPhase.ONE_SHOT_LISTEN): + return + self._prompt_listen_watch_armed = False + self._prompt_silence_accum_sec = 0.0 + logger.info( + "[会话] 滴声后持续静默 ≥%.1fs(未截句),播超时提示并回待机", + self._listen_silence_timeout_sec, + ) + self._enqueue_llm_speak(MSG_PROMPT_LISTEN_TIMEOUT) + self._pending_finish_wake_cycle_after_tts = True + + def _tick_prompt_listen_silence_accum(self, processed_chunk: np.ndarray) -> None: + if not self._prompt_listen_watch_armed or self._cloud_client is None: + return + with self._wake_flow_lock: + if self._wake_phase != int(_WakeFlowPhase.ONE_SHOT_LISTEN): + return + rms = self._int16_chunk_rms(processed_chunk) + dt = float(len(processed_chunk)) / float(self.audio_capture.sample_rate) + speaking = ( + self._ev_speaking + if self._use_energy_vad + else self.vad.is_speaking + ) + if speaking or rms >= self._energy_rms_low: + self._prompt_silence_accum_sec = 0.0 + return + self._prompt_silence_accum_sec += dt + if self._prompt_silence_accum_sec >= self._listen_silence_timeout_sec: + try: + self._on_prompt_listen_timeout() + except Exception as e: # noqa: BLE001 + logger.error("PROMPT_LISTEN 静默超时处理异常: %s", e, exc_info=True) + + def _on_vad_speech_start_prompt_listen(self) -> None: + """VAD 判「开始说话」时清零静默累计(v1 §4,与 RMS≥rms_low 并行)。""" + if self._cloud_client is None: + return + with self._wake_flow_lock: + if self._wake_phase != int(_WakeFlowPhase.ONE_SHOT_LISTEN): + return + self._prompt_silence_accum_sec = 0.0 + + def _submit_concatenated_speech_to_stt(self) -> None: + """在唤醒/一问一答流程中节流 VAD:避免问候或云端推理时继续向 STT 积压整句。""" + allow_greeting_stt = os.environ.get( + "ROCKET_VAD_STT_DURING_GREETING", "" + ).lower() in ("1", "true", "yes") + with self._wake_flow_lock: + phase = self._wake_phase + if phase == int(_WakeFlowPhase.GREETING_WAIT) and not allow_greeting_stt: + with self.speech_buffer_lock: + self.speech_buffer.clear() + if os.environ.get("ROCKET_PRINT_VAD", "").lower() in ( + "1", + "true", + "yes", + ): + print( + "[VAD] 问候播放中,本段不送 STT(说完问候后再说指令;" + "若需在问候同时识别请设 ROCKET_VAD_STT_DURING_GREETING=1", + flush=True, + ) + return + if phase == int(_WakeFlowPhase.LLM_BUSY): + with self.speech_buffer_lock: + self.speech_buffer.clear() + if os.environ.get("ROCKET_PRINT_VAD", "").lower() in ( + "1", + "true", + "yes", + ): + print( + "[VAD] 大模型/云端处理中,本段不送 STT(请等本轮播报结束后再说)", + flush=True, + ) + return + if ( + self._cloud_client is not None + and phase == int(_WakeFlowPhase.ONE_SHOT_LISTEN) + ): + if len(self.speech_buffer) == 0: + return + speech_audio = np.concatenate(self.speech_buffer) + self.speech_buffer.clear() + min_samples = int(self.audio_capture.sample_rate * 0.5) + if len(speech_audio) >= min_samples: + try: + self.command_queue.put( + ( + _PCM_TURN_MARKER, + speech_audio.copy(), + int(self.audio_capture.sample_rate), + ), + block=False, + ) + if os.environ.get("ROCKET_PRINT_VAD", "").lower() in ( + "1", + "true", + "yes", + ): + print( + f"[VAD] turn.audio 已排队,{len(speech_audio)} 采样点" + f"(≈{len(speech_audio) / float(self.audio_capture.sample_rate):.2f}s)", + flush=True, + ) + except queue.Full: + logger.warning("命令队列已满,跳过 PCM 上行") + elif os.environ.get("ROCKET_PRINT_VAD", "").lower() in ( + "1", + "true", + "yes", + ): + print( + f"[VAD] 语音段太短已丢弃({len(speech_audio)} < {min_samples} 采样)", + flush=True, + ) + return + super()._submit_concatenated_speech_to_stt() + + def _llm_tts_output_device(self) -> str | int | None: + raw = os.environ.get("ROCKET_TTS_DEVICE", "").strip() + if raw.isdigit(): + return int(raw) + if raw: + return raw + return None + + def _before_audio_iteration(self) -> None: + self._drain_mic_ops() + super()._before_audio_iteration() + self._drain_llm_playback_queue() + + def _drain_mic_ops(self) -> None: + """主线程:执行命令线程请求的麦克风流 stop/start。""" + while True: + try: + op = self._mic_op_queue.get_nowait() + except queue.Empty: + break + try: + if op == "stop": + if self.audio_capture.stream is not None: + self.audio_capture.stop_stream() + elif op == "start" and self.running: + if self.audio_capture.stream is None: + self.audio_capture.start_stream() + self.vad.reset() + with self.speech_buffer_lock: + self.speech_buffer.clear() + self.pre_speech_buffer.clear() + except Exception as e: # noqa: BLE001 + logger.warning("麦克风流控制失败 (%r): %s", op, e) + + def _finish_wake_cycle(self) -> None: + self._cancel_prompt_listen_timer() + self._cancel_flight_confirm_timer() + with self._flight_confirm_timer_lock: + self._pending_flight_confirm = None + self._pending_flight_confirm_after_tts = False + self._pending_finish_wake_cycle_after_tts = False + with self._wake_flow_lock: + self._wake_phase = int(_WakeFlowPhase.IDLE) + self._reset_llm_history() + print("[唤醒] 本轮结束。请说「无人机」再次唤醒。", flush=True) + + def _reset_llm_history(self) -> None: + with self._chat_session_lock: + self._llm_messages.clear() + self._chat_session_until = 0.0 + + def _flush_llm_playback_queue_silent(self) -> None: + """丢弃 LLM 播报队列(无日志);新一轮唤醒前清空,避免与问候语或上一轮残段叠播。""" + while True: + try: + self._llm_playback_queue.get_nowait() + except queue.Empty: + break + + def _prepare_wake_session_resources(self) -> None: + """新一轮唤醒:清空对话状态、播报队列与待 STT 段(问候/快路径共用)。""" + self._reset_llm_history() + self._flush_llm_playback_queue_silent() + self.discard_pending_stt_segments() + + def _recover_from_cloud_failure( + self, + user_msg: str, + *, + finish_wake_after_tts: bool, + idle_speak: str, + ) -> None: + """云端 run_turn 失败后:按需回退本地 LLM 或播一句占位。""" + if self._cloud_fallback_local: + print("[云端] 回退本地 LLM…", flush=True) + self._handle_llm_turn_local(user_msg, finish_wake_after_tts=finish_wake_after_tts) + return + self._enqueue_llm_speak(idle_speak) + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + + def _begin_wake_cycle(self, staged_followup: str | None) -> None: + """命中唤醒后:排队问候语,并在主线程播完后由 _after_greeting_pipeline 继续。""" + with self._wake_flow_lock: + if self._wake_phase != int(_WakeFlowPhase.IDLE): + logger.info( + "唤醒忽略:当前非 IDLE(phase=%s),不重复排队问候", + _WakeFlowPhase(self._wake_phase).name, + ) + return + self._wake_phase = int(_WakeFlowPhase.GREETING_WAIT) + self._prepare_wake_session_resources() + s = (staged_followup or "").strip() + self._staged_one_shot_after_greeting = s if s else None + self._greeting_done.clear() + self._playback_batch_is_greeting = True + self._enqueue_wake_word_ack_beep() + self._enqueue_llm_speak(_WAKE_GREETING) + threading.Thread( + target=self._after_greeting_pipeline, + daemon=True, + name="wake-after-greeting", + ).start() + + def _wake_fast_path_process_follow(self, follow: str) -> bool: + """同一句已含唤醒词+指令时:跳过问候与滴声,清队列后直接 _process_one_shot_command。""" + follow = (follow or "").strip() + if not follow: + return False + with self._wake_flow_lock: + if self._wake_phase != int(_WakeFlowPhase.IDLE): + logger.info( + "唤醒连带指令忽略:当前非 IDLE(phase=%s)", + _WakeFlowPhase(self._wake_phase).name, + ) + return False + self._wake_phase = int(_WakeFlowPhase.LLM_BUSY) + self._prepare_wake_session_resources() + self._staged_one_shot_after_greeting = None + self._enqueue_wake_word_ack_beep() + logger.info("唤醒含指令,跳过问候与提示音,直接处理: %s", follow[:120]) + self._process_one_shot_command(follow) + return True + + def _after_greeting_pipeline(self) -> None: + if not self._greeting_done.wait(timeout=120): + logger.error("问候语播放超时,回到 IDLE") + self._finish_wake_cycle() + return + self._greeting_done.clear() + staged: str | None = None + with self._wake_flow_lock: + staged = self._staged_one_shot_after_greeting + self._staged_one_shot_after_greeting = None + if staged is not None: + with self._wake_flow_lock: + self._wake_phase = int(_WakeFlowPhase.LLM_BUSY) + self._process_one_shot_command(staged) + else: + with self._wake_flow_lock: + self._wake_phase = int(_WakeFlowPhase.ONE_SHOT_LISTEN) + print("[唤醒] 请说您的指令(一句)。", flush=True) + self._arm_prompt_listen_timeout() + + def _process_one_shot_command(self, raw: str) -> None: + """已关麦或准备关麦:处理一句指令(起飞 / LLM),结束后再切回 IDLE。""" + user_msg = (raw or "").strip() + if not user_msg: + self._finish_wake_cycle() + return + iw, _ = self.wake_word_detector.detect(user_msg) + if iw: + user_msg = ( + self.wake_word_detector.extract_command_text(user_msg) or user_msg + ).strip() + if not user_msg: + self._finish_wake_cycle() + return + print(f"[指令] {user_msg}", flush=True) + try: + self._mic_op_queue.put_nowait("stop") + except queue.Full: + pass + time.sleep(0.12) + + _, params = self.text_preprocessor.preprocess_fast(user_msg) + if ( + self._local_keyword_takeoff_enabled + and params.command_keyword == "takeoff" + ): + threading.Thread( + target=self._run_takeoff_offboard_and_wavs, + daemon=True, + ).start() + self._finish_wake_cycle() + try: + self._mic_op_queue.put_nowait("start") + except queue.Full: + pass + return + + if self._llm_disabled and not self._cloud_voice_enabled: + print("[LLM] 已禁用(ROCKET_LLM_DISABLE)。", flush=True) + self._finish_wake_cycle() + try: + self._mic_op_queue.put_nowait("start") + except queue.Full: + pass + return + + self._handle_llm_turn( + user_msg, finish_wake_after_tts=(self._cloud_client is None) + ) + + @staticmethod + def _flight_payload_requests_takeoff(payload: dict) -> bool: + for a in payload.get("actions") or []: + if isinstance(a, dict) and a.get("type") == "takeoff": + return True + return False + + def _enqueue_llm_speak(self, line: str) -> None: + t = (line or "").strip() + if not t: + return + try: + self._llm_playback_queue.put(t, block=False) + except queue.Full: + logger.warning("LLM 播报队列已满,跳过: %s…", t[:40]) + + def _ensure_llm(self): + if self._llm is not None: + return self._llm + with self._model_warm_lock: + if self._llm is not None: + return self._llm + if not self._llm_model_path.is_file(): + logger.error("未找到 GGUF: %s", self._llm_model_path) + return None + logger.info("正在加载 LLM: %s", self._llm_model_path) + print("[LLM] 正在加载 Qwen(GGUF)…", flush=True) + self._llm = load_llama_qwen(self._llm_model_path, n_ctx=self._llm_ctx) + if self._llm is None: + logger.error("llama-cpp-python 未安装或加载失败") + else: + print("[LLM] Qwen 已载入。", flush=True) + return self._llm + + def _ensure_llm_tts(self): + if self._llm_tts_engine is not None: + return self._llm_tts_engine + with self._model_warm_lock: + if self._llm_tts_engine is not None: + return self._llm_tts_engine + from voice_drone.core.tts import KokoroOnnxTTS + + print("[LLM] 正在加载 Kokoro TTS(ONNX)…", flush=True) + self._llm_tts_engine = KokoroOnnxTTS() + print("[LLM] Kokoro 已载入。", flush=True) + return self._llm_tts_engine + + def _preload_llm_and_tts_if_enabled(self) -> None: + """启动后预加载,避免首轮对话/播报长时间卡顿。""" + if self._cloud_voice_enabled: + print( + "[云端] 跳过本地 Qwen 预加载;对话 TTS 以云端 PCM 为主。", + flush=True, + ) + try: + p = _resolve_wake_greeting_wav() + if not p.is_file(): + if ( + not self._llm_disabled + and not self._cloud_remote_tts_for_local + ): + self._ensure_wake_greeting_wav_on_disk() + except Exception as e: # noqa: BLE001 + logger.debug("云端模式下预热问候 WAV 跳过: %s", e) + if self._cloud_remote_tts_for_local: + print( + "[云端] 本地字符串播报由 tts.synthesize 提供,跳过 Kokoro 预加载" + "(失败时会临场加载 Kokoro)。", + flush=True, + ) + return + # 飞控确认超时/取消、云端 fallback 等仍走本地 Kokoro;启动时加载一次, + # 避免超时播报时现场冷启动模型(数秒卡顿)。 + if self._skip_model_preload: + print( + "[云端] 已跳过 Kokoro 预加载(--no-preload / ROCKET_SKIP_MODEL_PRELOAD);" + "首次本地提示时再加载。", + flush=True, + ) + else: + t0 = time.monotonic() + try: + print( + "[LLM] 云端模式:预加载 Kokoro(确认超时/取消等本地语音)…", + flush=True, + ) + self._ensure_llm_tts() + except Exception as e: # noqa: BLE001 + logger.warning( + "云端模式 Kokoro 预加载失败(将在首次本地播报时重试): %s", + e, + exc_info=True, + ) + print(f"[LLM] Kokoro 预加载失败: {e}", flush=True) + else: + dt = time.monotonic() - t0 + print(f"[LLM] Kokoro 预加载完成(约 {dt:.1f}s)。", flush=True) + return + + if self._llm_disabled or self._skip_model_preload: + if self._skip_model_preload and not self._llm_disabled: + print( + "[LLM] 已跳过预加载(--no-preload 或 ROCKET_SKIP_MODEL_PRELOAD),将在首次使用时加载。", + flush=True, + ) + return + if not self._llm_model_path.is_file(): + print( + f"[LLM] 未找到 GGUF,跳过预加载: {self._llm_model_path}", + flush=True, + ) + return + print( + "[LLM] 预加载 Qwen + Kokoro(数十秒属正常,完成后的首轮对话会快很多)…", + flush=True, + ) + t0 = time.monotonic() + try: + if self._ensure_llm() is None: + return + self._ensure_llm_tts() + self._ensure_wake_greeting_wav_on_disk() + except Exception as e: # noqa: BLE001 + logger.warning("预加载模型失败(将在首次使用时重试): %s", e, exc_info=True) + print(f"[LLM] 预加载失败: {e}", flush=True) + return + dt = time.monotonic() - t0 + print(f"[LLM] 预加载完成(耗时约 {dt:.1f}s)。", flush=True) + + def _ensure_wake_greeting_wav_on_disk(self) -> Path: + """若尚无问候 WAV,则用 Kokoro 合成一次并写入;之后只走 play_wav_path。""" + p = _resolve_wake_greeting_wav() + if p.is_file(): + return p + try: + p.parent.mkdir(parents=True, exist_ok=True) + except OSError as e: + logger.warning("无法创建问候缓存目录 %s: %s", p.parent, e) + return p + try: + tts = self._ensure_llm_tts() + tts.synthesize_to_file(_WAKE_GREETING, str(p)) + logger.info("已自动生成唤醒问候缓存(此后只播此文件): %s", p) + print(f"[TTS] 已写入问候缓存,下次起不再合成: {p}", flush=True) + except Exception as e: # noqa: BLE001 + logger.warning( + "自动生成问候 WAV 失败(需 scipy 写盘;将本次仍用实时合成): %s", + e, + exc_info=True, + ) + return p + + def _play_wake_ready_beep(self, output_device: object | None) -> None: + """问候语播完后短鸣一声,提示用户再开口下指令。""" + from voice_drone.core.tts import play_tts_audio + + if os.environ.get("ROCKET_WAKE_PROMPT_BEEP", "1").lower() in ( + "0", + "false", + "no", + ): + return + sr = 24000 + try: + dur = float(os.environ.get("ROCKET_WAKE_BEEP_SEC", "0.11")) + except ValueError: + dur = 0.11 + dur = max(0.04, min(0.25, dur)) + try: + hz = float(os.environ.get("ROCKET_WAKE_BEEP_HZ", "988")) + except ValueError: + hz = 988.0 + try: + amp = float(os.environ.get("ROCKET_WAKE_BEEP_GAIN", "0.22")) + except ValueError: + amp = 0.22 + amp = max(0.05, min(0.45, amp)) + audio = _synthesize_ready_beep( + sr, duration_sec=dur, frequency_hz=hz, amplitude=amp + ) + try: + play_tts_audio(audio, sr, output_device=output_device) + print("[唤醒] 提示音已播,请说指令。", flush=True) + except Exception as e: # noqa: BLE001 + logger.debug("唤醒提示音播放跳过: %s", e) + + def _enqueue_wake_word_ack_beep(self) -> None: + """唤醒词命中后立即排队一声短鸣,主线程播报(与云 TTS 同队列,不阻塞命令线程)。""" + if os.environ.get("ROCKET_WAKE_ACK_BEEP", "1").lower() in ( + "0", + "false", + "no", + ): + return + try: + self._llm_playback_queue.put_nowait(_WAKE_HIT_BEEP_TAG) + except queue.Full: + logger.warning("播报队列已满,跳过唤醒确认短音") + + def _play_wake_word_hit_beep(self, output_device: object | None) -> None: + """刚识别到唤醒词时的一声「滴」,默认略短于问候后的滴声。""" + from voice_drone.core.tts import play_tts_audio + + if os.environ.get("ROCKET_WAKE_ACK_BEEP", "1").lower() in ( + "0", + "false", + "no", + ): + return + sr = 24000 + try: + raw = os.environ.get("ROCKET_WAKE_ACK_BEEP_SEC", "").strip() + if raw: + dur = float(raw) + else: + dur = float(os.environ.get("ROCKET_WAKE_BEEP_SEC", "0.11")) * 0.72 + except ValueError: + dur = 0.08 + dur = max(0.04, min(0.25, dur)) + try: + raw_h = os.environ.get("ROCKET_WAKE_ACK_BEEP_HZ", "").strip() + hz = float(raw_h) if raw_h else float(os.environ.get("ROCKET_WAKE_BEEP_HZ", "988")) + except ValueError: + hz = 1100.0 + try: + raw_g = os.environ.get("ROCKET_WAKE_ACK_BEEP_GAIN", "").strip() + amp = float(raw_g) if raw_g else float(os.environ.get("ROCKET_WAKE_BEEP_GAIN", "0.22")) + except ValueError: + amp = 0.22 + amp = max(0.05, min(0.45, amp)) + audio = _synthesize_ready_beep( + sr, duration_sec=dur, frequency_hz=hz, amplitude=amp + ) + try: + play_tts_audio(audio, sr, output_device=output_device) + except Exception as e: # noqa: BLE001 + logger.debug("唤醒确认短音播放失败: %s", e) + return + print("[唤醒] 确认短音已播。", flush=True) + + def _try_play_line_via_cloud_tts(self, s: str, dev: object | None) -> bool: + """docs/API.md §3.3 tts.synthesize:成功播放返回 True,否则 False(调用方回退 Kokoro)。""" + if not self._cloud_remote_tts_for_local or self._cloud_client is None: + return False + txt = (s or "").strip() + if not txt: + return False + from voice_drone.core.cloud_voice_client import CloudVoiceError + from voice_drone.core.tts import play_tts_audio + + t0 = time.monotonic() + try: + out = self._cloud_client.run_tts_synthesize(txt) + except CloudVoiceError as e: + logger.warning("云端 tts.synthesize 失败: %s", e) + return False + except Exception as e: # noqa: BLE001 + logger.warning("云端 tts.synthesize 异常: %s", e, exc_info=True) + return False + pcm = out.get("pcm") + try: + sr = int(out.get("sample_rate_hz") or 24000) + except (TypeError, ValueError): + sr = 24000 + if pcm is None or np.asarray(pcm).size == 0: + logger.warning("云端 tts.synthesize 返回空 PCM") + return False + pcm_i16 = np.asarray(pcm, dtype=np.int16).reshape(-1) + logger.info( + "云端 tts.synthesize: samples=%s int16_max_abs=%s elapsed=%.3fs", + pcm_i16.size, + int(np.max(np.abs(pcm_i16))), + time.monotonic() - t0, + ) + audio_f32 = pcm_i16.astype(np.float32) / 32768.0 + try: + play_tts_audio(audio_f32, sr, output_device=dev) + except Exception as e: # noqa: BLE001 + logger.warning("播放云端 tts.synthesize 结果失败: %s", e, exc_info=True) + return False + return True + + def _play_segment_end_cue(self, dev: object | None) -> None: + """断句后极短提示(§5);不计入闲聊再滴声。""" + from voice_drone.core.tts import play_tts_audio + + sr = 24000 + dur = self._segment_cue_duration_ms / 1000.0 + dur = max(0.02, min(0.5, dur)) + audio = _synthesize_ready_beep( + sr, + duration_sec=dur, + frequency_hz=1420.0, + amplitude=0.18, + ) + try: + play_tts_audio(audio, sr, output_device=dev) + except Exception as e: # noqa: BLE001 + logger.debug("断句提示音: %s", e) + + def _play_chitchat_reprompt_beep(self, dev: object | None) -> None: + """闲聊 TTS 播完后再滴一声,进入下一轮 PROMPT_LISTEN。""" + self._play_wake_word_hit_beep(dev) + + def _handle_pcm_uplink_turn(self, pcm: np.ndarray, sample_rate_hz: int) -> None: + """SEGMENT_END:断句提示 + 消抖 → turn.audio 上行一轮。""" + with self._wake_flow_lock: + if self._wake_phase != int(_WakeFlowPhase.ONE_SHOT_LISTEN): + logger.debug("PCM 上行忽略:当前非 PROMPT_LISTEN") + return + self._cancel_prompt_listen_timer() + try: + self._mic_op_queue.put_nowait("stop") + except queue.Full: + pass + self._segment_cue_done.clear() + try: + self._llm_playback_queue.put_nowait(_SEGMENT_END_CUE_TAG) + except queue.Full: + logger.error("播报队列满,无法播断句提示") + try: + self._mic_op_queue.put_nowait("start") + except queue.Full: + pass + return + if not self._segment_cue_done.wait(timeout=15.0): + logger.error("断句提示音同步超时") + try: + self._mic_op_queue.put_nowait("start") + except queue.Full: + pass + return + time.sleep(self._post_cue_mic_mute_ms / 1000.0) + with self._wake_flow_lock: + self._wake_phase = int(_WakeFlowPhase.LLM_BUSY) + self._handle_llm_turn_cloud_pcm( + pcm, sample_rate_hz, finish_wake_after_tts=False + ) + + def _drain_llm_playback_queue(self, recover_mic: bool = True) -> None: + from voice_drone.core.tts import play_tts_audio, play_wav_path + + lines: list[str] = [] + while True: + try: + lines.append(self._llm_playback_queue.get_nowait()) + except queue.Empty: + break + if not lines: + # 流式分段 TTS 时:最后一次 drain 可能在 _finalize_llm_turn 设置 + # _pending_finish_wake_cycle_after_tts 之前就把队列播空;此处补上结束本轮唤醒。 + # 注意:飞控确认窗须在「播完含本轮云端 TTS 的一批队列」之后在 finally 里进入, + # 不可在此处用 _pending_flight_confirm_after_tts,否则主线程可能在 PCM 入队前 + # 空跑 drain,抢先 begin_confirm 并清掉标志,命令线程末尾又会设 _pending_finish_wake_cycle。 + if self._pending_finish_wake_cycle_after_tts: + self._pending_finish_wake_cycle_after_tts = False + self._finish_wake_cycle() + return + greeting_batch = self._playback_batch_is_greeting + self._playback_batch_is_greeting = False + mic_stopped = False + if self.ack_pause_mic_for_playback: + # 关麦前再丢一次队列:唤醒到 drain 之间 VAD 可能又提交了片段 + self.discard_pending_stt_segments() + try: + self.audio_capture.stop_stream() + mic_stopped = True + except Exception as e: # noqa: BLE001 + logger.warning("暂停麦克风失败: %s", e) + try: + tts = None + dev = self._llm_tts_output_device() + for line in lines: + if line == _WAKE_HIT_BEEP_TAG: + self._play_wake_word_hit_beep(dev) + continue + if line == _SEGMENT_END_CUE_TAG: + self._play_segment_end_cue(dev) + self._segment_cue_done.set() + continue + if line == _CHITCHAT_REPROMPT_BEEP_TAG: + self._play_chitchat_reprompt_beep(dev) + self._arm_prompt_listen_timeout() + continue + if ( + isinstance(line, tuple) + and len(line) == 3 + and line[0] == _CLOUD_PCM_TAG + ): + _, pcm_i16, sr_cloud = line + try: + pcm_i16 = np.asarray(pcm_i16, dtype=np.int16).reshape(-1) + if pcm_i16.size == 0: + continue + dbg_max = int(np.max(np.abs(pcm_i16))) + logger.info( + "云端 PCM 解码: samples=%s int16_max_abs=%s (若 max_abs=0 则为全零或" + "协议/端序与云端不一致;请在服务端导出同段 WAV 对比)", + pcm_i16.size, + dbg_max, + ) + audio_f32 = pcm_i16.astype(np.float32) / 32768.0 + t_play0 = time.monotonic() + play_tts_audio( + audio_f32, int(sr_cloud), output_device=dev + ) + print( + f"[计时] 云端 TTS 播放 {time.monotonic() - t_play0:.3f}s " + f"({pcm_i16.size / int(sr_cloud):.2f}s 音频)", + flush=True, + ) + print("[LLM] 已播报。", flush=True) + except Exception as e: # noqa: BLE001 + logger.warning("云端 PCM 播放失败: %s", e, exc_info=True) + continue + + s = (line or "").strip() + if not s: + continue + try: + if s == _WAKE_GREETING: + t_w0 = time.monotonic() + cloud_ok = self._try_play_line_via_cloud_tts(s, dev) + if not cloud_ok: + greet_wav = self._ensure_wake_greeting_wav_on_disk() + if greet_wav.is_file(): + play_wav_path(greet_wav, output_device=dev) + print( + f"[计时] TTS 预生成问候 WAV 播完,耗时 " + f"{time.monotonic() - t_w0:.3f}s", + flush=True, + ) + else: + if tts is None: + tts = self._ensure_llm_tts() + logger.info("TTS: 开始合成并播放: %r", s) + t_syn0 = time.monotonic() + audio, sr = tts.synthesize(s) + t_syn1 = time.monotonic() + play_tts_audio(audio, sr, output_device=dev) + t_play1 = time.monotonic() + print( + f"[计时] TTS 合成 {t_syn1 - t_syn0:.3f}s," + f"播放 {t_play1 - t_syn1:.3f}s" + f"(本段合计 {t_play1 - t_syn0:.3f}s)", + flush=True, + ) + logger.info("TTS: 播放完成") + else: + print( + f"[计时] 云端 tts.synthesize 问候,耗时 " + f"{time.monotonic() - t_w0:.3f}s", + flush=True, + ) + if greeting_batch: + self._play_wake_ready_beep(dev) + else: + t_line0 = time.monotonic() + cloud_ok = self._try_play_line_via_cloud_tts(s, dev) + if not cloud_ok: + if tts is None: + tts = self._ensure_llm_tts() + logger.info("TTS: 开始合成并播放: %r", s) + t_syn0 = time.monotonic() + audio, sr = tts.synthesize(s) + t_syn1 = time.monotonic() + play_tts_audio(audio, sr, output_device=dev) + t_play1 = time.monotonic() + print( + f"[计时] TTS 合成 {t_syn1 - t_syn0:.3f}s," + f"播放 {t_play1 - t_syn1:.3f}s" + f"(本段合计 {t_play1 - t_syn0:.3f}s)", + flush=True, + ) + logger.info("TTS: 播放完成") + else: + print( + f"[计时] 云端 tts.synthesize 本段合计 " + f"{time.monotonic() - t_line0:.3f}s", + flush=True, + ) + print("[LLM] 已播报。", flush=True) + except Exception as e: # noqa: BLE001 + logger.warning("LLM 播报失败: %s", e, exc_info=True) + finally: + if mic_stopped and recover_mic: + try: + self.audio_capture.start_stream() + try: + settle_ms = float( + os.environ.get("ROCKET_MIC_RESTART_SETTLE_MS", "150") + ) + except ValueError: + settle_ms = 150.0 + settle_ms = max(0.0, min(2000.0, settle_ms)) + if settle_ms > 0: + time.sleep(settle_ms / 1000.0) + try: + self.audio_preprocessor.reset() + except Exception as e: # noqa: BLE001 + logger.debug("audio_preprocessor.reset: %s", e) + self.vad.reset() + with self.speech_buffer_lock: + self.speech_buffer.clear() + self.pre_speech_buffer.clear() + except Exception as e: # noqa: BLE001 + logger.error("麦克风恢复失败: %s", e) + if greeting_batch: + self._greeting_done.set() + if self._pending_flight_confirm_after_tts: + self._pending_flight_confirm_after_tts = False + self._begin_flight_confirm_listen() + elif self._pending_chitchat_reprompt_after_tts: + self._pending_chitchat_reprompt_after_tts = False + with self._wake_flow_lock: + self._wake_phase = int(_WakeFlowPhase.ONE_SHOT_LISTEN) + try: + self._llm_playback_queue.put_nowait(_CHITCHAT_REPROMPT_BEEP_TAG) + except queue.Full: + logger.warning("播报队列已满,跳过闲聊再滴声") + elif self._pending_finish_wake_cycle_after_tts: + self._pending_finish_wake_cycle_after_tts = False + self._finish_wake_cycle() + + def _discard_llm_playback_queue(self) -> None: + """退出时丢弃未播完的大模型 TTS,避免 stop() 里 speak_text/sounddevice 长时间阻塞导致 Ctrl+C 无法结束进程。""" + dropped = 0 + while True: + try: + self._llm_playback_queue.get_nowait() + dropped += 1 + except queue.Empty: + break + if dropped: + logger.info("退出:已丢弃 %s 条待播 LLM 语音", dropped) + + @staticmethod + def _chunk_delta_text(chunk: object) -> str: + if not isinstance(chunk, dict): + return "" + choices = chunk.get("choices") or [] + if not choices: + return "" + c0 = choices[0] + d = c0.get("delta") if isinstance(c0, dict) else None + if not isinstance(d, dict): + d = c0.get("message") if isinstance(c0, dict) else None + if not isinstance(d, dict): + return "" + raw = d.get("content") + return raw if isinstance(raw, str) else "" + + def _enqueue_segment_capped(self, seg: str, budget: int) -> int: + seg = (seg or "").strip() + if not seg or budget <= 0: + return budget + if len(seg) <= budget: + self._enqueue_llm_speak(seg) + return budget - len(seg) + self._enqueue_llm_speak(seg[: max(0, budget - 1)] + "…") + return 0 + + def _finalize_llm_turn( + self, + reply: str, + finish_wake_after_tts: bool, + *, + streamed_chat: bool, + ) -> None: + if not reply: + self._enqueue_llm_speak("我没听清,请再说一遍。") + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + return + mode, payload = parse_flight_intent_reply(reply) + with self._chat_session_lock: + self._llm_messages.append({"role": "assistant", "content": reply}) + + print(f"[LLM] 判定={mode}", flush=True) + print(f"[LLM] 原文: {reply[:500]}{'…' if len(reply) > 500 else ''}", flush=True) + + if streamed_chat: + if payload is not None and self._flight_payload_requests_takeoff(payload): + threading.Thread( + target=self._run_takeoff_offboard_and_wavs, + daemon=True, + ).start() + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + return + + if payload is not None: + to_say = str(payload.get("summary") or "好的。").strip() + if self._flight_payload_requests_takeoff(payload): + threading.Thread( + target=self._run_takeoff_offboard_and_wavs, + daemon=True, + ).start() + else: + to_say = reply.strip() + + if len(to_say) > self._llm_tts_max_chars: + to_say = to_say[: self._llm_tts_max_chars] + "…" + self._enqueue_llm_speak(to_say) + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + + def _enqueue_cloud_pcm_playback( + self, pcm_int16: np.ndarray, sample_rate_hz: int + ) -> None: + if pcm_int16 is None or np.asarray(pcm_int16).size == 0: + return + try: + self._llm_playback_queue.put( + (_CLOUD_PCM_TAG, np.asarray(pcm_int16, dtype=np.int16), int(sample_rate_hz)), + block=False, + ) + except queue.Full: + logger.warning("LLM 播报队列已满,跳过云端 PCM") + + def _send_socket_command(self, cmd: Command) -> bool: + cmd.fill_defaults() + if self.socket_client.send_command_with_retry(cmd): + logger.info("✅ Socket 已发送: %s", cmd.command) + return True + logger.warning("Socket 未送达(已达 max_retries): %s", cmd.command) + return False + + def _publish_flight_intent_to_ros_bridge(self, flight: dict) -> None: + """校验 flight_intent 后由子进程发布到 ROS std_msgs/String(伴飞桥 ~input)。""" + _parsed, errors = parse_flight_intent_dict(flight) + if errors or _parsed is None: + logger.warning("[飞控-ROS桥] flight_intent 校验失败,未发布: %s", errors) + return + setup = os.environ.get( + "ROCKET_FLIGHT_BRIDGE_SETUP", "source /opt/ros/noetic/setup.bash" + ).strip() + topic = os.environ.get("ROCKET_FLIGHT_BRIDGE_TOPIC", "/input").strip() or "/input" + wait_raw = os.environ.get("ROCKET_FLIGHT_BRIDGE_WAIT_SUB", "2").strip() + try: + wait_sub = float(wait_raw) + except ValueError: + wait_sub = 2.0 + + root = str(_PROJECT_ROOT) + body = json.dumps(flight, ensure_ascii=False) + fd, tmp_path = tempfile.mkstemp(prefix="flight_intent_", suffix=".json", text=True) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.write(body) + except OSError: + try: + os.close(fd) + except OSError: + pass + try: + os.unlink(tmp_path) + except OSError: + pass + logger.warning("[飞控-ROS桥] 无法写入临时 JSON") + return + + # 须追加 PYTHONPATH:若写成 PYTHONPATH=仅工程根,会覆盖 ROS setup 注入的 /opt/ros/.../dist-packages,导致找不到 rospy。 + cmd = ( + f"{setup} && cd {shlex.quote(root)} && " + f"export PYTHONPATH={shlex.quote(root)}:$PYTHONPATH && " + "python3 -m voice_drone.tools.publish_flight_intent_ros_once " + f"--topic {shlex.quote(topic)} --wait-subscribers {wait_sub} " + f"{shlex.quote(tmp_path)}" + ) + try: + r = subprocess.run( + ["bash", "-lc", cmd], + capture_output=True, + text=True, + timeout=60, + ) + except subprocess.TimeoutExpired: + logger.warning("[飞控-ROS桥] 子进程超时(>60s)") + return + except OSError as e: + logger.warning("[飞控-ROS桥] 无法启动 bash: %s", e) + return + finally: + try: + os.unlink(tmp_path) + except OSError: + pass + + if r.returncode != 0: + logger.warning( + "[飞控-ROS桥] 发布失败 code=%s stderr=%s", + r.returncode, + (r.stderr or "").strip()[:800], + ) + else: + logger.info("[飞控-ROS桥] 已发布至 %s", topic) + + def _run_cloud_flight_intent_sequence(self, flight: dict) -> None: + """ + 在后台线程中顺序执行云端 flight_intent(校验 v1 + takeoff 走 offboard + 其余 Socket)。 + 含 takeoff 时:先跑完 offboard 流程,再继续 hover/wait/land 等(修复此前仅触发起飞、后续动作丢失)。 + """ + parsed, errors = parse_flight_intent_dict(flight) + if errors: + logger.warning("[飞控] flight_intent 校验失败: %s", errors) + return + tid = (parsed.trace_id or "").strip() or "-" + logger.info("[飞控] 开始执行序列 trace_id=%s steps=%d", tid, len(parsed.actions)) + + for step, action in enumerate(parsed.actions): + if isinstance(action, ActionTakeoff): + alt = action.args.relative_altitude_m + if alt is not None: + logger.info( + "[飞控] takeoff 请求相对高度 %.2fm(当前 offboard 脚本是否使用该参数请自行扩展)", + alt, + ) + self._run_takeoff_offboard_and_wavs() + elif isinstance(action, ActionLand): + cmd = Command.create("land", self._get_next_sequence_id()) + self._send_socket_command(cmd) + elif isinstance(action, ActionReturnHome): + cmd = Command.create("return_home", self._get_next_sequence_id()) + self._send_socket_command(cmd) + elif isinstance(action, (ActionHover, ActionHold)): + cmd = Command.create("hover", self._get_next_sequence_id()) + self._send_socket_command(cmd) + elif isinstance(action, ActionGoto): + cmd, err = goto_action_to_command(action, self._get_next_sequence_id()) + if err: + logger.warning("[飞控] step %d goto: %s", step, err) + continue + if cmd is not None: + self._send_socket_command(cmd) + elif isinstance(action, ActionWait): + sec = float(action.args.seconds) + logger.info("[飞控] step %d wait %.2fs", step, sec) + time.sleep(sec) + else: + logger.warning("[飞控] step %d 未处理的动作类型: %r", step, action) + + def _cancel_flight_confirm_timer(self) -> None: + with self._flight_confirm_timer_lock: + t = self._flight_confirm_timer + self._flight_confirm_timer = None + if t is not None: + try: + t.cancel() + except Exception: # noqa: BLE001 + pass + + def _begin_flight_confirm_listen(self) -> None: + """云端 TTS 播完后进入口头确认窗(cloud_voice_dialog_v1)。""" + self._cancel_prompt_listen_timer() + with self._flight_confirm_timer_lock: + if self._pending_flight_confirm is None: + logger.warning("[飞控] 无待确认意图,跳过确认窗") + self._finish_wake_cycle() + return + cd = self._pending_flight_confirm["confirm"] + timeout_sec = float(cd["timeout_sec"]) + phrases_repr = (cd["confirm_phrases"], cd["cancel_phrases"]) + self._cancel_flight_confirm_timer() + with self._wake_flow_lock: + self._wake_phase = int(_WakeFlowPhase.FLIGHT_CONFIRM_LISTEN) + print( + f"[飞控] 请口头确认 {phrases_repr[0]!r} 或取消 {phrases_repr[1]!r}," + f"超时 {timeout_sec:.0f}s。", + flush=True, + ) + + def _fire() -> None: + try: + self._on_flight_confirm_timeout() + except Exception as e: # noqa: BLE001 + logger.error("确认窗超时处理异常: %s", e, exc_info=True) + + with self._flight_confirm_timer_lock: + self._flight_confirm_timer = threading.Timer(timeout_sec, _fire) + self._flight_confirm_timer.daemon = True + self._flight_confirm_timer.start() + + def _on_flight_confirm_timeout(self) -> None: + with self._flight_confirm_timer_lock: + if self._pending_flight_confirm is None: + return + self._pending_flight_confirm = None + self._flight_confirm_timer = None + logger.info("[飞控] 确认窗超时") + self._enqueue_llm_speak(MSG_CONFIRM_TIMEOUT) + self._pending_finish_wake_cycle_after_tts = True + + def _handle_flight_confirm_text(self, raw: str) -> None: + utter = (raw or "").strip() + if not utter: + return + norm = normalize_phrase_text(utter) + print(f"[飞控-确认窗] {utter!r}", flush=True) + + action: str = "noop" + fi_ok: dict | None = None + t: threading.Timer | None = None + with self._flight_confirm_timer_lock: + pend = self._pending_flight_confirm + if pend is None: + return + cd = pend["confirm"] + cancel_hit = match_phrase_list(norm, cd["cancel_phrases"]) + confirm_hit = match_phrase_list(norm, cd["confirm_phrases"]) + if cancel_hit: + action = "cancel" + self._pending_flight_confirm = None + t = self._flight_confirm_timer + self._flight_confirm_timer = None + elif confirm_hit: + action = "confirm" + fi_ok = pend["flight"] + self._pending_flight_confirm = None + t = self._flight_confirm_timer + self._flight_confirm_timer = None + else: + logger.info("[飞控] 确认窗未命中短语,忽略: %s", utter[:80]) + return + + if t is not None: + try: + t.cancel() + except Exception: # noqa: BLE001 + pass + + if action == "cancel": + logger.info("[飞控] 用户取消待执行意图") + self._enqueue_llm_speak(MSG_CANCELLED) + self._pending_finish_wake_cycle_after_tts = True + return + + if action == "confirm" and fi_ok is not None: + logger.info("[飞控] 用户已确认,开始执行 flight_intent") + self._start_cloud_flight_execution(fi_ok) + self._enqueue_llm_speak(MSG_CONFIRM_EXECUTING) + self._pending_finish_wake_cycle_after_tts = True + + def _start_cloud_flight_execution(self, fi: dict) -> None: + """ROCKET_CLOUD_EXECUTE_FLIGHT 已通过校验后,起线程执行。""" + if os.environ.get("ROCKET_CLOUD_EXECUTE_FLIGHT", "").lower() not in ( + "1", + "true", + "yes", + ): + return + if os.environ.get("ROCKET_FLIGHT_INTENT_ROS_BRIDGE", "").lower() in ( + "1", + "true", + "yes", + ): + threading.Thread( + target=self._publish_flight_intent_to_ros_bridge, + args=(fi,), + daemon=True, + ).start() + else: + threading.Thread( + target=self._run_cloud_flight_intent_sequence, + args=(fi,), + daemon=True, + ).start() + + def _handle_llm_turn( + self, user_msg: str, *, finish_wake_after_tts: bool = False + ) -> None: + if self._cloud_voice_enabled and self._cloud_client is not None: + self._handle_llm_turn_cloud(user_msg, finish_wake_after_tts=finish_wake_after_tts) + return + self._handle_llm_turn_local(user_msg, finish_wake_after_tts=finish_wake_after_tts) + + def _apply_cloud_dialog_result( + self, + result: dict, + *, + finish_wake_after_tts: bool, + ) -> None: + proto = result.get("protocol") + routing = result.get("routing") + fi = result.get("flight_intent") + confirm_raw = result.get("confirm") + scheduled_flight_confirm = False + + if routing == "flight_intent" and isinstance(fi, dict) and fi.get("is_flight_intent"): + summary = str(fi.get("summary") or "好的。").strip() + actions = fi.get("actions") or [] + print(f"[LLM] 判定=飞控意图(云端) summary={summary!r}", flush=True) + print(f"[LLM] actions={actions!r}", flush=True) + if proto != CLOUD_VOICE_DIALOG_V1: + logger.error( + "[云端] flight_intent 须 protocol=%r,收到 %r;按 v1 拒执行飞控", + CLOUD_VOICE_DIALOG_V1, + proto, + ) + cd = parse_confirm_dict(confirm_raw) + if cd is None: + logger.error("[云端] flight_intent 须带合法 confirm 对象(v1),拒执行飞控") + exec_enabled = os.environ.get("ROCKET_CLOUD_EXECUTE_FLIGHT", "").lower() in ( + "1", + "true", + "yes", + ) + if ( + exec_enabled + and proto == CLOUD_VOICE_DIALOG_V1 + and cd is not None + ): + if cd["required"]: + scheduled_flight_confirm = True + with self._flight_confirm_timer_lock: + self._pending_flight_confirm = {"flight": fi, "confirm": cd} + self._pending_flight_confirm_after_tts = True + logger.info( + "[云端] flight_intent 待口头确认(pending_id=%s);" + "播完 TTS 后听确认/超时", + cd.get("pending_id"), + ) + else: + logger.info( + "[云端] flight_intent confirm.required=false,将直接执行(若已开执行开关)" + ) + self._start_cloud_flight_execution(fi) + elif exec_enabled and ( + proto != CLOUD_VOICE_DIALOG_V1 or cd is None + ): + logger.warning( + "[云端] 协议或 confirm 不完整,本轮不执行飞控(仍播 TTS)" + ) + else: + logger.info( + "[云端] flight_intent 已下发(未设 ROCKET_CLOUD_EXECUTE_FLIGHT,仅播报)" + ) + elif routing == "chitchat": + if proto != CLOUD_VOICE_DIALOG_V1: + logger.warning( + "[云端] chitchat 期望 protocol=%r,实际=%r", + CLOUD_VOICE_DIALOG_V1, + proto, + ) + cr = (result.get("chat_reply") or "").strip() + print(f"[LLM] 判定=闲聊(云端) reply={cr[:200]!r}", flush=True) + else: + logger.warning("未知 routing: %s", routing) + + pcm = result.get("pcm") + sr = int(result.get("sample_rate_hz") or 24000) + if pcm is not None and np.asarray(pcm).size > 0: + self._enqueue_cloud_pcm_playback(np.asarray(pcm, dtype=np.int16), sr) + elif self._cloud_fallback_local: + if routing == "flight_intent" and isinstance(fi, dict): + fallback_txt = str(fi.get("summary") or "好的。").strip() + else: + fallback_txt = (result.get("chat_reply") or "好的。").strip() + if fallback_txt: + self._enqueue_llm_speak(fallback_txt) + else: + self._enqueue_llm_speak("未收到云端语音。") + + if routing == "chitchat": + self._pending_chitchat_reprompt_after_tts = True + elif scheduled_flight_confirm: + pass + elif finish_wake_after_tts and not scheduled_flight_confirm: + self._pending_finish_wake_cycle_after_tts = True + elif routing == "flight_intent" and not scheduled_flight_confirm: + self._pending_finish_wake_cycle_after_tts = True + elif routing not in ("chitchat", "flight_intent"): + self._pending_finish_wake_cycle_after_tts = True + + def _handle_llm_turn_cloud( + self, user_msg: str, *, finish_wake_after_tts: bool = False + ) -> None: + from voice_drone.core.cloud_voice_client import CloudVoiceError + + assert self._cloud_client is not None + t0 = time.monotonic() + try: + result = self._cloud_client.run_turn(user_msg) + except CloudVoiceError as e: + print(f"[云端] 失败: {e} (code={e.code!r})", flush=True) + logger.error("云端对话失败: %s", e, exc_info=True) + self._recover_from_cloud_failure( + user_msg, + finish_wake_after_tts=finish_wake_after_tts, + idle_speak="云端服务不可用,请稍后再试。", + ) + return + except Exception as e: # noqa: BLE001 + print(f"[云端] 异常: {e}", flush=True) + logger.error("云端对话异常: %s", e, exc_info=True) + self._recover_from_cloud_failure( + user_msg, + finish_wake_after_tts=finish_wake_after_tts, + idle_speak="网络异常,请稍后再试。", + ) + return + + dt = time.monotonic() - t0 + metrics = result.get("metrics") or {} + print( + f"[计时] 云端一轮(turn.text) {dt:.3f}s " + f"(llm_ms={metrics.get('llm_ms')!r}, " + f"tts_first_byte_ms={metrics.get('tts_first_byte_ms')!r})", + flush=True, + ) + self._apply_cloud_dialog_result(result, finish_wake_after_tts=finish_wake_after_tts) + + def _handle_llm_turn_cloud_pcm( + self, + pcm_i16: np.ndarray, + sample_rate_hz: int, + *, + finish_wake_after_tts: bool = False, + ) -> None: + from voice_drone.core.cloud_voice_client import CloudVoiceError + + assert self._cloud_client is not None + t0 = time.monotonic() + try: + result = self._cloud_client.run_turn_audio(pcm_i16, int(sample_rate_hz)) + except CloudVoiceError as e: + print(f"[云端] turn.audio 失败: {e} (code={e.code!r})", flush=True) + logger.error("云端 turn.audio 失败: %s", e, exc_info=True) + self._recover_from_cloud_failure( + "", + finish_wake_after_tts=True, + idle_speak="云端语音识别失败,请稍后再试。", + ) + return + except Exception as e: # noqa: BLE001 + print(f"[云端] turn.audio 异常: {e}", flush=True) + logger.error("云端 turn.audio 异常: %s", e, exc_info=True) + self._recover_from_cloud_failure( + "", + finish_wake_after_tts=True, + idle_speak="网络异常,请稍后再试。", + ) + return + + dt = time.monotonic() - t0 + metrics = result.get("metrics") or {} + print( + f"[计时] 云端一轮(turn.audio) {dt:.3f}s " + f"(llm_ms={metrics.get('llm_ms')!r}, " + f"tts_first_byte_ms={metrics.get('tts_first_byte_ms')!r})", + flush=True, + ) + self._apply_cloud_dialog_result(result, finish_wake_after_tts=finish_wake_after_tts) + + def _handle_llm_turn_local( + self, user_msg: str, *, finish_wake_after_tts: bool = False + ) -> None: + llm = self._ensure_llm() + if llm is None: + self._enqueue_llm_speak( + "大模型未就绪。请确认已下载 GGUF,或设置环境变量 ROCKET_LLM_GGUF 指向模型文件。" + ) + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + return + + with self._chat_session_lock: + self._llm_messages = [ + {"role": "system", "content": FLIGHT_INTENT_CHAT_SYSTEM}, + {"role": "user", "content": user_msg}, + ] + messages_snapshot = list(self._llm_messages) + + if not self._llm_stream_enabled: + t_llm0 = time.monotonic() + try: + out = llm.create_chat_completion( + messages=messages_snapshot, + max_tokens=self._llm_max_tokens, + ) + except Exception as e: # noqa: BLE001 + dt_llm = time.monotonic() - t_llm0 + print(f"[计时] LLM 推理 {dt_llm:.3f}s(失败)", flush=True) + logger.error("LLM 推理失败: %s", e, exc_info=True) + with self._chat_session_lock: + if self._llm_messages and self._llm_messages[-1].get("role") == "user": + self._llm_messages.pop() + self._enqueue_llm_speak("推理出错,请稍后再说。") + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + return + dt_llm = time.monotonic() - t_llm0 + print(f"[计时] LLM 推理 {dt_llm:.3f}s", flush=True) + + reply = ( + (out.get("choices") or [{}])[0].get("message") or {} + ).get("content", "").strip() + self._finalize_llm_turn( + reply, finish_wake_after_tts, streamed_chat=False + ) + return + + t_llm0 = time.monotonic() + try: + stream = llm.create_chat_completion( + messages=messages_snapshot, + max_tokens=self._llm_max_tokens, + stream=True, + ) + except Exception as e: # noqa: BLE001 + dt_llm = time.monotonic() - t_llm0 + print(f"[计时] LLM 推理 {dt_llm:.3f}s(失败)", flush=True) + logger.error("LLM 推理失败: %s", e, exc_info=True) + with self._chat_session_lock: + if self._llm_messages and self._llm_messages[-1].get("role") == "user": + self._llm_messages.pop() + self._enqueue_llm_speak("推理出错,请稍后再说。") + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + return + + full_reply = "" + pending = "" + tts_budget = self._llm_tts_max_chars + route: str | None = None + + try: + for chunk in stream: + content = self._chunk_delta_text(chunk) + if not content: + continue + full_reply += content + if route is None: + lead = full_reply.lstrip() + if lead: + route = "json" if lead[0] == "{" else "chat" + if route != "chat" or tts_budget <= 0: + continue + pending += content + while tts_budget > 0 and pending: + segs, pending = take_completed_sentences(pending) + if segs: + for seg in segs: + tts_budget = self._enqueue_segment_capped(seg, tts_budget) + if tts_budget <= 0: + break + continue + forced, pending = force_soft_split( + pending, self._stream_tts_chunk_chars + ) + if not forced: + break + for seg in forced: + tts_budget = self._enqueue_segment_capped(seg, tts_budget) + if tts_budget <= 0: + break + except Exception as e: # noqa: BLE001 + dt_llm = time.monotonic() - t_llm0 + print(f"[计时] LLM 推理 {dt_llm:.3f}s(失败)", flush=True) + logger.error("LLM 流式推理失败: %s", e, exc_info=True) + with self._chat_session_lock: + if self._llm_messages and self._llm_messages[-1].get("role") == "user": + self._llm_messages.pop() + self._enqueue_llm_speak("推理出错,请稍后再说。") + if finish_wake_after_tts: + self._pending_finish_wake_cycle_after_tts = True + return + + dt_llm = time.monotonic() - t_llm0 + print(f"[计时] LLM 推理 {dt_llm:.3f}s", flush=True) + + reply = full_reply.strip() + if route == "chat" and tts_budget > 0: + tail = pending.strip() + if tail: + self._enqueue_segment_capped(tail, tts_budget) + + self._finalize_llm_turn( + reply, finish_wake_after_tts, streamed_chat=(route == "chat") + ) + + def start(self) -> None: + if self.running: + logger.warning("识别器已在运营") + return + + self.running = True + + self.stt_thread = threading.Thread(target=self._stt_worker_thread, daemon=True) + self.stt_thread.start() + + self.command_thread = threading.Thread( + target=self._takeoff_only_command_worker, daemon=True + ) + self.command_thread.start() + + # 先预加载再开麦:否则 PortAudio 回调会一直往 audio_queue 塞数据,而主线程还没进入 + # process_audio_stream,默认仅 10 块的队列会迅速满并触发「音频队列已满,丢弃数据块」。 + logger.info("voice_drone_assistant: 准备预加载模型(若启用)…") + self._preload_llm_and_tts_if_enabled() + + try: + self.audio_capture.start_stream() + except BaseException: + self.running = False + try: + self.stt_queue.put(None, timeout=0.5) + except Exception: # noqa: BLE001 + pass + try: + self.command_queue.put(None, timeout=0.5) + except Exception: # noqa: BLE001 + pass + if self.stt_thread is not None: + self.stt_thread.join(timeout=2.0) + if self.command_thread is not None: + self.command_thread.join(timeout=2.0) + raise + + if self._cloud_voice_enabled: + logger.info( + "voice_drone_assistant: 已启动(对话走云端 WebSocket;TTS 为云端 PCM;飞控见 Socket/offboard)" + ) + else: + logger.info( + "voice_drone_assistant: 已启动(无试飞控 Socket;大模型答复走 Kokoro TTS)" + ) + ld = os.environ.get("LD_PRELOAD", "") + sys_asound = "libasound.so" in ld and "/usr/" in ld + if not sys_asound: + print( + "\n⚠ 建议用系统 ALSA 启动(conda 下否则常无声或 VAD 不触发):\n" + " bash with_system_alsa.sh python main.py\n", + flush=True, + ) + if self._llm_disabled and not self._cloud_voice_enabled: + if self._local_keyword_takeoff_enabled: + llm_hint = "已 ROCKET_LLM_DISABLE=1:除 keywords.yaml 中 takeoff 关键词外,其它指令仅打印,不调大模型。\n" + else: + llm_hint = ( + "已 ROCKET_LLM_DISABLE=1 且未启用本地口令起飞(assistant.local_keyword_takeoff_enabled / " + "ROCKET_LOCAL_KEYWORD_TAKEOFF):指令仅打印,不调大模型。\n" + ) + elif self._cloud_voice_enabled: + if self._local_keyword_takeoff_enabled: + llm_hint = "已启用云端对话:非 takeoff 关键词指令经 WebSocket 上云,播报为云端 TTS 流。\n" + else: + llm_hint = "已启用云端对话:指令经 WebSocket 上云,播报为云端 TTS 流(本地口令起飞已关闭)。\n" + else: + llm_hint = ( + "说「无人机」唤醒后会先播报问候,再听您说一句(不必再带唤醒词);说完后关麦推理,答句播完后再说「" + f"{self.wake_word_detector.primary}」开始下一轮。非起飞指令走大模型(" + "飞控相关→JSON,否则闲聊)。\n" + ) + if self._local_keyword_takeoff_enabled: + takeoff_banner = ( + "\n本地口令起飞已开启:说「无人机」+ keywords.yaml 里 takeoff 词(如「起飞演示」)→ 播提示音、" + "启动 scripts/run_px4_offboard_one_terminal.sh(串口真机)、再播返航提示并结束脚本。\n" + ) + else: + takeoff_banner = ( + "\n本地口令起飞已关闭(飞控请用云端 flight_intent / ROS 桥等);" + "若需恢复 keywords.yaml takeoff → offboard,设 assistant.local_keyword_takeoff_enabled: true 或 " + "ROCKET_LOCAL_KEYWORD_TAKEOFF=1。\n" + ) + print( + f"{takeoff_banner}" + f"{llm_hint}" + "标记说明:[VAD] 已截段送 STT;[STT] 识别文字;[唤醒] 是否含唤醒词;[LLM] 对话与播报。\n" + "录音已在启动时选好;扬声器可设 ROCKET_TTS_DEVICE。建议:bash with_system_alsa.sh python …\n" + "Ctrl+C 退出。\n", + flush=True, + ) + + def _play_wav_serialized(self, path: Path) -> None: + if not path.is_file(): + logger.warning("WAV 文件不存在,跳过播放: %s", path) + return + with self._audio_play_lock: + try: + _play_wav_blocking(path) + except Exception as e: # noqa: BLE001 + logger.warning("播放 WAV 失败 %s: %s", path, e, exc_info=True) + + def _run_takeoff_offboard_and_wavs(self) -> None: + """独立线程:起 offboard 脚本;播第一段;第一段结束后等 10s;再播第二段;第二段结束后杀掉脚本进程组。""" + if not _OFFBOARD_SCRIPT.is_file(): + logger.error("未找到 offboard 脚本: %s", _OFFBOARD_SCRIPT) + return + + acquired = self._takeoff_side_task_busy.acquire(blocking=False) + if not acquired: + logger.warning("起飞联动已在执行,忽略重复触发") + return + + proc: subprocess.Popen | None = None + try: + log_path = Path( + os.environ.get("ROCKET_OFFBOARD_LOG", "/tmp/rocket_drone_offboard_script.log") + ).expanduser() + log_f = open(log_path, "ab", buffering=0) + try: + proc = subprocess.Popen( + [ + "bash", + str(_OFFBOARD_SCRIPT), + "/dev/ttyACM0", + "921600", + "20", + ], + cwd=str(_PROJECT_ROOT), + stdout=log_f, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + except Exception as e: # noqa: BLE001 + logger.error("启动 run_px4_offboard_one_terminal.sh 失败: %s", e, exc_info=True) + return + finally: + log_f.close() + + with self._offboard_proc_lock: + self._active_offboard_proc = proc + + time.sleep(0.5) + early_rc = proc.poll() + if early_rc is not None: + logger.error( + "offboard 一键脚本已立即结束 (exit=%s),未持续运行。日志: %s (常见原因:找不到 " + "px4_ctrl_offboard_demo.py、ROS 环境、或串口未连)", + early_rc, + log_path, + ) + + logger.info( + "已启动 offboard 一键脚本 (pid=%s),并播放起飞提示音;脚本输出见 %s", + proc.pid, + log_path, + ) + + self._play_wav_serialized(_TAKEOFF_ACK_WAV) + time.sleep(10.0) + self._play_wav_serialized(_TAKEOFF_DONE_WAV) + finally: + if proc is not None: + logger.info("第二段 WAV 已播完,终止 offboard 脚本进程组 (pid=%s)", proc.pid) + _terminate_process_group(proc) + with self._offboard_proc_lock: + if self._active_offboard_proc is proc: + self._active_offboard_proc = None + self._takeoff_side_task_busy.release() + + def _takeoff_only_command_worker(self) -> None: + """唤醒;同句带指令则直转 LLM/起飞;否则问候+滴声→再问一句→关麦播报。""" + logger.info("唤醒流程命令线程已启动") + while self.running: + try: + text = self.command_queue.get(timeout=0.1) + except queue.Empty: + continue + except Exception as e: # noqa: BLE001 + logger.error(f"命令处理线程错误: {e}", exc_info=True) + continue + + try: + if text is None: + break + + try: + if ( + isinstance(text, tuple) + and len(text) == 3 + and text[0] == _PCM_TURN_MARKER + ): + self._handle_pcm_uplink_turn(text[1], int(text[2])) + continue + + with self._wake_flow_lock: + phase = self._wake_phase + + if phase == int(_WakeFlowPhase.LLM_BUSY): + continue + if phase == int(_WakeFlowPhase.GREETING_WAIT): + continue + + if phase == int(_WakeFlowPhase.FLIGHT_CONFIRM_LISTEN): + self._handle_flight_confirm_text(text) + continue + + if phase == int(_WakeFlowPhase.ONE_SHOT_LISTEN): + with self._wake_flow_lock: + self._wake_phase = int(_WakeFlowPhase.LLM_BUSY) + self._process_one_shot_command(text) + continue + + is_wake, matched = self.wake_word_detector.detect(text) + if not is_wake: + logger.debug("未检测到唤醒词,忽略: %s", text) + if os.environ.get("ROCKET_PRINT_STT", "").lower() in ( + "1", + "true", + "yes", + ): + print( + f"[唤醒] 未命中「{self.wake_word_detector.primary}」,原文: {text!r}", + flush=True, + ) + continue + + logger.info("唤醒词命中: %s", matched) + command_text = self.wake_word_detector.extract_command_text(text) + follow = (command_text or "").strip() + if follow: + if not self._wake_fast_path_process_follow(follow): + continue + continue + self._begin_wake_cycle(None) + + except Exception as e: # noqa: BLE001 + logger.error("命令处理失败: %s", e, exc_info=True) + finally: + self.command_queue.task_done() + + logger.info("唤醒流程命令线程已停止") + + def stop(self) -> None: + """停止识别;不重连 Socket(从未连接)。""" + if not self.running: + return + + self.running = False + + self._cancel_prompt_listen_timer() + self._cancel_flight_confirm_timer() + with self._flight_confirm_timer_lock: + self._pending_flight_confirm = None + self._pending_flight_confirm_after_tts = False + + if self.stt_thread is not None: + self.stt_queue.put(None) + if self.command_thread is not None: + self.command_queue.put(None) + if self.stt_thread is not None: + self.stt_thread.join(timeout=2.0) + if self.command_thread is not None: + self.command_thread.join(timeout=2.0) + + # 不在此线程做 speak_text:会阻塞数秒至数十秒,用户多次 Ctrl+C 仍杀不掉进程 + self._discard_llm_playback_queue() + + with self._offboard_proc_lock: + op = self._active_offboard_proc + self._active_offboard_proc = None + if op is not None and op.poll() is None: + logger.info("主程序退出:终止仍在运行的 offboard 脚本") + _terminate_process_group(op) + + try: + self.audio_capture.stop_stream() + except KeyboardInterrupt: + logger.info("关闭麦克风流时中断,跳过") + except Exception as e: # noqa: BLE001 + logger.warning("关闭麦克风流失败: %s", e) + + if self._cloud_client is not None: + try: + self._cloud_client.close() + except Exception as e: # noqa: BLE001 + logger.debug("关闭云端 WebSocket: %s", e) + + if self.socket_client.connected: + self.socket_client.disconnect() + + logger.info("voice_drone_assistant 已停止") + print("\n已退出。", flush=True) + + +def main() -> None: + ap = argparse.ArgumentParser( + description="无人机语音:唤醒 → 问候 → 一句指令 → 起飞或 LLM 播报 → 再唤醒" + ) + ap.add_argument( + "--input-index", + "-I", + type=int, + default=None, + help="跳过交互菜单,直接指定 PyAudio 录音设备索引(与启动时「PyAudio_index=」一致)。", + ) + ap.add_argument( + "--non-interactive", + action="store_true", + help="不选设备:用 system.yaml 的 audio.input_device_index(为 null 时自动枚举默认可录音设备)。", + ) + ap.add_argument( + "--no-preload", + action="store_true", + help="不预加载 Qwen/Kokoro,缩短启动时间(首轮对话与首次播报会变慢)。", + ) + args = ap.parse_args() + non_inter = args.non_interactive or os.environ.get( + "ROCKET_NON_INTERACTIVE", "" + ).lower() in ("1", "true", "yes") + + idx = args.input_index + if idx is None: + raw_ix = os.environ.get("ROCKET_INPUT_DEVICE_INDEX", "").strip() + if raw_ix.isdigit() or (raw_ix.startswith("-") and raw_ix[1:].isdigit()): + idx = int(raw_ix) + + if idx is not None: + from voice_drone.core.mic_device_select import apply_input_device_index_only + + apply_input_device_index_only(idx) + logger.info("录音设备: PyAudio 索引 %s(CLI/环境变量)", idx) + elif not non_inter: + from voice_drone.core.mic_device_select import ( + apply_input_device_index_only, + prompt_for_input_device_index, + ) + + chosen = prompt_for_input_device_index() + apply_input_device_index_only(chosen) + else: + logger.info( + "非交互模式:使用 system.yaml 的 audio.input_device_index(null=自动探测)" + ) + + app = TakeoffPrintRecognizer(skip_model_preload=args.no_preload) + try: + app.run() + except KeyboardInterrupt: + logger.info("用户中断") + finally: + if app.running: + app.stop() + + +if __name__ == "__main__": + main() diff --git a/voice_drone/tools/__init__.py b/voice_drone/tools/__init__.py new file mode 100644 index 0000000..2cf02fb --- /dev/null +++ b/voice_drone/tools/__init__.py @@ -0,0 +1 @@ +"""One-off helpers (CLI tools).""" diff --git a/voice_drone/tools/config_loader.py b/voice_drone/tools/config_loader.py new file mode 100644 index 0000000..bebd851 --- /dev/null +++ b/voice_drone/tools/config_loader.py @@ -0,0 +1,15 @@ +import yaml + +def load_config(config_path): + """ + 加载配置 + + Args: + config_path (str): 配置文件路径 + + Returns: + dict: 返回解析后的配置字典 + """ + with open(config_path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + return config \ No newline at end of file diff --git a/voice_drone/tools/publish_flight_intent_ros_once.py b/voice_drone/tools/publish_flight_intent_ros_once.py new file mode 100644 index 0000000..7792370 --- /dev/null +++ b/voice_drone/tools/publish_flight_intent_ros_once.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +"""Publish one flight_intent JSON to a ROS1 std_msgs/String topic (for flight bridge). + +Run after sourcing ROS, e.g.: + source /opt/ros/noetic/setup.bash && python3 -m voice_drone.tools.publish_flight_intent_ros_once /tmp/intent.json + +Or from repo root(须在已 source ROS 的 shell 中,且勿覆盖 ROS 的 PYTHONPATH;应 prepend): + export PYTHONPATH="$PWD:$PYTHONPATH" && python3 -m voice_drone.tools.publish_flight_intent_ros_once ... +""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Any + +import rospy +from std_msgs.msg import String + + +def _load_payload(path: str | None) -> dict[str, Any]: + if path and path != "-": + raw = Path(path).read_text(encoding="utf-8") + else: + raw = sys.stdin.read() + data = json.loads(raw) + if not isinstance(data, dict): + raise ValueError("root must be a JSON object") + return data + + +def main() -> None: + ap = argparse.ArgumentParser(description="Publish flight_intent JSON once to ROS String topic") + ap.add_argument( + "json_file", + nargs="?", + default="-", + help="Path to JSON file, or - for stdin (default: stdin)", + ) + ap.add_argument("--topic", default="/input", help="ROS topic (default: /input)") + ap.add_argument( + "--wait-subscribers", + type=float, + default=0.0, + help="Seconds to wait for at least one subscriber (default: 0)", + ) + args = ap.parse_args() + + payload = _load_payload(args.json_file if args.json_file != "-" else None) + json_str = json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + + rospy.init_node("flight_intent_ros_publisher_once", anonymous=True) + pub = rospy.Publisher(args.topic, String, queue_size=1, latch=False) + deadline = rospy.Time.now() + rospy.Duration(args.wait_subscribers) + while args.wait_subscribers > 0 and pub.get_num_connections() < 1 and not rospy.is_shutdown(): + if rospy.Time.now() > deadline: + break + rospy.sleep(0.05) + rospy.sleep(0.15) + pub.publish(String(data=json_str)) + rospy.sleep(0.25) + + +if __name__ == "__main__": + main() diff --git a/voice_drone/tools/wrapper.py b/voice_drone/tools/wrapper.py new file mode 100644 index 0000000..4367106 --- /dev/null +++ b/voice_drone/tools/wrapper.py @@ -0,0 +1,17 @@ +import time + +def time_cost(tag=None): + """ + 装饰器,自定义函数耗时打印信息 + :param tag: 可选,自定义标志(说明当前执行的函数) + """ + def decorator(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + label = tag if tag is not None else func.__name__ + print(f"耗时统计[{label}]: {(end_time - start_time)*1000:.2f} 毫秒") + return result + return wrapper + return decorator \ No newline at end of file diff --git a/with_system_alsa.sh b/with_system_alsa.sh new file mode 100644 index 0000000..b66b4f1 --- /dev/null +++ b/with_system_alsa.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +# Conda 里 PyAudio 自带的 libportaudio 会加载同目录 libasound,从而去 conda/.../lib/alsa-lib/ +# 找插件(多数环境不完整,终端刷屏且麦克风电平异常)。 +# 在启动 Python **之前** 预加载系统的 libasound.so.2 可恢复正常 ALSA 行为。 +# +# 用法(项目根目录): +# bash with_system_alsa.sh python src/mic_level_check.py +# bash with_system_alsa.sh python src/rocket_drone_audio.py + +set -euo pipefail +ARCH="$(uname -m)" +case "${ARCH}" in + aarch64) ASO="/usr/lib/aarch64-linux-gnu/libasound.so.2" ;; + x86_64) ASO="/usr/lib/x86_64-linux-gnu/libasound.so.2" ;; + *) ASO="" ;; +esac +if [[ -n "${ASO}" && -f "${ASO}" ]]; then + export LD_PRELOAD="${ASO}${LD_PRELOAD:+:${LD_PRELOAD}}" +fi +exec "$@"