Initial commit: voice drone assistant
Made-with: Cursor
This commit is contained in:
commit
157a34fe87
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
.Python
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
.env
|
||||||
|
*.egg-info/
|
||||||
|
.eggs/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# 大模型与缓存(见 models/README.txt,请单独拷贝或从原仓库同步)
|
||||||
|
models/
|
||||||
|
cache/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
94
README.md
Normal file
94
README.md
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# voice_drone_assistant
|
||||||
|
|
||||||
|
从原仓库抽离的**独立可运行**子工程:麦克风采集 → VAD 切段 → **SenseVoice STT** → **唤醒词** →(关键词起飞 / **Qwen + Kokoro 对话播报**)。
|
||||||
|
|
||||||
|
**部署与外场启动(推荐先读):[docs/DEPLOYMENT_AND_OPERATIONS.md](docs/DEPLOYMENT_AND_OPERATIONS.md)**
|
||||||
|
**日常配置索引:[docs/PROJECT_GUIDE.md](docs/PROJECT_GUIDE.md)** · 云端协议:[docs/llmcon.md](docs/llmcon.md)
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
| 路径 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `main.py` | 启动入口 |
|
||||||
|
| `with_system_alsa.sh` | Conda 下建议包一层启动,修正 ALSA/PortAudio |
|
||||||
|
| `voice_drone/core/` | 音频、VAD、STT、TTS、预处理、唤醒、配置、识别器主流程 |
|
||||||
|
| `voice_drone/main_app.py` | 唤醒流程 + LLM 流式 + 起飞脚本联动(原 `rocket_drone_audio.py`) |
|
||||||
|
| `voice_drone/config/` | `system.yaml`、`wake_word.yaml`、`keywords.yaml`、`command_.yaml` |
|
||||||
|
| `voice_drone/logging_/` | 彩色日志 |
|
||||||
|
| `voice_drone/tools/` | YAML 加载等 |
|
||||||
|
| `scripts/` | PX4 offboard、`generate_wake_greeting_wav.py` |
|
||||||
|
| `assets/tts_cache/` | 唤醒问候 WAV 缓存 |
|
||||||
|
| `models/` | **需自备或软链**,见 `models/README.txt` |
|
||||||
|
|
||||||
|
## 环境准备
|
||||||
|
|
||||||
|
1. Python 3.10+(与原项目一致即可),安装依赖:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 模型:将 STT / TTS /(可选)Silero VAD 放到 `models/`,或按 `models/README.txt` 从原仓库 `src/models` 创建符号链接。
|
||||||
|
|
||||||
|
3. 大模型:默认查找 `cache/qwen25-1.5b-gguf/qwen2.5-1.5b-instruct-q4_k_m.gguf`,或通过环境变量 `ROCKET_LLM_GGUF` 指定 GGUF 路径。
|
||||||
|
|
||||||
|
## 运行
|
||||||
|
|
||||||
|
在 **`voice_drone_assistant` 根目录** 执行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash with_system_alsa.sh python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
常用参数与环境变量与原 `rocket_drone_audio.py` 相同(如 `ROCKET_LLM_STREAM`、`ROCKET_INPUT_DEVICE_INDEX`、`--input-index`、`ROCKET_ENERGY_VAD` 等),说明见 `voice_drone/main_app.py` 文件头注释。
|
||||||
|
|
||||||
|
也可直接跑模块:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash with_system_alsa.sh python -m voice_drone.main_app
|
||||||
|
```
|
||||||
|
|
||||||
|
## 为什么不默认带上原仓库的 models?
|
||||||
|
|
||||||
|
- **ONNX / GGUF 体积大**(动辄数百 MB~数 GB),放进 Git 或重复拷贝会加重仓库和同步成本。
|
||||||
|
- 抽离时只保证 **代码与配置自给**;权重文件用 **本机拷贝 / U 盘 / 另一台预先 `bundle`** 更灵活。
|
||||||
|
|
||||||
|
若你本机仍摆着原仓库 `rocket_drone_audio`,且 `voice_drone_assistant` 在其子目录下,代码里有个**临时便利**:`models/...` 找不到时会尝试 **上一级 `src/models/...`**,所以在开发机上可以不改目录也能跑。
|
||||||
|
**这只在「子目录 + 上层仍有原仓库」时有效**,把 `voice_drone_assistant` **单独拷到另一台香橙派后,上层没有原仓库,必须在本目录自备 `models/`(和可选 `cache/`)**。
|
||||||
|
|
||||||
|
## 拷到另一台香橙派要做什么?
|
||||||
|
|
||||||
|
1. **整目录复制**(建议先在本机执行下面脚本打全模型,再打包 `voice_drone_assistant`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /path/to/voice_drone_assistant
|
||||||
|
bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio
|
||||||
|
```
|
||||||
|
|
||||||
|
会把 `SenseVoiceSmall`、`Kokoro-82M-v1.1-zh-ONNX`(及存在的 `SileroVad`)复制到本目录 `models/`;可按提示选择是否复制 Qwen GGUF。
|
||||||
|
|
||||||
|
2. **新机器上 Python 依赖**:另一台是**全新系统**时,需要再装一次(或整体迁移同一个 conda/env):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd voice_drone_assistant
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
二进制/系统库层面若仍用 conda + PortAudio,建议继续 **`bash with_system_alsa.sh python main.py`**。
|
||||||
|
|
||||||
|
3. **大模型路径**:若未打包 `cache/`,在新机器设环境变量或放入默认路径,例如:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export ROCKET_LLM_GGUF=/path/to/qwen2.5-1.5b-instruct-q4_k_m.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
综上:**工程可独立**,但必须带上 **`models/` + 已装依赖 +(可选)GGUF**;**`pip install` 每台新环境通常要做一次**,除非你把整个 conda env 目录一起迁移。
|
||||||
|
|
||||||
|
## 与原仓库关系
|
||||||
|
|
||||||
|
- 本目录为**代码与配置的复制 + 包名调整**(`src.*` → `voice_drone.*`),默认不把大体积 `models/`、`cache/` 放进版本库。
|
||||||
|
- 原仓库 `rocket_drone_audio` 仍可继续使用;开发阶段两者可并存,部署到单机时只带走 `voice_drone_assistant`(+ `bundle` 后的模型)即可。
|
||||||
|
|
||||||
|
## 未纳入本工程的模块
|
||||||
|
|
||||||
|
PX4 电机演示、独立录音脚本、Socket 试飞控协议服务端、ChatTTS 转换脚本等均留在原仓库,以减小篇幅;本工程仍通过 `SocketClient` 预留配置项(`TakeoffPrintRecognizer` 使用 `auto_connect_socket=False`,不依赖外置试飞控 Socket)。
|
||||||
BIN
assets/tts_cache/wake_greeting.wav
Normal file
BIN
assets/tts_cache/wake_greeting.wav
Normal file
Binary file not shown.
0
docs/API.md
Normal file
0
docs/API.md
Normal file
179
docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md
Normal file
179
docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
# 云端语音 · `dialog_result` 与飞控二次确认(v1)
|
||||||
|
|
||||||
|
供 **云端服务** 与 **机端 voice_drone_assistant** 同步实现。**尚无线上存量**:本文即 **`dialog_result` 的飞机位约定**,服务端可按 v1 直接改结构,无需迁就旧字段。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 目标
|
||||||
|
|
||||||
|
1. **`routing=chitchat`**:只走闲聊与对应 TTS,**不**下发可执行飞控负载。
|
||||||
|
2. **`routing=flight_intent`**:携 **`flight_intent`(v1)** + **`confirm`**;机端是否立刻执行仅由 **`confirm.required`** 决定,并支持 **确认 / 取消 / 超时** 交互。
|
||||||
|
3. **ASR**:飞控句是否改用云端识别见 **附录 A**;与 `confirm` 独立。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 术语
|
||||||
|
|
||||||
|
| 术语 | 含义 |
|
||||||
|
|------|------|
|
||||||
|
| **首轮** | 用户说一句;本轮 WS 收到 `dialog_result` 为止。 |
|
||||||
|
| **确认窗** | `confirm.required=true` 时,机端播完本轮 PCM 后 **仅收口令** 的时段,时长 **`confirm.timeout_sec`**。 |
|
||||||
|
| **`flight_intent`** | 见 `FLIGHT_INTENT_SCHEMA_v1.md`。 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. `dialog_result` 形状(云端 → 机端)
|
||||||
|
|
||||||
|
### 3.1 公共顶层(每轮必带)
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `turn_id` | string | 是 | 与现有一致,关联本 turn。 |
|
||||||
|
| **`protocol`** | string | 是 | 固定 **`cloud_voice_dialog_v1`**,便于机端强校验、排障。 |
|
||||||
|
| `routing` | string | 是 | **`chitchat`** \| **`flight_intent`** |
|
||||||
|
| `user_input` | string | 建议 | 本回合用于生成回复的用户文本(可为云端 STT 结果)。 |
|
||||||
|
|
||||||
|
### 3.2 `routing=chitchat`
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `chat_reply` | string | 是 | 闲聊文本(与 TTS 语义一致或由服务端定义)。 |
|
||||||
|
| `flight_intent` | — | **禁止** | 不得出现。 |
|
||||||
|
| `confirm` | — | **禁止** | 不得出现。 |
|
||||||
|
|
||||||
|
### 3.3 `routing=flight_intent`
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `flight_intent` | object | 是 | v1:`is_flight_intent`、`version`、`actions`、`summary` 等。 |
|
||||||
|
| **`confirm`** | object | 是 | 见 §3.4;**每轮飞控必带**,机端拒收缺字段报文。 |
|
||||||
|
|
||||||
|
### 3.4 `confirm` 对象(`routing=flight_intent` 时必填)
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| **`required`** | bool | 是 | `true`:进入确认窗,**首轮禁止**执行飞控;`false`:首轮允许按机端执行开关立即执行(调试/免确认策略)。 |
|
||||||
|
| **`timeout_sec`** | number | 是 | 确认窗秒数;建议默认 **10**。 |
|
||||||
|
| **`confirm_phrases`** | string[] | 是 | 非空;与口播一致,推荐 **`["确认"]`**。 |
|
||||||
|
| **`cancel_phrases`** | string[] | 是 | 非空;推荐 **`["取消"]`**。 |
|
||||||
|
| **`pending_id`** | string | 是 | 本轮待定意图 ID(建议 UUID);日志、可选第二轮遥测(附录 B)。 |
|
||||||
|
| **`summary_for_user`** | string | 建议 | 与口播语义一致,供日志/本地 TTS 兜底;**最终以本轮 PCM 为准**。 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 播报(理解与提示)
|
||||||
|
|
||||||
|
- **TTS**:仍用 **`tts_audio_chunk` + PCM**;内容示例:复述理解 + **「请回复确认或取消」**;服务端在 `confirm_*_phrases` 中与口播保持一致(推荐 **`确认` / `取消`**)。
|
||||||
|
- 机端 **须** 在 **本轮 PCM 播放结束**(或播放管线给出「可收听下一句」)后再进入确认窗,避免抢话。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 机端短语匹配(确认窗内)
|
||||||
|
|
||||||
|
对用户 **一句** STT 规范化后,与 `confirm_phrases` / `cancel_phrases` 比对(机端实现见 `match_phrase_list`):
|
||||||
|
|
||||||
|
1. **取消优先**:若命中 `cancel_phrases` 任一 → 取消本轮。
|
||||||
|
2. **确认**:否则若命中 `confirm_phrases` 任一 → 执行 **`flight_intent`**。
|
||||||
|
3. **规则要点**:**全等**(去尾标点)算命中;或对 **很短** 的句子(长度 ≤ 短语长+2)允许 **子串** 命中,以便「好的确认」类说法;**整句复述**云端长提示(如「请回复确认或取消」)不会因同时含「确认」「取消」子串而误匹配。
|
||||||
|
4. **未命中**:可静候超时(v1 建议确认窗内 **可多句** 直至超时,由机端实现决定)。
|
||||||
|
4. **超时 / 取消** 固定中文播报见下表(机端本地 TTS,降低时延):
|
||||||
|
|
||||||
|
| 事件 | 文案 |
|
||||||
|
|------|------|
|
||||||
|
| 超时 | `未收到确认指令,请重新下发指令` |
|
||||||
|
| 取消 | `已取消指令,请重新唤醒后下发指令` |
|
||||||
|
| 确认并执行 | `开始执行飞控指令` |
|
||||||
|
|
||||||
|
若产品强制云端音色,见 **附录 C**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 机端执行条件(归纳)
|
||||||
|
|
||||||
|
| 条件 | 行为 |
|
||||||
|
|------|------|
|
||||||
|
| `routing=chitchat` | 不执行飞控。 |
|
||||||
|
| `routing=flight_intent` 且 `confirm.required=false` 且机端已开执行开关 | 首轮校验通过后 **可立即** 执行。 |
|
||||||
|
| `routing=flight_intent` 且 `confirm.required=true` | **仅**在确认窗内命中确认短语后执行;**首轮绝不**执行。 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 机端状态机(摘要)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
stateDiagram-v2
|
||||||
|
[*] --> Idle
|
||||||
|
Idle --> Chitchat: routing=chitchat
|
||||||
|
Idle --> ExecNow: routing=flight_intent 且 confirm.required=false
|
||||||
|
Idle --> ConfirmWin: routing=flight_intent 且 confirm.required=true
|
||||||
|
|
||||||
|
ConfirmWin --> ExecIntent: 命中 confirm_phrases
|
||||||
|
ConfirmWin --> SayCancel: 命中 cancel_phrases
|
||||||
|
ConfirmWin --> SayTimeout: timeout_sec
|
||||||
|
|
||||||
|
ExecNow --> Idle
|
||||||
|
ExecIntent --> Idle
|
||||||
|
SayCancel --> Idle
|
||||||
|
SayTimeout --> Idle
|
||||||
|
Chitchat --> Idle
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 会话握手
|
||||||
|
|
||||||
|
**`session.start`**(或等价)的 `client` **须** 带:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"protocol": {
|
||||||
|
"dialog_result": "cloud_voice_dialog_v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
服务端仅对声明该协议的客户端下发 §3 结构;机端若未声明,服务端可拒绝或返显式错误码(由服务端定义)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. 安全说明
|
||||||
|
|
||||||
|
二次确认减轻 **错词误飞**,不替代 **急停、遥控介入、场地规范**。
|
||||||
|
TTS 若为「请回复确认或取消」,服务端请在 `confirm_phrases` / `cancel_phrases` 中下发 **`确认`**、**`取消`**(与口播一致);**听与判均在机端**,云端无需再收一轮确认消息。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录 A:云端 ASR(可选)
|
||||||
|
|
||||||
|
服务端可将飞控相关 utterance 改为 **云端 STT** 结果填入 `user_input`,与 `flight_intent` 解析同源;**执行仍以 `flight_intent` + `confirm` 为准**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录 B:第二轮 `turn`(可选遥测)
|
||||||
|
|
||||||
|
用户确认后机端可再发一轮文本(ASR 原文),payload 可带 `pending_id`、`phase: confirm_ack`;**执行成功与否不依赖**该轮响应。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录 C:超时/取消走云端 TTS(可选)
|
||||||
|
|
||||||
|
若 `confirm.play_server_tts_on_timeout` 为真(服务端与机端扩展字段),则由云端推 PCM;**易增延迟**,v1 默认 **关**,以 §5 本地播报为准。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 文档关系
|
||||||
|
|
||||||
|
| 文档 | 关系 |
|
||||||
|
|------|------|
|
||||||
|
| `FLIGHT_INTENT_SCHEMA_v1.md` | `flight_intent` 体 |
|
||||||
|
| `DEPLOYMENT_AND_OPERATIONS.md` | 部署 |
|
||||||
|
|
||||||
|
**版本**:`cloud_voice_dialog_v1`(本文);后续 breaking 变更递增 `cloud_voice_dialog_v2` 等。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 机端实现状态(voice_drone_assistant)
|
||||||
|
|
||||||
|
- **`CloudVoiceClient`**:`session.start.client` 已带 `protocol.dialog_result: cloud_voice_dialog_v1`;`run_turn` 返回含 `protocol`、`confirm`。
|
||||||
|
- **`main_app.TakeoffPrintRecognizer`**:解析 `confirm`;`required=true` 且已开 `ROCKET_CLOUD_EXECUTE_FLIGHT` 时,播完本轮 PCM 后进入 **`FLIGHT_CONFIRM_LISTEN`**,本地匹配短语 / 超时文案见 **`voice_drone.core.cloud_dialog_v1`**。
|
||||||
|
- **服务端未升级前**:若缺 `protocol` 或 `confirm`,机端 **不执行** 飞控(仍播 TTS)。
|
||||||
55
docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md
Normal file
55
docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# PCM ASR 上行协议 v1(机端实现摘要)
|
||||||
|
|
||||||
|
与 `CloudVoiceClient`(`voice_drone/core/cloud_voice_client.py`)及 voicellmcloud 的 `pcm_asr_uplink` **session.start.transport_profile** 对齐。
|
||||||
|
|
||||||
|
## 上行:仅文本 WebSocket 帧
|
||||||
|
|
||||||
|
**禁止**用 WebSocket **binary** 发送用户 PCM。对端(Starlette)`receive_text()` 与 `receive_bytes()` 分流;binary 上发会导致对端异常,客户端可能表现为空文本等异常。
|
||||||
|
|
||||||
|
用户音频只出现在 **`turn.audio.chunk` 的 JSON 字段 `pcm_base64`** 中(标准 Base64,内容为 little-endian **pcm_s16le** 原始字节)。
|
||||||
|
|
||||||
|
## session.start
|
||||||
|
|
||||||
|
- `transport_profile`: **`pcm_asr_uplink`**
|
||||||
|
- 其余与会话通用字段相同(`client`、`auth_token`、`session_id` 等)。
|
||||||
|
|
||||||
|
## 单轮上行(一个 `turn_id`)
|
||||||
|
|
||||||
|
1. **文本 JSON**:`turn.audio.start`
|
||||||
|
- `type`: `"turn.audio.start"`
|
||||||
|
- `proto_version`: `"1.0"`
|
||||||
|
- `transport_profile`: `"pcm_asr_uplink"`
|
||||||
|
- `turn_id`: UUID 字符串
|
||||||
|
- `sample_rate_hz`: 整数(机端一般为 **16000**,与采集一致)
|
||||||
|
- `codec`: `"pcm_s16le"`
|
||||||
|
- `channels`: **1**
|
||||||
|
|
||||||
|
2. **文本 JSON**(可多条):`turn.audio.chunk`
|
||||||
|
- `type`: `"turn.audio.chunk"`
|
||||||
|
- `proto_version`、`transport_profile`、`turn_id` 与 start 一致
|
||||||
|
- `pcm_base64`: 本段 PCM 原始字节的 Base64(不传 WebSocket binary)
|
||||||
|
|
||||||
|
每段原始字节长度由环境变量 **`ROCKET_CLOUD_AUDIO_CHUNK_BYTES`** 控制(默认 8192,对 **解码前** 的 PCM 字节数做钳制)。
|
||||||
|
|
||||||
|
3. **文本 JSON**:`turn.audio.end`
|
||||||
|
- `type`: `"turn.audio.end"`
|
||||||
|
- `proto_version`、`transport_profile`、`turn_id` 与 start 一致。
|
||||||
|
|
||||||
|
**并发**:同一 WebSocket 会话内,**勿**在收到上一轮的 `turn.complete` 之前再发新一轮 `turn.audio.start`。
|
||||||
|
|
||||||
|
## 下行(与 turn.text 同形态)
|
||||||
|
|
||||||
|
- 可选:`asr.partial` — 机端仅日志/UI,**不参与状态机**。
|
||||||
|
- `llm.text_delta`(可选)
|
||||||
|
- `tts_audio_chunk`(JSON)后随 **binary PCM**(TTS 下行仍可为 binary,与上行约定无关)
|
||||||
|
- `dialog_result`
|
||||||
|
- `turn.complete`
|
||||||
|
|
||||||
|
机端对 **空文本帧** 会忽略并继续读(与云端「空文本忽略」一致)。
|
||||||
|
|
||||||
|
机端须 **收齐 `turn.complete` 且按序拼完该轮 TTS 二进制** 后再视为播报结束,再按产品规则分支(闲聊再滴声 / 飞控确认窗等)。
|
||||||
|
|
||||||
|
## 参考
|
||||||
|
|
||||||
|
- 会话产品流:[`CLOUD_VOICE_SESSION_SCHEME_v1.md`](./CLOUD_VOICE_SESSION_SCHEME_v1.md)
|
||||||
|
- 飞控确认:`CLOUD_VOICE_FLIGHT_CONFIRM_v1.md`
|
||||||
163
docs/CLOUD_VOICE_SESSION_SCHEME_v1.md
Normal file
163
docs/CLOUD_VOICE_SESSION_SCHEME_v1.md
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
# 语音助手会话方案 v1(服务端 + 机端对齐)
|
||||||
|
|
||||||
|
本文档描述 **「唤醒 → 问候 → 滴声开录 → 断句上云 → 提示音 → 云端理解与 TTS → 分支循环/待机」** 的端到端方案,供 **服务端(voicellmcloud)** 与 **机端(本仓库 voice_drone_assistant)** 分工落地。
|
||||||
|
|
||||||
|
**v1 明确不做**:播报中途 **抢话 / 打断 TTS(barge-in)**;播放 TTS 时机端 **关麦或不处理用户语音**。
|
||||||
|
|
||||||
|
**关联协议**:
|
||||||
|
|
||||||
|
- 音频上行与 Fun-ASR:[CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md](./CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md)
|
||||||
|
- 未唤醒不上云:由服务端/产品文档 `CLOUD_VOICE_CLIENT_WAKE_GATE_v1` 约定(机端须本地门禁后再建联上云)
|
||||||
|
- 总接口:以 voicellmcloud 仓库 `API_SPECIFICATION` 为准
|
||||||
|
- 飞控确认窗:[CLOUD_VOICE_FLIGHT_CONFIRM_v1.md](./CLOUD_VOICE_FLIGHT_CONFIRM_v1.md)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 产品流程(用户视角)
|
||||||
|
|
||||||
|
1. 用户说唤醒词(如「无人机」,由机端配置),**仅本地处理,不上云 ASR**。
|
||||||
|
2. 机端播放问候语(如「你好,有什么事儿吗」)— 可用本地 TTS 或 `tts.synthesize`。
|
||||||
|
3. 机端 **滴一声**,表示 **开始收音**;同时启动 **5 秒静默超时** 计时(见 §4)。
|
||||||
|
4. 用户说话;机端 **VAD/端点检测** 得到 **一整句** 后:
|
||||||
|
- 播放 **极短断句提示音**(表示「已截句、将上云」);
|
||||||
|
- **提示音播放期间闭麦或做回声隔离**,提示音结束后 **短消抖**(建议 **150~300 ms**)再恢复采集逻辑,避免把提示音当成用户语音。
|
||||||
|
5. 将该句 **PCM** 以 `turn.audio.*` 发云端;云端 **Fun-ASR → LLM → `dialog_result` → TTS**;机端播完 **全部** `tts_audio_chunk` 及收到 **`turn.complete`** 后,视为本轮播报结束(见 §3 服务端)。
|
||||||
|
6. **分支**:
|
||||||
|
- **`routing === flight_intent`**:进入 **飞控子状态机**(口头确认/取消/超时),**不使用** §4 的「闲聊滴声后 5s」规则覆盖确认窗;超时以 **`dialog_result.confirm.timeout_sec`** 及 [CLOUD_VOICE_FLIGHT_CONFIRM_v1.md](./CLOUD_VOICE_FLIGHT_CONFIRM_v1.md) 为准。
|
||||||
|
- **`routing === chitchat`**:本轮结束后 **再滴一声**,进入下一轮 **步骤 4**(同一 WebSocket 会话内 **新 `turn_id`** 再起一轮 `turn.audio.*`)。
|
||||||
|
7. 若在 **步骤 3 的滴声之后** 连续 **5 s** 内未检测到有效语音(见 §4),机端播 **超时提示音**,**不再收音**,回到 **待机**(仅唤醒)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 机端状态机(规范性)
|
||||||
|
|
||||||
|
| 状态 | 含义 | 开麦 | 上云 ASR | 备注 |
|
||||||
|
|------|------|------|----------|------|
|
||||||
|
| `STANDBY` | 仅监听唤醒 | 按现网 VAD | **禁止** `turn.audio.*` | 本地 STT + 唤醒词 |
|
||||||
|
| `GREETING` | 播问候 | 可关麦或忽略输入 | 否 | 避免问候进识别 |
|
||||||
|
| `PROMPT_LISTEN` | 已滴「开始录」,等用户一句 | 开 | 否 | **5s 超时**在此状态监控 |
|
||||||
|
| `SEGMENT_END` | 已断句,播短提示音 | **闭麦/屏蔽** | 否 | 消抖后再转 `UPLOADING` |
|
||||||
|
| `UPLOADING` | 发送 `turn.audio.*` | 否 | **是** | 一轮一个 `turn_id` |
|
||||||
|
| `PLAYING_CLOUD_TTS` | 播云端 TTS | **关麦**(v1 无抢话) | 否 | 至 `turn.complete` + PCM 播完 |
|
||||||
|
| `FLIGHT_CONFIRM` | 飞控确认窗 | 按飞控文档 | **可** `turn.text` 或按产品另定 | **独立超时**,不共用 5s |
|
||||||
|
| `CHITCHAT_TAIL` | 闲聊结束,将再滴一声 | — | 否 | 回到 `PROMPT_LISTEN` |
|
||||||
|
|
||||||
|
**并发**:同一时刻仅允许 **一路** `turn.audio.start`~`end`;须等 `turn.complete` 后再开下一轮(与现有 `pipeline_lock` 一致)。
|
||||||
|
|
||||||
|
**唤醒前**:须满足未唤醒不上云的产品/协议约定。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 服务端(voicellmcloud)职责 — v1 **无新消息类型**
|
||||||
|
|
||||||
|
### 3.1 单轮行为(不变)
|
||||||
|
|
||||||
|
对每个完整 `turn.audio.*` 或 `turn.text`:
|
||||||
|
|
||||||
|
1. Fun-ASR(仅 `pcm_asr_uplink` + 音频轮)→ 文本
|
||||||
|
2. LLM 流式 → `dialog_result`(`routing` / `flight_intent` / `chat_reply` 等)
|
||||||
|
3. `tts_audio_chunk*` → `turn.complete`
|
||||||
|
|
||||||
|
服务端 **不** 下发「请再滴一声」「进入待机」类机端 UX 信令;这些由机端根据 **`routing` + `turn.complete`** **固定规则** 驱动。
|
||||||
|
|
||||||
|
### 3.2 机端判定「播报完成」
|
||||||
|
|
||||||
|
须同时满足:
|
||||||
|
|
||||||
|
- 收到该轮 **`turn.complete`**
|
||||||
|
- 已按序播完该轮关联的 **binary PCM**(`tts_audio_chunk` 与现实现一致)
|
||||||
|
|
||||||
|
然后机端再执行 §1 步骤 6 的分支。
|
||||||
|
|
||||||
|
### 3.3 可选下行
|
||||||
|
|
||||||
|
- **`asr.partial`**:机端 **不得** 用于驱动状态跳转;仅可 UI 展示。
|
||||||
|
- **错误**:`error` / `ASR_FAILED` 等 → 机端播简短失败提示后,建议 **回 `STANDBY` 或回到 `PROMPT_LISTEN`**(产品定)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 5 秒静默超时(闲聊路径)
|
||||||
|
|
||||||
|
| 项 | 约定 |
|
||||||
|
|----|------|
|
||||||
|
| **起算点** | 「**开始收音**」的 **滴声播放结束** 时刻(或滴声后固定 **50~100 ms** 偏移,避免与滴声能量重叠)。 |
|
||||||
|
| **「无说话」** | 麦克 **RMS / VAD** 低于阈值,持续累计 ≥ **5 s**(建议可配置,默认 5)。 |
|
||||||
|
| **期间若开始说话** | 清零超时;**断句上云**后本超时在下一轮「滴声」后重新起算。 |
|
||||||
|
| **触发动作** | 播 **超时提示音** → 进入 **`STANDBY`**(不再滴声、不上云)。 |
|
||||||
|
| **不适用** | **`FLIGHT_CONFIRM`** 整段;确认窗用 **服务端给的 `timeout_sec`**。 |
|
||||||
|
|
||||||
|
**机端配置**(`system.yaml` `cloud_voice`):`listen_silence_timeout_sec`、`post_cue_mic_mute_ms`、`segment_cue_duration_ms`;环境变量见 `main_app.py` 头部说明。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 断句后提示音(工程)
|
||||||
|
|
||||||
|
| 项 | 约定 |
|
||||||
|
|----|------|
|
||||||
|
| 目的 | 用户感知「已截句,可等待播报」 |
|
||||||
|
| 实现 | 机端本地短 WAV / 蜂鸣;时长建议 **≤ 200 ms** |
|
||||||
|
| 回声 | **SEGMENT_END** 阶段闭麦或硬件 AEC;结束后 **≥ 150 ms** 再进入 `UPLOADING` |
|
||||||
|
| 与云端 | **无需** 上传该提示音 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 时序简图(闲聊多轮)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
participant U as 用户
|
||||||
|
participant D as 机端
|
||||||
|
participant S as 服务端
|
||||||
|
|
||||||
|
U->>D: 唤醒词(本地)
|
||||||
|
D->>D: GREETING 播问候
|
||||||
|
D->>D: 滴声 → PROMPT_LISTEN(起 5s 定时)
|
||||||
|
U->>D: 一句语音
|
||||||
|
D->>D: VAD 断句 → 短提示音 → UPLOADING
|
||||||
|
D->>S: turn.audio.start/chunk/end
|
||||||
|
S->>D: asr.partial(可选)
|
||||||
|
S->>D: dialog_result + TTS + turn.complete
|
||||||
|
D->>D: PLAYING_CLOUD_TTS(关麦)
|
||||||
|
alt chitchat
|
||||||
|
D->>D: 再滴声 → PROMPT_LISTEN
|
||||||
|
else flight_intent
|
||||||
|
D->>D: FLIGHT_CONFIRM(独立超时)
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 配置建议(机端)
|
||||||
|
|
||||||
|
| 键 | 默认值 | 说明 |
|
||||||
|
|----|--------|------|
|
||||||
|
| `listen_silence_timeout_sec` | `5` | 滴声后起算 |
|
||||||
|
| `post_cue_mic_mute_ms` | `150`~`300` | 断句提示音后再采集 |
|
||||||
|
| `cue_tone_duration_ms` / `segment_cue_duration_ms` | `≤200` | 断句提示 |
|
||||||
|
| `flight_confirm_handling` | 遵循飞控文档 | 禁用闲聊 5s 覆盖 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 机端开发自检
|
||||||
|
|
||||||
|
- [ ] `STANDBY` 下无 `turn.audio.start`。
|
||||||
|
- [ ] `PLAYING_CLOUD_TTS` 与 `SEGMENT_END` 提示音阶段 **不开麦**(v1)。
|
||||||
|
- [ ] 每轮新 `turn_id`;不并行两轮音频上行。
|
||||||
|
- [ ] `flight_intent` 后进入 `FLIGHT_CONFIRM`,**不**误用 5s 闲聊超时。
|
||||||
|
- [ ] `chitchat` 在 TTS 完成后 **再滴** 再 `PROMPT_LISTEN`。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. 非目标(v1)
|
||||||
|
|
||||||
|
- 播报中抢话、打断 TTS、实时 re-prompt。
|
||||||
|
- 服务端驱动「滴声/待机」(均由机端规则实现)。
|
||||||
|
- 连续免唤醒「直接说指令」跨多轮(若需另开 v2)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. 修订记录
|
||||||
|
|
||||||
|
| 版本 | 日期 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| v1 | 2026-04-07 | 首版:小爱类会话 + 双端分工;不含 barge-in |
|
||||||
288
docs/DEPLOYMENT_AND_OPERATIONS.md
Normal file
288
docs/DEPLOYMENT_AND_OPERATIONS.md
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
# 部署与运维手册(项目总结)
|
||||||
|
|
||||||
|
本文面向 **生产/外场部署**:说明 **voice_drone_assistant** 是什么、与 **云端** / **ROS 伴飞桥** / **PX4** 如何衔接,以及 **推荐启动顺序**、**环境变量**与**常见问题**。
|
||||||
|
协议细节见 [`llmcon.md`](llmcon.md),通用配置索引见 [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md),伴飞桥行为见 [`FLIGHT_BRIDGE_ROS1.md`](FLIGHT_BRIDGE_ROS1.md),`flight_intent` 字段见 [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 项目总结
|
||||||
|
|
||||||
|
### 1.1 定位
|
||||||
|
|
||||||
|
**voice_drone_assistant** 是板端 **语音无人机助手**:麦克风 → 降噪/VAD → **SenseVoice STT** → **唤醒词** → 用户一句指令 → **云端 WebSocket**(LLM + TTS)或 **本地 Qwen + Kokoro** → 若服务端返回 **`flight_intent`**,可在本机 **校验后执行**(TCP Socket 旧路径,或 **ROS 伴飞桥** 推荐路径)。
|
||||||
|
|
||||||
|
### 1.2 推荐数据流(方案一:云 → 语音程序 → ROS 桥)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart LR
|
||||||
|
subgraph cloud [云端]
|
||||||
|
WS[WebSocket LLM+TTS]
|
||||||
|
end
|
||||||
|
subgraph board [机载 香橙派等]
|
||||||
|
MIC[麦克风]
|
||||||
|
MAIN[main.py TakeoffPrintRecognizer]
|
||||||
|
ROSPUB[子进程 publish JSON]
|
||||||
|
BRIDGE[flight_bridge ros1_node]
|
||||||
|
MAV[MAVROS]
|
||||||
|
end
|
||||||
|
FCU[PX4 飞控]
|
||||||
|
|
||||||
|
MIC --> MAIN
|
||||||
|
MAIN <-->|WSS pcm_asr_uplink + flight_intent| WS
|
||||||
|
MAIN -->|ROCKET_FLIGHT_INTENT_ROS_BRIDGE| ROSPUB
|
||||||
|
ROSPUB -->|std_msgs/String /input| BRIDGE
|
||||||
|
BRIDGE --> MAV --> FCU
|
||||||
|
```
|
||||||
|
|
||||||
|
- **不**把 ROS 直接暴露给公网:云端只连板子的 **WSS/WS**;飞控由 **本机 MAVROS + 伴飞桥** 执行。
|
||||||
|
- **TCP Socket**(`system.yaml` → `socket_server`)是另一条试飞控通道,与云端 **无关**;未起 Socket 服务端时仅会重连日志,不影响 ROS 方案。
|
||||||
|
|
||||||
|
### 1.3 目录与核心入口(仓库根 = `voice_drone_assistant/`)
|
||||||
|
|
||||||
|
| 路径 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `main.py` | 语音助手入口 |
|
||||||
|
| `with_system_alsa.sh` | 建议包装启动,修正 Conda 与系统 ALSA |
|
||||||
|
| `voice_drone/main_app.py` | 唤醒、云端/本地 LLM、TTS、`flight_intent` 执行策略 |
|
||||||
|
| `voice_drone/flight_bridge/ros1_node.py` | ROS1 订阅 `/input`,执行 `flight_intent` |
|
||||||
|
| `voice_drone/flight_bridge/ros1_mavros_executor.py` | MAVROS:offboard / AUTO.LAND / RTL |
|
||||||
|
| `voice_drone/tools/publish_flight_intent_ros_once.py` | 单次向 ROS 发布 JSON(主程序 ROS 桥会子进程调用) |
|
||||||
|
| `scripts/run_flight_bridge_with_mavros.sh` | 一键:roscore(可选)+ MAVROS + 伴飞桥 |
|
||||||
|
| `scripts/run_flight_intent_bridge_ros1.sh` | 仅伴飞桥(须已有 roscore + MAVROS) |
|
||||||
|
| `voice_drone/config/system.yaml` | 音频、STT、TTS、云端、`assistant` 等 |
|
||||||
|
| `requirements.txt` | Python 依赖;**rospy** 来自 `apt` 的 ROS Noetic,见文件内注释 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 环境与依赖
|
||||||
|
|
||||||
|
### 2.1 硬件与系统(典型)
|
||||||
|
|
||||||
|
- ARM64 板卡(如 RK3588)、ES8388 等音频编解码器、USB/内置麦克风。
|
||||||
|
- Ubuntu 20.04 + **ROS Noetic**(伴飞桥 / MAVROS 路径);同机运行语音进程与 `ros1_node`。
|
||||||
|
- 飞控串口(如 `/dev/ttyACM0`)与 MAVROS `fcu_url` 一致。
|
||||||
|
|
||||||
|
### 2.2 Python
|
||||||
|
|
||||||
|
- Python 3.10+(与原仓库一致即可)。
|
||||||
|
- 在 **`voice_drone_assistant`** 根目录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 ROS / MAVROS(伴飞桥方案必选)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo apt install ros-noetic-ros-base ros-noetic-mavros ros-noetic-mavros-extras
|
||||||
|
# 按官方文档执行 mavros 地理库安装(如有)
|
||||||
|
```
|
||||||
|
|
||||||
|
- 语音主程序的 **ROS 桥**子进程会 `source /opt/ros/noetic/setup.bash` 并 **prepend** `PYTHONPATH`,**不要**在未 source ROS 的 shell 里把 `PYTHONPATH` 设成「只有工程根」,否则会找不到 `rospy`(参见 `main_app` 中 `_publish_flight_intent_to_ros_bridge`)。
|
||||||
|
|
||||||
|
### 2.4 模型与权重
|
||||||
|
|
||||||
|
- STT / TTS /(可选)VAD 放入 `models/`,或 `bash scripts/bundle_for_device.sh` 从原仓库打包。
|
||||||
|
- 本地 LLM:GGUF 默认路径或 `ROCKET_LLM_GGUF`;**纯云端对话**时可弱化本地模型,但回退/混合模式仍需。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 部署拓扑
|
||||||
|
|
||||||
|
### 3.1 单机一体化(常见)
|
||||||
|
|
||||||
|
同一台香橙派上同时运行:
|
||||||
|
|
||||||
|
1. **roscore**(若尚无 master,由 `run_flight_bridge_with_mavros.sh` 拉起)。
|
||||||
|
2. **MAVROS**(`px4.launch`,串口连 PX4)。
|
||||||
|
3. **伴飞桥** `python3 -m voice_drone.flight_bridge.ros1_node`(订阅 **`/input`**)。
|
||||||
|
4. **语音** `bash with_system_alsa.sh python main.py`。
|
||||||
|
|
||||||
|
`ROS_MASTER_URI` / `ROS_HOSTNAME`:一键脚本内默认 `http://127.0.0.1:11311` 与 `127.0.0.1`;**新开调试终端** 执行 `rostopic`/`rosservice` 前须自行 `source /opt/ros/noetic/setup.bash` 并 export **同一** `ROS_MASTER_URI`(见下文「常见问题」)。
|
||||||
|
|
||||||
|
### 3.2 网络
|
||||||
|
|
||||||
|
- 板子能访问 **云端 WebSocket**(`ROCKET_CLOUD_WS_URL`)。
|
||||||
|
- PX4 + 遥控 + 安全开关等按外场规范配置;本文不替代安全检校清单。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 启动顺序(推荐)
|
||||||
|
|
||||||
|
### 4.1 终端 A:飞控栈 + 伴飞桥
|
||||||
|
|
||||||
|
在 **`voice_drone_assistant`** 根目录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /path/to/voice_drone_assistant
|
||||||
|
bash scripts/run_flight_bridge_with_mavros.sh /dev/ttyACM0 921600
|
||||||
|
```
|
||||||
|
|
||||||
|
脚本会:
|
||||||
|
|
||||||
|
- 设置 `ROS_MASTER_URI`、`ROS_HOSTNAME`(未预设时默认为本机 master);
|
||||||
|
- 如无 master 则启动 **roscore**;
|
||||||
|
- 启动 **MAVROS** 并等待 `/mavros/state` **connected**;
|
||||||
|
- 前台启动伴飞桥,日志中应出现:`flight_intent_bridge 就绪:订阅 /input`。
|
||||||
|
|
||||||
|
**仅桥(已有 MAVROS 时)**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source /opt/ros/noetic/setup.bash
|
||||||
|
export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}"
|
||||||
|
export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}"
|
||||||
|
bash scripts/run_flight_intent_bridge_ros1.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 终端 B:语音助手 + 云端 + 执行飞控
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /path/to/voice_drone_assistant
|
||||||
|
|
||||||
|
export ROCKET_CLOUD_VOICE=1
|
||||||
|
export ROCKET_CLOUD_WS_URL='ws://<云主机>:8766/v1/voice/session'
|
||||||
|
export ROCKET_CLOUD_AUTH_TOKEN='<token>'
|
||||||
|
export ROCKET_CLOUD_DEVICE_ID='drone-001' # 可选
|
||||||
|
|
||||||
|
# 云端返回 flight_intent 时是否在机端执行
|
||||||
|
export ROCKET_CLOUD_EXECUTE_FLIGHT=1
|
||||||
|
# 走 ROS 伴飞桥(与 Socket/offboard 序列互斥,勿双开重复执行)
|
||||||
|
export ROCKET_FLIGHT_INTENT_ROS_BRIDGE=1
|
||||||
|
# 可选:ROCKET_FLIGHT_BRIDGE_TOPIC=/input ROCKET_FLIGHT_BRIDGE_WAIT_SUB=2
|
||||||
|
|
||||||
|
# 默认关闭本地「起飞演示」口令直起 offboard;需要时再设为 1
|
||||||
|
# export ROCKET_LOCAL_KEYWORD_TAKEOFF=1
|
||||||
|
|
||||||
|
bash with_system_alsa.sh python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
成功时日志类似:`[飞控-ROS桥] 已发布至 /input`;伴飞桥端出现 `执行 flight_intent:steps=...`。
|
||||||
|
|
||||||
|
### 4.3 配置写进 YAML(可选)
|
||||||
|
|
||||||
|
- 云端:`system.yaml` → `cloud_voice`(`enabled`、`server_url`、`auth_token` 等)。
|
||||||
|
- 本地口令起飞:`assistant.local_keyword_takeoff_enabled`(默认 `false`);环境变量 `ROCKET_LOCAL_KEYWORD_TAKEOFF` **非空时优先生效**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 环境变量速查(飞控与云端)
|
||||||
|
|
||||||
|
| 变量 | 含义 |
|
||||||
|
|------|------|
|
||||||
|
| `ROCKET_CLOUD_VOICE` | `1`:对话走云端 WebSocket |
|
||||||
|
| `ROCKET_CLOUD_WS_URL` | 云端会话地址 |
|
||||||
|
| `ROCKET_CLOUD_AUTH_TOKEN` | WS 鉴权 |
|
||||||
|
| `ROCKET_CLOUD_DEVICE_ID` | 设备 ID(可选) |
|
||||||
|
| `ROCKET_CLOUD_EXECUTE_FLIGHT` | `1`:云端 `flight_intent` 在机端执行 |
|
||||||
|
| `ROCKET_FLIGHT_INTENT_ROS_BRIDGE` | `1`:执行方式为 **发布到 ROS `/input`**,不跑机内 Socket+offboard 序列 |
|
||||||
|
| `ROCKET_FLIGHT_BRIDGE_TOPIC` | 默认 `/input` |
|
||||||
|
| `ROCKET_FLIGHT_BRIDGE_SETUP` | 子进程内 source ROS 的命令,默认 `source /opt/ros/noetic/setup.bash` |
|
||||||
|
| `ROCKET_FLIGHT_BRIDGE_WAIT_SUB` | 发布前等待订阅者的秒数,默认 `2`;`0` 即尽可能快发 |
|
||||||
|
| `ROCKET_LOCAL_KEYWORD_TAKEOFF` | 非空时:`1/true/yes` 开启 **`keywords.yaml` takeoff → 本地 offboard** |
|
||||||
|
| `ROCKET_CLOUD_PX4_CONTEXT_FILE` | 覆盖 `cloud_voice.px4_context_file`,合并进 session.start |
|
||||||
|
|
||||||
|
更多调试变量见 **`voice_drone/main_app.py` 文件头注释** 与 [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md) 第 5 节。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 联调与自测
|
||||||
|
|
||||||
|
### 6.1 仅测 ROS 链(无语音)
|
||||||
|
|
||||||
|
终端已 `source /opt/ros/noetic/setup.bash` 且与 master 一致:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rostopic pub -1 /input std_msgs/String \
|
||||||
|
"data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"测\"}'"
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:`std_msgs/String` 在命令行里只能写 **`data: '...json...'`**,不能把 JSON 放在消息顶层。
|
||||||
|
|
||||||
|
### 6.2 确认话题与 master
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source /opt/ros/noetic/setup.bash
|
||||||
|
export ROS_MASTER_URI=http://127.0.0.1:11311
|
||||||
|
rosnode list
|
||||||
|
rostopic info /input
|
||||||
|
rosservice list | grep set_mode
|
||||||
|
```
|
||||||
|
|
||||||
|
若 `Unable to communicate with master!`:当前 shell 未连上正在运行的 **roscore**(或未 export 正确 `ROS_MASTER_URI`)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 常见问题(摘录)
|
||||||
|
|
||||||
|
| 现象 | 可能原因 | 处理 |
|
||||||
|
|------|----------|------|
|
||||||
|
| `ModuleNotFoundError: rospy` | 子进程未继承 ROS 的 `PYTHONPATH` | 已修复为 `PYTHONPATH=<根>:$PYTHONPATH`;确保 `ROCKET_FLIGHT_BRIDGE_SETUP` 能 source Noetic |
|
||||||
|
| 语音端「已发布」但桥无日志 | 曾用相对 `input`,与全局 `/input` 不一致 | 伴飞桥默认已改为订阅 **`/input`**;重启桥 |
|
||||||
|
| `set_mode unavailable` / land 失败 | OFFBOARD 断流、MAVROS 异常等 | 伴飞桥降落逻辑已带持续 setpoint + 重连 proxy;仍失败则查 `rosservice`、`/mavros/state`、链路 |
|
||||||
|
| takeoff 超时 | 未进 OFFBOARD、未解锁、定位未就绪 | 查地面站、`/mavros/state`、适当增大 `~takeoff_timeout_sec`(ROS 私有参数) |
|
||||||
|
| ALSA underrun | 播放与采集竞争 | 板端常见;可调缓冲区/设备或 `recognizer.ack_pause_mic_for_playback` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 安全与运维建议
|
||||||
|
|
||||||
|
- 外场前在 **SITL 或系留** 环境验证完整 **`flight_intent`** 序列。
|
||||||
|
- 云端 token、WS URL 勿提交到公开仓库;用环境变量或本机 **overlay** 配置注入。
|
||||||
|
- 升级伴飞桥或 MAVROS 后清日志重试一遍 **`/input`** 手发 JSON。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. 迁移到另一台香橙派:是否只拷贝 `voice_drone_assistant` 即可?
|
||||||
|
|
||||||
|
**结论:目录是「代码 + 配置」的核心载体,但仅靠「整文件夹 scp 过去」通常不够;新板必须再装系统级依赖、模型与(可选)ROS,并按现场改配置。**
|
||||||
|
|
||||||
|
### 9.1 拷贝目录本身会带上什么
|
||||||
|
|
||||||
|
| 已包含 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 全部 Python 源码、`voice_drone/config/*.yaml` 默认配置 | 可直接改 YAML / 环境变量适配新环境 |
|
||||||
|
| `scripts/`、`with_system_alsa.sh`、`docs/` | 启动与说明在包内 |
|
||||||
|
|
||||||
|
### 9.2 新板必须单独准备(不随目录自动存在)
|
||||||
|
|
||||||
|
| 项 | 说明 |
|
||||||
|
|----|------|
|
||||||
|
| **Ubuntu + 音频/ALSA** | 与当前开发板同代或自行适配;录音设备索引可能变化,需重选或设 `ROCKET_INPUT_DEVICE_INDEX` |
|
||||||
|
| **`pip install -r requirements.txt`** | 每台新 Python 环境执行一次(或整体迁移同一 conda 目录) |
|
||||||
|
| **`models/`** | STT/TTS/VAD 体积大,**务必**在本机先 `bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio` 或手工拷入,见 `models/README.txt` |
|
||||||
|
| **`cache/` GGUF** | 纯云端可不强依赖;若需本地 Qwen 回退,拷贝或设 `ROCKET_LLM_GGUF` |
|
||||||
|
| **ROS Noetic + MAVROS** | **apt** 安装;伴飞桥方案 **必选**;`rospy` **不要**指望只靠 pip |
|
||||||
|
| **云端连通** | 新板 IP/防火墙能访问 `ROCKET_CLOUD_WS_URL`;token 用环境变量注入 |
|
||||||
|
| **`dialout` 等权限** | 访问 `/dev/ttyACM0` 的用户加入 `dialout`,否则 MAVROS 无串口 |
|
||||||
|
| **`system.yaml` 现场差异** | `socket_server` IP、可选 `tts.output_device`、若麦索引固定可写 `audio.input_device_index` |
|
||||||
|
|
||||||
|
### 9.3 推荐迁移流程(简表)
|
||||||
|
|
||||||
|
1. 在旧机或 CI:**bundle 模型** → 打包整个 `voice_drone_assistant`(含 `models/`,按需含 `cache/`)。
|
||||||
|
2. 新香橙派:解压到任意路径,安装 **`requirements.txt`**、**ROS+MAVROS**、系统音频工具。
|
||||||
|
3. 用 **`with_system_alsa.sh python main.py`** 试麦与 STT;再按本文 **§4** 双终端起 **桥 + 语音**。
|
||||||
|
4. 首次外场前做一次 **`rostopic pub /input`** 手发 JSON(见 **§6**)。
|
||||||
|
|
||||||
|
### 9.4 常见误区
|
||||||
|
|
||||||
|
- **只拉 Git、不拷 `models/`**:STT/TTS 启动即失败。
|
||||||
|
- **新板 Noetic 未装却开 `ROCKET_FLIGHT_INTENT_ROS_BRIDGE`**:发布子进程仍可能报错。
|
||||||
|
- **假设麦克风设备号一定相同**:Orange Pi 刷机或换内核后常变,以首次启动日志为准。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. 文档索引
|
||||||
|
|
||||||
|
| 文档 | 用途 |
|
||||||
|
|------|------|
|
||||||
|
| [`README.md`](../README.md) | 仓库简介、模型、`bundle` |
|
||||||
|
| [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md) | 配置项与日常用法索引 |
|
||||||
|
| **本文** | 部署拓扑、启动顺序、环境变量、联调、**§9 迁移清单** |
|
||||||
|
| [`FLIGHT_BRIDGE_ROS1.md`](FLIGHT_BRIDGE_ROS1.md) | 伴飞桥参数、PX4 行为、`rostopic pub` 注意 |
|
||||||
|
| [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md) | JSON 协议 |
|
||||||
|
| [`llmcon.md`](llmcon.md) | 云端协议 |
|
||||||
|
| [`CLOUD_VOICE_FLIGHT_CONFIRM_v1.md`](CLOUD_VOICE_FLIGHT_CONFIRM_v1.md) | **飞控口头二次确认**(闲聊不变、确认/取消/超时)云端与机端字段约定 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*文档版本与仓库同步;若行为与代码不一致,以当前 `main_app.py`、`flight_bridge/*.py` 为准。*
|
||||||
88
docs/FLIGHT_BRIDGE_ROS1.md
Normal file
88
docs/FLIGHT_BRIDGE_ROS1.md
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# Flight Intent 伴飞桥(ROS 1 + MAVROS)
|
||||||
|
|
||||||
|
本目录代码与语音助手 **`main.py` 独立进程**:在 **MAVROS 已连接 PX4** 的前提下,订阅一条 JSON,按 **`FLIGHT_INTENT_SCHEMA_v1.md`** 顺序执行。
|
||||||
|
|
||||||
|
## 依赖
|
||||||
|
|
||||||
|
- Ubuntu / 设备上已装 **ROS Noetic**、`mavros`(与 `scripts/run_px4_offboard_one_terminal.sh` 一致)
|
||||||
|
- Python 能 import:`rospy`、`std_msgs`、`geometry_msgs`、`mavros_msgs`
|
||||||
|
- 本仓库根目录 **`voice_drone_assistant`** 需在 `PYTHONPATH`(启动脚本已设置)
|
||||||
|
|
||||||
|
## 启动
|
||||||
|
|
||||||
|
**推荐(不会单独敲 roslaunch 时用)**:一键拉起 roscore(若尚无)→ MAVROS → 伴飞桥:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd voice_drone_assistant
|
||||||
|
bash scripts/run_flight_bridge_with_mavros.sh
|
||||||
|
bash scripts/run_flight_bridge_with_mavros.sh /dev/ttyACM0 921600
|
||||||
|
```
|
||||||
|
|
||||||
|
**已有 MAVROS 时**只启桥:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd voice_drone_assistant
|
||||||
|
bash scripts/run_flight_intent_bridge_ros1.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
默认节点名(含 anonymous 后缀):`/flight_intent_mavros_bridge_<...>`,默认订阅 **全局 `/input`**(`std_msgs/String`,内容为 JSON)。私有参数 `~input_topic` 可改(例如专用名时填入完整话题)。桥启动日志会打印实际订阅名。
|
||||||
|
|
||||||
|
## JSON 格式
|
||||||
|
|
||||||
|
- **完整** `flight_intent`(与云端相同顶层字段),或
|
||||||
|
- **最小**:`{"actions":[...], "summary":"任意非空"}`(节点内会补 `is_flight_intent/version`)
|
||||||
|
|
||||||
|
校验失败会打 `rospy.logerr`,不执行。
|
||||||
|
|
||||||
|
## 行为映射(首版)
|
||||||
|
|
||||||
|
| `type` | 行为(简述) |
|
||||||
|
|--------|----------------|
|
||||||
|
| `takeoff` | Offboard 位姿:当前点预热 setpoint → `OFFBOARD` + arm → `z0 + Δz`(Δz 来自 `relative_altitude_m` 或参数 `~default_takeoff_relative_m`,与 `px4_ctrl_offboard_demo` 同号约定) |
|
||||||
|
| `hover` / `hold` | 当前位姿 hold 约 1s(持续发 setpoint) |
|
||||||
|
| `wait` | `rospy.Duration`,Offboard 时顺带维持当前点 setpoint |
|
||||||
|
| `goto` | `local_ned` / `body_ned` 增量 → 目标 NED 点,到位容差见 `~goto_position_tolerance` |
|
||||||
|
| `land` | `AUTO.LAND` |
|
||||||
|
| `return_home` | `AUTO.RTL` |
|
||||||
|
|
||||||
|
**注意**:真机前请 SITL 验证;不同 PX4/机型的 `custom_mode` 字符串若不一致需在 `ros1_mavros_executor.py` 中调整。
|
||||||
|
|
||||||
|
## 参数(私有命名空间)
|
||||||
|
|
||||||
|
| 参数 | 默认 | 含义 |
|
||||||
|
|------|------|------|
|
||||||
|
| `~input_topic` | `/input` | 订阅话题(建议绝对路径;勿再用相对名 `input`,否则与 `/input` 对不上) |
|
||||||
|
| `~default_takeoff_relative_m` | `0.5` | `takeoff` 无 `relative_altitude_m` 时 |
|
||||||
|
| `~takeoff_timeout_sec` | `15` | |
|
||||||
|
| `~goto_position_tolerance` | `0.15` | m |
|
||||||
|
| `~goto_timeout_sec` | `60` | |
|
||||||
|
| `~land_timeout_sec` | `45` | land/rtl 等待 disarm 超时 |
|
||||||
|
| `~offboard_pre_stream_count` | `80` | 与 demo 类似 |
|
||||||
|
|
||||||
|
## 与语音程序的关系
|
||||||
|
|
||||||
|
- **`main.py`**:仍可用 Socket / 本地 offboard **演示脚本**(产品过渡期)。
|
||||||
|
- **桥**:适合作为 **长期** MAVROS 执行端;后续可把语音侧改为 **向本节点 `input` 发布 JSON**(`rospy` 或 `roslibpy` 等),而不再直接起 bash demo。
|
||||||
|
|
||||||
|
## 与语音侧联调:`rostopic pub`(注意 `data:`)
|
||||||
|
|
||||||
|
`std_msgs/String` 的 YAML 只有字段 **`data`**,JSON 必须写在 **`data: '...'`** 里;不能把 JSON 直接当消息顶层(否则会 `ERROR: No field name [is_flight_intent]` / `Args are: [data]`)。
|
||||||
|
|
||||||
|
终端 A:已跑 `run_flight_bridge_with_mavros.sh`(或 MAVROS + 桥)。
|
||||||
|
终端 B:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source /opt/ros/noetic/setup.bash
|
||||||
|
# 订阅名以桥日志为准,常见为全局 /input
|
||||||
|
rostopic pub -1 /input std_msgs/String \
|
||||||
|
"data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"联调降落\"}'"
|
||||||
|
```
|
||||||
|
|
||||||
|
**悬停 2 秒再降**(须已在空中时再试):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rostopic pub -1 /input std_msgs/String \
|
||||||
|
"data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"hover\",\"args\":{}},{\"type\":\"wait\",\"args\":{\"seconds\":2}},{\"type\":\"land\",\"args\":{}}],\"summary\":\"测\"}'"
|
||||||
|
```
|
||||||
|
|
||||||
|
在语音进程内集成时,建议 **不要在 asyncio 里直接调 rospy**,用 **独立桥进程 + topic** 或 **UNIX socket 转发到桥**。
|
||||||
113
docs/FLIGHT_INTENT_IMPLEMENTATION_PLAN.md
Normal file
113
docs/FLIGHT_INTENT_IMPLEMENTATION_PLAN.md
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
# Flight Intent v1 + 伴飞桥 — 实施计划
|
||||||
|
|
||||||
|
本文档与 [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md) 配套,描述从协议闭环到 ROS/PX4 可控的**分阶段交付**。顺序建议按阶段 0→4;各阶段内任务可并行处已标注。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 目标与验收标准
|
||||||
|
|
||||||
|
| 维度 | 验收标准 |
|
||||||
|
|------|-----------|
|
||||||
|
| **协议** | 云端下发的 `flight_intent` 满足 v1:含 `wait`、`takeoff` 可选高度、`trace_id`;L1–L3 校验可自动化 |
|
||||||
|
| **语音客户端** | 能解析并记录完整 `actions`;在 `ROCKET_CLOUD_EXECUTE_FLIGHT=1` 时通过 Socket/桥 执行或与桥约定本地执行 `wait` |
|
||||||
|
| **桥** | 顺序执行 `actions`,每步有超时/失败策略;可对接 MAVROS(或既定 ROS 2 栈)驱动 PX4 |
|
||||||
|
| **安全** | 执行前 L4 门禁、执行中可中断、急停路径明确 |
|
||||||
|
| **回归** | SITL 或台架可重复跑通「起飞 → 悬停 → wait → 降落」等示例 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 阶段 0:对齐与基线(约 0.5~1 天)
|
||||||
|
|
||||||
|
- [ ] 全员精读 `FLIGHT_INTENT_SCHEMA_v1.md`,冻结 **v1 白名单**(`type` / `args` 键)。
|
||||||
|
- [ ] 确认伴飞侧技术选型:**ROS 2 + MAVROS**(或 `px4_ros_com`)与默认 **AUTO vs Offboard** 策略(写入桥 YAML,不写进 JSON)。
|
||||||
|
- [ ] 盘点现有 **Socket 服务**:是否即「桥」或仅转发;是否需新进程 `flight_intent_bridge`。
|
||||||
|
- [ ] 建立 **trace_id** 在日志中的格式(云端 / 语音 / 桥统一)。
|
||||||
|
|
||||||
|
**产出**:架构一页纸(谁消费 WebSocket、谁连 PX4)、桥配置模板路径约定。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 阶段 1:协议与云端(可与阶段 2 并行,约 2~4 天)
|
||||||
|
|
||||||
|
- [ ] **Schema 校验**:服务端对 `flight_intent` 做 L1–L3(必要时 L4 占位);非法则 `routing=error` 或产品协议兜底。
|
||||||
|
- [ ] **LLM 提示词**:只允许 §3.7 中 `type` 与允许键;强调 **时长必须用 `wait`**,禁止用 `summary` 控机。
|
||||||
|
- [ ] **示例与回归用例**:固定 JSON golden(§7.1~§7.3 + 边界:首步 `wait`、`seconds` 超界、多余 `args` 键)。
|
||||||
|
- [ ] **可选 `trace_id`**:服务端生成或在 bundle 层透传。
|
||||||
|
|
||||||
|
**产出**:校验测试集、提示词 MR、发布说明(对客户端可见的字段变更)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 阶段 2:语音客户端(`voice_drone_assistant`)(约 3~5 天)
|
||||||
|
|
||||||
|
可与阶段 1、3 部分并行。
|
||||||
|
|
||||||
|
- [x] **Pydantic**:`voice_drone/core/flight_intent.py`(v2)按 v1 文档收紧动作与 `args`。
|
||||||
|
- [x] **`parse_flight_intent_dict`**:等价 L1–L3 + 首步禁止 `wait`;白名单、`goto.frame`、`wait.seconds`、`takeoff.relative_altitude_m`。
|
||||||
|
- [x] **`main_app`**:`ROCKET_CLOUD_EXECUTE_FLIGHT=1` 时在后台线程 **`_run_cloud_flight_intent_sequence`** 顺序执行;`wait` 用 `time.sleep`;`goto` **单轴** 映射 Socket `Command`;`return_home` 已入 `Command`;**含 `takeoff` 的序列**在 offboard 完成后继续后续步(不再丢失)。
|
||||||
|
- [x] **日志**:序列开始时打印 `trace_id`;`takeoff` 打相对高度提示(offboard 是否消费须自行接参数)。
|
||||||
|
- [x] **单测**:`tests/test_flight_intent.py`(无完整依赖时 goto 用例自动 skip)。
|
||||||
|
|
||||||
|
**产出**:MR 合并后,本地无 PX4 也能跑通解析与 mock 执行。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 阶段 3:伴飞桥 + ROS/PX4(约 5~10 天,视现网复用程度)
|
||||||
|
|
||||||
|
- [x] **进程边界(首版)**:独立 ROS1 节点,订阅 `std_msgs/String` JSON;见 **`docs/FLIGHT_BRIDGE_ROS1.md`**、`scripts/run_flight_intent_bridge_ros1.sh`。
|
||||||
|
- [x] **执行器(首版)**:`voice_drone/flight_bridge/ros1_mavros_executor.py` 单线程顺序执行;`takeoff/goto` 带超时;`land/rtl` 等待 disarm 超时。
|
||||||
|
- [x] **翻译实现(首版 / MAVROS)**:
|
||||||
|
- `takeoff` / `hover` / `wait` / `goto`:`/mavros/setpoint_raw/local`(Offboard)+ `set_mode` / `arming`。
|
||||||
|
- `land` / `return_home`:`AUTO.LAND` / `AUTO.RTL`。
|
||||||
|
- [ ] **安全**:L4(电量、围栏、急停 topic);`wait` 中异常策略。
|
||||||
|
- [ ] **回执**:result topic / 与 `main.py` 的 topic 串联。
|
||||||
|
- [ ] **ROS2 / 仅 TCP 无 ROS**:按需另起接口。
|
||||||
|
|
||||||
|
**产出(当前)**:ROS1 桥可 `rostopic pub` 联调;**待** launch、与语音侧发布 JSON、SITL CI。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 阶段 4:联调、硬化与发布(约 3~7 天)
|
||||||
|
|
||||||
|
- [ ] **端到端**:真机或 SITL:语音 → 云 → 客户端 → 桥 → PX4,带 `trace_id` 串 log。
|
||||||
|
- [ ] **压测与失败注入**:断 WebSocket、桥崩溃重启、Offboard 丢失等(预期行为写进运维文档)。
|
||||||
|
- [ ] **配置与门禁**:默认关闭实飞执行;仅生产镜像打开;参数与围栏双人复核。
|
||||||
|
- [ ] **文档**:更新 `PROJECT_GUIDE.md` 中「飞控路径」链接到本文与 SCHEMA。
|
||||||
|
|
||||||
|
**产出**:发布 checklist、已知限制列表(如某机型仅支持 AUTO 等)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 依赖与风险
|
||||||
|
|
||||||
|
| 风险 | 缓解 |
|
||||||
|
|------|------|
|
||||||
|
| Socket 协议与 `Command` 无法表达多步 | **推荐**由桥消费**完整** `flight_intent` JSON,客户端只负责下发一份;少步经 Socket 逐条 |
|
||||||
|
| Offboard 与 AUTO 混用冲突 | 桥配置单一「主策略」;`goto` 仅在 Offboard 就绪时接受 |
|
||||||
|
| LLM 仍产出非法 JSON | L2 硬拒绝 + 提示词回归 + golden 测试 |
|
||||||
|
| 排期膨胀 | 先交付 **AUTO 模式族 + wait + land**,再迭代复杂 `goto` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 建议里程碑(日历为估算)
|
||||||
|
|
||||||
|
| 里程碑 | 内容 |
|
||||||
|
|--------|------|
|
||||||
|
| **M1** | 阶段 0–1 完成:云校验 + 提示词 + golden |
|
||||||
|
| **M2** | 阶段 2 完成:客户端 strict 模型 + `wait` + 执行路径单一数据源 |
|
||||||
|
| **M3** | 阶段 3 完成:桥 + SITL 跑通 §7.2 |
|
||||||
|
| **M4** | 阶段 4:联调签字 + 生产策略 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 文档索引
|
||||||
|
|
||||||
|
| 文档 | 用途 |
|
||||||
|
|------|------|
|
||||||
|
| [`FLIGHT_INTENT_SCHEMA_v1.md`](FLIGHT_INTENT_SCHEMA_v1.md) | 字段、校验、桥分层、ROS 参考 |
|
||||||
|
| [`PROJECT_GUIDE.md`](PROJECT_GUIDE.md) | 仓库总览与运行方式 |
|
||||||
|
| 本文 | 任务拆解、顺序、验收 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**版本**:2026-04-07;随 SCHEMA v1 修订同步更新本计划中的阶段勾选与工期估算。
|
||||||
372
docs/FLIGHT_INTENT_SCHEMA_v1.md
Normal file
372
docs/FLIGHT_INTENT_SCHEMA_v1.md
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
# 云端高层飞控意图 JSON 规范 v1(完整版)
|
||||||
|
|
||||||
|
> **定位**:定义 WebSocket `dialog_result.flight_intent` 中的**语义对象**与**伴飞侧执行约定**,不是 MAVLink 二进制帧。
|
||||||
|
> **目标**:
|
||||||
|
>
|
||||||
|
> 1. **协议**:客户端与云端可 **100% 按字段表 strict 解析**;**禁止**用自然语言或 `summary` 驱动机控。
|
||||||
|
> 2. **桥(companion)**:按本文执行 **有序 `actions`**、校验、排队与安全门,再译为 PX4 可接受的模式 / 指令 / Offboard setpoint。
|
||||||
|
> 3. **ROS**:为 MAVROS(或等价 ROS 2 封装)提供**参考映射表**;具体 topic/service 名称以装机软件栈为准。
|
||||||
|
|
||||||
|
**协议(bundle)**:`proto_version: "1.0"`,会话 `transport_profile: "pcm_asr_uplink"`(与机端 CloudVoiceClient 一致)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 顶层对象 `flight_intent`
|
||||||
|
|
||||||
|
当 `routing === "flight_intent"` 时,`flight_intent` **必须非 null**,且为下表 JSON 对象(键顺序任意)。
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `is_flight_intent` | `boolean` | 是 | **必须为 `true`** |
|
||||||
|
| `version` | `integer` | 是 | **Schema 版本,本文档固定为 `1`** |
|
||||||
|
| `actions` | `array` | 是 | **有序**动作列表,按**时间先后**执行(见 §5、§10) |
|
||||||
|
| `summary` | `string` | 是 | 一句人类可读中文摘要(播报/日志);**不参与机控解析** |
|
||||||
|
| `trace_id` | `string` | 否 | **端到端追踪 ID**(建议 UUID 或雪花 ID);用于桥、ROS 节点与日志关联;长度建议 ≤ 128 |
|
||||||
|
|
||||||
|
**禁止字段**:除上表外,顶层不得出现其它键(便于 `strict` 解析)。扩展须 **递增 `version`**(见 §8)。
|
||||||
|
|
||||||
|
**字符编码**:UTF-8。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. `actions[]` 通用形式
|
||||||
|
|
||||||
|
每个元素:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "<ActionType>", "args": { ... } }
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 |
|
||||||
|
|------|------|------|
|
||||||
|
| `type` | `string` | 是,取值限于 §3 枚举 |
|
||||||
|
| `args` | `object` | 是,允许为空对象 `{}`;**仅允许**各小节表中列出的键 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. `ActionType` 与白名单 `args`
|
||||||
|
|
||||||
|
以下 **`type` 仅允许小写**。未列出的值在 v1 **非法**。
|
||||||
|
|
||||||
|
### 3.1 `takeoff`
|
||||||
|
|
||||||
|
| args 键 | 类型 | 必填 | 说明 |
|
||||||
|
|---------|------|------|------|
|
||||||
|
| `relative_altitude_m` | `number` | 否 | 相对起飞点(或飞控定义的 TAKEOFF 参考)的**目标高度**,单位 **米**,须 **> 0**。省略则 **完全由机端/飞控缺省参数** 决定(与旧版「空 args」语义一致) |
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "takeoff", "args": {} }
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "takeoff", "args": { "relative_altitude_m": 5 } }
|
||||||
|
```
|
||||||
|
|
||||||
|
**桥 / ROS 映射提示**:PX4 `TAKEOFF` 模式、`MAV_CMD_NAV_TAKEOFF`;相对高度常与 `MIS_TAKEOFF_ALT` 或命令参数结合,**以装机参数为准**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.2 `land`
|
||||||
|
|
||||||
|
| args 键 | 类型 | 必填 | 说明 |
|
||||||
|
|---------|------|------|------|
|
||||||
|
| — | — | — | v1 无参;降落行为由飞控 `AUTO.LAND` 等策略决定 |
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "land", "args": {} }
|
||||||
|
```
|
||||||
|
|
||||||
|
**桥 / ROS 映射提示**:`AUTO.LAND`、`MAV_CMD_NAV_LAND`;固定翼 / 多旋翼路径不同,由机端处理。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.3 `return_home`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "return_home", "args": {} }
|
||||||
|
```
|
||||||
|
|
||||||
|
语义:**返航至 Home 并按飞控策略降落或盘旋**(与 PX4 **RTL** 概义一致)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.4 `hover` 与 `hold`
|
||||||
|
|
||||||
|
二者在 v1 **语义等价**:**在当前位置附近保持**(多旋翼常见为位置保持;固定翼可能映射为 Loiter,由机端按机型解释)。
|
||||||
|
|
||||||
|
| args 键 | 类型 | 必填 | 说明 |
|
||||||
|
|---------|------|------|------|
|
||||||
|
| — | — | — | 仅 `{}` 合法;表示进入保持,**不隐含**持续时间(时长用 §3.6 `wait`) |
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "hover", "args": {} }
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "hold", "args": {} }
|
||||||
|
```
|
||||||
|
|
||||||
|
**约定**:同一 `actions` 序列建议只选 `hover` 或 `hold` 一种命名;解析端可映射到同一 PX4 行为。
|
||||||
|
|
||||||
|
**互操作(非规范首选)**:个别上游可能错误输出 `"args": { "duration": 3 }`(秒)。**伴飞客户端**(如本仓库 `flight_intent.py`)可在校验时将其**折叠**为:`hover`(无 `duration`)+ `wait`,与上表典型组合等价;**新开发的上游仍应只产 `wait`**。
|
||||||
|
|
||||||
|
**典型组合**(「悬停 3 秒后降落」):
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{ "type": "takeoff", "args": {} },
|
||||||
|
{ "type": "hover", "args": {} },
|
||||||
|
{ "type": "wait", "args": { "seconds": 3 } },
|
||||||
|
{ "type": "land", "args": {} }
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.5 `goto` — 相对/局部位移
|
||||||
|
|
||||||
|
| args 键 | 类型 | 必填 | 说明 |
|
||||||
|
|---------|------|------|------|
|
||||||
|
| `frame` | `string` | 是 | 坐标系,取值见 §4 |
|
||||||
|
| `x` | `number` \| `null` | 否 | **米**;`null` 或 **省略** 表示该轴**无位移意图**(机端保持) |
|
||||||
|
| `y` | `number` \| `null` | 否 | 同上 |
|
||||||
|
| `z` | `number` \| `null` | 否 | 同上 |
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "goto",
|
||||||
|
"args": {
|
||||||
|
"frame": "local_ned",
|
||||||
|
"x": 100,
|
||||||
|
"y": 0,
|
||||||
|
"z": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**语义**:在 `frame` 下相对当前位置的**增量**。v1 **不**定义绝对经纬度航点;若未来需要,应 **v2** 增加 `goto_global` / `waypoint`。
|
||||||
|
|
||||||
|
**口语映射示例**:「向前飞 10 米」可建模为 `frame: "body_ned"`, `x: 10`(前为 x+,与 §4 一致)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.6 `wait` — 纯时间等待
|
||||||
|
|
||||||
|
**不包含**模式切换;桥在**本地计时**后继续下一步。用于「悬停多久」「停顿再执行」等。
|
||||||
|
|
||||||
|
| args 键 | 类型 | 必填 | 说明 |
|
||||||
|
|---------|------|------|------|
|
||||||
|
| `seconds` | `number` | 是 | 等待秒数,须满足 **0 < seconds ≤ 3600**(上限可防止 LLM 写极大值;产品可改小) |
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "type": "wait", "args": { "seconds": 3 } }
|
||||||
|
```
|
||||||
|
|
||||||
|
**安全**:等待期间桥须持续监测遥测(失联、低电量、姿态异常等),**可中断**序列并转入 `RTL` / `LAND` / `HOLD`(策略见 §10.4)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.7 v1 动作类型一览
|
||||||
|
|
||||||
|
| `type` | 必填 `args` 键 | 备注 |
|
||||||
|
|--------|----------------|------|
|
||||||
|
| `takeoff` | 无 | 可选 `relative_altitude_m` |
|
||||||
|
| `land` | 无 | |
|
||||||
|
| `return_home` | 无 | |
|
||||||
|
| `hover` | 无 | |
|
||||||
|
| `hold` | 无 | 与 `hover` 等价 |
|
||||||
|
| `goto` | `frame` | 可选 x/y/z |
|
||||||
|
| `wait` | `seconds` | |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. `frame`(仅 `goto`)
|
||||||
|
|
||||||
|
| 取值 | 含义 |
|
||||||
|
|------|------|
|
||||||
|
| `local_ned` | **局部 NED**:北(x)-东(y)-地(z),单位 m;**向下为 z 正**(与 PX4 `LOCAL_NED` 常见用法一致) |
|
||||||
|
| `body_ned` | **机体系**:**前(x)-右(y)-下(z)**,单位 m;桥或 ROS 侧需转换到 NED / setpoint |
|
||||||
|
|
||||||
|
**v1 仅此两值**;其它字符串 **L2 非法**。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 序列语义与组合规则
|
||||||
|
|
||||||
|
- **`actions` 有序**:严格对应口语**时间顺序**(先执行索引 0,再 1,…)。
|
||||||
|
- **空列表**:**不允许**(至少 1 个元素)。
|
||||||
|
- **`wait`**:不改变飞控模式;若需「边悬停边等」,应先 `hover`/`hold` 再 `wait`(或在上一步已进入位置模式的假定下仅 `wait`,由机端策略定义;**推荐**显式 `hover` 再 `wait`)。
|
||||||
|
- **首步**:首元素为 `wait` **不推荐**(飞机未起飞则等待无控飞意义);服务端可做 **L4 警告或拒绝**。
|
||||||
|
- **`takeoff` 后出现 `goto`**:桥应确保已有位置估计/GPS 等前置条件,否则拒绝并回报原因。
|
||||||
|
- **重复动作**:不禁止连续多个 `goto` / `wait`;机端可合并或排队。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 校验分级(服务端 + 桥建议共用)
|
||||||
|
|
||||||
|
| 级别 | 内容 |
|
||||||
|
|------|------|
|
||||||
|
| **L1 结构** | JSON 可解析;`is_flight_intent===true`;`version===1`;`actions` 为非空数组;`summary` 为非空字符串;`trace_id` 若存在则为 string。 |
|
||||||
|
| **L2 枚举** | 每个 `action.type` ∈ §3.7;`goto` 含合法 `frame`;各 `args` **仅含**该 `type` 允许的键(无多余键)。 |
|
||||||
|
| **L3 数值** | `relative_altitude_m` 若存在则 **> 0** 且建议 capped(如 ≤ 500);`wait.seconds` 在 **(0, 3600]**;`goto` 的 x/y/z 为有限 number 或 null;位移模长可设上限(如 10e3 m)。 |
|
||||||
|
| **L4 语义** | 结合 `session.start.client.px4`(机型、是否支持 Offboard、地理围栏等);禁止不合法序列(如无定位时 `goto`);不通过则 `error` 或带工程约定 `warnings`(v2 可标准化 `warnings` 数组)。 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 完整示例
|
||||||
|
|
||||||
|
### 7.1 起飞 → 北飞 100m → 悬停
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"is_flight_intent": true,
|
||||||
|
"version": 1,
|
||||||
|
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"actions": [
|
||||||
|
{ "type": "takeoff", "args": {} },
|
||||||
|
{
|
||||||
|
"type": "goto",
|
||||||
|
"args": { "frame": "local_ned", "x": 100, "y": 0, "z": 0 }
|
||||||
|
},
|
||||||
|
{ "type": "hover", "args": {} }
|
||||||
|
],
|
||||||
|
"summary": "起飞后向北飞约100米并悬停"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.2 起飞 → 悬停 3 秒 → 降落
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"is_flight_intent": true,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [
|
||||||
|
{ "type": "takeoff", "args": { "relative_altitude_m": 3 } },
|
||||||
|
{ "type": "hover", "args": {} },
|
||||||
|
{ "type": "wait", "args": { "seconds": 3 } },
|
||||||
|
{ "type": "land", "args": {} }
|
||||||
|
],
|
||||||
|
"summary": "起飞至约3米高,悬停3秒后降落"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.3 返航
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"is_flight_intent": true,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [
|
||||||
|
{ "type": "return_home", "args": {} }
|
||||||
|
],
|
||||||
|
"summary": "返航至 Home"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 演进与扩展
|
||||||
|
|
||||||
|
- **新增 `type`、新 `frame`、顶层字段**:须 **递增 `version`**(如 `2`)并附迁移说明。
|
||||||
|
- **严禁**:在 `flight_intent` 内增加自由文本字段用于机动解释(仅 `summary` 可读)。
|
||||||
|
- **调试**:可在外层 bundle(非 `flight_intent` 体内)附加 `schema: "cloud_voice.flight_intent@1"`,由工程约定。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. JSON → PX4 责任边界(摘要)
|
||||||
|
|
||||||
|
| JSON `type` | 机端典型职责(PX4 侧,非规范强制) |
|
||||||
|
|-------------|--------------------------------------|
|
||||||
|
| `return_home` | RTL / `MAV_CMD_NAV_RETURN_TO_LAUNCH` 等 |
|
||||||
|
| `takeoff` | TAKEOFF / `MAV_CMD_NAV_TAKEOFF`,高度来自 args 或参数 |
|
||||||
|
| `land` | LAND 模式 / `MAV_CMD_NAV_LAND` |
|
||||||
|
| `goto` | Offboard 轨迹、外部跟踪或 Mission 航点(**桥根据策略选一**) |
|
||||||
|
| `hover` / `hold` | LOITER / HOLD / 位置保持 setpoint |
|
||||||
|
| `wait` | 仅伴飞计时;**不发**模式切换 MAV 命令(除非实现为「保持当前模式下的阻塞」) |
|
||||||
|
|
||||||
|
**不重样规定**:MAVLink messageId、发送频率、Offboard 心跳、EKF 就绪条件由 **companion + PX4 装机** 保证。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. 伴飞桥(Bridge)设计要点
|
||||||
|
|
||||||
|
本节约定:**桥** = 运行在伴飞计算机上的进程(可与语音同源或独立),负责消费 `flight_intent`(或等价 JSON),**绝不**把原始 LLM 文本直接发给 PX4。
|
||||||
|
|
||||||
|
### 10.1 逻辑分层
|
||||||
|
|
||||||
|
1. **接入**:WebSocket 回调 → 解析 `flight_intent`;或订阅 ROS Topic `flight_intent/json`;或 TCP 接收与本文相同的 JSON。
|
||||||
|
2. **校验**:至少 L1–L3;有 px4 上下文时做 L4。
|
||||||
|
3. **执行器**:对 `actions` **单线程顺序**执行;内部每步调用 **翻译器**(见 §10.3)。
|
||||||
|
4. **遥测与安全**:每步前置检查(模式、解锁、定位、电量、围栏);执行中 watchdog;可打断队列。
|
||||||
|
5. **回执(建议)**:ROS 发布 `flight_intent/result` 或写日志/Socket:success / rejected / aborted + `trace_id` + 步号 + 原因码。
|
||||||
|
|
||||||
|
### 10.2 与语音客户端的关系(本仓库)
|
||||||
|
|
||||||
|
- 语音侧可将 **`flight_intent` 映射** 为现有 `Command`(`command` + `params` + `sequence_id` + `timestamp`)经 **Socket** 发到桥;或由桥 **直接订阅云端结果**(二选一,避免双源)。
|
||||||
|
- **`wait`**:若 Socket 协议暂无对应 `Command`,桥在本地对「已解析的 `actions` 列表」执行 `wait`,**不必**经 Socket 转发计时。
|
||||||
|
- **扩展 `Command`**:若希望所有步骤可经 Socket 观测,可增加 `command: "noop"` + `params.duration` 仅作日志,但 **推荐** 桥本地处理 `wait`。
|
||||||
|
|
||||||
|
### 10.3 翻译器(`type` → 行为)
|
||||||
|
|
||||||
|
实现为代码表 + 机型分支,示例:
|
||||||
|
|
||||||
|
| `type` | 桥内典型步骤(抽象) |
|
||||||
|
|--------|----------------------|
|
||||||
|
| `takeoff` | 检查 arming 策略 → 发送起飞命令/切 TAKEOFF → 等待「达到 hover 可接受高度」或超时 |
|
||||||
|
| `land` | 切 LAND 或发 NAV_LAND → 监测直到 disarm 或超时 |
|
||||||
|
| `return_home` | 切 RTL |
|
||||||
|
| `hover`/`hold` | 切 AUTO.LOITER 或发位置保持 setpoint(Offboard 路径则发零速/当前位 setpoint) |
|
||||||
|
| `goto` | 按 `frame` 解算目标 → Offboard 轨迹或上传迷你 mission → 等待到达容差或超时 |
|
||||||
|
| `wait` | `sleep(seconds)` + 可中断环形检查遥测 |
|
||||||
|
|
||||||
|
每步应定义 **超时** 与 **失败策略**(中止整段序列 / 仅跳过一步)。
|
||||||
|
|
||||||
|
### 10.4 安全与中断
|
||||||
|
|
||||||
|
- **急停 / 人机优先级**:本地硬件或 ROS `/emergency_hold` 等应能 **清空队列** 并进入安全模式。
|
||||||
|
- **云断连**:不要求中断已在执行的序列(产品可配置「断连即 RTL」)。
|
||||||
|
- **`wait` 期间**:持续判据;触发阈值则 **中止等待** 并执行安全动作。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. ROS / MAVROS 实施参考
|
||||||
|
|
||||||
|
以下为方便对接 **ROS 2 + MAVROS**(或 `px4_ros_com`)的**参考映射**;实际包名、话题名、QoS 以你方 `mavros` 版本与 launch 为准。
|
||||||
|
|
||||||
|
### 11.1 常用接口类型
|
||||||
|
|
||||||
|
| 目的 | 常见 ROS 2 形态 | 说明 |
|
||||||
|
|------|------------------|------|
|
||||||
|
| 模式切换 | `mavros_msgs/srv/VehicleCmd` 或 SetMode 等价服务 | 切 `AUTO.TAKEOFF`, `AUTO.LAND`, `AUTO.LOITER`, `AUTO.RTL` 等 |
|
||||||
|
| 解锁/上锁 | `cmd/arming` 服务或 VehicleCommand | 桥策略决定是否自动 arm |
|
||||||
|
| Offboard 轨迹 | `trajectory_setpoint`、`offboard_control_mode`(PX4 官方 ROS 2 示例) | 用于 `goto` / `hover` 的 setpoint 路径 |
|
||||||
|
| 状态反馈 | `vehicle_status`、`local_position`、电池 topic | L4 与每步完成判定 |
|
||||||
|
| 长航指令 | `Mission`、`CMD` 接口 | 复杂航迹可选用 mission 上传 |
|
||||||
|
|
||||||
|
### 11.2 JSON → ROS 责任划分建议
|
||||||
|
|
||||||
|
- **桥节点**订阅或接收 `flight_intent`,执行 §10.3,并调用 **MAVROS / px4_ros_com** 客户端。
|
||||||
|
- **飞控仿真**:同一套 `flight_intent` 可在 SITL 上回放,便于 CI。
|
||||||
|
- **单飞控单 writer**:同一时刻建议只有一个节点向 Offboard 端口写 setpoint,避免竞争。
|
||||||
|
|
||||||
|
### 11.3 与 PX4 模式的关系(概念)
|
||||||
|
|
||||||
|
- **AUTO 模式族**(TAKEOFF / LAND / LOITER / RTL):适合 `takeoff`、`land`、`return_home`、部分 `hover`。
|
||||||
|
- **Offboard**:适合连续 `goto`、精细悬停;桥需负责 **先切 Offboard 再发 setpoint**,并满足 PX4 Offboard 丢包监测。
|
||||||
|
- 具体选 AUTO 还是 Offboard 由 **桥配置**(YAML)决定,**不写入** `flight_intent` JSON(保持云侧与机型解耦)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. 与当前仓库实现的对齐清单
|
||||||
|
|
||||||
|
| 项 | 建议 |
|
||||||
|
|----|------|
|
||||||
|
| Pydantic | `FlightIntentPayload` / `FlightIntentAction`:收紧 `type` Literal;`args` 按 §3 分类型或 discriminated union |
|
||||||
|
| 云端校验 | `_validate_flight_intent`:L2 白名单 + `goto.frame` + `wait.seconds` + `takeoff.relative_altitude_m` |
|
||||||
|
| LLM 提示词 | 仅允许 §3.7 中 `type` 与各 `args` 键;**必须**用 `wait` 表达明确停顿时长 |
|
||||||
|
| `main_app` | `land`/`hover` 已有 Socket 映射;`goto`/`return_home`/`takeoff`/`wait` 需在桥或 Socket 侧补全 |
|
||||||
|
| `Command` | 可扩展 `Literal` 与 `CommandParams`,或与桥约定「语音只发 Socket,复杂序列由桥执行」 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**文档版本**:2026-04-07(修订:增加 `wait`、`takeoff` 可选高度、`trace_id`、桥与 ROS 章节)。与 **`flight_intent.version === 1`** 对应。
|
||||||
148
docs/PROJECT_GUIDE.md
Normal file
148
docs/PROJECT_GUIDE.md
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# voice_drone_assistant — 项目说明与配置指南
|
||||||
|
|
||||||
|
面向部署与二次开发:**目录结构**、**配置文件用法**、**启动与日常操作**、**与云端/飞控的关系**。**外场统一部署与双终端启动顺序**见 **`docs/DEPLOYMENT_AND_OPERATIONS.md`**;协议细节以 `docs/llmcon.md` 为准。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 项目做什么
|
||||||
|
|
||||||
|
- **麦克风** → 预处理(降噪/AGC)→ **VAD 切段** → **SenseVoice STT** → **唤醒词**
|
||||||
|
- **关键词起飞(offboard 演示,默认关闭)**:`system.yaml` → **`assistant.local_keyword_takeoff_enabled`** 或 **`ROCKET_LOCAL_KEYWORD_TAKEOFF=1`** 开启后,`keywords.yaml` 里 **`takeoff` 词表**(如「起飞演示」)→ 提示音 + offboard 脚本;飞控主路径推荐 **云端 `flight_intent` + ROS 伴飞桥**
|
||||||
|
- **其它语音**:本地 **Qwen + Kokoro**,或 **云端 WebSocket**(LLM + TTS 上云,见 `cloud_voice`)
|
||||||
|
- 可选通过 **TCP Socket** 下发结构化飞控命令(`VoiceCommandRecognizer` 路径;`TakeoffPrintRecognizer` 默认不在启动时连 Socket,飞控多为云端 JSON + 可选 `ROCKET_CLOUD_EXECUTE_FLIGHT`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 目录结构(仓库根 = `voice_drone_assistant/`)
|
||||||
|
|
||||||
|
| 路径 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `main.py` | 程序入口(会 `chdir` 到本目录并跑 `voice_drone.main_app`) |
|
||||||
|
| `with_system_alsa.sh` | 在 Conda/残缺 ALSA 环境下修正 `LD_LIBRARY_PATH`,建议始终包一层启动 |
|
||||||
|
| `requirements.txt` | Python 依赖(含 `websocket-client` 等) |
|
||||||
|
| **`voice_drone/main_app.py`** | 主流程:唤醒、问候/快路径、关麦、LLM/云端、TTS、offboard |
|
||||||
|
| **`voice_drone/core/`** | 音频采集、预处理、VAD、STT、TTS、Socket、云端 WS、唤醒、命令、文本预处理、配置加载 |
|
||||||
|
| **`voice_drone/flight_bridge/`** | 伴飞桥(ROS1+MAVROS):`flight_intent` → 飞控;说明见 **`docs/FLIGHT_BRIDGE_ROS1.md`** |
|
||||||
|
| **`voice_drone/config/`** | 各类 YAML,见下文「配置文件」 |
|
||||||
|
| **`voice_drone/logging_/`** | 日志与彩色输出 |
|
||||||
|
| **`voice_drone/tools/`** | `config_loader` 等工具 |
|
||||||
|
| **`docs/`** | `PROJECT_GUIDE.md`(本文)、`llmcon.md`(云端协议)、`clientguide.md`(联调与示例) |
|
||||||
|
| **`scripts/`** | `run_px4_offboard_one_terminal.sh`;**伴飞桥** `run_flight_bridge_with_mavros.sh`(含 MAVROS)、`run_flight_intent_bridge_ros1.sh`(仅桥);另有 `generate_wake_greeting_wav.py`、`bundle_for_device.sh` |
|
||||||
|
| **`assets/tts_cache/`** | 唤醒问候等预生成 WAV(可自动生成) |
|
||||||
|
| **`models/`** | STT/TTS/VAD ONNX 等(需自备或 bundle,见 `models/README.txt`) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 配置文件一览
|
||||||
|
|
||||||
|
配置由 `voice_drone/core/configuration.py` 在进程启动时读入;主文件为 **`voice_drone/config/system.yaml`**(路径相对 **`voice_drone_assistant` 根目录**)。
|
||||||
|
|
||||||
|
| 文件 | 作用 |
|
||||||
|
|------|------|
|
||||||
|
| **`system.yaml`** | **总控**:`audio`(采样、设备、AGC、降噪)、`vad`、`stt`、`tts`、`cloud_voice`、`socket_server`、`text_preprocessor`、`recognizer`(VAD 能量门槛、尾静音、问候/TTS 关麦等) |
|
||||||
|
| **`wake_word.yaml`** | 唤醒词主词、变体、模糊/部分匹配策略 |
|
||||||
|
| **`keywords.yaml`** | 命令关键词与同义词(供文本预处理映射到 `Command`) |
|
||||||
|
| **`command_.yaml`** | 各飞行动作默认 `distance/speed/duration`(与 `Command` 联动) |
|
||||||
|
| **`cloud_voice_px4_context.yaml`** | 云端 **`session.start.client` 扩展**:`vehicle_class`、`mav_type`、`default_setpoint_frame`、`extras` 等,供服务端 LLM 生成 PX4 相关指令;路径在 `system.yaml` → `cloud_voice.px4_context_file`,也可用环境变量 **`ROCKET_CLOUD_PX4_CONTEXT_FILE`** 覆盖 |
|
||||||
|
|
||||||
|
修改 YAML 后需**重启** `main.py` 生效(`SYSTEM_CLOUD_VOICE_PX4_CONTEXT` 等在 import 时加载一次)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. `system.yaml` 常用区块(索引)
|
||||||
|
|
||||||
|
- **`audio`**:采样率、`frame_size`、`input_device_index`(`null` 则枚举设备)、`prefer_stereo_capture`(ES8388 等)、`noise_reduce`、`agc*`、`agc_release_alpha`
|
||||||
|
- **`vad`**:Silero 用阈值、`end_frame` 等(能量 VAD 时部分由 `recognizer` 覆盖)
|
||||||
|
- **`stt`**:SenseVoice 模型路径、ORT 线程等
|
||||||
|
- **`tts`**:Kokoro 目录、音色 `voice`、`speed`、`output_device`、`playback_*`
|
||||||
|
- **`cloud_voice`**:`enabled`、`server_url`、`auth_token`、`device_id`、`timeout`、`fallback_to_local`、`px4_context_file`
|
||||||
|
- **`socket_server`**:试飞控 TCP 地址、`reconnect_interval`、`max_retries`(`-1` 为断线持续重连直至成功)
|
||||||
|
- **`recognizer`**:`trailing_silence_seconds`、`vad_backend`(`energy`/`silero`)、`energy_vad_*`、`energy_vad_utt_peak_decay`、`energy_vad_end_peak_ratio`、`pre_speech_max_seconds`、`ack_pause_mic_for_playback`、应答 TTS 等
|
||||||
|
|
||||||
|
更细的参数含义以各 YAML 内注释为准。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 系统使用方式
|
||||||
|
|
||||||
|
### 5.1 推荐启动命令
|
||||||
|
|
||||||
|
在 **`voice_drone_assistant` 根目录**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash with_system_alsa.sh python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或使用模块方式:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash with_system_alsa.sh python -m voice_drone.main_app
|
||||||
|
```
|
||||||
|
|
||||||
|
录音设备:**首次**可交互选择;非交互时可设 `ROCKET_INPUT_DEVICE_INDEX` 或使用 `main.py --input-index N` / `--non-interactive`(详见 `main_app` 内 `argparse` 与文件头注释)。
|
||||||
|
|
||||||
|
### 5.2 典型工作流(默认 `TakeoffPrintRecognizer`)
|
||||||
|
|
||||||
|
1. 说唤醒词(如「无人机」);若**同句带指令**,会**跳过问候与滴声**,直接关麦处理:命中 `keywords.yaml` 的 **takeoff** 则 offboard,否则走 LLM/云端。
|
||||||
|
2. 若**只唤醒**,则问候(或缓存 WAV)+ 可选滴声 → 再说**一句**指令。
|
||||||
|
3. 云端模式:指令以文本上云,TTS 多为服务端 PCM;本地模式:Qwen 推理 + Kokoro 播报。
|
||||||
|
|
||||||
|
### 5.3 云端语音(可选)
|
||||||
|
|
||||||
|
- `system.yaml` 里 `cloud_voice.enabled: true`,或环境变量 **`ROCKET_CLOUD_VOICE=1`**
|
||||||
|
- **`ROCKET_CLOUD_WS_URL`**、`ROCKET_CLOUD_AUTH_TOKEN`、可选 **`ROCKET_CLOUD_DEVICE_ID`**(可覆盖 yaml)
|
||||||
|
- PX4 语境:见 `cloud_voice_px4_context.yaml` / **`ROCKET_CLOUD_PX4_CONTEXT_FILE`**
|
||||||
|
- 协议与消息类型: **`docs/llmcon.md`**
|
||||||
|
- 飞控 JSON 是否机端执行: **`ROCKET_CLOUD_EXECUTE_FLIGHT=1`**;走 ROS 伴飞桥时再设 **`ROCKET_FLIGHT_INTENT_ROS_BRIDGE=1`**(详见 **`docs/DEPLOYMENT_AND_OPERATIONS.md`**)
|
||||||
|
|
||||||
|
### 5.4 本地大模型与 TTS
|
||||||
|
|
||||||
|
- GGUF:`cache/` 默认路径或 **`ROCKET_LLM_GGUF`**
|
||||||
|
- 关闭对话:**`ROCKET_LLM_DISABLE=1`**
|
||||||
|
- 流式输出:**`ROCKET_LLM_STREAM=0`** 可改为整段生成后再播(调试)
|
||||||
|
- 详细列表见 **`voice_drone/main_app.py` 文件头部注释**。
|
||||||
|
|
||||||
|
### 5.5 其它实用环境变量(摘录)
|
||||||
|
|
||||||
|
| 变量 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `ROCKET_ENERGY_VAD` | `1` 时使用能量 VAD(板载麦常见) |
|
||||||
|
| `ROCKET_PRINT_STT` / `ROCKET_PRINT_VAD` | 终端打印 STT/VAD 诊断 |
|
||||||
|
| `ROCKET_CLOUD_TURN_RETRIES` | 云端 WS 单轮失败重连重试次数(默认 3) |
|
||||||
|
| `ROCKET_PRINT_LLM_STREAM` | 云端流式字 `llm.text_delta` 打印到终端 |
|
||||||
|
| `ROCKET_WAKE_PROMPT_BEEP` | `0` 关闭问候后滴声 |
|
||||||
|
| `ROCKET_MIC_RESTART_SETTLE_MS` | 播完 TTS 恢复麦克风后的等待毫秒 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 相关文档与代码入口
|
||||||
|
|
||||||
|
| 文档 | 内容 |
|
||||||
|
|------|------|
|
||||||
|
| **`README.md`** | 简版说明、bundle 到香橙派、与原仓库关系 |
|
||||||
|
| **`docs/DEPLOYMENT_AND_OPERATIONS.md`** | **部署与外场启动**:拓扑、`ROS_MASTER_URI`、双终端启动顺序、环境变量速查、联调 |
|
||||||
|
| **`docs/PROJECT_GUIDE.md`** | 本文:目录、配置、使用方式总览 |
|
||||||
|
| **`docs/FLIGHT_BRIDGE_ROS1.md`** | ROS1 伴飞桥、MAVROS、`/input`、`rostopic pub` |
|
||||||
|
| **`docs/llmcon.md`** | 云端 WebSocket 消息类型与客户端约定 |
|
||||||
|
| **`docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md`** | 云端 **`dialog_result` v1**(`protocol=cloud_voice_dialog_v1`,闲聊/飞控分流 + `confirm`) |
|
||||||
|
|
||||||
|
| 能力 | 主要代码 |
|
||||||
|
|------|-----------|
|
||||||
|
| 音频采集/AGC | `voice_drone/core/audio.py` |
|
||||||
|
| 能量/Silero VAD | `voice_drone/core/recognizer.py`、`voice_drone/core/vad.py` |
|
||||||
|
| STT | `voice_drone/core/stt.py` |
|
||||||
|
| 本地 LLM 提示词 | `voice_drone/core/qwen_intent_chat.py`(`FLIGHT_INTENT_CHAT_SYSTEM`) |
|
||||||
|
| 云端会话 | `voice_drone/core/cloud_voice_client.py` |
|
||||||
|
| 主流程 | `voice_drone/main_app.py` |
|
||||||
|
| 配置聚合 | `voice_drone/core/configuration.py` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 版本与维护
|
||||||
|
|
||||||
|
- 配置项会随功能迭代增加;若与运行日志或 `llmcon` 不一致,以**当前仓库 YAML + 代码**为准。
|
||||||
|
- 新增仅与云端相关的字段时,请同时通知服务端解析 **`session.start.client`**(含 PX4 扩展块)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*文档版本:与仓库同步维护;更新日期见 Git 提交。*
|
||||||
29
main.py
Normal file
29
main.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""入口:请在工程根目录 voice_drone_assistant 下运行。
|
||||||
|
|
||||||
|
bash with_system_alsa.sh python main.py
|
||||||
|
|
||||||
|
或使用包方式:python -m voice_drone.main_app(需先 cd 到本目录)。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parent
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
try:
|
||||||
|
os.chdir(ROOT)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio
|
||||||
|
|
||||||
|
fix_ld_path_for_portaudio()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from voice_drone.main_app import main
|
||||||
|
|
||||||
|
main()
|
||||||
25
requirements.txt
Normal file
25
requirements.txt
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# 语音助手最小依赖(从大工程 requirements 精简,若 import 报错再补包)
|
||||||
|
numpy>=1.21.0
|
||||||
|
scipy>=1.10.0
|
||||||
|
onnx>=1.14.0
|
||||||
|
onnxruntime>=1.16.0
|
||||||
|
librosa>=0.10.0
|
||||||
|
soundfile>=0.12.0
|
||||||
|
pyaudio>=0.2.14
|
||||||
|
noisereduce>=2.0.0
|
||||||
|
sounddevice>=0.4.6
|
||||||
|
pyyaml>=5.4.0
|
||||||
|
jieba>=0.42.1
|
||||||
|
pypinyin>=0.50.0
|
||||||
|
opencc-python-reimplemented>=0.1.7
|
||||||
|
cn2an>=0.5.0
|
||||||
|
misaki[zh]>=0.8.2
|
||||||
|
pydantic>=2.4.0
|
||||||
|
# 大模型(Qwen GGUF)
|
||||||
|
llama-cpp-python>=0.2.0
|
||||||
|
# 云端语音 WebSocket(websocket-client)
|
||||||
|
websocket-client>=1.6.0
|
||||||
|
|
||||||
|
# ROS1:publish_flight_intent_ros_once / flight_bridge 需 rospy(std_msgs)。Noetic 一般 apt 安装,勿用 pip 覆盖:
|
||||||
|
# sudo apt install ros-noetic-ros-base # 或桌面版,须含 rospy
|
||||||
|
# 使用 conda/venv 时请在启动前 source /opt/ros/noetic/setup.bash,且保持 PYTHONPATH 含上述 dist-packages(主程序 ROS 桥子进程已把工程根 prepend 到 $PYTHONPATH)。
|
||||||
54
scripts/bundle_for_device.sh
Normal file
54
scripts/bundle_for_device.sh
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# 在本机把「原仓库」里的 models(及可选 GGUF 缓存)拷进本子工程,便于整包 scp/rsync 到另一台香橙派。
|
||||||
|
#
|
||||||
|
# 用法(在 voice_drone_assistant 根目录):
|
||||||
|
# bash scripts/bundle_for_device.sh
|
||||||
|
# bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio
|
||||||
|
#
|
||||||
|
# 默认上一级目录为原仓库(即 voice_drone_assistant 仍放在 rocket_drone_audio 子目录时的布局)。
|
||||||
|
set -euo pipefail
|
||||||
|
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
SRC="${1:-$ROOT/..}"
|
||||||
|
M="$ROOT/models"
|
||||||
|
|
||||||
|
if [[ ! -d "$SRC/src/models" ]]; then
|
||||||
|
echo "未找到 $SRC/src/models ,请传入正确的原仓库根路径:" >&2
|
||||||
|
echo " bash scripts/bundle_for_device.sh /path/to/rocket_drone_audio" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "$M"
|
||||||
|
echo "源目录: $SRC"
|
||||||
|
echo "目标 models: $M"
|
||||||
|
|
||||||
|
copy_dir() {
|
||||||
|
local name="$1"
|
||||||
|
if [[ -d "$SRC/src/models/$name" ]]; then
|
||||||
|
echo " 复制 models/$name ..."
|
||||||
|
rm -rf "$M/$name"
|
||||||
|
cp -a "$SRC/src/models/$name" "$M/"
|
||||||
|
else
|
||||||
|
echo " 跳过(不存在): src/models/$name"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_dir "SenseVoiceSmall"
|
||||||
|
copy_dir "Kokoro-82M-v1.1-zh-ONNX"
|
||||||
|
copy_dir "SileroVad"
|
||||||
|
|
||||||
|
# 可选:大模型 GGUF(体积大,按需)
|
||||||
|
if [[ -d "$SRC/cache/qwen25-1.5b-gguf" ]]; then
|
||||||
|
read -r -p "是否复制 Qwen GGUF 到 cache/?(可能数百 MB~数 GB)[y/N] " ans
|
||||||
|
if [[ "${ans:-}" =~ ^[yY] ]]; then
|
||||||
|
mkdir -p "$ROOT/cache/qwen25-1.5b-gguf"
|
||||||
|
cp -a "$SRC/cache/qwen25-1.5b-gguf/"* "$ROOT/cache/qwen25-1.5b-gguf/" 2>/dev/null || true
|
||||||
|
echo " 已复制 cache/qwen25-1.5b-gguf"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo " 未找到 $SRC/cache/qwen25-1.5b-gguf ,大模型请在新机器上再下载或单独拷贝"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "完成。可将整个目录打包拷贝到另一台设备:"
|
||||||
|
echo " $ROOT"
|
||||||
|
echo "新设备上请执行: pip install -r requirements.txt(或使用相同 conda 环境)"
|
||||||
42
scripts/generate_wake_greeting_wav.py
Normal file
42
scripts/generate_wake_greeting_wav.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""生成 assets/tts_cache/wake_greeting.wav。须与 voice_drone.main_app._WAKE_GREETING 一致。
|
||||||
|
|
||||||
|
用法(在 voice_drone_assistant 根目录):
|
||||||
|
python scripts/generate_wake_greeting_wav.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_ROOT))
|
||||||
|
try:
|
||||||
|
os.chdir(_ROOT)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio
|
||||||
|
|
||||||
|
fix_ld_path_for_portaudio()
|
||||||
|
|
||||||
|
_WAKE_TEXT = "你好,我在呢"
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
out_dir = _ROOT / "assets" / "tts_cache"
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = out_dir / "wake_greeting.wav"
|
||||||
|
|
||||||
|
from voice_drone.core.tts import KokoroOnnxTTS
|
||||||
|
|
||||||
|
tts = KokoroOnnxTTS()
|
||||||
|
tts.synthesize_to_file(_WAKE_TEXT, str(out_path))
|
||||||
|
print(f"已写入: {out_path}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
146
scripts/run_flight_bridge_with_mavros.sh
Normal file
146
scripts/run_flight_bridge_with_mavros.sh
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# 一键:roscore(若还没有)→ MAVROS(串口连 PX4)→ flight_intent 伴飞桥(前台,直到 Ctrl+C)
|
||||||
|
#
|
||||||
|
# 适合「不想自己敲 roslaunch」时直接跑;用法与 run_px4_offboard_one_terminal.sh 前两参一致。
|
||||||
|
#
|
||||||
|
# 在 voice_drone_assistant 根目录:
|
||||||
|
# bash scripts/run_flight_bridge_with_mavros.sh
|
||||||
|
# bash scripts/run_flight_bridge_with_mavros.sh /dev/ttyACM0 921600
|
||||||
|
#
|
||||||
|
# 飞控连上后,另开终端发 JSON(std_msgs/String 必须用 YAML 字段 data,见 docs/FLIGHT_BRIDGE_ROS1.md):
|
||||||
|
# source /opt/ros/noetic/setup.bash
|
||||||
|
# rostopic pub -1 /input std_msgs/String \
|
||||||
|
# "data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"降\"}'"
|
||||||
|
#
|
||||||
|
# 环境变量(可选):
|
||||||
|
# ROS_MASTER_URI 默认 http://127.0.0.1:11311
|
||||||
|
# ROS_HOSTNAME 默认 127.0.0.1
|
||||||
|
# BRIDGE_PYTHON 若不设则 python3(与 mavros 同机即可,不必 conda)
|
||||||
|
# OFFBOARD_PYTHON 未设 BRIDGE_PYTHON 时也试 yanshi 环境(与 offboard 脚本习惯一致)
|
||||||
|
#
|
||||||
|
# Ctrl+C:结束伴飞桥,并停止本脚本拉起的 MAVROS;若 roscore 由本脚本启动则一并结束。
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||||
|
cd "$ROOT"
|
||||||
|
|
||||||
|
if [[ ! -f /opt/ros/noetic/setup.bash ]]; then
|
||||||
|
echo "未找到 /opt/ros/noetic/setup.bash(伴飞桥当前仅支持 ROS1 Noetic)" >&2
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source /opt/ros/noetic/setup.bash
|
||||||
|
export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}"
|
||||||
|
export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}"
|
||||||
|
|
||||||
|
DEV="${1:-/dev/ttyACM0}"
|
||||||
|
BAUD="${2:-921600}"
|
||||||
|
FCU_URL="${DEV}:${BAUD}"
|
||||||
|
|
||||||
|
ROSCORE_PID=""
|
||||||
|
MAVROS_PID=""
|
||||||
|
WE_STARTED_ROSCORE=0
|
||||||
|
|
||||||
|
master_ok() {
|
||||||
|
timeout 3 rosnode list &>/dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_mavros() {
|
||||||
|
pkill -f '/opt/ros/noetic/lib/mavros/mavros_node' 2>/dev/null || true
|
||||||
|
}
|
||||||
|
|
||||||
|
kill_children() {
|
||||||
|
if [[ -n "${MAVROS_PID:-}" ]] && kill -0 "$MAVROS_PID" 2>/dev/null; then
|
||||||
|
kill "$MAVROS_PID" 2>/dev/null || true
|
||||||
|
wait "$MAVROS_PID" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
if [[ "${WE_STARTED_ROSCORE}" -eq 1 ]] && [[ -n "${ROSCORE_PID:-}" ]] && kill -0 "$ROSCORE_PID" 2>/dev/null; then
|
||||||
|
kill "$ROSCORE_PID" 2>/dev/null || true
|
||||||
|
wait "$ROSCORE_PID" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'kill_children; exit 130' INT
|
||||||
|
trap 'kill_children; exit 143' TERM
|
||||||
|
|
||||||
|
resolve_python() {
|
||||||
|
if [[ -n "${BRIDGE_PYTHON:-}" ]] && [[ -x "$BRIDGE_PYTHON" ]]; then
|
||||||
|
echo "$BRIDGE_PYTHON"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
if [[ -n "${OFFBOARD_PYTHON:-}" ]] && [[ -x "$OFFBOARD_PYTHON" ]]; then
|
||||||
|
echo "$OFFBOARD_PYTHON"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
for cand in \
|
||||||
|
"${HOME}/miniconda3/envs/yanshi/bin/python" \
|
||||||
|
"${HOME}/anaconda3/envs/yanshi/bin/python" \
|
||||||
|
"${HOME}/mambaforge/envs/yanshi/bin/python"; do
|
||||||
|
if [[ -x "$cand" ]]; then
|
||||||
|
echo "$cand"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
command -v python3
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "===== flight_intent 伴飞桥(含 MAVROS)fcu_url=${FCU_URL} ====="
|
||||||
|
|
||||||
|
if ! master_ok; then
|
||||||
|
echo "[1/3] 启动 roscore …"
|
||||||
|
roscore > /tmp/roscore_flight_bridge.log 2>&1 &
|
||||||
|
ROSCORE_PID=$!
|
||||||
|
WE_STARTED_ROSCORE=1
|
||||||
|
for _ in $(seq 1 50); do
|
||||||
|
master_ok && break
|
||||||
|
sleep 0.2
|
||||||
|
done
|
||||||
|
if ! master_ok; then
|
||||||
|
echo "roscore 未起来: tail -40 /tmp/roscore_flight_bridge.log" >&2
|
||||||
|
kill_children
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "roscore 已就绪 (pid=$ROSCORE_PID)"
|
||||||
|
else
|
||||||
|
echo "[1/3] 已有 ROS master,跳过 roscore"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "[2/3] 启动 MAVROS …"
|
||||||
|
stop_mavros
|
||||||
|
sleep 1
|
||||||
|
roslaunch mavros px4.launch fcu_url:="$FCU_URL" > /tmp/mavros_flight_bridge.log 2>&1 &
|
||||||
|
MAVROS_PID=$!
|
||||||
|
|
||||||
|
echo "等待飞控 connected=true(超时 60s)…"
|
||||||
|
connected=0
|
||||||
|
for _ in $(seq 1 60); do
|
||||||
|
if timeout 4 rostopic echo /mavros/state -n 1 2>/dev/null | grep -qE 'connected: [Tt]rue'; then
|
||||||
|
connected=1
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
if [[ "$connected" -ne 1 ]]; then
|
||||||
|
echo "仍未连上飞控。检查串口/波特率: tail -60 /tmp/mavros_flight_bridge.log" >&2
|
||||||
|
echo "可重试: bash $0 ${DEV} 57600" >&2
|
||||||
|
kill_children
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "MAVROS 已连接飞控"
|
||||||
|
|
||||||
|
PY="$(resolve_python)"
|
||||||
|
export PYTHONPATH="${PYTHONPATH:-}:${ROOT}"
|
||||||
|
|
||||||
|
echo "[3/3] 启动伴飞桥(Python: $PY)…"
|
||||||
|
echo " Ctrl+C 退出并清理本脚本启动的 MAVROS$([[ "$WE_STARTED_ROSCORE" -eq 1 ]] && echo +roscore)。"
|
||||||
|
echo " 发意图: 另开终端 source noetic 后 rostopic pub ... 见脚本头注释。"
|
||||||
|
set +e
|
||||||
|
"$PY" -m voice_drone.flight_bridge.ros1_node
|
||||||
|
BRIDGE_EXIT=$?
|
||||||
|
set -e
|
||||||
|
|
||||||
|
kill_children
|
||||||
|
exit "$BRIDGE_EXIT"
|
||||||
37
scripts/run_flight_intent_bridge_ros1.sh
Normal file
37
scripts/run_flight_intent_bridge_ros1.sh
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# 在已有 ROS master + MAVROS(已连飞控)的前提下启动 flight_intent 伴飞桥节点。
|
||||||
|
# 若希望一键连 MAVROS:用同目录 run_flight_bridge_with_mavros.sh
|
||||||
|
#
|
||||||
|
# 用法(在 voice_drone_assistant 根目录):
|
||||||
|
# bash scripts/run_flight_intent_bridge_ros1.sh
|
||||||
|
# bash scripts/run_flight_intent_bridge_ros1.sh my_bridge # 节点名前缀(anonymous 仍会加后缀)
|
||||||
|
#
|
||||||
|
# 另开终端发意图(示例降落,默认订阅全局 /input):
|
||||||
|
# rostopic pub -1 /input std_msgs/String \
|
||||||
|
# "{data: '{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"降\"}'}"
|
||||||
|
#
|
||||||
|
# 若改了 ~input_topic:rosnode info <节点名> 查看订阅话题
|
||||||
|
#
|
||||||
|
# 环境变量(与 run_flight_bridge_with_mavros.sh 一致,未设置时给默认值):
|
||||||
|
# ROS_MASTER_URI 默认 http://127.0.0.1:11311
|
||||||
|
# ROS_HOSTNAME 默认 127.0.0.1
|
||||||
|
# 注意:这些只在「本脚本进程」里生效;另开终端调试 rostopic/rosservice 时须自行 source noetic 并 export 相同 URI,或与跑 roscore 的机器一致。
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
cd "$ROOT"
|
||||||
|
|
||||||
|
if [[ ! -f /opt/ros/noetic/setup.bash ]]; then
|
||||||
|
echo "未找到 /opt/ros/noetic/setup.bash(当前桥仅支持 ROS1 Noetic)" >&2
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source /opt/ros/noetic/setup.bash
|
||||||
|
|
||||||
|
export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}"
|
||||||
|
export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}"
|
||||||
|
|
||||||
|
export PYTHONPATH="${PYTHONPATH}:${ROOT}"
|
||||||
|
|
||||||
|
exec python3 -m voice_drone.flight_bridge.ros1_node "$@"
|
||||||
172
scripts/run_px4_offboard_one_terminal.sh
Normal file
172
scripts/run_px4_offboard_one_terminal.sh
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# 单终端一键:按需 roscore → MAVROS 串口(MAVLink) → px4_ctrl_offboard_demo.py
|
||||||
|
#
|
||||||
|
# 用法(在仓库根目录):
|
||||||
|
# bash scripts/run_px4_offboard_one_terminal.sh
|
||||||
|
# bash scripts/run_px4_offboard_one_terminal.sh /dev/ttyACM0 921600
|
||||||
|
# bash scripts/run_px4_offboard_one_terminal.sh /dev/ttyACM0 921600 120
|
||||||
|
# 第三参 DEMO_MAX_SEC:仅限制「px4_ctrl_offboard_demo.py」运行时长(秒),到时发 SIGTERM 结束
|
||||||
|
# demo,随后脚本清理 MAVROS / 可选 roscore 并退出;0 或不写表示不限制。
|
||||||
|
# 也可用环境变量 OFFBOARD_DEMO_MAX_SEC(第三参优先)。
|
||||||
|
#
|
||||||
|
# 环境变量(可选):
|
||||||
|
# ROS_MASTER_URI 默认 http://127.0.0.1:11311
|
||||||
|
# ROS_HOSTNAME 默认 127.0.0.1
|
||||||
|
# OFFBOARD_PYTHON 默认优先用 conda env「yanshi」里的 python,否则 python3
|
||||||
|
# OFFBOARD_DEMO_MAX_SEC 同第三参;整条脚本从头限时请用: timeout 600 bash 本脚本 ...
|
||||||
|
# ROCKET_OFFBOARD_DEMO_PY 显式指定 px4_ctrl_offboard_demo.py 路径(缺省时先 $ROOT/src/ 再 $ROOT/../src/)
|
||||||
|
#
|
||||||
|
# 退出:Ctrl+C 会结束 demo,并停止本脚本拉起的 MAVROS;若 roscore 由本脚本启动,会一并结束。
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||||
|
cd "$ROOT"
|
||||||
|
|
||||||
|
resolve_demo_py() {
|
||||||
|
if [[ -n "${ROCKET_OFFBOARD_DEMO_PY:-}" ]] && [[ -f "${ROCKET_OFFBOARD_DEMO_PY}" ]]; then
|
||||||
|
echo "$(cd "$(dirname "${ROCKET_OFFBOARD_DEMO_PY}")" && pwd)/$(basename "${ROCKET_OFFBOARD_DEMO_PY}")"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
local c1="$ROOT/src/px4_ctrl_offboard_demo.py"
|
||||||
|
if [[ -f "$c1" ]]; then
|
||||||
|
echo "$c1"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
local c2="$ROOT/../src/px4_ctrl_offboard_demo.py"
|
||||||
|
if [[ -f "$c2" ]]; then
|
||||||
|
echo "$(cd "$(dirname "$c2")" && pwd)/px4_ctrl_offboard_demo.py"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
echo "错误: 找不到 px4_ctrl_offboard_demo.py。已尝试:" >&2
|
||||||
|
echo " ROCKET_OFFBOARD_DEMO_PY=${ROCKET_OFFBOARD_DEMO_PY:-(未设置)}" >&2
|
||||||
|
echo " $c1" >&2
|
||||||
|
echo " $c2" >&2
|
||||||
|
exit 3
|
||||||
|
}
|
||||||
|
|
||||||
|
DEMO_PY="$(resolve_demo_py)"
|
||||||
|
|
||||||
|
source /opt/ros/noetic/setup.bash
|
||||||
|
export ROS_MASTER_URI="${ROS_MASTER_URI:-http://127.0.0.1:11311}"
|
||||||
|
export ROS_HOSTNAME="${ROS_HOSTNAME:-127.0.0.1}"
|
||||||
|
|
||||||
|
DEV="${1:-/dev/ttyACM0}"
|
||||||
|
BAUD="${2:-921600}"
|
||||||
|
DEMO_MAX_SEC="${3:-${OFFBOARD_DEMO_MAX_SEC:-0}}"
|
||||||
|
FCU_URL="${DEV}:${BAUD}"
|
||||||
|
|
||||||
|
if ! [[ "$DEMO_MAX_SEC" =~ ^[0-9]+$ ]]; then
|
||||||
|
echo "错误: 第三参(demo 最长运行秒数)须为非负整数,当前=${DEMO_MAX_SEC}"
|
||||||
|
exit 2
|
||||||
|
fi
|
||||||
|
|
||||||
|
ROSCORE_PID=""
|
||||||
|
MAVROS_PID=""
|
||||||
|
WE_STARTED_ROSCORE=0
|
||||||
|
|
||||||
|
master_ok() {
|
||||||
|
timeout 3 rosnode list &>/dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_mavros() {
|
||||||
|
pkill -f '/opt/ros/noetic/lib/mavros/mavros_node' 2>/dev/null || true
|
||||||
|
}
|
||||||
|
|
||||||
|
kill_children() {
|
||||||
|
if [[ -n "${MAVROS_PID:-}" ]] && kill -0 "$MAVROS_PID" 2>/dev/null; then
|
||||||
|
kill "$MAVROS_PID" 2>/dev/null || true
|
||||||
|
wait "$MAVROS_PID" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
if [[ "${WE_STARTED_ROSCORE}" -eq 1 ]] && [[ -n "${ROSCORE_PID:-}" ]] && kill -0 "$ROSCORE_PID" 2>/dev/null; then
|
||||||
|
kill "$ROSCORE_PID" 2>/dev/null || true
|
||||||
|
wait "$ROSCORE_PID" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'kill_children; exit 130' INT
|
||||||
|
trap 'kill_children; exit 143' TERM
|
||||||
|
|
||||||
|
resolve_python() {
|
||||||
|
if [[ -n "${OFFBOARD_PYTHON:-}" ]] && [[ -x "$OFFBOARD_PYTHON" ]]; then
|
||||||
|
echo "$OFFBOARD_PYTHON"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
for cand in \
|
||||||
|
"${HOME}/miniconda3/envs/yanshi/bin/python" \
|
||||||
|
"${HOME}/anaconda3/envs/yanshi/bin/python" \
|
||||||
|
"${HOME}/mambaforge/envs/yanshi/bin/python"; do
|
||||||
|
if [[ -x "$cand" ]]; then
|
||||||
|
echo "$cand"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
command -v python3
|
||||||
|
}
|
||||||
|
|
||||||
|
if ! master_ok; then
|
||||||
|
echo "[1/3] 启动 roscore ..."
|
||||||
|
roscore > /tmp/roscore_offboard_one_terminal.log 2>&1 &
|
||||||
|
ROSCORE_PID=$!
|
||||||
|
WE_STARTED_ROSCORE=1
|
||||||
|
for _ in $(seq 1 50); do
|
||||||
|
master_ok && break
|
||||||
|
sleep 0.2
|
||||||
|
done
|
||||||
|
if ! master_ok; then
|
||||||
|
echo "roscore 未起来,请查看: tail -40 /tmp/roscore_offboard_one_terminal.log"
|
||||||
|
kill_children
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "roscore 已就绪 (pid=$ROSCORE_PID)"
|
||||||
|
else
|
||||||
|
echo "[1/3] 已有 ROS master,跳过 roscore"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "[2/3] 启动 MAVROS fcu_url=${FCU_URL}"
|
||||||
|
stop_mavros
|
||||||
|
sleep 1
|
||||||
|
roslaunch mavros px4.launch fcu_url:="$FCU_URL" > /tmp/mavros_offboard_one_terminal.log 2>&1 &
|
||||||
|
MAVROS_PID=$!
|
||||||
|
|
||||||
|
echo "等待飞控连接 (connected=true),超时 60s ..."
|
||||||
|
connected=0
|
||||||
|
for _ in $(seq 1 60); do
|
||||||
|
if timeout 4 rostopic echo /mavros/state -n 1 2>/dev/null | grep -qE 'connected: [Tt]rue'; then
|
||||||
|
connected=1
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
if [[ "$connected" -ne 1 ]]; then
|
||||||
|
echo "仍未连上飞控。请检查串口与波特率,日志: tail -60 /tmp/mavros_offboard_one_terminal.log"
|
||||||
|
echo "可重试: bash $0 ${DEV} 57600"
|
||||||
|
kill_children
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "MAVROS 已连接飞控"
|
||||||
|
|
||||||
|
PY="$(resolve_python)"
|
||||||
|
echo "[3/3] 使用 Python: $PY"
|
||||||
|
if [[ "$DEMO_MAX_SEC" -gt 0 ]]; then
|
||||||
|
echo "运行 $DEMO_PY,${DEMO_MAX_SEC}s 后自动结束 demo 并退出本脚本(随后清理子进程)"
|
||||||
|
else
|
||||||
|
echo "运行 $DEMO_PY (Ctrl+C 结束并将停止本脚本启动的 MAVROS)"
|
||||||
|
fi
|
||||||
|
set +e
|
||||||
|
if [[ "$DEMO_MAX_SEC" -gt 0 ]]; then
|
||||||
|
timeout --signal=TERM --kill-after=8 "$DEMO_MAX_SEC" \
|
||||||
|
"$PY" "$DEMO_PY"
|
||||||
|
DEMO_EXIT=$?
|
||||||
|
if [[ "$DEMO_EXIT" -eq 124 ]]; then
|
||||||
|
echo "[timeout] 已到 ${DEMO_MAX_SEC}s,已终止 px4_ctrl_offboard_demo.py(退出码 124)"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
"$PY" "$DEMO_PY"
|
||||||
|
DEMO_EXIT=$?
|
||||||
|
fi
|
||||||
|
set -e
|
||||||
|
|
||||||
|
kill_children
|
||||||
|
exit "$DEMO_EXIT"
|
||||||
44
tests/test_cloud_dialog_v1.py
Normal file
44
tests/test_cloud_dialog_v1.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""cloud_voice_dialog_v1 短语与 confirm 解析。"""
|
||||||
|
from voice_drone.core.cloud_dialog_v1 import (
|
||||||
|
match_phrase_list,
|
||||||
|
normalize_phrase_text,
|
||||||
|
parse_confirm_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize():
|
||||||
|
assert normalize_phrase_text(" a b ") == "a b"
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_phrases():
|
||||||
|
assert match_phrase_list(normalize_phrase_text("确认"), ["确认"])
|
||||||
|
assert match_phrase_list(normalize_phrase_text("取消"), ["取消"])
|
||||||
|
assert match_phrase_list(normalize_phrase_text("好的确认"), ["确认"])
|
||||||
|
# 复述长提示:不应仅因子串「取消」命中
|
||||||
|
long_prompt = normalize_phrase_text("请回复确认或取消")
|
||||||
|
assert match_phrase_list(long_prompt, ["取消"]) is False
|
||||||
|
assert match_phrase_list(long_prompt, ["确认"]) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_confirm_ok():
|
||||||
|
d = parse_confirm_dict(
|
||||||
|
{
|
||||||
|
"required": True,
|
||||||
|
"timeout_sec": 10,
|
||||||
|
"confirm_phrases": ["确认"],
|
||||||
|
"cancel_phrases": ["取消"],
|
||||||
|
"pending_id": "p1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert d is not None
|
||||||
|
assert d["required"] is True
|
||||||
|
assert d["timeout_sec"] == 10.0
|
||||||
|
assert "确认" in d["confirm_phrases"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_confirm_reject_bad_required():
|
||||||
|
assert parse_confirm_dict({"required": "yes"}) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_priority_concept():
|
||||||
|
assert match_phrase_list(normalize_phrase_text("取消"), ["取消"])
|
||||||
160
tests/test_flight_intent.py
Normal file
160
tests/test_flight_intent.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
"""flight_intent v1 校验与 goto→Command 映射。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from voice_drone.core.flight_intent import (
|
||||||
|
goto_action_to_command,
|
||||||
|
parse_flight_intent_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _command_stack_available() -> bool:
|
||||||
|
try:
|
||||||
|
import yaml # noqa: F401
|
||||||
|
|
||||||
|
from voice_drone.core.command import Command # noqa: F401
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
needs_command_stack = pytest.mark.skipif(
|
||||||
|
not _command_stack_available(),
|
||||||
|
reason="需要 pyyaml 与工程配置以加载 Command",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_minimal_ok():
|
||||||
|
v, err = parse_flight_intent_dict(
|
||||||
|
{
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [{"type": "land", "args": {}}],
|
||||||
|
"summary": "降落",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert err == []
|
||||||
|
assert v is not None
|
||||||
|
assert v.actions[0].type == "land"
|
||||||
|
|
||||||
|
|
||||||
|
def test_hover_duration_cloud_legacy_expands_to_wait():
|
||||||
|
"""云端误将停顿时长写在 hover.args.duration 时,客户端规范化为 hover + wait。"""
|
||||||
|
v, err = parse_flight_intent_dict(
|
||||||
|
{
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [
|
||||||
|
{"type": "takeoff", "args": {}},
|
||||||
|
{"type": "hover", "args": {"duration": 3}},
|
||||||
|
{"type": "land", "args": {}},
|
||||||
|
],
|
||||||
|
"summary": "test",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert err == []
|
||||||
|
assert v is not None
|
||||||
|
assert len(v.actions) == 4
|
||||||
|
assert v.actions[0].type == "takeoff"
|
||||||
|
assert v.actions[1].type == "hover"
|
||||||
|
assert v.actions[2].type == "wait"
|
||||||
|
assert float(v.actions[2].args.seconds) == 3.0
|
||||||
|
assert v.actions[3].type == "land"
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait_after_hover_ok():
|
||||||
|
v, err = parse_flight_intent_dict(
|
||||||
|
{
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [
|
||||||
|
{"type": "takeoff", "args": {}},
|
||||||
|
{"type": "hover", "args": {}},
|
||||||
|
{"type": "wait", "args": {"seconds": 2.5}},
|
||||||
|
{"type": "land", "args": {}},
|
||||||
|
],
|
||||||
|
"summary": "test",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert err == []
|
||||||
|
assert v is not None
|
||||||
|
assert len(v.actions) == 4
|
||||||
|
assert v.actions[2].type == "wait"
|
||||||
|
assert v.trace_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_wait_rejected():
|
||||||
|
v, err = parse_flight_intent_dict(
|
||||||
|
{
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [
|
||||||
|
{"type": "wait", "args": {"seconds": 1}},
|
||||||
|
{"type": "land", "args": {}},
|
||||||
|
],
|
||||||
|
"summary": "x",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert v is None
|
||||||
|
assert err
|
||||||
|
|
||||||
|
|
||||||
|
def test_extra_top_key_rejected():
|
||||||
|
v, err = parse_flight_intent_dict(
|
||||||
|
{
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [{"type": "land", "args": {}}],
|
||||||
|
"summary": "x",
|
||||||
|
"foo": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert v is None
|
||||||
|
assert err
|
||||||
|
|
||||||
|
|
||||||
|
@needs_command_stack
|
||||||
|
def test_goto_body_forward():
|
||||||
|
v, err = parse_flight_intent_dict(
|
||||||
|
{
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [
|
||||||
|
{
|
||||||
|
"type": "goto",
|
||||||
|
"args": {"frame": "body_ned", "x": 3},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"summary": "s",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert v is not None and not err
|
||||||
|
cmd, reason = goto_action_to_command(v.actions[0], sequence_id=7)
|
||||||
|
assert reason is None
|
||||||
|
assert cmd is not None
|
||||||
|
assert cmd.command == "forward"
|
||||||
|
assert cmd.sequence_id == 7
|
||||||
|
|
||||||
|
|
||||||
|
@needs_command_stack
|
||||||
|
def test_goto_multi_axis_no_command():
|
||||||
|
v, err = parse_flight_intent_dict(
|
||||||
|
{
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [
|
||||||
|
{
|
||||||
|
"type": "goto",
|
||||||
|
"args": {"frame": "body_ned", "x": 1, "y": 1},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"summary": "s",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert v is not None and not err
|
||||||
|
cmd, reason = goto_action_to_command(v.actions[0], 1)
|
||||||
|
assert cmd is None
|
||||||
|
assert reason and "multi-axis" in reason
|
||||||
3
voice_drone/__init__.py
Normal file
3
voice_drone/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""语音无人机助手:采集 → VAD → STT → 唤醒 → LLM/起飞 → Kokoro 播报。"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
25
voice_drone/config/cloud_voice_px4_context.yaml
Normal file
25
voice_drone/config/cloud_voice_px4_context.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# 云端 session.start → client 扩展字段(PX4 / MAVLink 语境)
|
||||||
|
# 与服务端约定:原样进入 LLM;vehicle_class 与 mav_type 应一致,冲突时以 mav_type 为准。
|
||||||
|
#
|
||||||
|
# 修改后重启 main.py;可用环境变量 ROCKET_CLOUD_PX4_CONTEXT_FILE 覆盖本文件路径。
|
||||||
|
|
||||||
|
# 机体类别(自由文本,与 mav_type 语义对齐即可,例如 multicopter / fixed_wing)
|
||||||
|
vehicle_class: multicopter
|
||||||
|
|
||||||
|
# MAV_TYPE 枚举整数值,参见 https://mavlink.io/en/messages/common.html#MAV_TYPE
|
||||||
|
mav_type: 2
|
||||||
|
|
||||||
|
# 口语「前/右/上」未说明坐标系时,云端生成 goto 的默认系:
|
||||||
|
# local_ned — 北东地(与常见 Offboard 位置设定一致)
|
||||||
|
# body_ned / body — 机体系前右下(仅当机端按体轴解析相对位移时填写)
|
||||||
|
default_setpoint_frame: local_ned
|
||||||
|
|
||||||
|
# 以下为可选运行时状态(可由 MAVROS / ROS2 写入同文件,或由机载进程覆写后再连云端)
|
||||||
|
# home_position_valid: true
|
||||||
|
# offboard_capable: true
|
||||||
|
# current_nav_state: "OFFBOARD"
|
||||||
|
|
||||||
|
# 任意短键名 JSON,进入 LLM;适合「电池低」「室内无 GPS」「仅限 mission」等
|
||||||
|
extras:
|
||||||
|
platform: companion_orangepi
|
||||||
|
# indoor_no_gps: true
|
||||||
59
voice_drone/config/command_.yaml
Normal file
59
voice_drone/config/command_.yaml
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# 命令配置文件 用于命令生成和填充默认值
|
||||||
|
control_params:
|
||||||
|
# 起飞默认参数
|
||||||
|
takeoff:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 1
|
||||||
|
# 降落默认参数
|
||||||
|
land:
|
||||||
|
distance: 0
|
||||||
|
speed: 0
|
||||||
|
duration: 0
|
||||||
|
# 跟随默认参数
|
||||||
|
follow:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 2
|
||||||
|
# 向前默认参数
|
||||||
|
forward:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 1
|
||||||
|
# 向后默认参数
|
||||||
|
backward:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 1
|
||||||
|
# 向左默认参数
|
||||||
|
left:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 1
|
||||||
|
# 向右默认参数
|
||||||
|
right:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 1
|
||||||
|
# 向上默认参数
|
||||||
|
up:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 1
|
||||||
|
# 向下默认参数
|
||||||
|
down:
|
||||||
|
distance: 0.5
|
||||||
|
speed: 0.5
|
||||||
|
duration: 1
|
||||||
|
# 悬停默认参数
|
||||||
|
hover:
|
||||||
|
distance: 0
|
||||||
|
speed: 0
|
||||||
|
duration: 5
|
||||||
|
# 返航(Socket 协议与 land/hover 同类占位参数)
|
||||||
|
return_home:
|
||||||
|
distance: 0
|
||||||
|
speed: 0
|
||||||
|
duration: 0
|
||||||
|
|
||||||
|
|
||||||
71
voice_drone/config/keywords.yaml
Normal file
71
voice_drone/config/keywords.yaml
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
keywords:
|
||||||
|
# takeoff 仅用于「一键 offboard 演示」唤醒路径;用「起飞演示」避免句子里单独出现「起飞」误触(如「起飞,悬停再降落」)
|
||||||
|
takeoff:
|
||||||
|
- "起飞演示"
|
||||||
|
- "演示起飞"
|
||||||
|
|
||||||
|
land:
|
||||||
|
- "立刻降落"
|
||||||
|
- "紧急降落"
|
||||||
|
- "降落"
|
||||||
|
- "落地"
|
||||||
|
- "着陆"
|
||||||
|
|
||||||
|
follow:
|
||||||
|
- "跟随"
|
||||||
|
- "跟着我"
|
||||||
|
- "跟我飞"
|
||||||
|
- "跟随模式"
|
||||||
|
|
||||||
|
hover:
|
||||||
|
- "马上停下"
|
||||||
|
- "立刻停下"
|
||||||
|
- "悬停"
|
||||||
|
- "停下"
|
||||||
|
- "停止"
|
||||||
|
- "停"
|
||||||
|
|
||||||
|
forward:
|
||||||
|
- "向前飞"
|
||||||
|
- "往前飞"
|
||||||
|
- "向前"
|
||||||
|
- "往前"
|
||||||
|
- "前面飞"
|
||||||
|
- "前进"
|
||||||
|
|
||||||
|
backward:
|
||||||
|
- "向后飞"
|
||||||
|
- "往后飞"
|
||||||
|
- "向后"
|
||||||
|
- "往后"
|
||||||
|
- "后退"
|
||||||
|
|
||||||
|
left:
|
||||||
|
- "向左飞"
|
||||||
|
- "往左飞"
|
||||||
|
- "向左"
|
||||||
|
- "往左"
|
||||||
|
- "左移"
|
||||||
|
|
||||||
|
right:
|
||||||
|
- "向右飞"
|
||||||
|
- "往右飞"
|
||||||
|
- "向右"
|
||||||
|
- "往右"
|
||||||
|
- "右移"
|
||||||
|
|
||||||
|
up:
|
||||||
|
- "向上飞"
|
||||||
|
- "往上飞"
|
||||||
|
- "向上"
|
||||||
|
- "往上"
|
||||||
|
- "上升"
|
||||||
|
- "升高"
|
||||||
|
|
||||||
|
down:
|
||||||
|
- "向下飞"
|
||||||
|
- "往下飞"
|
||||||
|
- "向下"
|
||||||
|
- "往下"
|
||||||
|
- "下降"
|
||||||
|
- "降低"
|
||||||
246
voice_drone/config/system.yaml
Normal file
246
voice_drone/config/system.yaml
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
# ***********音频采集配置 *************
|
||||||
|
audio:
|
||||||
|
sample_rate: 16000
|
||||||
|
channels: 1
|
||||||
|
sample_width: 2
|
||||||
|
frame_size: 1024
|
||||||
|
audio_format: "wav"
|
||||||
|
# 麦克风:仅使用 input_device_index(PyAudio 整数)。null = 自动尝试默认输入 + 所有 in>0 设备。
|
||||||
|
# 运行 python src/rocket_drone_audio.py 会默认打印 arecord -l、PyAudio 列表与 hw 映射并交互选择。
|
||||||
|
# 自动化可加 --non-interactive 并在此写入整数索引;或传 --input-index / ROCKET_INPUT_DEVICE_INDEX。
|
||||||
|
input_device_index: null
|
||||||
|
# 以下字段已不再参与选麦逻辑(可删或保留作备忘)
|
||||||
|
input_strict_selection: false
|
||||||
|
input_hw_card_device: null
|
||||||
|
input_device_name_match: null
|
||||||
|
# 设备报告双声道、但 channels 为 1 时,用立体声打开再下混为 mono(Orange Pi ES8388 等需开启)
|
||||||
|
prefer_stereo_capture: true
|
||||||
|
# ES8388 常需 48k 才能打开;打开后会在采集里重采样回 sample_rate
|
||||||
|
audio_open_try_rates: [16000, 48000, 44100]
|
||||||
|
# conda 下 PyAudio 常链到残缺 ALSA 插件:请在 shell 用 bash with_system_alsa.sh python … 启动
|
||||||
|
# ES8388 大声说话仍只有 RMS≈30:多为 ALSA「Left/Right Channel」采集=0%,先执行 scripts/es8388_capture_up.sh 再 sudo alsactl store
|
||||||
|
|
||||||
|
# 高性能优化配置
|
||||||
|
high_performance_mode: true # 启用高性能模式(多线程+异步处理)
|
||||||
|
use_callback_mode: true # 使用回调模式(非阻塞)
|
||||||
|
buffer_queue_size: 10 # 音频缓冲队列大小
|
||||||
|
processing_threads: 2 # 处理线程数(采集和处理并行)
|
||||||
|
batch_processing: true # 启用批处理优化
|
||||||
|
|
||||||
|
# 降噪配置
|
||||||
|
noise_reduce: true
|
||||||
|
noise_reduction_method: "lightweight" # "lightweight" (轻量级) 或 "noisereduce" (完整版)
|
||||||
|
noise_sample_duration_ms: 500 # 噪声样本收集时长(毫秒)
|
||||||
|
noise_reduction_cutoff_hz: 80.0 # 轻量级降噪高通滤波截止频率(Hz)
|
||||||
|
|
||||||
|
# 自动增益控制配置
|
||||||
|
agc: true
|
||||||
|
agc_method: "incremental" # "incremental" (增量) 或 "standard" (标准)
|
||||||
|
agc_target_db: -20.0 # AGC 目标音量(dB)
|
||||||
|
# 过小(如 0.1)时在短时强噪声后易把波形压到 int16 近 0,能量 VAD 长期收不到音;建议 0.25~0.5
|
||||||
|
agc_gain_min: 0.25 # AGC 最小增益倍数
|
||||||
|
agc_gain_max: 10.0 # AGC 最大增益倍数
|
||||||
|
agc_rms_threshold: 1e-6 # AGC RMS 阈值(避免除零)
|
||||||
|
agc_smoothing_alpha: 0.1 # 压低增益时的平滑系数(0-1,越小越慢)
|
||||||
|
agc_release_alpha: 0.45 # 需要抬增益时(巨响/小声后恢复)用更大系数,由 audio.py 读取
|
||||||
|
|
||||||
|
# ***********语音活动检测配置 *************
|
||||||
|
vad:
|
||||||
|
# 略降低门槛,避免板载麦/经 AGC 后仍达不到 0.65
|
||||||
|
threshold: 0.45
|
||||||
|
start_frame: 2
|
||||||
|
# 句尾连续静音块数(每块时长≈audio.frame_size/sample_rate);recognizer.trailing_silence_seconds 优先覆盖此项与 energy_vad_end_chunks
|
||||||
|
end_frame: 10
|
||||||
|
min_silence_duration_s: 0.5
|
||||||
|
max_silence_duration_s: 30
|
||||||
|
model_path: "models/SileroVad/silero_vad.onnx"
|
||||||
|
|
||||||
|
# ***********语音识别配置 *************
|
||||||
|
stt:
|
||||||
|
# 模型路径配置
|
||||||
|
model_dir: "models/SenseVoiceSmall" # 模型目录
|
||||||
|
model_path: "models/SenseVoiceSmall/model.int8.onnx" # 直接指定模型路径(优先级高于 model_dir)
|
||||||
|
prefer_int8: true # 是否优先使用 INT8 量化模型
|
||||||
|
warmup_file: "" # 预热音频文件
|
||||||
|
|
||||||
|
# 音频预处理配置
|
||||||
|
sample_rate: 16000 # 采样率(与 audio 配置保持一致)
|
||||||
|
n_mels: 80 # Mel 滤波器数量
|
||||||
|
frame_length_ms: 25 # 帧长度(毫秒)
|
||||||
|
frame_shift_ms: 10 # 帧移(毫秒)
|
||||||
|
log_eps: 1e-10 # log 计算时的极小值(避免 log(0))
|
||||||
|
|
||||||
|
# ARM 设备优化配置
|
||||||
|
arm_optimization:
|
||||||
|
enabled: true # 是否启用 ARM 优化
|
||||||
|
max_threads: 4 # 最大线程数(RK3588 使用 4 个大核)
|
||||||
|
|
||||||
|
# CTC 解码配置
|
||||||
|
ctc_decode:
|
||||||
|
blank_id: 0 # 空白 token ID
|
||||||
|
|
||||||
|
# 语言和文本规范化配置(默认值,实际从模型元数据读取)
|
||||||
|
language:
|
||||||
|
zh_id: 3 # 中文语言ID(默认值)
|
||||||
|
text_norm:
|
||||||
|
with_itn_id: 14 # 使用 ITN 的 ID(默认值)
|
||||||
|
without_itn_id: 15 # 不使用 ITN 的 ID(默认值)
|
||||||
|
|
||||||
|
# 后处理配置
|
||||||
|
postprocess:
|
||||||
|
special_tokens: # 需要移除的特殊 token
|
||||||
|
- "<|zh|>"
|
||||||
|
- "<|NEUTRAL|>"
|
||||||
|
- "<|Speech|>"
|
||||||
|
- "<|woitn|>"
|
||||||
|
- "<|withitn|>"
|
||||||
|
|
||||||
|
# ***********文本转语音配置 *************
|
||||||
|
tts:
|
||||||
|
# Kokoro ONNX 中文模型目录
|
||||||
|
model_dir: "models/Kokoro-82M-v1.1-zh-ONNX"
|
||||||
|
|
||||||
|
# ONNX 子目录中的模型文件名
|
||||||
|
# 清晰度优先: model_fp16.onnx / model.onnx > model_q4f16.onnx > int8/uint8
|
||||||
|
# 速度(板端 CPU):可试 model_q4f16.onnx;uint8 有时更慢(取决于 ORT/CPU)
|
||||||
|
# 若仓库中仅有量化版,可改回 model_uint8.onnx 或 model_q4f16.onnx
|
||||||
|
model_name: "model_uint8.onnx"
|
||||||
|
|
||||||
|
# 语音风格(对应 voices 目录下的 *.bin, 这里不写扩展名)
|
||||||
|
# 女声常用 zf_001 / zf_002;可换 zm_* 男声。以本机 voices 目录实际文件为准。
|
||||||
|
voice: "zm_009"
|
||||||
|
|
||||||
|
# 语速系数(1.0 最自然; >1 易显赶、含糊; <1 更稳但略慢)
|
||||||
|
speed: 1.15
|
||||||
|
# 输出采样率(与 Kokoro 模型保持一致, 官方为 24000Hz)
|
||||||
|
sample_rate: 24000
|
||||||
|
|
||||||
|
# sounddevice 播放输出设备(命令应答语音走扬声器时请指定,避免播到虚拟声卡/耳机)
|
||||||
|
# null:使用系统默认输出设备
|
||||||
|
# 整数:设备索引(启动日志会列出「sounddevice 输出设备列表」供对照)
|
||||||
|
# 字符串:设备名称子串匹配(不区分大小写),例如 "扬声器"、"Speakers"、"Realtek"
|
||||||
|
# 香橙派走 HDMI 出声时 PortAudio 名常含 rockchip-hdmi0,可设 output_device: "rockchip-hdmi0"(与 ROCKET_TTS_DEVICE 一致)
|
||||||
|
output_device: null
|
||||||
|
|
||||||
|
# 播放前重采样到该输出设备的 default_samplerate(Windows/WASAPI 下 24000Hz 常无声,强烈建议 true)
|
||||||
|
playback_resample_to_device_native: true
|
||||||
|
|
||||||
|
# 播放前将峰值压到约 0.92,减轻削波导致的爆音/杂音
|
||||||
|
playback_peak_normalize: true
|
||||||
|
# 播放音量增益(波形幅度乘法,1.0 不变;1.3~1.8 更响,过大可能削波失真)
|
||||||
|
playback_gain: 2.5
|
||||||
|
# 首尾淡入淡出(毫秒),减轻驱动/缓冲区切换时的爆音与「咔哒」声
|
||||||
|
playback_edge_fade_ms: 8
|
||||||
|
# sounddevice OutputStream 延迟: low / medium / high(high 易积缓冲,部分机器听感发闷或拖尾)
|
||||||
|
playback_output_latency: low
|
||||||
|
|
||||||
|
# ***********云端语音(LLM + TTS 上云,见 clientguide.md)*************
|
||||||
|
cloud_voice:
|
||||||
|
enabled: false
|
||||||
|
server_url: "ws://192.168.0.186:8766/v1/voice/session"
|
||||||
|
auth_token: "drone-voice-cloud-token-2024"
|
||||||
|
device_id: "drone-001"
|
||||||
|
timeout: 120
|
||||||
|
# PROMPT_LISTEN:麦克 RMS 持续低于 recognizer.energy_vad_rms_low 的累计秒数 ≥ 此值则超时(非滴声后固定墙上时钟);消抖/提示音见下
|
||||||
|
listen_silence_timeout_sec: 5
|
||||||
|
post_cue_mic_mute_ms: 200
|
||||||
|
segment_cue_duration_ms: 120
|
||||||
|
# 问候语 / 本地 LLM 文案 / 飞控确认超时等字符串是否走 WebSocket tts.synthesize(见 docs/API.md §3.3);失败回退 Kokoro
|
||||||
|
remote_tts_for_local: true
|
||||||
|
# 云端失败时是否回退本地 Qwen + Kokoro(需本地模型)
|
||||||
|
fallback_to_local: true
|
||||||
|
# PX4/MAV 语境:合并进 WebSocket session.start 的 client,供服务端 LLM;也可用 ROCKET_CLOUD_PX4_CONTEXT_FILE 覆盖路径
|
||||||
|
px4_context_file: "voice_drone/config/cloud_voice_px4_context.yaml"
|
||||||
|
|
||||||
|
# ***********socket服务器配置 *************
|
||||||
|
socket_server:
|
||||||
|
# deployed
|
||||||
|
host: "192.168.43.200"
|
||||||
|
port: 6666
|
||||||
|
|
||||||
|
#local
|
||||||
|
# host: "127.0.0.1"
|
||||||
|
# port: 8888
|
||||||
|
connection_timeout: 5.0
|
||||||
|
send_timeout: 2.0
|
||||||
|
reconnect_interval: 3.0
|
||||||
|
# -1:断线后持续重连并发送直到成功(仅打 warning,不当作一次性致命错误);正整数:最多尝试次数
|
||||||
|
max_retries: -1
|
||||||
|
|
||||||
|
# ***********文本预处理配置 *************
|
||||||
|
text_preprocessor:
|
||||||
|
# 功能开关
|
||||||
|
enable_traditional_to_simplified: true # 启用繁简转换
|
||||||
|
enable_segmentation: true # 启用分词
|
||||||
|
enable_correction: true # 启用纠错
|
||||||
|
enable_number_extraction: true # 启用数字提取
|
||||||
|
enable_keyword_detection: true # 启用关键词检测
|
||||||
|
|
||||||
|
# 性能配置
|
||||||
|
lru_cache_size: 512 # LRU缓存大小(分词结果缓存)
|
||||||
|
|
||||||
|
# ***********识别器流程配置 *************
|
||||||
|
recognizer:
|
||||||
|
# 句尾连续静音达到该秒数后才切段送 STT,减少句中停顿被切开、识别半句。按 audio.frame_size 与 sample_rate 换算块数,
|
||||||
|
# 并同时设置 Silero 的 vad.end_frame 与 energy 的 energy_vad_end_chunks。不配置则分别用 yaml 中上述两项。
|
||||||
|
trailing_silence_seconds: 1.5
|
||||||
|
# VAD 后端:energy(默认,无需 Silero ONNX)或 silero(需 models/SileroVad/silero_vad.onnx)
|
||||||
|
vad_backend: energy
|
||||||
|
energy_vad_rms_high: 8000 # int16 块 RMS,连续达到 start_chunks 块判为开始说话
|
||||||
|
energy_vad_rms_low: 5000 # 连续 end_chunks 块低于此判为结束(高噪底时单独不够)
|
||||||
|
# 相对峰值判停:首字很响时若仍用 0.88,句中略轻易误判「已说完」。可改为 0.75~0.82;设 0 则关闭相对判据(只靠 rms_low + 尾静音)
|
||||||
|
energy_vad_end_peak_ratio: 0.80
|
||||||
|
# 每音频块后对句内峰值衰减系数(0.95~0.999),与 end_peak_ratio 配合减少「长句说到一半被切」
|
||||||
|
energy_vad_utt_peak_decay: 0.988
|
||||||
|
energy_vad_start_chunks: 4
|
||||||
|
energy_vad_end_chunks: 15
|
||||||
|
# 预缓冲:在检测到语音开始时,把开始前的一小段音频也拼进语音段
|
||||||
|
# 用于避免 VAD 起始判定稍慢导致“丢字/丢开头”
|
||||||
|
pre_speech_max_seconds: 0.8
|
||||||
|
# Socket 命令发送成功后,是否用 TTS 语音回复(需 Kokoro 模型与 sounddevice)
|
||||||
|
ack_tts_enabled: true
|
||||||
|
# 若配置了 ack_tts_phrases 且非空:仅下列命令会播报,且每次从对应列表随机选一句。
|
||||||
|
# 若未配置或为空 {}:回退为全局 ack_tts_text,所有成功命令均播报同一句(并可预缓存波形)。
|
||||||
|
# 阻塞预加载时会对所有「不重复」的备选句逐条合成并缓存;句数越多启动阶段越久,但播报为低延迟播放缓存。
|
||||||
|
ack_tts_phrases:
|
||||||
|
takeoff:
|
||||||
|
- "收到,正在控制无人机起飞"
|
||||||
|
- "明白,正在准备起飞"
|
||||||
|
- "懂你意思, 这就开始起飞"
|
||||||
|
land:
|
||||||
|
- "收到,正在控制无人机降落"
|
||||||
|
- "明白,马上降落"
|
||||||
|
- "懂你意思, 这就开始降落"
|
||||||
|
follow:
|
||||||
|
- "好的,我将持续跟随你,但请你不要移动太快"
|
||||||
|
- "主人,我已经开始跟随模式,请不要突然离开我的视线"
|
||||||
|
- "我已经开启了跟随模式"
|
||||||
|
hover:
|
||||||
|
- "我已悬停"
|
||||||
|
- "我已经停下脚步了"
|
||||||
|
- "放心,我现在开始不会乱动了"
|
||||||
|
# 未使用 ack_tts_phrases 时的全局固定应答(旧行为)
|
||||||
|
ack_tts_text: "收到!执行命令!"
|
||||||
|
# 应答波形磁盘缓存:文案与 tts 配置未变时从 cache/ack_tts_pcm/ 读取,跳过后续启动时的逐条合成(可明显加快二次启动)
|
||||||
|
ack_tts_disk_cache: true
|
||||||
|
# 启动时预加载 Kokoro(首次加载约需十余秒)
|
||||||
|
ack_tts_prewarm: true
|
||||||
|
# true:启动阶段阻塞到 TTS 就绪后再进入监听(命令成功后可马上播报);false:后台加载(首条命令可能仍要等十余秒)
|
||||||
|
ack_tts_prewarm_blocking: true
|
||||||
|
# 播报应答前暂时停止 PyAudio 麦克风(Windows 上输入/输出同时占用时常导致扬声器无声;单独跑 tts 脚本时无此问题)
|
||||||
|
# 香橙派 ES8388:暂停/恢复采集后出现「VAD RMS≈0、像没电平」时,可改为 false(避免 stop/start 采集),或加大 ROCKET_MIC_RESTART_SETTLE_MS。
|
||||||
|
ack_pause_mic_for_playback: true
|
||||||
|
|
||||||
|
# ***********主程序 main_app(TakeoffPrintRecognizer)*************
|
||||||
|
assistant:
|
||||||
|
# 「keywords.yaml 里 takeoff 词 → 本地 offboard + WAV」捷径;默认关,飞控走云端 flight_intent / ROS 桥即可。
|
||||||
|
# 若需恢复口令起飞:此处改为 true,或启动前 export ROCKET_LOCAL_KEYWORD_TAKEOFF=1(非空环境变量优先于本项)。
|
||||||
|
local_keyword_takeoff_enabled: false
|
||||||
|
|
||||||
|
# ***********日志配置 *************
|
||||||
|
logging:
|
||||||
|
level: "INFO"
|
||||||
|
debug: false
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
71
voice_drone/config/wake_word.yaml
Normal file
71
voice_drone/config/wake_word.yaml
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# 唤醒词配置
|
||||||
|
wake_word:
|
||||||
|
# 主唤醒词(标准形式)
|
||||||
|
primary: "无人机"
|
||||||
|
|
||||||
|
# 唤醒词变体映射(支持同音字、拼音、错别字等)
|
||||||
|
variants:
|
||||||
|
# 标准形式
|
||||||
|
- "无人机"
|
||||||
|
- "无 人 机" # 带空格
|
||||||
|
- "无人机,"
|
||||||
|
- "无人机 。"
|
||||||
|
- "无人机。"
|
||||||
|
- "喂 无人机"
|
||||||
|
- "喂,无人机"
|
||||||
|
- "嗨 无人机"
|
||||||
|
- "嘿 无人机"
|
||||||
|
- "哈喽 无人机"
|
||||||
|
|
||||||
|
# 常见误识别 / 近音
|
||||||
|
- "五人机"
|
||||||
|
- "吾人机"
|
||||||
|
- "无人鸡"
|
||||||
|
- "无认机"
|
||||||
|
- "无任机"
|
||||||
|
|
||||||
|
# 拼音变体(小写)
|
||||||
|
- "wu ren ji"
|
||||||
|
- "wurenj"
|
||||||
|
- "wu renji"
|
||||||
|
- "wuren ji"
|
||||||
|
- "wu ren, ji"
|
||||||
|
- "wu ren ji."
|
||||||
|
- "hey wu ren ji"
|
||||||
|
- "hi wu ren ji"
|
||||||
|
- "hello wu ren ji"
|
||||||
|
- "ok wu ren ji"
|
||||||
|
|
||||||
|
# 拼音变体(大小写混合)
|
||||||
|
- "Wu Ren Ji"
|
||||||
|
- "WU REN JI"
|
||||||
|
- "Wu ren ji"
|
||||||
|
- "Hey Wu Ren Ji"
|
||||||
|
- "Hi Wu Ren Ji"
|
||||||
|
|
||||||
|
# 短说 / 部分匹配(易与日常用语冲突时可删去「无人」单独一项)
|
||||||
|
- "无人"
|
||||||
|
# 勿单独使用「人机」:易在「牛人机学庭」等误识别里子串命中,导致假唤醒
|
||||||
|
- "hey 无人"
|
||||||
|
- "嗨 无人"
|
||||||
|
- "喂 无人"
|
||||||
|
|
||||||
|
# 匹配模式配置
|
||||||
|
matching:
|
||||||
|
# 是否启用模糊匹配(同音字、拼音)
|
||||||
|
enable_fuzzy: true
|
||||||
|
|
||||||
|
# 是否启用部分匹配(代码侧:主词长度足够时取前半段;短词依赖上面 variants)
|
||||||
|
enable_partial: true
|
||||||
|
|
||||||
|
# 是否忽略大小写
|
||||||
|
ignore_case: true
|
||||||
|
|
||||||
|
# 是否忽略空格
|
||||||
|
ignore_spaces: true
|
||||||
|
|
||||||
|
# 最小匹配长度(字符数,用于部分匹配)
|
||||||
|
min_match_length: 2
|
||||||
|
|
||||||
|
# 相似度阈值(0-1,用于模糊匹配)
|
||||||
|
similarity_threshold: 0.7
|
||||||
714
voice_drone/core/audio.py
Normal file
714
voice_drone/core/audio.py
Normal file
@ -0,0 +1,714 @@
|
|||||||
|
"""
|
||||||
|
音频采集模块 - 优化版本
|
||||||
|
|
||||||
|
输入设备(麦克风)选择(已简化):
|
||||||
|
- 若 system.yaml 中 audio.input_device_index 为整数:只尝试该 PyAudio 索引(无则启动失败并列设备)。
|
||||||
|
- 若为 null:依次尝试系统默认输入、所有 maxInputChannels>0 的设备。
|
||||||
|
rocket_drone_audio 启动时可交互选择并写入 input_device_index(见 src.core.mic_device_select)。
|
||||||
|
"""
|
||||||
|
from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio
|
||||||
|
|
||||||
|
fix_ld_path_for_portaudio()
|
||||||
|
|
||||||
|
import re
|
||||||
|
import pyaudio
|
||||||
|
import numpy as np
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from voice_drone.core.configuration import SYSTEM_AUDIO_CONFIG
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("audio.capture.optimized")
|
||||||
|
|
||||||
|
|
||||||
|
class AudioCaptureOptimized:
|
||||||
|
"""
|
||||||
|
优化版音频采集器
|
||||||
|
|
||||||
|
使用回调模式 + 队列,实现非阻塞音频采集
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
初始化音频采集器
|
||||||
|
"""
|
||||||
|
# 确保数值类型正确(从 YAML 读取可能是字符串)
|
||||||
|
self.sample_rate = int(SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000))
|
||||||
|
self.channels = int(SYSTEM_AUDIO_CONFIG.get("channels", 1))
|
||||||
|
self.chunk_size = int(SYSTEM_AUDIO_CONFIG.get("frame_size", 1024))
|
||||||
|
self.sample_width = int(SYSTEM_AUDIO_CONFIG.get("sample_width", 2))
|
||||||
|
|
||||||
|
# 高性能模式配置
|
||||||
|
self.buffer_queue_size = int(SYSTEM_AUDIO_CONFIG.get("buffer_queue_size", 10))
|
||||||
|
self._prefer_stereo_capture = bool(
|
||||||
|
SYSTEM_AUDIO_CONFIG.get("prefer_stereo_capture", True)
|
||||||
|
)
|
||||||
|
raw_idx = SYSTEM_AUDIO_CONFIG.get("input_device_index", None)
|
||||||
|
self._input_device_index_cfg: Optional[int] = (
|
||||||
|
int(raw_idx) if raw_idx is not None and str(raw_idx).strip() != "" else None
|
||||||
|
)
|
||||||
|
|
||||||
|
tr = SYSTEM_AUDIO_CONFIG.get("audio_open_try_rates")
|
||||||
|
if tr:
|
||||||
|
raw_rates: List[int] = [int(x) for x in tr if x is not None]
|
||||||
|
else:
|
||||||
|
raw_rates = [self.sample_rate, 48000, 44100, 32000]
|
||||||
|
seen_r: set[int] = set()
|
||||||
|
self._open_try_rates: List[int] = []
|
||||||
|
for r in raw_rates:
|
||||||
|
if r not in seen_r:
|
||||||
|
seen_r.add(r)
|
||||||
|
self._open_try_rates.append(r)
|
||||||
|
|
||||||
|
# 逻辑通道(送给 VAD/STT 的 mono);_pa_channels 为 PortAudio 实际打开的通道数
|
||||||
|
self._pa_channels = self.channels
|
||||||
|
self._stereo_downmix = False
|
||||||
|
self._pa_open_sample_rate: int = self.sample_rate
|
||||||
|
|
||||||
|
self.audio = pyaudio.PyAudio()
|
||||||
|
self.format = self.audio.get_format_from_width(self.sample_width)
|
||||||
|
|
||||||
|
# 使用队列缓冲音频数据(非阻塞)
|
||||||
|
self.audio_queue = queue.Queue(maxsize=self.buffer_queue_size)
|
||||||
|
self.stream: Optional[pyaudio.Stream] = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"优化版音频采集器初始化成功: "
|
||||||
|
f"采样率={self.sample_rate}Hz, "
|
||||||
|
f"块大小={self.chunk_size}, "
|
||||||
|
f"使用回调模式+队列缓冲"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _device_hw_tuple_in_name(self, dev_name: str) -> Optional[Tuple[int, int]]:
|
||||||
|
m = re.search(r"\(hw:(\d+),\s*(\d+)\)", dev_name)
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
return int(m.group(1)), int(m.group(2))
|
||||||
|
|
||||||
|
def _ordered_input_candidates(self) -> Tuple[List[int], List[int]]:
|
||||||
|
preferred: List[int] = []
|
||||||
|
seen: set[int] = set()
|
||||||
|
|
||||||
|
def add(idx: Optional[int]) -> None:
|
||||||
|
if idx is None:
|
||||||
|
return
|
||||||
|
ii = int(idx)
|
||||||
|
if ii in seen:
|
||||||
|
return
|
||||||
|
seen.add(ii)
|
||||||
|
preferred.append(ii)
|
||||||
|
|
||||||
|
# 配置了整数索引:只打开该设备(与交互选择 / CLI 写入一致)
|
||||||
|
if self._input_device_index_cfg is not None:
|
||||||
|
add(self._input_device_index_cfg)
|
||||||
|
return preferred, []
|
||||||
|
|
||||||
|
try:
|
||||||
|
add(int(self.audio.get_default_input_device_info()["index"]))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
for i in range(self.audio.get_device_count()):
|
||||||
|
try:
|
||||||
|
inf = self.audio.get_device_info_by_index(i)
|
||||||
|
if int(inf.get("maxInputChannels", 0)) > 0:
|
||||||
|
add(i)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
fallback: List[int] = []
|
||||||
|
for i in range(self.audio.get_device_count()):
|
||||||
|
if i in seen:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
self.audio.get_device_info_by_index(i)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
fallback.append(i)
|
||||||
|
return preferred, fallback
|
||||||
|
|
||||||
|
def _channel_plan(self, max_in: int, dev_name: str) -> List[Tuple[int, bool]]:
|
||||||
|
ch = self.channels
|
||||||
|
pref = self._prefer_stereo_capture
|
||||||
|
if ch != 1:
|
||||||
|
return [(ch, False)]
|
||||||
|
if max_in <= 0:
|
||||||
|
logger.warning(
|
||||||
|
"设备 %s 报告 maxInputChannels=%s,将尝试 mono / stereo",
|
||||||
|
dev_name or "?",
|
||||||
|
max_in,
|
||||||
|
)
|
||||||
|
return [(1, False), (2, True)]
|
||||||
|
if max_in == 1:
|
||||||
|
return [(1, False)]
|
||||||
|
if pref:
|
||||||
|
return [(2, True), (1, False)]
|
||||||
|
return [(1, False)]
|
||||||
|
|
||||||
|
def _frames_per_buffer_for_rate(self, pa_rate: int) -> int:
|
||||||
|
if pa_rate <= 0:
|
||||||
|
pa_rate = self.sample_rate
|
||||||
|
return max(128, int(round(self.chunk_size * pa_rate / self.sample_rate)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resample_linear_int16(
|
||||||
|
x: np.ndarray, sr_in: int, sr_out: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
if sr_in == sr_out or x.size == 0:
|
||||||
|
return x
|
||||||
|
n_out = max(1, int(round(x.size * (sr_out / sr_in))))
|
||||||
|
t_in = np.arange(x.size, dtype=np.float64)
|
||||||
|
t_out = np.linspace(0.0, float(x.size - 1), n_out, dtype=np.float64)
|
||||||
|
y = np.interp(t_out, t_in, x.astype(np.float32))
|
||||||
|
return np.clip(np.round(y), -32768, 32767).astype(np.int16)
|
||||||
|
|
||||||
|
def _try_open_on_device(self, input_device_index: int) -> bool:
|
||||||
|
try:
|
||||||
|
dev = self.audio.get_device_info_by_index(input_device_index)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
max_in = int(dev.get("maxInputChannels", 0))
|
||||||
|
dev_name = str(dev.get("name", ""))
|
||||||
|
if (
|
||||||
|
max_in <= 0
|
||||||
|
and self._input_device_index_cfg is not None
|
||||||
|
and int(input_device_index) == int(self._input_device_index_cfg)
|
||||||
|
and self._prefer_stereo_capture
|
||||||
|
and self.channels == 1
|
||||||
|
):
|
||||||
|
max_in = 2
|
||||||
|
logger.warning(
|
||||||
|
"设备 %s 上报 maxInputChannels=0,假定 2 通道以尝试 ES8388 立体声采集",
|
||||||
|
input_device_index,
|
||||||
|
)
|
||||||
|
plan = self._channel_plan(max_in, dev_name)
|
||||||
|
hw_t = self._device_hw_tuple_in_name(dev_name)
|
||||||
|
|
||||||
|
for pa_ch, stereo_dm in plan:
|
||||||
|
self._pa_channels = pa_ch
|
||||||
|
self._stereo_downmix = stereo_dm
|
||||||
|
if stereo_dm and pa_ch == 2:
|
||||||
|
logger.info(
|
||||||
|
"输入按立体声打开并下混 mono(%s index=%s)",
|
||||||
|
dev_name,
|
||||||
|
input_device_index,
|
||||||
|
)
|
||||||
|
for rate in self._open_try_rates:
|
||||||
|
fpb = self._frames_per_buffer_for_rate(int(rate))
|
||||||
|
try:
|
||||||
|
self.stream = self.audio.open(
|
||||||
|
format=self.format,
|
||||||
|
channels=self._pa_channels,
|
||||||
|
rate=int(rate),
|
||||||
|
input=True,
|
||||||
|
input_device_index=input_device_index,
|
||||||
|
frames_per_buffer=fpb,
|
||||||
|
stream_callback=self._audio_callback,
|
||||||
|
start=False,
|
||||||
|
)
|
||||||
|
self.stream.start_stream()
|
||||||
|
self._pa_open_sample_rate = int(rate)
|
||||||
|
extra = (
|
||||||
|
f" hw=card{hw_t[0]}dev{hw_t[1]}" if hw_t else ""
|
||||||
|
)
|
||||||
|
if self._pa_open_sample_rate != self.sample_rate:
|
||||||
|
logger.warning(
|
||||||
|
"输入实际 %s Hz,将重采样为 %s Hz 供 VAD/STT",
|
||||||
|
self._pa_open_sample_rate,
|
||||||
|
self.sample_rate,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"音频流启动成功 index=%s name=%r PA_ch=%s PA_rate=%s 逻辑rate=%s%s",
|
||||||
|
input_device_index,
|
||||||
|
dev_name,
|
||||||
|
self._pa_channels,
|
||||||
|
self._pa_open_sample_rate,
|
||||||
|
self.sample_rate,
|
||||||
|
extra,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
if self.stream is not None:
|
||||||
|
try:
|
||||||
|
self.stream.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.stream = None
|
||||||
|
logger.warning(
|
||||||
|
"打开失败 index=%s ch=%s rate=%s: %s",
|
||||||
|
input_device_index,
|
||||||
|
pa_ch,
|
||||||
|
rate,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _audio_callback(self, in_data, frame_count, time_info, status):
|
||||||
|
"""
|
||||||
|
音频回调函数(非阻塞)
|
||||||
|
"""
|
||||||
|
if status:
|
||||||
|
logger.warning(f"音频流状态: {status}")
|
||||||
|
|
||||||
|
# 将数据放入队列(非阻塞)
|
||||||
|
try:
|
||||||
|
self.audio_queue.put(in_data, block=False)
|
||||||
|
except queue.Full:
|
||||||
|
logger.warning("音频队列已满,丢弃数据块")
|
||||||
|
|
||||||
|
return (None, pyaudio.paContinue)
|
||||||
|
|
||||||
|
def _log_input_devices_for_user(self) -> None:
|
||||||
|
"""列出 PortAudio 全部设备(含 in_ch=0),便于选 --input-index / 核对子串。"""
|
||||||
|
n_dev = self.audio.get_device_count()
|
||||||
|
if n_dev <= 0:
|
||||||
|
print(
|
||||||
|
"[audio] PyAudio get_device_count()=0,多为 ALSA/PortAudio 未初始化;"
|
||||||
|
"请用 bash with_system_alsa.sh python … 启动。",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
logger.error("PyAudio 枚举不到任何设备")
|
||||||
|
return
|
||||||
|
lines: List[str] = []
|
||||||
|
for i in range(n_dev):
|
||||||
|
try:
|
||||||
|
inf = self.audio.get_device_info_by_index(i)
|
||||||
|
mic = int(inf.get("maxInputChannels", 0))
|
||||||
|
outc = int(inf.get("maxOutputChannels", 0))
|
||||||
|
name = str(inf.get("name", "?"))
|
||||||
|
mark = " <- 可录音" if mic > 0 else ""
|
||||||
|
lines.append(f" [{i}] in={mic} out={outc} {name}{mark}")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
msg = "\n".join(lines)
|
||||||
|
logger.error("PortAudio 设备列表:\n%s", msg)
|
||||||
|
print(
|
||||||
|
"[audio] PortAudio 设备列表(in>0 才可作输入;若板载显示 in=0 仍可用 probe 试采):\n"
|
||||||
|
+ msg,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start_stream(self) -> None:
|
||||||
|
"""启动音频流(回调模式)"""
|
||||||
|
if self.stream is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
preferred, fallback = self._ordered_input_candidates()
|
||||||
|
to_try: List[int] = preferred + fallback
|
||||||
|
if not to_try:
|
||||||
|
print(
|
||||||
|
"[audio] 无任何输入候选。请检查 PortAudio/ALSA(建议:bash with_system_alsa.sh python …)。",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if self._input_device_index_cfg is not None:
|
||||||
|
logger.error(
|
||||||
|
"已配置 input_device_index=%s 但无效或不可打开",
|
||||||
|
self._input_device_index_cfg,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[audio] 当前配置的 PyAudio 索引 {self._input_device_index_cfg} 不可用,"
|
||||||
|
"请改 system.yaml 或重新运行交互选设备。",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._log_input_devices_for_user()
|
||||||
|
raise OSError("未找到任何 PyAudio 输入候选设备")
|
||||||
|
|
||||||
|
for input_device_index in to_try:
|
||||||
|
if self._try_open_on_device(input_device_index):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.error("启动音频流失败:全部候选设备无法打开")
|
||||||
|
self._log_input_devices_for_user()
|
||||||
|
raise OSError("启动音频流失败:全部候选设备无法打开")
|
||||||
|
|
||||||
|
def stop_stream(self) -> None:
|
||||||
|
"""停止音频流"""
|
||||||
|
if self.stream is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.stream.stop_stream()
|
||||||
|
self.stream.close()
|
||||||
|
self.stream = None
|
||||||
|
self._pa_open_sample_rate = self.sample_rate
|
||||||
|
# 清空队列
|
||||||
|
while not self.audio_queue.empty():
|
||||||
|
try:
|
||||||
|
self.audio_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
logger.info("音频流已停止")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"停止音频流失败: {e}")
|
||||||
|
|
||||||
|
def read_chunk(self, timeout: float = 0.1) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
读取一个音频块(非阻塞)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
音频数据(bytes),如果超时则返回 None
|
||||||
|
"""
|
||||||
|
if self.stream is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.audio_queue.get(timeout=timeout)
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def read_chunk_numpy(self, timeout: float = 0.1) -> Optional[np.ndarray]:
|
||||||
|
"""读取一个音频块并转换为 numpy 数组(非阻塞)"""
|
||||||
|
data = self.read_chunk(timeout)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sample_size = self._pa_channels * self.sample_width
|
||||||
|
if len(data) % sample_size != 0:
|
||||||
|
aligned_len = (len(data) // sample_size) * sample_size
|
||||||
|
if aligned_len == 0:
|
||||||
|
return None
|
||||||
|
data = data[:aligned_len]
|
||||||
|
|
||||||
|
mono = np.frombuffer(data, dtype="<i2")
|
||||||
|
if self._stereo_downmix and self._pa_channels == 2:
|
||||||
|
n = mono.size // 2
|
||||||
|
if n == 0:
|
||||||
|
return None
|
||||||
|
s = mono[: n * 2].reshape(n, 2).astype(np.int32)
|
||||||
|
mono = ((s[:, 0] + s[:, 1]) // 2).astype(np.int16)
|
||||||
|
if self._pa_open_sample_rate != self.sample_rate:
|
||||||
|
mono = self._resample_linear_int16(
|
||||||
|
mono, self._pa_open_sample_rate, self.sample_rate
|
||||||
|
)
|
||||||
|
return mono
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start_stream()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.stop_stream()
|
||||||
|
self.audio.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalRMS:
|
||||||
|
"""
|
||||||
|
增量 RMS 计算器(滑动窗口)
|
||||||
|
|
||||||
|
用于 AGC,避免每次重新计算整个音频块的 RMS
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 1024):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
window_size: 滑动窗口大小
|
||||||
|
"""
|
||||||
|
self.window_size = window_size
|
||||||
|
self.buffer = np.zeros(window_size, dtype=np.float32)
|
||||||
|
self.sum_sq = 0.0
|
||||||
|
self.idx = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, sample: float) -> float:
|
||||||
|
"""
|
||||||
|
更新 RMS 值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: 新的采样值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
当前 RMS 值
|
||||||
|
"""
|
||||||
|
if self.count < self.window_size:
|
||||||
|
# 填充阶段
|
||||||
|
self.buffer[self.count] = sample
|
||||||
|
self.sum_sq += sample * sample
|
||||||
|
self.count += 1
|
||||||
|
if self.count == 0:
|
||||||
|
return 0.0
|
||||||
|
return np.sqrt(self.sum_sq / self.count)
|
||||||
|
else:
|
||||||
|
# 滑动窗口阶段
|
||||||
|
old_sq = self.buffer[self.idx] * self.buffer[self.idx]
|
||||||
|
self.sum_sq = self.sum_sq - old_sq + sample * sample
|
||||||
|
self.buffer[self.idx] = sample
|
||||||
|
self.idx = (self.idx + 1) % self.window_size
|
||||||
|
return np.sqrt(self.sum_sq / self.window_size)
|
||||||
|
|
||||||
|
def update_batch(self, samples: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
批量更新 RMS 值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples: 采样数组
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
当前 RMS 值
|
||||||
|
"""
|
||||||
|
for sample in samples:
|
||||||
|
self.update(sample)
|
||||||
|
return np.sqrt(self.sum_sq / min(self.count, self.window_size))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""重置计算器"""
|
||||||
|
self.buffer.fill(0.0)
|
||||||
|
self.sum_sq = 0.0
|
||||||
|
self.idx = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
|
||||||
|
class LightweightNoiseReduction:
|
||||||
|
"""
|
||||||
|
轻量级降噪算法
|
||||||
|
|
||||||
|
使用简单的高通滤波 + 谱减法,性能比 noisereduce 快 10-20 倍
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sample_rate: int = 16000, cutoff: float = 80.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
sample_rate: 采样率
|
||||||
|
cutoff: 高通滤波截止频率(Hz)
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.cutoff = cutoff
|
||||||
|
|
||||||
|
# 简单的 IIR 高通滤波器系数(一阶 Butterworth)
|
||||||
|
# H(z) = (1 - z^-1) / (1 - 0.99*z^-1)
|
||||||
|
self.alpha = np.exp(-2.0 * np.pi * cutoff / sample_rate)
|
||||||
|
self.prev_input = 0.0
|
||||||
|
self.prev_output = 0.0
|
||||||
|
|
||||||
|
def process(self, audio: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
处理音频(高通滤波)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: 音频数据(float32,范围 [-1, 1])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的音频
|
||||||
|
"""
|
||||||
|
if audio.dtype != np.float32:
|
||||||
|
audio = audio.astype(np.float32)
|
||||||
|
|
||||||
|
# 简单的一阶高通滤波
|
||||||
|
output = np.zeros_like(audio)
|
||||||
|
for i in range(len(audio)):
|
||||||
|
output[i] = self.alpha * (self.prev_output + audio[i] - self.prev_input)
|
||||||
|
self.prev_input = audio[i]
|
||||||
|
self.prev_output = output[i]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""重置滤波器状态"""
|
||||||
|
self.prev_input = 0.0
|
||||||
|
self.prev_output = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPreprocessorOptimized:
|
||||||
|
"""
|
||||||
|
优化版音频预处理器
|
||||||
|
|
||||||
|
性能优化:
|
||||||
|
1. 轻量级降噪(替代 noisereduce)
|
||||||
|
2. 增量 AGC 计算
|
||||||
|
3. 减少类型转换
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, enable_noise_reduction: Optional[bool] = None,
|
||||||
|
enable_agc: Optional[bool] = None):
|
||||||
|
"""
|
||||||
|
初始化音频预处理器
|
||||||
|
"""
|
||||||
|
# 从配置读取
|
||||||
|
if enable_noise_reduction is None:
|
||||||
|
enable_noise_reduction = SYSTEM_AUDIO_CONFIG.get("noise_reduce", True)
|
||||||
|
if enable_agc is None:
|
||||||
|
enable_agc = SYSTEM_AUDIO_CONFIG.get("agc", True)
|
||||||
|
|
||||||
|
self.enable_noise_reduction = enable_noise_reduction
|
||||||
|
self.enable_agc = enable_agc
|
||||||
|
self.sample_rate = int(SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000))
|
||||||
|
|
||||||
|
# AGC 参数(确保类型正确,从 YAML 读取可能是字符串)
|
||||||
|
self.agc_target_db = float(SYSTEM_AUDIO_CONFIG.get("agc_target_db", -20.0))
|
||||||
|
self.agc_gain_min = float(SYSTEM_AUDIO_CONFIG.get("agc_gain_min", 0.1))
|
||||||
|
self.agc_gain_max = float(SYSTEM_AUDIO_CONFIG.get("agc_gain_max", 10.0))
|
||||||
|
self.agc_rms_threshold = float(SYSTEM_AUDIO_CONFIG.get("agc_rms_threshold", 1e-6))
|
||||||
|
self._agc_alpha = float(SYSTEM_AUDIO_CONFIG.get("agc_smoothing_alpha", 0.1))
|
||||||
|
self._agc_alpha = max(0.02, min(0.95, self._agc_alpha))
|
||||||
|
# 当需要抬增益(小声/巨响过后)时用更大系数,避免长时间压在 agc_gain_min
|
||||||
|
self._agc_release_alpha = float(
|
||||||
|
SYSTEM_AUDIO_CONFIG.get("agc_release_alpha", 0.45)
|
||||||
|
)
|
||||||
|
self._agc_release_alpha = max(self._agc_alpha, min(0.95, self._agc_release_alpha))
|
||||||
|
|
||||||
|
# 初始化组件
|
||||||
|
if enable_noise_reduction:
|
||||||
|
self.noise_reducer = LightweightNoiseReduction(
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
cutoff=80.0 # 可配置
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.noise_reducer = None
|
||||||
|
|
||||||
|
if enable_agc:
|
||||||
|
# 使用增量 RMS 计算器
|
||||||
|
window_size = int(SYSTEM_AUDIO_CONFIG.get("frame_size", 1024))
|
||||||
|
self.rms_calculator = IncrementalRMS(window_size=window_size)
|
||||||
|
self.current_gain = 1.0 # 缓存当前增益
|
||||||
|
else:
|
||||||
|
self.rms_calculator = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"优化版音频预处理器初始化完成: "
|
||||||
|
f"降噪={'启用(轻量级)' if enable_noise_reduction else '禁用'}, "
|
||||||
|
f"自动增益控制={'启用(增量)' if enable_agc else '禁用'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
重置高通滤波与 AGC 状态。应在「暂停采集再重新 start_stream」之后调用,
|
||||||
|
避免停麦/播 TTS 期间的状态带到新流上(否则易出现恢复后长时间 RMS≈0 或电平怪异)。
|
||||||
|
"""
|
||||||
|
if self.noise_reducer is not None:
|
||||||
|
self.noise_reducer.reset()
|
||||||
|
if self.rms_calculator is not None:
|
||||||
|
self.rms_calculator.reset()
|
||||||
|
if self.enable_agc:
|
||||||
|
self.current_gain = 1.0
|
||||||
|
|
||||||
|
def reset_agc_state(self) -> None:
|
||||||
|
"""
|
||||||
|
每段语音结束或需Recovery时调用:清空 RMS 滑窗并将增益重置为 1。
|
||||||
|
避免短时强噪声把 current_gain 压在 agc_gain_min、滑窗仍含高能量导致后续 RMS≈0。
|
||||||
|
"""
|
||||||
|
if not self.enable_agc or self.rms_calculator is None:
|
||||||
|
return
|
||||||
|
self.rms_calculator.reset()
|
||||||
|
self.current_gain = 1.0
|
||||||
|
|
||||||
|
def reduce_noise(self, audio_data: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
轻量级降噪处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: 音频数据(int16 或 float32)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
降噪后的音频数据
|
||||||
|
"""
|
||||||
|
if not self.enable_noise_reduction or self.noise_reducer is None:
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
# 转换为 float32
|
||||||
|
if audio_data.dtype == np.int16:
|
||||||
|
audio_float = audio_data.astype(np.float32) / 32768.0
|
||||||
|
is_int16 = True
|
||||||
|
else:
|
||||||
|
audio_float = audio_data.astype(np.float32)
|
||||||
|
is_int16 = False
|
||||||
|
|
||||||
|
# 轻量级降噪
|
||||||
|
reduced = self.noise_reducer.process(audio_float)
|
||||||
|
|
||||||
|
# 转换回原始格式
|
||||||
|
if is_int16:
|
||||||
|
reduced = (reduced * 32768.0).astype(np.int16)
|
||||||
|
|
||||||
|
return reduced
|
||||||
|
|
||||||
|
def automatic_gain_control(self, audio_data: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
自动增益控制(使用增量 RMS)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: 音频数据(int16 或 float32)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
增益调整后的音频数据
|
||||||
|
"""
|
||||||
|
if not self.enable_agc or self.rms_calculator is None:
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
# 转换为 float32
|
||||||
|
if audio_data.dtype == np.int16:
|
||||||
|
audio_float = audio_data.astype(np.float32) / 32768.0
|
||||||
|
is_int16 = True
|
||||||
|
else:
|
||||||
|
audio_float = audio_data.astype(np.float32)
|
||||||
|
is_int16 = False
|
||||||
|
|
||||||
|
# 使用增量 RMS 计算
|
||||||
|
rms = self.rms_calculator.update_batch(audio_float)
|
||||||
|
|
||||||
|
if rms < self.agc_rms_threshold:
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
# 计算增益(可以进一步优化:使用滑动平均)
|
||||||
|
current_db = 20 * np.log10(rms)
|
||||||
|
gain_db = self.agc_target_db - current_db
|
||||||
|
gain_linear = 10 ** (gain_db / 20.0)
|
||||||
|
gain_linear = np.clip(gain_linear, self.agc_gain_min, self.agc_gain_max)
|
||||||
|
|
||||||
|
# 压低增益用较小 alpha;需要恢复(gain_linear 明显高于当前)时用 release alpha
|
||||||
|
if gain_linear > self.current_gain * 1.08:
|
||||||
|
alpha = self._agc_release_alpha
|
||||||
|
else:
|
||||||
|
alpha = self._agc_alpha
|
||||||
|
self.current_gain = alpha * gain_linear + (1 - alpha) * self.current_gain
|
||||||
|
|
||||||
|
# 应用增益
|
||||||
|
adjusted = audio_float * self.current_gain
|
||||||
|
adjusted = np.clip(adjusted, -1.0, 1.0)
|
||||||
|
|
||||||
|
# 转换回原始格式
|
||||||
|
if is_int16:
|
||||||
|
adjusted = (adjusted * 32768.0).astype(np.int16)
|
||||||
|
|
||||||
|
return adjusted
|
||||||
|
|
||||||
|
def process(self, audio_data: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
完整的预处理流程(优化版)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: 音频数据(numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预处理后的音频数据
|
||||||
|
"""
|
||||||
|
processed = audio_data.copy()
|
||||||
|
|
||||||
|
# 降噪
|
||||||
|
if self.enable_noise_reduction:
|
||||||
|
processed = self.reduce_noise(processed)
|
||||||
|
|
||||||
|
# 自动增益控制
|
||||||
|
if self.enable_agc:
|
||||||
|
processed = self.automatic_gain_control(processed)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
# 向后兼容别名(保持API一致性)
|
||||||
|
AudioCapture = AudioCaptureOptimized
|
||||||
|
AudioPreprocessor = AudioPreprocessorOptimized
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 优化版使用
|
||||||
|
with AudioCapture() as capture:
|
||||||
|
preprocessor = AudioPreprocessor()
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
chunk = capture.read_chunk_numpy(timeout=0.1)
|
||||||
|
if chunk is not None:
|
||||||
|
processed = preprocessor.process(chunk)
|
||||||
|
print(f"处理了 {len(processed)} 个采样点")
|
||||||
78
voice_drone/core/cloud_dialog_v1.py
Normal file
78
voice_drone/core/cloud_dialog_v1.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"""cloud_voice_dialog_v1:dialog_result 约定(见 docs/CLOUD_VOICE_FLIGHT_CONFIRM_v1.md)。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
CLOUD_VOICE_DIALOG_V1 = "cloud_voice_dialog_v1"
|
||||||
|
|
||||||
|
MSG_CONFIRM_TIMEOUT = "未收到确认指令,请重新下发指令。"
|
||||||
|
MSG_CANCELLED = "已取消指令,请重新唤醒后下发指令。"
|
||||||
|
MSG_CONFIRM_EXECUTING = "开始执行飞控指令。"
|
||||||
|
MSG_PROMPT_LISTEN_TIMEOUT = "未检测到语音,请重新唤醒后再说。"
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_phrase_text(s: str) -> str:
|
||||||
|
"""去首尾空白、合并连续空白。"""
|
||||||
|
return " ".join((s or "").strip().split())
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_tail_punct(s: str) -> str:
|
||||||
|
return s.rstrip("。!!??,, \t")
|
||||||
|
|
||||||
|
|
||||||
|
def match_phrase_list(norm: str, phrases: Any) -> bool:
|
||||||
|
"""
|
||||||
|
命中规则(适配「请回复确认或取消」类长提示 + 只说「确认」「取消」):
|
||||||
|
- 去尾标点后 **全等** 短语;或
|
||||||
|
- 短语为 **子串** 且整句长度 <= len(短语)+2,避免用户复述整段提示时同时含「确认」「取消」而误触。
|
||||||
|
"""
|
||||||
|
if not isinstance(phrases, list) or not norm:
|
||||||
|
return False
|
||||||
|
base = _strip_tail_punct(normalize_phrase_text(norm))
|
||||||
|
if not base:
|
||||||
|
return False
|
||||||
|
for p in phrases:
|
||||||
|
raw = _strip_tail_punct((p or "").strip())
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
if base == raw:
|
||||||
|
return True
|
||||||
|
if raw in base and len(base) <= len(raw) + 2:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_confirm_dict(raw: Any) -> dict[str, Any] | None:
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
return None
|
||||||
|
required = raw.get("required")
|
||||||
|
if not isinstance(required, bool):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
timeout_sec = float(raw.get("timeout_sec", 10))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
timeout_sec = 10.0
|
||||||
|
timeout_sec = max(1.0, min(120.0, timeout_sec))
|
||||||
|
cp = raw.get("confirm_phrases")
|
||||||
|
kp = raw.get("cancel_phrases")
|
||||||
|
if not isinstance(cp, list) or not cp:
|
||||||
|
return None
|
||||||
|
if not isinstance(kp, list) or not kp:
|
||||||
|
return None
|
||||||
|
pending = raw.get("pending_id")
|
||||||
|
if pending is not None and not isinstance(pending, str):
|
||||||
|
pending = str(pending)
|
||||||
|
cplist = [str(x) for x in cp if str(x).strip()]
|
||||||
|
kplist = [str(x) for x in kp if str(x).strip()]
|
||||||
|
if not cplist or not kplist:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"required": required,
|
||||||
|
"timeout_sec": timeout_sec,
|
||||||
|
"confirm_phrases": cplist,
|
||||||
|
"cancel_phrases": kplist,
|
||||||
|
"pending_id": pending,
|
||||||
|
"summary_for_user": raw.get("summary_for_user"),
|
||||||
|
"raw": raw,
|
||||||
|
}
|
||||||
999
voice_drone/core/cloud_voice_client.py
Normal file
999
voice_drone/core/cloud_voice_client.py
Normal file
@ -0,0 +1,999 @@
|
|||||||
|
"""
|
||||||
|
云端语音 WebSocket 客户端:会话 `session.start.transport_profile` 固定为 pcm_asr_uplink。
|
||||||
|
|
||||||
|
- 主路径:`turn.audio.start` → 若干 `turn.audio.chunk`(每条仅文本 JSON,含 `pcm_base64`)→ `turn.audio.end`;**禁止**用 WebSocket binary 上发 PCM(与 Starlette receive 语义一致)。
|
||||||
|
- 辅助:`run_turn` 发 `turn.text`(如同句快路径仅有文本);`run_tts_synthesize` 仅 TTS。
|
||||||
|
- `asr.partial` 仅调试展示,不参与机端状态机。
|
||||||
|
|
||||||
|
文档:`docs/CLOUD_VOICE_SESSION_SCHEME_v1.md`,`docs/CLOUD_VOICE_PROTOCOL_pcm_asr_uplink_v1.md`。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from voice_drone.core.cloud_dialog_v1 import CLOUD_VOICE_DIALOG_V1
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("voice_drone.cloud_voice")
|
||||||
|
|
||||||
|
_CLOUD_PROTO = "1.0"
|
||||||
|
TRANSPORT_PCM_ASR_UPLINK = "pcm_asr_uplink"
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_session_client(
|
||||||
|
device_id: str,
|
||||||
|
*,
|
||||||
|
session_client_extensions: dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""session.start 的 client:capabilities 与设备信息 + 可选 PX4 等扩展(不覆盖 device_id/locale)。"""
|
||||||
|
client: dict[str, Any] = {
|
||||||
|
"device_id": device_id,
|
||||||
|
"locale": "zh-CN",
|
||||||
|
"capabilities": {
|
||||||
|
"playback_sample_rate_hz": 24000,
|
||||||
|
"prefer_tts_codec": "pcm_s16le",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ext = session_client_extensions or {}
|
||||||
|
for k, v in ext.items():
|
||||||
|
if v is None or k in ("device_id", "locale", "capabilities", "protocol"):
|
||||||
|
continue
|
||||||
|
if k == "extras" and isinstance(v, dict) and len(v) == 0:
|
||||||
|
continue
|
||||||
|
client[k] = v
|
||||||
|
client["protocol"] = {"dialog_result": CLOUD_VOICE_DIALOG_V1}
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def _transient_ws_exc(exc: BaseException) -> bool:
|
||||||
|
"""可通过对端已关、网络抖动等通过重连重发 turn 恢复的异常。"""
|
||||||
|
import websocket as _websocket # noqa: PLC0415
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
exc,
|
||||||
|
(
|
||||||
|
BrokenPipeError,
|
||||||
|
ConnectionResetError,
|
||||||
|
ConnectionAbortedError,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if isinstance(
|
||||||
|
exc,
|
||||||
|
(
|
||||||
|
_websocket.WebSocketConnectionClosedException,
|
||||||
|
_websocket.WebSocketTimeoutException,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if isinstance(exc, OSError) and getattr(exc, "errno", None) in (
|
||||||
|
32,
|
||||||
|
104,
|
||||||
|
110,
|
||||||
|
): # EPIPE, ECONNRESET, ETIMEDOUT
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_tts_pcm_chunks(
|
||||||
|
chunk_entries: list[tuple[int | None, int, bytes]],
|
||||||
|
) -> bytes:
|
||||||
|
"""按 seq 升序拼接;无 seq 时按到达顺序。chunk_entries: (seq|None, arrival, pcm)。"""
|
||||||
|
if not chunk_entries:
|
||||||
|
return b""
|
||||||
|
if all(s is not None for s, _, _ in chunk_entries):
|
||||||
|
ordered = sorted(chunk_entries, key=lambda x: (x[0], x[1]))
|
||||||
|
seqs = [x[0] for x in ordered]
|
||||||
|
for a, b in zip(seqs, seqs[1:]):
|
||||||
|
if b != a + 1:
|
||||||
|
logger.warning("TTS seq 不连续(仍按序拼接): %s → %s", a, b)
|
||||||
|
break
|
||||||
|
return b"".join(x[2] for x in ordered)
|
||||||
|
return b"".join(x[2] for x in sorted(chunk_entries, key=lambda x: x[1]))
|
||||||
|
|
||||||
|
|
||||||
|
class CloudVoiceError(RuntimeError):
|
||||||
|
"""云端返回 error 消息或协议不符合预期。"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, *, code: str | None = None, retryable: bool = False):
|
||||||
|
super().__init__(message)
|
||||||
|
self.code = code
|
||||||
|
self.retryable = retryable
|
||||||
|
|
||||||
|
|
||||||
|
class CloudVoiceClient:
|
||||||
|
"""连接 ws://…/v1/voice/session;session 为 pcm_asr_uplink,含 run_turn_audio / run_turn / tts.synthesize。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
auth_token: str,
|
||||||
|
device_id: str,
|
||||||
|
recv_timeout: float = 120.0,
|
||||||
|
session_client_extensions: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.server_url = server_url.strip()
|
||||||
|
self.auth_token = auth_token.strip()
|
||||||
|
self.device_id = (device_id or "drone-001").strip()
|
||||||
|
self.recv_timeout = float(recv_timeout)
|
||||||
|
self._session_client_extensions: dict[str, Any] = dict(
|
||||||
|
session_client_extensions or {}
|
||||||
|
)
|
||||||
|
self._transport_profile: str = TRANSPORT_PCM_ASR_UPLINK
|
||||||
|
self._ws: Any = None
|
||||||
|
self._session_id: str | None = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connected(self) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
return self._ws is not None
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._close_nolock()
|
||||||
|
|
||||||
|
def _close_nolock(self) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
self._session_id = None
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if self._session_id:
|
||||||
|
try:
|
||||||
|
self._ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "session.end",
|
||||||
|
"proto_version": _CLOUD_PROTO,
|
||||||
|
"session_id": self._session_id,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
self._ws.close()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
self._ws = None
|
||||||
|
self._session_id = None
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""建立 WSS,发送 session.start,等待 session.ready。"""
|
||||||
|
with self._lock:
|
||||||
|
self._connect_nolock()
|
||||||
|
|
||||||
|
def _connect_nolock(self) -> None:
|
||||||
|
import websocket # websocket-client
|
||||||
|
|
||||||
|
self._close_nolock()
|
||||||
|
hdr = [f"Authorization: Bearer {self.auth_token}"]
|
||||||
|
try:
|
||||||
|
self._ws = websocket.create_connection(
|
||||||
|
self.server_url,
|
||||||
|
header=hdr,
|
||||||
|
timeout=self.recv_timeout,
|
||||||
|
)
|
||||||
|
self._ws.settimeout(self.recv_timeout)
|
||||||
|
self._session_id = str(uuid.uuid4())
|
||||||
|
client_payload = _merge_session_client(
|
||||||
|
self.device_id,
|
||||||
|
session_client_extensions=self._session_client_extensions,
|
||||||
|
)
|
||||||
|
if self._session_client_extensions:
|
||||||
|
logger.info(
|
||||||
|
"session.start 已附加 client 扩展键: %s",
|
||||||
|
sorted(self._session_client_extensions.keys()),
|
||||||
|
)
|
||||||
|
start = {
|
||||||
|
"type": "session.start",
|
||||||
|
"proto_version": _CLOUD_PROTO,
|
||||||
|
"transport_profile": self._transport_profile,
|
||||||
|
"session_id": self._session_id,
|
||||||
|
"auth_token": self.auth_token,
|
||||||
|
"client": client_payload,
|
||||||
|
}
|
||||||
|
self._ws.send(json.dumps(start, ensure_ascii=False))
|
||||||
|
raw = self._ws.recv()
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
raise CloudVoiceError("session.ready 期望 JSON 文本帧,收到二进制")
|
||||||
|
data = json.loads(raw)
|
||||||
|
if data.get("type") != "session.ready":
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"期望 session.ready,收到: {data.get('type')!r}",
|
||||||
|
code="INVALID_MESSAGE",
|
||||||
|
)
|
||||||
|
logger.info("云端会话已就绪 session_id=%s", self._session_id)
|
||||||
|
except Exception:
|
||||||
|
self._close_nolock()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def ensure_connected(self) -> None:
|
||||||
|
with self._lock:
|
||||||
|
if self._ws is None:
|
||||||
|
self._connect_nolock()
|
||||||
|
|
||||||
|
def _execute_turn_nolock(self, t: str) -> dict[str, Any]:
|
||||||
|
"""已持锁且 _ws 已连接:发送 turn.text 并收齐本轮帧。"""
|
||||||
|
import websocket # websocket-client
|
||||||
|
|
||||||
|
ws = self._ws
|
||||||
|
if ws is None:
|
||||||
|
raise CloudVoiceError("WebSocket 未连接")
|
||||||
|
|
||||||
|
turn_id = str(uuid.uuid4())
|
||||||
|
turn_msg = {
|
||||||
|
"type": "turn.text",
|
||||||
|
"proto_version": _CLOUD_PROTO,
|
||||||
|
"transport_profile": self._transport_profile,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
"text": t,
|
||||||
|
"is_final": True,
|
||||||
|
"source": "device_stt",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
ws.send(json.dumps(turn_msg, ensure_ascii=False))
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e):
|
||||||
|
raise
|
||||||
|
raise CloudVoiceError(f"发送 turn.text 失败: {e}", code="INTERNAL") from e
|
||||||
|
logger.debug("→ turn.text turn_id=%s", turn_id)
|
||||||
|
|
||||||
|
expecting_binary = False
|
||||||
|
_pending_tts_seq: int | None = None
|
||||||
|
pcm_entries: list[tuple[int | None, int, bytes]] = []
|
||||||
|
_pcm_arrival = 0
|
||||||
|
llm_stream_parts: list[str] = []
|
||||||
|
dialog: dict[str, Any] | None = None
|
||||||
|
metrics: dict[str, Any] = {}
|
||||||
|
sample_rate_hz = 24000
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = ws.recv()
|
||||||
|
except websocket.WebSocketConnectionClosedException as e:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"连接已断开: {e}",
|
||||||
|
code="DISCONNECTED",
|
||||||
|
retryable=True,
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e):
|
||||||
|
raise
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(msg, bytes):
|
||||||
|
if expecting_binary:
|
||||||
|
expecting_binary = False
|
||||||
|
else:
|
||||||
|
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
|
||||||
|
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
|
||||||
|
_pcm_arrival += 1
|
||||||
|
_pending_tts_seq = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(msg, str):
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"期望文本帧为 str,实际为 {type(msg).__name__}",
|
||||||
|
code="INVALID_MESSAGE",
|
||||||
|
)
|
||||||
|
text_frame = msg.strip()
|
||||||
|
if not text_frame:
|
||||||
|
logger.debug("跳过空 WebSocket 文本帧")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(text_frame)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
head = text_frame[:200].replace("\n", "\\n")
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
|
||||||
|
code="INVALID_MESSAGE",
|
||||||
|
) from e
|
||||||
|
mtype = data.get("type")
|
||||||
|
|
||||||
|
if mtype == "asr.partial":
|
||||||
|
logger.debug("← asr.partial(机端不参与状态跳转)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "llm.text_delta":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
logger.debug(
|
||||||
|
"llm.text_delta turn_id 与当前不一致,忽略 type=%s",
|
||||||
|
mtype,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
raw_d = data.get("delta")
|
||||||
|
delta = "" if raw_d is None else str(raw_d)
|
||||||
|
llm_stream_parts.append(delta)
|
||||||
|
_print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
)
|
||||||
|
if _print_stream:
|
||||||
|
print(delta, end="", flush=True)
|
||||||
|
if data.get("done"):
|
||||||
|
print(flush=True)
|
||||||
|
logger.debug(
|
||||||
|
"← llm.text_delta done=%s delta_chars=%s",
|
||||||
|
data.get("done"),
|
||||||
|
len(delta),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "tts_audio_chunk":
|
||||||
|
_pending_tts_seq = None
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
sample_rate_hz = int(
|
||||||
|
data.get("sample_rate_hz") or sample_rate_hz
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
_s = data.get("seq")
|
||||||
|
try:
|
||||||
|
if _s is not None:
|
||||||
|
_pending_tts_seq = int(_s)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
_pending_tts_seq = None
|
||||||
|
if data.get("is_final"):
|
||||||
|
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
|
||||||
|
expecting_binary = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "dialog_result":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
"dialog_result turn_id 不匹配", code="INVALID_MESSAGE"
|
||||||
|
)
|
||||||
|
dialog = data
|
||||||
|
logger.info(
|
||||||
|
"← dialog_result routing=%s", data.get("routing")
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "turn.complete":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
|
||||||
|
)
|
||||||
|
metrics = data.get("metrics") or {}
|
||||||
|
break
|
||||||
|
|
||||||
|
if mtype == "error":
|
||||||
|
code = str(data.get("code") or "INTERNAL")
|
||||||
|
raise CloudVoiceError(
|
||||||
|
data.get("message") or code,
|
||||||
|
code=code,
|
||||||
|
retryable=bool(data.get("retryable")),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("忽略服务端消息 type=%s", mtype)
|
||||||
|
|
||||||
|
if dialog is None:
|
||||||
|
raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE")
|
||||||
|
|
||||||
|
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
|
||||||
|
pcm = (
|
||||||
|
np.frombuffer(full_pcm, dtype=np.int16).copy()
|
||||||
|
if full_pcm
|
||||||
|
else np.array([], dtype=np.int16)
|
||||||
|
)
|
||||||
|
if pcm.size > 0:
|
||||||
|
mx = int(np.max(np.abs(pcm)))
|
||||||
|
if mx == 0:
|
||||||
|
logger.warning(
|
||||||
|
"云端 TTS 已收齐二进制总长 %s 字节(≈%s 个 s16 采样),但全为 0x00,"
|
||||||
|
"属于服务端发出的静音占位或未写入合成结果;机端无法通过重采样/扬声器修复。"
|
||||||
|
"请在服务端对同一次 synthesize 写 WAV 核对非零采样,并确认 WS 先发 tts_audio_chunk JSON、"
|
||||||
|
"再发 raw PCM 帧、且未把 JSON/base64 误当 binary 发出。",
|
||||||
|
len(full_pcm),
|
||||||
|
pcm.size,
|
||||||
|
)
|
||||||
|
if os.environ.get("ROCKET_CLOUD_PCM_HEX", "").strip().lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
):
|
||||||
|
head = full_pcm[:64]
|
||||||
|
logger.warning(
|
||||||
|
"ROCKET_CLOUD_PCM_HEX: 前 %s 字节 hex=%s",
|
||||||
|
len(head),
|
||||||
|
head.hex(),
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_stream_text = "".join(llm_stream_parts)
|
||||||
|
return {
|
||||||
|
"protocol": dialog.get("protocol"),
|
||||||
|
"routing": dialog.get("routing"),
|
||||||
|
"flight_intent": dialog.get("flight_intent"),
|
||||||
|
"confirm": dialog.get("confirm"),
|
||||||
|
"chat_reply": dialog.get("chat_reply"),
|
||||||
|
"user_input": dialog.get("user_input"),
|
||||||
|
"pcm": pcm,
|
||||||
|
"sample_rate_hz": sample_rate_hz,
|
||||||
|
"metrics": metrics,
|
||||||
|
"llm_stream_text": llm_stream_text,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _execute_turn_audio_nolock(
|
||||||
|
self, pcm_int16: np.ndarray, sample_rate_hz: int
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""发送 turn.audio.start → 多条 turn.audio.chunk(pcm_base64 文本帧)→ turn.audio.end;禁止 binary 上发 PCM。"""
|
||||||
|
import websocket # websocket-client
|
||||||
|
|
||||||
|
ws = self._ws
|
||||||
|
if ws is None:
|
||||||
|
raise CloudVoiceError("WebSocket 未连接")
|
||||||
|
|
||||||
|
pcm_int16 = np.asarray(pcm_int16, dtype=np.int16).reshape(-1)
|
||||||
|
if pcm_int16.size == 0:
|
||||||
|
raise CloudVoiceError("turn.audio PCM 为空")
|
||||||
|
|
||||||
|
pcm_mx = int(np.max(np.abs(pcm_int16)))
|
||||||
|
pcm_rms = float(np.sqrt(np.mean(pcm_int16.astype(np.float64) ** 2)))
|
||||||
|
dur_sec = float(pcm_int16.size) / max(1, int(sample_rate_hz))
|
||||||
|
logger.info(
|
||||||
|
"turn.audio 上行: samples=%s sr_hz=%s dur≈%.2fs abs_max=%s rms=%.1f dtype=int16",
|
||||||
|
pcm_int16.size,
|
||||||
|
int(sample_rate_hz),
|
||||||
|
dur_sec,
|
||||||
|
pcm_mx,
|
||||||
|
pcm_rms,
|
||||||
|
)
|
||||||
|
if pcm_mx == 0:
|
||||||
|
logger.warning(
|
||||||
|
"turn.audio 上行波形全零,云端 ASR 通常会判无有效语音(请查麦/切段/VAD 是否误交静音)"
|
||||||
|
)
|
||||||
|
elif pcm_mx < 200:
|
||||||
|
logger.warning(
|
||||||
|
"turn.audio 上行幅值极小 abs_max=%s(仍发送);若云端反复无识别请调 AGC/VAD/麦增益",
|
||||||
|
pcm_mx,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_id = str(uuid.uuid4())
|
||||||
|
start = {
|
||||||
|
"type": "turn.audio.start",
|
||||||
|
"proto_version": _CLOUD_PROTO,
|
||||||
|
"transport_profile": self._transport_profile,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
"sample_rate_hz": int(sample_rate_hz),
|
||||||
|
"codec": "pcm_s16le",
|
||||||
|
"channels": 1,
|
||||||
|
}
|
||||||
|
raw = pcm_int16.tobytes()
|
||||||
|
try:
|
||||||
|
ws.send(json.dumps(start, ensure_ascii=False))
|
||||||
|
try:
|
||||||
|
raw_chunk = int(os.environ.get("ROCKET_CLOUD_AUDIO_CHUNK_BYTES", "8192"))
|
||||||
|
except ValueError:
|
||||||
|
raw_chunk = 8192
|
||||||
|
raw_chunk = max(2048, min(256 * 1024, raw_chunk))
|
||||||
|
n_chunks = 0
|
||||||
|
for i in range(0, len(raw), raw_chunk):
|
||||||
|
piece = raw[i : i + raw_chunk]
|
||||||
|
chunk_msg = {
|
||||||
|
"type": "turn.audio.chunk",
|
||||||
|
"proto_version": _CLOUD_PROTO,
|
||||||
|
"transport_profile": self._transport_profile,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
"pcm_base64": base64.b64encode(piece).decode("ascii"),
|
||||||
|
}
|
||||||
|
ws.send(json.dumps(chunk_msg, ensure_ascii=False))
|
||||||
|
n_chunks += 1
|
||||||
|
end = {
|
||||||
|
"type": "turn.audio.end",
|
||||||
|
"proto_version": _CLOUD_PROTO,
|
||||||
|
"transport_profile": self._transport_profile,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
}
|
||||||
|
ws.send(json.dumps(end, ensure_ascii=False))
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e):
|
||||||
|
raise
|
||||||
|
raise CloudVoiceError(f"发送 turn.audio 失败: {e}", code="INTERNAL") from e
|
||||||
|
logger.debug(
|
||||||
|
"→ turn.audio start/%s chunk(s)/end turn_id=%s samples=%s",
|
||||||
|
n_chunks,
|
||||||
|
turn_id,
|
||||||
|
pcm_int16.size,
|
||||||
|
)
|
||||||
|
|
||||||
|
expecting_binary = False
|
||||||
|
_pending_tts_seq: int | None = None
|
||||||
|
pcm_entries: list[tuple[int | None, int, bytes]] = []
|
||||||
|
_pcm_arrival = 0
|
||||||
|
llm_stream_parts: list[str] = []
|
||||||
|
dialog: dict[str, Any] | None = None
|
||||||
|
metrics: dict[str, Any] = {}
|
||||||
|
out_sr = 24000
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = ws.recv()
|
||||||
|
except websocket.WebSocketConnectionClosedException as e:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"连接已断开: {e}",
|
||||||
|
code="DISCONNECTED",
|
||||||
|
retryable=True,
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e):
|
||||||
|
raise
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(msg, bytes):
|
||||||
|
if expecting_binary:
|
||||||
|
expecting_binary = False
|
||||||
|
else:
|
||||||
|
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
|
||||||
|
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
|
||||||
|
_pcm_arrival += 1
|
||||||
|
_pending_tts_seq = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(msg, str):
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"期望文本帧为 str,实际为 {type(msg).__name__}",
|
||||||
|
code="INVALID_MESSAGE",
|
||||||
|
)
|
||||||
|
text_frame = msg.strip()
|
||||||
|
if not text_frame:
|
||||||
|
logger.debug("跳过空 WebSocket 文本帧")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(text_frame)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
head = text_frame[:200].replace("\n", "\\n")
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
|
||||||
|
code="INVALID_MESSAGE",
|
||||||
|
) from e
|
||||||
|
mtype = data.get("type")
|
||||||
|
|
||||||
|
if mtype == "asr.partial":
|
||||||
|
logger.debug("← asr.partial(机端不参与状态跳转)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "llm.text_delta":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
logger.debug(
|
||||||
|
"llm.text_delta turn_id 与当前不一致,忽略 type=%s",
|
||||||
|
mtype,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
raw_d = data.get("delta")
|
||||||
|
delta = "" if raw_d is None else str(raw_d)
|
||||||
|
llm_stream_parts.append(delta)
|
||||||
|
_print_stream = os.environ.get("ROCKET_PRINT_LLM_STREAM", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
)
|
||||||
|
if _print_stream:
|
||||||
|
print(delta, end="", flush=True)
|
||||||
|
if data.get("done"):
|
||||||
|
print(flush=True)
|
||||||
|
logger.debug(
|
||||||
|
"← llm.text_delta done=%s delta_chars=%s",
|
||||||
|
data.get("done"),
|
||||||
|
len(delta),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "tts_audio_chunk":
|
||||||
|
_pending_tts_seq = None
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
logger.warning("tts_audio_chunk turn_id 与当前不一致,仍消费后续二进制")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
out_sr = int(data.get("sample_rate_hz") or out_sr)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
_s = data.get("seq")
|
||||||
|
try:
|
||||||
|
if _s is not None:
|
||||||
|
_pending_tts_seq = int(_s)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
_pending_tts_seq = None
|
||||||
|
if data.get("is_final"):
|
||||||
|
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
|
||||||
|
expecting_binary = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "dialog_result":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
"dialog_result turn_id 不匹配", code="INVALID_MESSAGE"
|
||||||
|
)
|
||||||
|
dialog = data
|
||||||
|
logger.info(
|
||||||
|
"← dialog_result routing=%s", data.get("routing")
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "turn.complete":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
|
||||||
|
)
|
||||||
|
metrics = data.get("metrics") or {}
|
||||||
|
break
|
||||||
|
|
||||||
|
if mtype == "error":
|
||||||
|
code = str(data.get("code") or "INTERNAL")
|
||||||
|
raise CloudVoiceError(
|
||||||
|
data.get("message") or code,
|
||||||
|
code=code,
|
||||||
|
retryable=bool(data.get("retryable")),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("忽略服务端消息 type=%s", mtype)
|
||||||
|
|
||||||
|
if dialog is None:
|
||||||
|
raise CloudVoiceError("未收到 dialog_result", code="INVALID_MESSAGE")
|
||||||
|
|
||||||
|
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
|
||||||
|
out_pcm = (
|
||||||
|
np.frombuffer(full_pcm, dtype=np.int16).copy()
|
||||||
|
if full_pcm
|
||||||
|
else np.array([], dtype=np.int16)
|
||||||
|
)
|
||||||
|
if out_pcm.size > 0:
|
||||||
|
mx = int(np.max(np.abs(out_pcm)))
|
||||||
|
if mx == 0:
|
||||||
|
logger.warning(
|
||||||
|
"云端 TTS 已收齐但全零采样,请核对服务端。",
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_stream_text = "".join(llm_stream_parts)
|
||||||
|
return {
|
||||||
|
"protocol": dialog.get("protocol"),
|
||||||
|
"routing": dialog.get("routing"),
|
||||||
|
"flight_intent": dialog.get("flight_intent"),
|
||||||
|
"confirm": dialog.get("confirm"),
|
||||||
|
"chat_reply": dialog.get("chat_reply"),
|
||||||
|
"user_input": dialog.get("user_input"),
|
||||||
|
"pcm": out_pcm,
|
||||||
|
"sample_rate_hz": out_sr,
|
||||||
|
"metrics": metrics,
|
||||||
|
"llm_stream_text": llm_stream_text,
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_turn_audio(
|
||||||
|
self, pcm_int16: np.ndarray, sample_rate_hz: int
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""上行一轮麦克风 PCM:chunk 均为含 pcm_base64 的文本 JSON;收齐 dialog_result + TTS + turn.complete。"""
|
||||||
|
try:
|
||||||
|
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
|
||||||
|
except ValueError:
|
||||||
|
raw_attempts = 3
|
||||||
|
attempts = max(1, raw_attempts)
|
||||||
|
try:
|
||||||
|
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
|
||||||
|
except ValueError:
|
||||||
|
delay = 0.35
|
||||||
|
delay = max(0.0, delay)
|
||||||
|
|
||||||
|
for attempt in range(attempts):
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
if self._ws is None:
|
||||||
|
self._connect_nolock()
|
||||||
|
return self._execute_turn_audio_nolock(pcm_int16, sample_rate_hz)
|
||||||
|
except CloudVoiceError as e:
|
||||||
|
retry = bool(e.retryable) or e.code == "DISCONNECTED"
|
||||||
|
if retry and attempt < attempts - 1:
|
||||||
|
logger.warning(
|
||||||
|
"turn.audio 可恢复错误,重连重试 (%s/%s): %s",
|
||||||
|
attempt + 1,
|
||||||
|
attempts,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self._close_nolock()
|
||||||
|
if delay:
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e) and attempt < attempts - 1:
|
||||||
|
logger.warning(
|
||||||
|
"turn.audio WebSocket 瞬断,重连重试 (%s/%s): %s",
|
||||||
|
attempt + 1,
|
||||||
|
attempts,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self._close_nolock()
|
||||||
|
if delay:
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
raise CloudVoiceError("run_turn_audio 未执行", code="INTERNAL")
|
||||||
|
|
||||||
|
def _execute_tts_synthesize_nolock(self, text: str) -> dict[str, Any]:
|
||||||
|
"""已持锁且 _ws 已连接:发送 tts.synthesize,仅收 tts_audio_chunk* 与 turn.complete(无 dialog_result)。"""
|
||||||
|
import websocket # websocket-client
|
||||||
|
|
||||||
|
ws = self._ws
|
||||||
|
if ws is None:
|
||||||
|
raise CloudVoiceError("WebSocket 未连接")
|
||||||
|
|
||||||
|
turn_id = str(uuid.uuid4())
|
||||||
|
synth_msg = {
|
||||||
|
"type": "tts.synthesize",
|
||||||
|
"proto_version": _CLOUD_PROTO,
|
||||||
|
"transport_profile": self._transport_profile,
|
||||||
|
"turn_id": turn_id,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
ws.send(json.dumps(synth_msg, ensure_ascii=False))
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e):
|
||||||
|
raise
|
||||||
|
raise CloudVoiceError(f"发送 tts.synthesize 失败: {e}", code="INTERNAL") from e
|
||||||
|
logger.debug("→ tts.synthesize turn_id=%s", turn_id)
|
||||||
|
|
||||||
|
expecting_binary = False
|
||||||
|
_pending_tts_seq: int | None = None
|
||||||
|
pcm_entries: list[tuple[int | None, int, bytes]] = []
|
||||||
|
_pcm_arrival = 0
|
||||||
|
metrics: dict[str, Any] = {}
|
||||||
|
sample_rate_hz = 24000
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = ws.recv()
|
||||||
|
except websocket.WebSocketConnectionClosedException as e:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"连接已断开: {e}",
|
||||||
|
code="DISCONNECTED",
|
||||||
|
retryable=True,
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e):
|
||||||
|
raise
|
||||||
|
raise
|
||||||
|
|
||||||
|
if isinstance(msg, bytes):
|
||||||
|
if expecting_binary:
|
||||||
|
expecting_binary = False
|
||||||
|
else:
|
||||||
|
logger.warning("收到未预期的二进制帧,仍作为 TTS 数据处理")
|
||||||
|
pcm_entries.append((_pending_tts_seq, _pcm_arrival, msg))
|
||||||
|
_pcm_arrival += 1
|
||||||
|
_pending_tts_seq = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(msg, str):
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"期望文本帧为 str,实际为 {type(msg).__name__}",
|
||||||
|
code="INVALID_MESSAGE",
|
||||||
|
)
|
||||||
|
text_frame = msg.strip()
|
||||||
|
if not text_frame:
|
||||||
|
logger.debug("跳过空 WebSocket 文本帧")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(text_frame)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
head = text_frame[:200].replace("\n", "\\n")
|
||||||
|
raise CloudVoiceError(
|
||||||
|
f"服务端文本帧不是合法 JSON: {e}; 前 {len(head)} 字符: {head!r}",
|
||||||
|
code="INVALID_MESSAGE",
|
||||||
|
) from e
|
||||||
|
mtype = data.get("type")
|
||||||
|
|
||||||
|
if mtype == "asr.partial":
|
||||||
|
logger.debug("← asr.partial(tts 轮次,忽略)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "llm.text_delta":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
logger.debug(
|
||||||
|
"llm.text_delta turn_id 与当前 tts 不一致,忽略",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "tts_audio_chunk":
|
||||||
|
_pending_tts_seq = None
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
logger.warning(
|
||||||
|
"tts_audio_chunk turn_id 与 tts.synthesize 不一致,仍消费后续二进制",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
sample_rate_hz = int(
|
||||||
|
data.get("sample_rate_hz") or sample_rate_hz
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
_s = data.get("seq")
|
||||||
|
try:
|
||||||
|
if _s is not None:
|
||||||
|
_pending_tts_seq = int(_s)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
_pending_tts_seq = None
|
||||||
|
if data.get("is_final"):
|
||||||
|
logger.debug("← tts_audio_chunk is_final=true seq=%s", _s)
|
||||||
|
expecting_binary = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "dialog_result":
|
||||||
|
logger.debug("tts.synthesize 收到 dialog_result(非预期),忽略")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if mtype == "turn.complete":
|
||||||
|
if data.get("turn_id") != turn_id:
|
||||||
|
raise CloudVoiceError(
|
||||||
|
"turn.complete turn_id 不匹配", code="INVALID_MESSAGE"
|
||||||
|
)
|
||||||
|
metrics = data.get("metrics") or {}
|
||||||
|
break
|
||||||
|
|
||||||
|
if mtype == "error":
|
||||||
|
code = str(data.get("code") or "INTERNAL")
|
||||||
|
raise CloudVoiceError(
|
||||||
|
data.get("message") or code,
|
||||||
|
code=code,
|
||||||
|
retryable=bool(data.get("retryable")),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("忽略服务端消息 type=%s", mtype)
|
||||||
|
|
||||||
|
full_pcm = _merge_tts_pcm_chunks(pcm_entries)
|
||||||
|
pcm = (
|
||||||
|
np.frombuffer(full_pcm, dtype=np.int16).copy()
|
||||||
|
if full_pcm
|
||||||
|
else np.array([], dtype=np.int16)
|
||||||
|
)
|
||||||
|
if pcm.size > 0:
|
||||||
|
mx = int(np.max(np.abs(pcm)))
|
||||||
|
if mx == 0:
|
||||||
|
logger.warning(
|
||||||
|
"tts.synthesize 收齐 PCM 但全零(服务端静音占位);总长 %s 字节",
|
||||||
|
len(full_pcm),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"pcm": pcm,
|
||||||
|
"sample_rate_hz": sample_rate_hz,
|
||||||
|
"metrics": metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_tts_synthesize(self, text: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
发送 tts.synthesize,收齐 TTS 块与 turn.complete(无 dialog_result)。
|
||||||
|
与 run_turn 共用连接,互斥由服务端排队;重试策略同 ROCKET_CLOUD_TURN_RETRIES。
|
||||||
|
"""
|
||||||
|
t = (text or "").strip()
|
||||||
|
if not t:
|
||||||
|
raise CloudVoiceError("tts.synthesize text 不能为空")
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
|
||||||
|
except ValueError:
|
||||||
|
raw_attempts = 3
|
||||||
|
attempts = max(1, raw_attempts)
|
||||||
|
try:
|
||||||
|
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
|
||||||
|
except ValueError:
|
||||||
|
delay = 0.35
|
||||||
|
delay = max(0.0, delay)
|
||||||
|
|
||||||
|
for attempt in range(attempts):
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
if self._ws is None:
|
||||||
|
self._connect_nolock()
|
||||||
|
return self._execute_tts_synthesize_nolock(t)
|
||||||
|
except CloudVoiceError as e:
|
||||||
|
retry = bool(e.retryable) or e.code == "DISCONNECTED"
|
||||||
|
if retry and attempt < attempts - 1:
|
||||||
|
logger.warning(
|
||||||
|
"tts.synthesize 可恢复错误,将重连并重试 (%s/%s): %s",
|
||||||
|
attempt + 1,
|
||||||
|
attempts,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self._close_nolock()
|
||||||
|
if delay:
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e) and attempt < attempts - 1:
|
||||||
|
logger.warning(
|
||||||
|
"tts.synthesize WebSocket 瞬断,重连并重试 (%s/%s): %s",
|
||||||
|
attempt + 1,
|
||||||
|
attempts,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self._close_nolock()
|
||||||
|
if delay:
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
raise CloudVoiceError("run_tts_synthesize 未执行", code="INTERNAL")
|
||||||
|
|
||||||
|
def run_turn(self, text: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
发送一轮用户文本,收齐 dialog_result、TTS 块、turn.complete。
|
||||||
|
|
||||||
|
支持流式下行:可先于 dialog_result 收到 tts_audio_chunk+PCM 与 llm.text_delta;
|
||||||
|
飞控与最终文案仍以 dialog_result 为准。
|
||||||
|
|
||||||
|
若中间因对端已关 TCP、ping/pong Broken pipe 等断开,会自动关连接、
|
||||||
|
重连 session 并重发本轮(次数由 ROCKET_CLOUD_TURN_RETRIES 控制,默认 3)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: routing, flight_intent, chat_reply, user_input, pcm, sample_rate_hz,
|
||||||
|
metrics, llm_stream_text(llm.text_delta 拼接,可选调试/UI)
|
||||||
|
"""
|
||||||
|
t = (text or "").strip()
|
||||||
|
if not t:
|
||||||
|
raise CloudVoiceError("turn.text 不能为空")
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_attempts = int(os.environ.get("ROCKET_CLOUD_TURN_RETRIES", "3"))
|
||||||
|
except ValueError:
|
||||||
|
raw_attempts = 3
|
||||||
|
attempts = max(1, raw_attempts)
|
||||||
|
try:
|
||||||
|
delay = float(os.environ.get("ROCKET_CLOUD_TURN_RETRY_DELAY_SEC", "0.35"))
|
||||||
|
except ValueError:
|
||||||
|
delay = 0.35
|
||||||
|
delay = max(0.0, delay)
|
||||||
|
|
||||||
|
for attempt in range(attempts):
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
if self._ws is None:
|
||||||
|
self._connect_nolock()
|
||||||
|
return self._execute_turn_nolock(t)
|
||||||
|
except CloudVoiceError as e:
|
||||||
|
retry = bool(e.retryable) or e.code == "DISCONNECTED"
|
||||||
|
if retry and attempt < attempts - 1:
|
||||||
|
logger.warning(
|
||||||
|
"云端回合可恢复错误,将重连并重试 (%s/%s): %s",
|
||||||
|
attempt + 1,
|
||||||
|
attempts,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self._close_nolock()
|
||||||
|
if delay:
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if _transient_ws_exc(e) and attempt < attempts - 1:
|
||||||
|
logger.warning(
|
||||||
|
"云端 WebSocket 瞬断(如对端先关、PONG 写失败),"
|
||||||
|
"重连并重发 turn (%s/%s): %s",
|
||||||
|
attempt + 1,
|
||||||
|
attempts,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self._close_nolock()
|
||||||
|
if delay:
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
raise CloudVoiceError("run_turn 未执行", code="INTERNAL")
|
||||||
205
voice_drone/core/command.py
Normal file
205
voice_drone/core/command.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Dict, Any, Optional, Literal
|
||||||
|
from datetime import datetime
|
||||||
|
from voice_drone.core.configuration import (
|
||||||
|
TAKEOFF_CONFIG,
|
||||||
|
LAND_CONFIG,
|
||||||
|
FOLLOW_CONFIG,
|
||||||
|
FORWARD_CONFIG,
|
||||||
|
BACKWARD_CONFIG,
|
||||||
|
LEFT_CONFIG,
|
||||||
|
RIGHT_CONFIG,
|
||||||
|
UP_CONFIG,
|
||||||
|
DOWN_CONFIG,
|
||||||
|
HOVER_CONFIG,
|
||||||
|
RETURN_HOME_CONFIG,
|
||||||
|
)
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
class CommandParams(BaseModel):
|
||||||
|
"""
|
||||||
|
命令参数
|
||||||
|
"""
|
||||||
|
distance: Optional[float] = Field(
|
||||||
|
None,
|
||||||
|
description="飞行距离,单位:米(m),必须大于等于0(land/hover 可以为0)",
|
||||||
|
ge=0
|
||||||
|
)
|
||||||
|
speed: Optional[float] = Field(
|
||||||
|
None,
|
||||||
|
description="飞行速度,单位:米每秒(m/s),必须大于等于0(land/hover 可以为0)",
|
||||||
|
ge=0
|
||||||
|
)
|
||||||
|
duration: Optional[float] = Field(
|
||||||
|
None,
|
||||||
|
description="飞行持续时间,单位:秒(s),必须大于0",
|
||||||
|
gt=0
|
||||||
|
)
|
||||||
|
|
||||||
|
class Command(BaseModel):
|
||||||
|
"""
|
||||||
|
无人机控制命令
|
||||||
|
"""
|
||||||
|
command: Literal[
|
||||||
|
"takeoff",
|
||||||
|
"follow",
|
||||||
|
"forward",
|
||||||
|
"backward",
|
||||||
|
"left",
|
||||||
|
"right",
|
||||||
|
"up",
|
||||||
|
"down",
|
||||||
|
"hover",
|
||||||
|
"land",
|
||||||
|
"return_home",
|
||||||
|
] = Field(
|
||||||
|
...,
|
||||||
|
description="无人机控制动作: takeoff(起飞), follow(跟随), forward(向前), backward(向后), left(向左), right(向右), up(向上), down(向下), hover(悬停), land(降落), return_home(返航)",
|
||||||
|
)
|
||||||
|
params: CommandParams = Field(..., description="命令参数")
|
||||||
|
timestamp: str = Field(..., description="命令生成时间戳,ISO 8601 格式(如:2024-01-01T12:00:00.000Z)")
|
||||||
|
sequence_id: int = Field(..., description="命令序列号,用于保证命令顺序和去重")
|
||||||
|
|
||||||
|
# 命令配置映射字典
|
||||||
|
_CONFIG_MAP = {
|
||||||
|
"takeoff": TAKEOFF_CONFIG,
|
||||||
|
"follow": FOLLOW_CONFIG,
|
||||||
|
"land": LAND_CONFIG,
|
||||||
|
"forward": FORWARD_CONFIG,
|
||||||
|
"backward": BACKWARD_CONFIG,
|
||||||
|
"left": LEFT_CONFIG,
|
||||||
|
"right": RIGHT_CONFIG,
|
||||||
|
"up": UP_CONFIG,
|
||||||
|
"down": DOWN_CONFIG,
|
||||||
|
"hover": HOVER_CONFIG,
|
||||||
|
"return_home": RETURN_HOME_CONFIG,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建命令
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
command: str,
|
||||||
|
sequence_id: int,
|
||||||
|
distance: Optional[float] = None,
|
||||||
|
speed: Optional[float] = None,
|
||||||
|
duration: Optional[float] = None
|
||||||
|
) -> "Command":
|
||||||
|
return cls(
|
||||||
|
command=command,
|
||||||
|
params=CommandParams(distance=distance, speed=speed, duration=duration),
|
||||||
|
timestamp=datetime.utcnow().isoformat() + "Z",
|
||||||
|
sequence_id=sequence_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_default_config(self):
|
||||||
|
"""获取当前命令的默认配置"""
|
||||||
|
return self._CONFIG_MAP.get(self.command)
|
||||||
|
|
||||||
|
# 填充默认值
|
||||||
|
def fill_defaults(self) -> None:
|
||||||
|
"""填充缺失的参数值"""
|
||||||
|
# 如果所有参数都已提供,直接返回
|
||||||
|
if (self.params.distance is not None and
|
||||||
|
self.params.speed is not None and
|
||||||
|
self.params.duration is not None):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果有缺失的参数,调用智能填充方法
|
||||||
|
self._fill_smart_params()
|
||||||
|
|
||||||
|
def _fill_smart_params(self):
|
||||||
|
"""智能填充缺失的参数值"""
|
||||||
|
default = self._get_default_config()
|
||||||
|
if default is None:
|
||||||
|
# 若命令未知,则直接返回不填充
|
||||||
|
return
|
||||||
|
|
||||||
|
d = self.params.distance
|
||||||
|
s = self.params.speed
|
||||||
|
t = self.params.duration
|
||||||
|
|
||||||
|
# 统计 None 个数
|
||||||
|
none_cnt = sum(x is None for x in [d, s, t])
|
||||||
|
|
||||||
|
# 三个都为None,直接填默认值
|
||||||
|
if none_cnt == 3:
|
||||||
|
self.params.distance = default["distance"]
|
||||||
|
self.params.speed = default["speed"]
|
||||||
|
self.params.duration = default["duration"]
|
||||||
|
return
|
||||||
|
|
||||||
|
# 只有一个参数有值的情况
|
||||||
|
if none_cnt == 2:
|
||||||
|
if s is not None and d is None and t is None:
|
||||||
|
# 仅速度:使用默认持续时间,计算距离
|
||||||
|
self.params.duration = default["duration"]
|
||||||
|
self.params.distance = s * self.params.duration
|
||||||
|
return
|
||||||
|
|
||||||
|
if t is not None and d is None and s is None:
|
||||||
|
# 仅持续时间:使用默认速度,计算距离
|
||||||
|
self.params.speed = default["speed"]
|
||||||
|
self.params.distance = self.params.speed * t
|
||||||
|
return
|
||||||
|
|
||||||
|
if d is not None and s is None and t is None:
|
||||||
|
# 仅距离:使用默认速度,计算持续时间
|
||||||
|
self.params.speed = default["speed"]
|
||||||
|
# 防止除以0
|
||||||
|
if self.params.speed == 0:
|
||||||
|
self.params.duration = 0
|
||||||
|
else:
|
||||||
|
self.params.duration = d / self.params.speed
|
||||||
|
return
|
||||||
|
|
||||||
|
# 两个参数有值,一个None,自动计算缺失的参数
|
||||||
|
if none_cnt == 1:
|
||||||
|
if d is None and s is not None and t is not None:
|
||||||
|
# 缺失距离:distance = speed * duration
|
||||||
|
self.params.distance = s * t
|
||||||
|
return
|
||||||
|
|
||||||
|
if s is None and d is not None and t is not None:
|
||||||
|
# 缺失速度:speed = distance / duration
|
||||||
|
if t == 0:
|
||||||
|
self.params.speed = 0
|
||||||
|
else:
|
||||||
|
self.params.speed = d / t
|
||||||
|
return
|
||||||
|
|
||||||
|
if t is None and d is not None and s is not None:
|
||||||
|
# 缺失持续时间:duration = distance / speed
|
||||||
|
if s == 0:
|
||||||
|
self.params.duration = 0
|
||||||
|
else:
|
||||||
|
self.params.duration = d / s
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# 转换为字典
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"command": self.command,
|
||||||
|
"params": {},
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"sequence_id": self.sequence_id
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.params.distance is None or self.params.speed is None or self.params.duration is None:
|
||||||
|
self.fill_defaults()
|
||||||
|
|
||||||
|
result["params"]["distance"] = self.params.distance
|
||||||
|
result["params"]["speed"] = self.params.speed
|
||||||
|
result["params"]["duration"] = self.params.duration
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
command = Command.create("takeoff", 1, speed=2)
|
||||||
|
print(command.to_dict())
|
||||||
209
voice_drone/core/configuration.py
Normal file
209
voice_drone/core/configuration.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from voice_drone.tools.config_loader import load_config
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
_cfg_log = get_logger("voice_drone.configuration")
|
||||||
|
|
||||||
|
# voice_drone/core/configuration.py -> 工程根目录 voice_drone_assistant 为 parents[2]
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
def _abs_config_path(relative: str) -> str:
|
||||||
|
p = Path(relative)
|
||||||
|
if p.is_absolute():
|
||||||
|
return str(p)
|
||||||
|
return str(_PROJECT_ROOT / p)
|
||||||
|
|
||||||
|
|
||||||
|
# 系统配置加载器
|
||||||
|
class SystemConfigLoader:
|
||||||
|
def __init__(self, config_path="voice_drone/config/system.yaml"):
|
||||||
|
self.config_path = _abs_config_path(config_path)
|
||||||
|
self.config = load_config(self.config_path)
|
||||||
|
|
||||||
|
# 获取音频配置
|
||||||
|
def get_audio_config(self):
|
||||||
|
return self.config["audio"]
|
||||||
|
|
||||||
|
# 获取语音活动检测配置
|
||||||
|
def get_vad_config(self):
|
||||||
|
return self.config["vad"]
|
||||||
|
|
||||||
|
# 获取socket配置
|
||||||
|
def get_socket_server_config(self):
|
||||||
|
return self.config["socket_server"]
|
||||||
|
|
||||||
|
# 获取日志配置
|
||||||
|
def get_logging_config(self):
|
||||||
|
return self.config["logging"]
|
||||||
|
|
||||||
|
# 获取STT配置
|
||||||
|
def get_stt_config(self):
|
||||||
|
return self.config["stt"]
|
||||||
|
|
||||||
|
# 获取TTS配置
|
||||||
|
def get_tts_config(self):
|
||||||
|
"""
|
||||||
|
获取文本转语音(TTS)配置
|
||||||
|
|
||||||
|
结构示例:
|
||||||
|
tts:
|
||||||
|
model_dir: "src/models/Kokoro-82M-v1.1-zh-ONNX"
|
||||||
|
model_name: "model_q4.onnx" # 可选: model.onnx, model_fp16.onnx, model_quantized.onnx 等
|
||||||
|
voice: "zf_001" # 语音风格文件名(不含扩展名)
|
||||||
|
speed: 1.0 # 语速系数
|
||||||
|
sample_rate: 24000 # 输出采样率
|
||||||
|
"""
|
||||||
|
return self.config.get("tts", {})
|
||||||
|
|
||||||
|
# 获取文本预处理配置
|
||||||
|
def get_text_preprocessor_config(self):
|
||||||
|
return self.config.get("text_preprocessor", {})
|
||||||
|
|
||||||
|
# 获取识别器流程配置
|
||||||
|
def get_recognizer_config(self):
|
||||||
|
return self.config.get("recognizer", {})
|
||||||
|
|
||||||
|
# 云端语音(WebSocket / pcm_asr_uplink 会话)
|
||||||
|
def get_cloud_voice_config(self):
|
||||||
|
return self.config.get("cloud_voice", {})
|
||||||
|
|
||||||
|
# 主程序 TakeoffPrintRecognizer(main_app)
|
||||||
|
def get_assistant_config(self):
|
||||||
|
return self.config.get("assistant", {})
|
||||||
|
|
||||||
|
# 命令配置加载器
|
||||||
|
class CommandConfigLoader:
|
||||||
|
def __init__(self, config_path="voice_drone/config/command_.yaml"):
|
||||||
|
self.config_path = _abs_config_path(config_path)
|
||||||
|
self.config = load_config(self.config_path)["control_params"]
|
||||||
|
|
||||||
|
def get_takeoff_config(self):
|
||||||
|
return self.config["takeoff"]
|
||||||
|
def get_land_config(self):
|
||||||
|
return self.config["land"]
|
||||||
|
def get_follow_config(self):
|
||||||
|
return self.config["follow"]
|
||||||
|
def get_forward_config(self):
|
||||||
|
return self.config["forward"]
|
||||||
|
def get_backward_config(self):
|
||||||
|
return self.config["backward"]
|
||||||
|
def get_left_config(self):
|
||||||
|
return self.config["left"]
|
||||||
|
def get_right_config(self):
|
||||||
|
return self.config["right"]
|
||||||
|
def get_up_config(self):
|
||||||
|
return self.config["up"]
|
||||||
|
def get_down_config(self):
|
||||||
|
return self.config["down"]
|
||||||
|
def get_hover_config(self):
|
||||||
|
return self.config["hover"]
|
||||||
|
|
||||||
|
def get_return_home_config(self):
|
||||||
|
return self.config["return_home"]
|
||||||
|
|
||||||
|
# 关键词配置加载器
|
||||||
|
class KeywordsConfigLoader:
|
||||||
|
def __init__(self, config_path="voice_drone/config/keywords.yaml"):
|
||||||
|
self.config_path = _abs_config_path(config_path)
|
||||||
|
self.config = load_config(self.config_path)["keywords"]
|
||||||
|
|
||||||
|
def get_keywords(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
# 唤醒词配置加载器
|
||||||
|
class WakeWordConfigLoader:
|
||||||
|
def __init__(self, config_path="voice_drone/config/wake_word.yaml"):
|
||||||
|
self.config_path = _abs_config_path(config_path)
|
||||||
|
self.config = load_config(self.config_path)["wake_word"]
|
||||||
|
|
||||||
|
def get_primary(self):
|
||||||
|
return self.config.get("primary", "")
|
||||||
|
|
||||||
|
def get_variants(self):
|
||||||
|
return self.config.get("variants", [])
|
||||||
|
|
||||||
|
def get_matching_config(self):
|
||||||
|
return self.config.get("matching", {})
|
||||||
|
|
||||||
|
|
||||||
|
# 系统配置常量
|
||||||
|
system_config = SystemConfigLoader()
|
||||||
|
SYSTEM_AUDIO_CONFIG = system_config.get_audio_config()
|
||||||
|
SYSTEM_VAD_CONFIG = system_config.get_vad_config()
|
||||||
|
SYSTEM_SOCKET_SERVER_CONFIG = system_config.get_socket_server_config()
|
||||||
|
SYSTEM_LOGGING_CONFIG = system_config.get_logging_config()
|
||||||
|
SYSTEM_STT_CONFIG = system_config.get_stt_config()
|
||||||
|
SYSTEM_TTS_CONFIG = system_config.get_tts_config()
|
||||||
|
SYSTEM_TEXT_PREPROCESSOR_CONFIG = system_config.get_text_preprocessor_config()
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG = system_config.get_recognizer_config()
|
||||||
|
SYSTEM_CLOUD_VOICE_CONFIG = system_config.get_cloud_voice_config()
|
||||||
|
SYSTEM_ASSISTANT_CONFIG = system_config.get_assistant_config()
|
||||||
|
|
||||||
|
|
||||||
|
def load_cloud_voice_px4_context() -> dict:
|
||||||
|
"""
|
||||||
|
加载合并到云端 session.start.client 的 PX4/MAV 扩展字段。
|
||||||
|
路径:环境变量 ROCKET_CLOUD_PX4_CONTEXT_FILE,否则 cloud_voice.px4_context_file(相对工程根)。
|
||||||
|
"""
|
||||||
|
cv = SYSTEM_CLOUD_VOICE_CONFIG if isinstance(SYSTEM_CLOUD_VOICE_CONFIG, dict) else {}
|
||||||
|
raw = (os.environ.get("ROCKET_CLOUD_PX4_CONTEXT_FILE") or "").strip()
|
||||||
|
rel = raw or str(cv.get("px4_context_file") or "").strip()
|
||||||
|
if not rel:
|
||||||
|
return {}
|
||||||
|
p = Path(rel)
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = _PROJECT_ROOT / rel
|
||||||
|
if not p.is_file():
|
||||||
|
_cfg_log.warning("cloud_voice PX4 上下文文件不存在,已跳过: %s", p)
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
data = load_config(str(p))
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
_cfg_log.warning("读取 PX4 上下文 YAML 失败: %s — %s", p, e)
|
||||||
|
return {}
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return {}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_CLOUD_VOICE_PX4_CONTEXT = load_cloud_voice_px4_context()
|
||||||
|
|
||||||
|
# 命令配置常量
|
||||||
|
command_config = CommandConfigLoader()
|
||||||
|
TAKEOFF_CONFIG = command_config.get_takeoff_config()
|
||||||
|
LAND_CONFIG = command_config.get_land_config()
|
||||||
|
FOLLOW_CONFIG = command_config.get_follow_config()
|
||||||
|
FORWARD_CONFIG = command_config.get_forward_config()
|
||||||
|
BACKWARD_CONFIG = command_config.get_backward_config()
|
||||||
|
LEFT_CONFIG = command_config.get_left_config()
|
||||||
|
RIGHT_CONFIG = command_config.get_right_config()
|
||||||
|
UP_CONFIG = command_config.get_up_config()
|
||||||
|
DOWN_CONFIG = command_config.get_down_config()
|
||||||
|
HOVER_CONFIG = command_config.get_hover_config()
|
||||||
|
RETURN_HOME_CONFIG = command_config.get_return_home_config()
|
||||||
|
|
||||||
|
# 关键词配置常量
|
||||||
|
keywords_config = KeywordsConfigLoader()
|
||||||
|
KEYWORDS_CONFIG = keywords_config.get_keywords()
|
||||||
|
|
||||||
|
# 唤醒词配置常量
|
||||||
|
wake_word_config = WakeWordConfigLoader()
|
||||||
|
WAKE_WORD_PRIMARY = wake_word_config.get_primary()
|
||||||
|
WAKE_WORD_VARIANTS = wake_word_config.get_variants()
|
||||||
|
WAKE_WORD_MATCHING_CONFIG = wake_word_config.get_matching_config()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(TAKEOFF_CONFIG)
|
||||||
|
print(LAND_CONFIG)
|
||||||
|
print(FORWARD_CONFIG)
|
||||||
|
print(BACKWARD_CONFIG)
|
||||||
|
print(LEFT_CONFIG)
|
||||||
|
print(RIGHT_CONFIG)
|
||||||
|
print(UP_CONFIG)
|
||||||
|
print(DOWN_CONFIG)
|
||||||
|
print(HOVER_CONFIG)
|
||||||
|
print(KEYWORDS_CONFIG)
|
||||||
338
voice_drone/core/flight_intent.py
Normal file
338
voice_drone/core/flight_intent.py
Normal file
@ -0,0 +1,338 @@
|
|||||||
|
"""
|
||||||
|
flight_intent v1 校验与辅助(对齐 docs/FLIGHT_INTENT_SCHEMA_v1.md)。
|
||||||
|
|
||||||
|
兼容 Pydantic v2(field_validator / ConfigDict)。
|
||||||
|
|
||||||
|
互操作:部分云端会把「悬停 N 秒」写成 hover.args.duration;规范建议用 hover + wait.seconds。
|
||||||
|
解析时会折叠为 hover(无 duration)+ wait(seconds),与机端执行一致。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
_COORD_CAP = 10_000.0
|
||||||
|
|
||||||
|
|
||||||
|
def _check_coord(name: str, v: Optional[float]) -> Optional[float]:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(v, (int, float)) or not math.isfinite(float(v)):
|
||||||
|
raise ValueError(f"{name} must be a finite number")
|
||||||
|
fv = float(v)
|
||||||
|
if abs(fv) > _COORD_CAP:
|
||||||
|
raise ValueError(f"{name} out of range (|.| <= {_COORD_CAP})")
|
||||||
|
return fv
|
||||||
|
|
||||||
|
|
||||||
|
# --- args -----------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TakeoffArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
relative_altitude_m: Optional[float] = None
|
||||||
|
|
||||||
|
@field_validator("relative_altitude_m")
|
||||||
|
@classmethod
|
||||||
|
def _alt(cls, v: Optional[float]) -> Optional[float]:
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValueError("relative_altitude_m must be > 0 when set")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
class HoverHoldArgs(BaseModel):
|
||||||
|
"""hover / hold:规范仅 {}; 为兼容云端可带 duration(秒),解析后展开为 wait。"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
duration: Optional[float] = None
|
||||||
|
|
||||||
|
@field_validator("duration")
|
||||||
|
@classmethod
|
||||||
|
def _dur(cls, v: Optional[float]) -> Optional[float]:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
fv = float(v)
|
||||||
|
if not (0 < fv <= 3600):
|
||||||
|
raise ValueError("duration must satisfy 0 < duration <= 3600 when set")
|
||||||
|
return fv
|
||||||
|
|
||||||
|
|
||||||
|
class WaitArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
seconds: float
|
||||||
|
|
||||||
|
@field_validator("seconds")
|
||||||
|
@classmethod
|
||||||
|
def _rng(cls, v: float) -> float:
|
||||||
|
if not (0 < v <= 3600):
|
||||||
|
raise ValueError("seconds must satisfy 0 < seconds <= 3600")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class GotoArgs(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
frame: str
|
||||||
|
x: Optional[float] = None
|
||||||
|
y: Optional[float] = None
|
||||||
|
z: Optional[float] = None
|
||||||
|
|
||||||
|
@field_validator("frame")
|
||||||
|
@classmethod
|
||||||
|
def _frame(cls, v: str) -> str:
|
||||||
|
if v not in ("local_ned", "body_ned"):
|
||||||
|
raise ValueError('frame must be "local_ned" or "body_ned"')
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("x", "y", "z")
|
||||||
|
@classmethod
|
||||||
|
def _coord(cls, v: Optional[float], info: ValidationInfo) -> Optional[float]:
|
||||||
|
return _check_coord(str(info.field_name), v)
|
||||||
|
|
||||||
|
|
||||||
|
# --- actions --------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ActionTakeoff(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: Literal["takeoff"] = "takeoff"
|
||||||
|
args: TakeoffArgs
|
||||||
|
|
||||||
|
|
||||||
|
class ActionLand(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: Literal["land"] = "land"
|
||||||
|
args: EmptyArgs
|
||||||
|
|
||||||
|
|
||||||
|
class ActionReturnHome(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: Literal["return_home"] = "return_home"
|
||||||
|
args: EmptyArgs
|
||||||
|
|
||||||
|
|
||||||
|
class ActionHover(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: Literal["hover"] = "hover"
|
||||||
|
args: HoverHoldArgs
|
||||||
|
|
||||||
|
|
||||||
|
class ActionHold(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: Literal["hold"] = "hold"
|
||||||
|
args: HoverHoldArgs
|
||||||
|
|
||||||
|
|
||||||
|
class ActionGoto(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: Literal["goto"] = "goto"
|
||||||
|
args: GotoArgs
|
||||||
|
|
||||||
|
|
||||||
|
class ActionWait(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: Literal["wait"] = "wait"
|
||||||
|
args: WaitArgs
|
||||||
|
|
||||||
|
|
||||||
|
FlightAction = Union[
|
||||||
|
ActionTakeoff,
|
||||||
|
ActionLand,
|
||||||
|
ActionReturnHome,
|
||||||
|
ActionHover,
|
||||||
|
ActionHold,
|
||||||
|
ActionGoto,
|
||||||
|
ActionWait,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FlightIntentPayload(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
is_flight_intent: bool
|
||||||
|
version: int
|
||||||
|
actions: List[Any]
|
||||||
|
summary: str
|
||||||
|
trace_id: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator("is_flight_intent")
|
||||||
|
@classmethod
|
||||||
|
def _flag(cls, v: bool) -> bool:
|
||||||
|
if v is not True:
|
||||||
|
raise ValueError("is_flight_intent must be true")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("version")
|
||||||
|
@classmethod
|
||||||
|
def _ver(cls, v: int) -> int:
|
||||||
|
if v != 1:
|
||||||
|
raise ValueError("version must be 1")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("summary")
|
||||||
|
@classmethod
|
||||||
|
def _sum(cls, v: str) -> str:
|
||||||
|
if not (isinstance(v, str) and v.strip()):
|
||||||
|
raise ValueError("summary must be non-empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("trace_id")
|
||||||
|
@classmethod
|
||||||
|
def _tid(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
if v is not None and len(v) > 128:
|
||||||
|
raise ValueError("trace_id length must be <= 128")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("actions")
|
||||||
|
@classmethod
|
||||||
|
def _actions_nonempty(cls, v: List[Any]) -> List[Any]:
|
||||||
|
if not isinstance(v, list) or len(v) == 0:
|
||||||
|
raise ValueError("actions must be a non-empty array")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidatedFlightIntent:
|
||||||
|
summary: str
|
||||||
|
trace_id: Optional[str]
|
||||||
|
actions: List[FlightAction]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_one_action(raw: dict) -> FlightAction:
|
||||||
|
t = raw.get("type")
|
||||||
|
if not isinstance(t, str):
|
||||||
|
raise ValueError("action.type must be a string")
|
||||||
|
args = raw.get("args")
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
raise ValueError("action.args must be an object")
|
||||||
|
if t == "takeoff":
|
||||||
|
return ActionTakeoff(args=TakeoffArgs.model_validate(args))
|
||||||
|
if t == "land":
|
||||||
|
return ActionLand(args=EmptyArgs.model_validate(args))
|
||||||
|
if t == "return_home":
|
||||||
|
return ActionReturnHome(args=EmptyArgs.model_validate(args))
|
||||||
|
if t == "hover":
|
||||||
|
return ActionHover(args=HoverHoldArgs.model_validate(args))
|
||||||
|
if t == "hold":
|
||||||
|
return ActionHold(args=HoverHoldArgs.model_validate(args))
|
||||||
|
if t == "goto":
|
||||||
|
return ActionGoto(args=GotoArgs.model_validate(args))
|
||||||
|
if t == "wait":
|
||||||
|
return ActionWait(args=WaitArgs.model_validate(args))
|
||||||
|
raise ValueError(f"unknown action.type: {t!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_hover_duration(actions: List[FlightAction]) -> List[FlightAction]:
|
||||||
|
"""将 hover/hold 上附带的 duration 转为标准 wait(seconds)。"""
|
||||||
|
out: List[FlightAction] = []
|
||||||
|
for a in actions:
|
||||||
|
if isinstance(a, ActionHover):
|
||||||
|
d = a.args.duration
|
||||||
|
if d is not None:
|
||||||
|
out.append(ActionHover(args=HoverHoldArgs()))
|
||||||
|
out.append(ActionWait(args=WaitArgs(seconds=float(d))))
|
||||||
|
else:
|
||||||
|
out.append(a)
|
||||||
|
elif isinstance(a, ActionHold):
|
||||||
|
d = a.args.duration
|
||||||
|
if d is not None:
|
||||||
|
out.append(ActionHold(args=HoverHoldArgs()))
|
||||||
|
out.append(ActionWait(args=WaitArgs(seconds=float(d))))
|
||||||
|
else:
|
||||||
|
out.append(a)
|
||||||
|
else:
|
||||||
|
out.append(a)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def parse_flight_intent_dict(data: dict) -> Tuple[Optional[ValidatedFlightIntent], List[str]]:
|
||||||
|
"""
|
||||||
|
L1–L3 校验。成功返回 (ValidatedFlightIntent, []);失败返回 (None, [错误信息, ...])。
|
||||||
|
"""
|
||||||
|
errors: List[str] = []
|
||||||
|
try:
|
||||||
|
top = FlightIntentPayload.model_validate(data)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return None, [str(e)]
|
||||||
|
|
||||||
|
parsed_actions: List[FlightAction] = []
|
||||||
|
for i, item in enumerate(top.actions):
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
errors.append(f"actions[{i}] must be an object")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed_actions.append(_parse_one_action(item))
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
errors.append(f"actions[{i}]: {e}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
return None, errors
|
||||||
|
|
||||||
|
parsed_actions = _expand_hover_duration(parsed_actions)
|
||||||
|
|
||||||
|
if isinstance(parsed_actions[0], ActionWait):
|
||||||
|
return None, ["first action must not be wait (nothing to control yet)"]
|
||||||
|
|
||||||
|
return (
|
||||||
|
ValidatedFlightIntent(
|
||||||
|
summary=top.summary.strip(),
|
||||||
|
trace_id=top.trace_id,
|
||||||
|
actions=parsed_actions,
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def goto_action_to_command(action: ActionGoto, sequence_id: int) -> Tuple[Optional[Any], Optional[str]]:
|
||||||
|
"""
|
||||||
|
将单轴 goto 映射为现有 Socket Command。
|
||||||
|
返回 (Command | None, error_reason | None)。
|
||||||
|
"""
|
||||||
|
from voice_drone.core.command import Command
|
||||||
|
|
||||||
|
a = action.args
|
||||||
|
coords = [
|
||||||
|
("x", a.x),
|
||||||
|
("y", a.y),
|
||||||
|
("z", a.z),
|
||||||
|
]
|
||||||
|
active = [(name, val) for name, val in coords if val is not None and val != 0]
|
||||||
|
if len(active) == 0:
|
||||||
|
return None, "goto: all axes omit or zero (no-op)"
|
||||||
|
if len(active) > 1:
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
f"goto: multi-axis ({', '.join(n for n, _ in active)}) not sent via Socket "
|
||||||
|
"(use bridge or decompose)",
|
||||||
|
)
|
||||||
|
|
||||||
|
name, val = active[0]
|
||||||
|
dist = abs(float(val))
|
||||||
|
body_map = {
|
||||||
|
"x": ("forward", "backward"),
|
||||||
|
"y": ("right", "left"),
|
||||||
|
"z": ("down", "up"),
|
||||||
|
}
|
||||||
|
pos, neg = body_map[name]
|
||||||
|
cmd_name = pos if val > 0 else neg
|
||||||
|
cmd = Command.create(cmd_name, sequence_id, distance=dist)
|
||||||
|
cmd.fill_defaults()
|
||||||
|
return cmd, None
|
||||||
186
voice_drone/core/mic_device_select.py
Normal file
186
voice_drone/core/mic_device_select.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
"""启动时列出 arecord -l 与 PyAudio 输入设备,并把 ALSA card/device 映射到 PyAudio 索引供交互选择。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from typing import List, Optional, Tuple, Any
|
||||||
|
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("mic_device_select")
|
||||||
|
|
||||||
|
|
||||||
|
def run_arecord_l() -> str:
|
||||||
|
try:
|
||||||
|
r = subprocess.run(
|
||||||
|
["arecord", "-l"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
out = (r.stdout or "").rstrip()
|
||||||
|
err = (r.stderr or "").strip()
|
||||||
|
body = out + (f"\n{err}" if err else "")
|
||||||
|
return body.strip() if body.strip() else "(arecord 无输出)"
|
||||||
|
except FileNotFoundError:
|
||||||
|
return "(未找到 arecord,可安装 alsa-utils)"
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
return f"(执行 arecord -l 失败: {e})"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arecord_capture_lines(text: str) -> List[Tuple[int, int, str]]:
|
||||||
|
rows: List[Tuple[int, int, str]] = []
|
||||||
|
for line in text.splitlines():
|
||||||
|
m = re.search(r"card\s+(\d+):.+?,\s*device\s+(\d+):", line, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
rows.append((int(m.group(1)), int(m.group(2)), line.strip()))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def _pyaudio_input_devices() -> List[Tuple[int, Any]]:
|
||||||
|
from voice_drone.core.portaudio_env import fix_ld_path_for_portaudio
|
||||||
|
|
||||||
|
fix_ld_path_for_portaudio()
|
||||||
|
import pyaudio
|
||||||
|
|
||||||
|
pa = pyaudio.PyAudio()
|
||||||
|
out: List[Tuple[int, Any]] = []
|
||||||
|
try:
|
||||||
|
for i in range(pa.get_device_count()):
|
||||||
|
try:
|
||||||
|
inf = pa.get_device_info_by_index(i)
|
||||||
|
if int(inf.get("maxInputChannels", 0)) <= 0:
|
||||||
|
continue
|
||||||
|
out.append((i, inf))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return out
|
||||||
|
finally:
|
||||||
|
pa.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def match_alsa_hw_to_pyaudio_index(
|
||||||
|
card: int,
|
||||||
|
dev: int,
|
||||||
|
pa_items: List[Tuple[int, Any]],
|
||||||
|
) -> Optional[int]:
|
||||||
|
want1 = f"(hw:{card},{dev})"
|
||||||
|
want2 = f"(hw:{card}, {dev})"
|
||||||
|
for idx, inf in pa_items:
|
||||||
|
name = str(inf.get("name", ""))
|
||||||
|
if want1 in name or want2 in name:
|
||||||
|
return idx
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def print_mic_device_menu() -> List[int]:
|
||||||
|
"""
|
||||||
|
打印 arecord + PyAudio + 映射表。
|
||||||
|
返回本菜单中列出的 PyAudio 索引列表(顺序与 [1]、[2]… 一致)。
|
||||||
|
"""
|
||||||
|
alsa_text = run_arecord_l()
|
||||||
|
pa_items = _pyaudio_input_devices()
|
||||||
|
|
||||||
|
print("\n" + "=" * 72, flush=True)
|
||||||
|
print("录音设备(先 ALSA,再 PortAudio;请记下要用的 PyAudio 索引)", flush=True)
|
||||||
|
print("=" * 72, flush=True)
|
||||||
|
print("\n--- arecord -l(系统硬件视角)---\n", flush=True)
|
||||||
|
print(alsa_text, flush=True)
|
||||||
|
|
||||||
|
print("\n--- PyAudio 可录音设备(maxInputChannels > 0)---\n", flush=True)
|
||||||
|
ordered_indices: List[int] = []
|
||||||
|
if not pa_items:
|
||||||
|
print("(无)\n", flush=True)
|
||||||
|
for rank, (idx, inf) in enumerate(pa_items, start=1):
|
||||||
|
ordered_indices.append(idx)
|
||||||
|
mic = int(inf.get("maxInputChannels", 0))
|
||||||
|
outc = int(inf.get("maxOutputChannels", 0))
|
||||||
|
name = str(inf.get("name", "?"))
|
||||||
|
print(
|
||||||
|
f" [{rank}] PyAudio_index={idx} in={mic} out={outc} {name}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
alsa_rows = parse_arecord_capture_lines(alsa_text)
|
||||||
|
print(
|
||||||
|
"\n--- 映射:arecord 的 card / device → PyAudio 索引(匹配设备名中的 hw:X,Y)---\n",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if not alsa_rows:
|
||||||
|
print(" (未解析到 card/device 行,请直接用上一表的 PyAudio_index)", flush=True)
|
||||||
|
for card, dev, line in alsa_rows:
|
||||||
|
pidx = match_alsa_hw_to_pyaudio_index(card, dev, pa_items)
|
||||||
|
short = line if len(line) <= 76 else line[:73] + "..."
|
||||||
|
if pidx is not None:
|
||||||
|
print(
|
||||||
|
f" card {card}, device {dev} → PyAudio 索引 {pidx}\n {short}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f" card {card}, device {dev} → (无 in>0 设备名含 hw:{card},{dev})\n {short}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"\n说明:程序只会用「一个」PyAudio 索引打开麦克风;"
|
||||||
|
"HDMI 等若 in=0 不会出现在可录音列表。\n"
|
||||||
|
+ "=" * 72
|
||||||
|
+ "\n",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
return ordered_indices
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_for_input_device_index() -> int:
|
||||||
|
"""交互式选择,返回写入 audio.input_device_index 的 PyAudio 索引。"""
|
||||||
|
ordered = print_mic_device_menu()
|
||||||
|
if not ordered:
|
||||||
|
print("错误:没有发现可录音的 PyAudio 设备。", flush=True)
|
||||||
|
raise SystemExit(2)
|
||||||
|
|
||||||
|
valid_set = set(ordered)
|
||||||
|
print(
|
||||||
|
"请输入菜单序号 [1-"
|
||||||
|
f"{len(ordered)}](推荐),或直接输入 PyAudio_index 数字;q 退出。",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
raw = input("录音设备> ").strip()
|
||||||
|
except EOFError:
|
||||||
|
raise SystemExit(1) from None
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
if raw.lower() in ("q", "quit", "exit"):
|
||||||
|
raise SystemExit(0)
|
||||||
|
if not raw.isdigit():
|
||||||
|
print("请输入正整数或 q。", flush=True)
|
||||||
|
continue
|
||||||
|
n = int(raw)
|
||||||
|
if 1 <= n <= len(ordered):
|
||||||
|
chosen = ordered[n - 1]
|
||||||
|
print(f"已选择:菜单 [{n}] → PyAudio 索引 {chosen}\n", flush=True)
|
||||||
|
logger.info("交互选择录音设备 PyAudio index=%s", chosen)
|
||||||
|
return chosen
|
||||||
|
if n in valid_set:
|
||||||
|
print(f"已选择:PyAudio 索引 {n}\n", flush=True)
|
||||||
|
logger.info("交互选择录音设备 PyAudio index=%s", n)
|
||||||
|
return n
|
||||||
|
print(
|
||||||
|
f"无效:{n} 不在可选列表。可选序号为 1~{len(ordered)},"
|
||||||
|
f"或索引之一 {sorted(valid_set)}。",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_input_device_index_only(index: int) -> None:
|
||||||
|
"""写入运行时配置:仅用索引选设备,其余 yaml 中的 hw/名称匹配不再参与 logic。"""
|
||||||
|
from voice_drone.core.configuration import SYSTEM_AUDIO_CONFIG
|
||||||
|
|
||||||
|
SYSTEM_AUDIO_CONFIG["input_device_index"] = int(index)
|
||||||
|
SYSTEM_AUDIO_CONFIG["input_hw_card_device"] = None
|
||||||
|
SYSTEM_AUDIO_CONFIG["input_device_name_match"] = None
|
||||||
|
SYSTEM_AUDIO_CONFIG["input_strict_selection"] = False
|
||||||
60
voice_drone/core/portaudio_env.py
Normal file
60
voice_drone/core/portaudio_env.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""PortAudio/PyAudio 启动前调整动态库搜索路径。
|
||||||
|
|
||||||
|
conda 环境下 `.../envs/xxx/lib` 里的 libasound 会到同前缀的 alsa-lib 子目录找插件,
|
||||||
|
该目录常缺 libasound_module_*.so,日志里刷屏且采集电平可能异常。
|
||||||
|
|
||||||
|
处理:1) 去掉仅含插件目录的路径;2) 把系统 /usr/lib/<triplet> 插到 LD_LIBRARY_PATH 最前,
|
||||||
|
让动态链接优先用系统的 libasound。"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
|
||||||
|
|
||||||
|
def strip_conda_alsa_from_ld_library_path() -> None:
|
||||||
|
ld = os.environ.get("LD_LIBRARY_PATH", "")
|
||||||
|
if not ld:
|
||||||
|
return
|
||||||
|
parts: list[str] = []
|
||||||
|
for p in ld.split(":"):
|
||||||
|
if not p:
|
||||||
|
continue
|
||||||
|
pl = p.lower()
|
||||||
|
if "conda" in pl or "miniconda" in pl or "mamba" in pl:
|
||||||
|
if "alsa-lib" in pl or "alsa_lib" in pl:
|
||||||
|
continue
|
||||||
|
parts.append(p)
|
||||||
|
if parts:
|
||||||
|
os.environ["LD_LIBRARY_PATH"] = ":".join(parts)
|
||||||
|
else:
|
||||||
|
os.environ.pop("LD_LIBRARY_PATH", None)
|
||||||
|
|
||||||
|
|
||||||
|
def prepend_system_lib_dirs_for_alsa() -> None:
|
||||||
|
"""Linux:把系统 lib 目录放在 LD_LIBRARY_PATH 最前面。"""
|
||||||
|
if platform.system() != "Linux":
|
||||||
|
return
|
||||||
|
triplet = {
|
||||||
|
"aarch64": ("/usr/lib/aarch64-linux-gnu", "/lib/aarch64-linux-gnu"),
|
||||||
|
"x86_64": ("/usr/lib/x86_64-linux-gnu", "/lib/x86_64-linux-gnu"),
|
||||||
|
"amd64": ("/usr/lib/x86_64-linux-gnu", "/lib/x86_64-linux-gnu"),
|
||||||
|
}.get(platform.machine().lower())
|
||||||
|
if not triplet:
|
||||||
|
return
|
||||||
|
prepend: list[str] = []
|
||||||
|
for d in triplet:
|
||||||
|
if os.path.isdir(d):
|
||||||
|
prepend.append(d)
|
||||||
|
if not prepend:
|
||||||
|
return
|
||||||
|
rest = [p for p in os.environ.get("LD_LIBRARY_PATH", "").split(":") if p]
|
||||||
|
out: list[str] = []
|
||||||
|
for p in prepend + rest:
|
||||||
|
if p not in out:
|
||||||
|
out.append(p)
|
||||||
|
os.environ["LD_LIBRARY_PATH"] = ":".join(out)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_ld_path_for_portaudio() -> None:
|
||||||
|
prepend_system_lib_dirs_for_alsa()
|
||||||
|
strip_conda_alsa_from_ld_library_path()
|
||||||
115
voice_drone/core/qwen_intent_chat.py
Normal file
115
voice_drone/core/qwen_intent_chat.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
"""与 scripts/qwen_flight_intent_sim.py 对齐:飞控意图 JSON vs 闲聊,供语音主程序内调用。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
# 与 qwen_flight_intent_sim._SYSTEM 保持一致
|
||||||
|
FLIGHT_INTENT_CHAT_SYSTEM = """你是无人机飞控意图助手,只做两件事(必须二选一):
|
||||||
|
|
||||||
|
【规则 A — 飞控相关】当用户话里包含对无人机的飞行任务、航线、起降、返航、悬停、等待、速度高度、坐标点、offboard、PX4/MAVROS 等操作意图时:
|
||||||
|
只输出一行 JSON,且不要有任何其它字符、不要 Markdown、不要代码块。
|
||||||
|
JSON Schema 含义(见仓库 docs/FLIGHT_INTENT_SCHEMA_v1.md):
|
||||||
|
{
|
||||||
|
"is_flight_intent": true,
|
||||||
|
"version": 1,
|
||||||
|
"actions": [ // 按时间顺序排列
|
||||||
|
{"type": "takeoff", "args": {}},
|
||||||
|
{"type": "takeoff", "args": {"relative_altitude_m": number}},
|
||||||
|
{"type": "goto", "args": {"frame": "local_ned"|"body_ned", "x": number|null, "y": number|null, "z": number|null}},
|
||||||
|
{"type": "land" | "return_home" | "hover" | "hold", "args": {}},
|
||||||
|
{"type": "wait", "args": {"seconds": number}}
|
||||||
|
],
|
||||||
|
"summary": "一句话概括",
|
||||||
|
"trace_id": "可选,简短追踪ID"
|
||||||
|
}
|
||||||
|
- 停多久、延迟多久必须用 wait,例如「悬停 3 秒再降落」应为 hover → wait(3) → land;不要把秒数写进 summary 代替 wait。
|
||||||
|
- 坐标缺省 frame 时用 "local_ned";无法确定的数字可省略字段或用 null。
|
||||||
|
- 返程/返航映射为 {"type":"return_home","args":{}}。
|
||||||
|
- 仅允许小写 type;args 只含规范允许键,禁止多余键。
|
||||||
|
|
||||||
|
【规则 B — 非飞控】若只是日常聊天、与无人机任务无关:用正常的自然中文回复,不要输出 JSON,不要用花括号开头。"""
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_fenced_json(text: str) -> str:
|
||||||
|
text = text.strip()
|
||||||
|
m = re.match(r"^```(?:json)?\s*\n?(.*)\n?```\s*$", text, re.DOTALL | re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
return m.group(1).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _first_balanced_json_object(text: str) -> Optional[str]:
|
||||||
|
t = _strip_fenced_json(text)
|
||||||
|
start = t.find("{")
|
||||||
|
if start < 0:
|
||||||
|
return None
|
||||||
|
depth = 0
|
||||||
|
for i in range(start, len(t)):
|
||||||
|
if t[i] == "{":
|
||||||
|
depth += 1
|
||||||
|
elif t[i] == "}":
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return t[start : i + 1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_flight_intent_reply(raw: str) -> Tuple[str, Optional[dict[str, Any]]]:
|
||||||
|
"""返回 (模式标签, 若为飞控则 dict 否则 None)。"""
|
||||||
|
chunk = _first_balanced_json_object(raw)
|
||||||
|
if chunk:
|
||||||
|
try:
|
||||||
|
obj = json.loads(chunk)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return "闲聊", None
|
||||||
|
if isinstance(obj, dict) and obj.get("is_flight_intent") is True:
|
||||||
|
return "飞控意图JSON", obj
|
||||||
|
return "闲聊", None
|
||||||
|
|
||||||
|
|
||||||
|
def default_qwen_gguf_path(project_root: Path) -> Path:
|
||||||
|
"""子工程优先本目录 cache/;不存在时回退到上级仓库(同级 rocket_drone_audio/cache/)。"""
|
||||||
|
name = "qwen2.5-1.5b-instruct-q4_k_m.gguf"
|
||||||
|
primary = project_root / "cache" / "qwen25-1.5b-gguf" / name
|
||||||
|
if primary.is_file():
|
||||||
|
return primary
|
||||||
|
legacy = project_root.parent / "cache" / "qwen25-1.5b-gguf" / name
|
||||||
|
if legacy.is_file():
|
||||||
|
return legacy
|
||||||
|
return primary
|
||||||
|
|
||||||
|
|
||||||
|
def load_llama_qwen(
|
||||||
|
model_path: Path,
|
||||||
|
n_ctx: int = 4096,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
llama-cpp-python 封装。可选环境变量(详见 rocket_drone_audio 文件头):
|
||||||
|
ROCKET_LLM_N_THREADS、ROCKET_LLM_N_GPU_LAYERS、ROCKET_LLM_N_BATCH。
|
||||||
|
"""
|
||||||
|
if not model_path.is_file():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from llama_cpp import Llama
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
opts: dict = {
|
||||||
|
"model_path": str(model_path),
|
||||||
|
"n_ctx": int(n_ctx),
|
||||||
|
"verbose": False,
|
||||||
|
}
|
||||||
|
nt = os.environ.get("ROCKET_LLM_N_THREADS", "").strip()
|
||||||
|
if nt.isdigit() or (nt.startswith("-") and nt[1:].isdigit()):
|
||||||
|
opts["n_threads"] = max(1, int(nt))
|
||||||
|
ng = os.environ.get("ROCKET_LLM_N_GPU_LAYERS", "").strip()
|
||||||
|
if ng.isdigit():
|
||||||
|
opts["n_gpu_layers"] = int(ng)
|
||||||
|
nb = os.environ.get("ROCKET_LLM_N_BATCH", "").strip()
|
||||||
|
if nb.isdigit():
|
||||||
|
opts["n_batch"] = max(1, int(nb))
|
||||||
|
return Llama(**opts)
|
||||||
969
voice_drone/core/recognizer.py
Normal file
969
voice_drone/core/recognizer.py
Normal file
@ -0,0 +1,969 @@
|
|||||||
|
"""
|
||||||
|
高性能实时语音识别与命令生成系统
|
||||||
|
|
||||||
|
整合所有模块,实现从语音检测到命令发送的完整流程:
|
||||||
|
1. 音频采集(高性能模式)
|
||||||
|
2. 音频预处理(降噪+AGC)
|
||||||
|
3. VAD语音活动检测
|
||||||
|
4. STT语音识别
|
||||||
|
5. 文本预处理(纠错+参数提取)
|
||||||
|
6. 命令生成
|
||||||
|
7. Socket发送
|
||||||
|
|
||||||
|
性能优化:
|
||||||
|
- 多线程异步处理
|
||||||
|
- 非阻塞音频采集
|
||||||
|
- LRU缓存优化
|
||||||
|
- 低延迟设计
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
from voice_drone.core.audio import AudioCapture, AudioPreprocessor
|
||||||
|
from voice_drone.core.vad import VAD
|
||||||
|
from voice_drone.core.stt import STT
|
||||||
|
from voice_drone.core.text_preprocessor import TextPreprocessor, get_preprocessor
|
||||||
|
from voice_drone.core.command import Command
|
||||||
|
from voice_drone.core.scoket_client import SocketClient
|
||||||
|
from voice_drone.core.configuration import (
|
||||||
|
SYSTEM_AUDIO_CONFIG,
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG,
|
||||||
|
SYSTEM_SOCKET_SERVER_CONFIG,
|
||||||
|
)
|
||||||
|
from voice_drone.core.tts_ack_cache import (
|
||||||
|
compute_ack_pcm_fingerprint,
|
||||||
|
load_cached_phrases,
|
||||||
|
persist_phrases,
|
||||||
|
)
|
||||||
|
from voice_drone.core.wake_word import WakeWordDetector, get_wake_word_detector
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from voice_drone.core.tts import KokoroOnnxTTS
|
||||||
|
|
||||||
|
logger = get_logger("recognizer")
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceCommandRecognizer:
|
||||||
|
"""
|
||||||
|
高性能实时语音命令识别器
|
||||||
|
|
||||||
|
完整的语音转命令系统,包括:
|
||||||
|
- 音频采集和预处理
|
||||||
|
- 语音活动检测
|
||||||
|
- 语音识别
|
||||||
|
- 文本预处理和参数提取
|
||||||
|
- 命令生成
|
||||||
|
- Socket发送
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auto_connect_socket: bool = True):
|
||||||
|
"""
|
||||||
|
初始化语音命令识别器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_connect_socket: 是否自动连接Socket服务器
|
||||||
|
"""
|
||||||
|
logger.info("初始化语音命令识别系统...")
|
||||||
|
|
||||||
|
# 初始化各模块
|
||||||
|
self.audio_capture = AudioCapture()
|
||||||
|
self.audio_preprocessor = AudioPreprocessor()
|
||||||
|
self.vad = VAD()
|
||||||
|
self.stt = STT()
|
||||||
|
self.text_preprocessor = get_preprocessor() # 使用全局单例
|
||||||
|
self.wake_word_detector = get_wake_word_detector() # 使用全局单例
|
||||||
|
|
||||||
|
# Socket客户端
|
||||||
|
self.socket_client = SocketClient(SYSTEM_SOCKET_SERVER_CONFIG)
|
||||||
|
self.auto_connect_socket = auto_connect_socket
|
||||||
|
if self.auto_connect_socket:
|
||||||
|
if not self.socket_client.connect():
|
||||||
|
logger.warning("Socket连接失败,将在发送命令时重试")
|
||||||
|
|
||||||
|
# 语音段缓冲区
|
||||||
|
self.speech_buffer: list = [] # 存储语音音频块
|
||||||
|
self.speech_buffer_lock = threading.Lock()
|
||||||
|
|
||||||
|
# 预缓冲区:保存语音检测前一小段音频,避免丢失开头
|
||||||
|
# 例如:pre_speech_max_seconds = 0.8 表示保留最近约 0.8 秒音频
|
||||||
|
self.pre_speech_buffer: list = [] # 存储最近的静音/背景音块
|
||||||
|
# 从系统配置读取(确保类型正确:YAML 可能把数值当字符串)
|
||||||
|
self.pre_speech_max_seconds: float = float(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("pre_speech_max_seconds", 0.8)
|
||||||
|
)
|
||||||
|
self.pre_speech_max_chunks: Optional[int] = None # 根据采样率和chunk大小动态计算
|
||||||
|
|
||||||
|
# 命令发送成功后的 TTS 反馈(懒加载 Kokoro,避免拖慢启动)
|
||||||
|
self.ack_tts_enabled = bool(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_enabled", True))
|
||||||
|
self.ack_tts_text = str(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_text", "好的收到")).strip()
|
||||||
|
self.ack_tts_phrases: Dict[str, List[str]] = self._normalize_ack_tts_phrases(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_phrases")
|
||||||
|
)
|
||||||
|
# True:仅 ack_tts_phrases 中出现的命令会播报,且每次随机一句;False:全局 ack_tts_text(所有成功命令同一应答)
|
||||||
|
self._ack_mode_phrases: bool = bool(self.ack_tts_phrases)
|
||||||
|
self.ack_tts_prewarm = bool(SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_prewarm", True))
|
||||||
|
self.ack_tts_prewarm_blocking = bool(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_prewarm_blocking", True)
|
||||||
|
)
|
||||||
|
self.ack_pause_mic_for_playback = bool(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("ack_pause_mic_for_playback", True)
|
||||||
|
)
|
||||||
|
self.ack_tts_disk_cache = bool(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("ack_tts_disk_cache", True)
|
||||||
|
)
|
||||||
|
self._tts_engine: Optional["KokoroOnnxTTS"] = None
|
||||||
|
# 阻塞预加载时缓存波形:全局单句 _tts_ack_pcm,或按命令随机模式下的 _tts_phrase_pcm_cache(每句一条)
|
||||||
|
self._tts_ack_pcm: Optional[Tuple[np.ndarray, int]] = None
|
||||||
|
self._tts_phrase_pcm_cache: Dict[str, Tuple[np.ndarray, int]] = {}
|
||||||
|
self._tts_lock = threading.Lock()
|
||||||
|
# 命令线程只入队,主线程 process_audio_stream 中统一播放(避免 Windows 下后台线程 sd.play 无声)
|
||||||
|
self._ack_playback_queue: queue.Queue = queue.Queue(maxsize=8)
|
||||||
|
|
||||||
|
# STT识别线程和队列
|
||||||
|
self.stt_queue = queue.Queue(maxsize=5) # STT识别队列
|
||||||
|
self.stt_thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
|
# 命令处理线程和队列
|
||||||
|
self.command_queue = queue.Queue(maxsize=10) # 命令处理队列
|
||||||
|
self.command_thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
|
# 运行状态
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# 命令序列号(用于去重和顺序保证)
|
||||||
|
self.sequence_id = 0
|
||||||
|
self.sequence_lock = threading.Lock()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"应答TTS配置: enabled={self.ack_tts_enabled}, "
|
||||||
|
f"mode={'按命令随机短语' if self._ack_mode_phrases else '全局固定文案'}, "
|
||||||
|
f"prewarm_blocking={self.ack_tts_prewarm_blocking}, "
|
||||||
|
f"pause_mic={self.ack_pause_mic_for_playback}, "
|
||||||
|
f"disk_cache={self.ack_tts_disk_cache}"
|
||||||
|
)
|
||||||
|
if self._ack_mode_phrases:
|
||||||
|
logger.info(f" 仅播报命令: {list(self.ack_tts_phrases.keys())}")
|
||||||
|
|
||||||
|
# VAD 后端:silero(默认)或 energy(按块 RMS,Silero 在部分板载麦上长期无段时使用)
|
||||||
|
_ev_env = os.environ.get("ROCKET_ENERGY_VAD", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
)
|
||||||
|
_yaml_backend = str(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("vad_backend", "silero")
|
||||||
|
).lower()
|
||||||
|
self._use_energy_vad: bool = _ev_env or _yaml_backend == "energy"
|
||||||
|
self._energy_rms_high: float = float(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_rms_high", 280)
|
||||||
|
)
|
||||||
|
self._energy_rms_low: float = float(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_rms_low", 150)
|
||||||
|
)
|
||||||
|
self._energy_start_chunks: int = int(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_start_chunks", 4)
|
||||||
|
)
|
||||||
|
self._energy_end_chunks: int = int(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_end_chunks", 15)
|
||||||
|
)
|
||||||
|
# 高噪底/AGC 下 RMS 几乎不低于 energy_vad_rms_low 时,用「相对本段峰值」辅助判停
|
||||||
|
self._energy_end_peak_ratio: float = float(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_end_peak_ratio", 0.88)
|
||||||
|
)
|
||||||
|
# 说话过程中对 utt 峰值每块乘衰减再与当前 rms 取 max,避免前几个字特响导致后半句一直被判「相对衰减」而误切段
|
||||||
|
self._energy_utt_peak_decay: float = float(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("energy_vad_utt_peak_decay", 0.988)
|
||||||
|
)
|
||||||
|
self._energy_utt_peak_decay = max(0.95, min(0.9999, self._energy_utt_peak_decay))
|
||||||
|
self._ev_speaking: bool = False
|
||||||
|
self._ev_high_run: int = 0
|
||||||
|
self._ev_low_run: int = 0
|
||||||
|
self._ev_rms_peak: float = 0.0
|
||||||
|
self._ev_last_diag_time: float = 0.0
|
||||||
|
self._ev_utt_peak: float = 0.0
|
||||||
|
# 可选:能量 VAD 刚进入「正在说话」时回调(用于机端 PROMPT_LISTEN 计时清零等)
|
||||||
|
self._vad_speech_start_hook: Optional[Callable[[], None]] = None
|
||||||
|
|
||||||
|
_trail_raw = SYSTEM_RECOGNIZER_CONFIG.get("trailing_silence_seconds")
|
||||||
|
if _trail_raw is not None:
|
||||||
|
_trail = float(_trail_raw)
|
||||||
|
if _trail > 0:
|
||||||
|
fs = int(SYSTEM_AUDIO_CONFIG.get("frame_size", 1024))
|
||||||
|
sr = int(SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000))
|
||||||
|
if fs > 0 and sr > 0:
|
||||||
|
n_end = max(1, int(math.ceil(_trail * sr / fs)))
|
||||||
|
self._energy_end_chunks = n_end
|
||||||
|
self.vad.silence_end_frames = n_end
|
||||||
|
logger.info(
|
||||||
|
"VAD 句尾切段:trailing_silence_seconds=%.2f → 连续静音块数=%d "
|
||||||
|
"(每块≈%.0fms,Silero 与 energy 共用)",
|
||||||
|
_trail,
|
||||||
|
n_end,
|
||||||
|
1000.0 * fs / sr,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._use_energy_vad:
|
||||||
|
logger.info(
|
||||||
|
"VAD 后端: energy(RMS)"
|
||||||
|
f" high={self._energy_rms_high} low={self._energy_rms_low} "
|
||||||
|
f"start_chunks={self._energy_start_chunks} end_chunks={self._energy_end_chunks}"
|
||||||
|
f" end_peak_ratio={self._energy_end_peak_ratio}"
|
||||||
|
f" utt_peak_decay={self._energy_utt_peak_decay}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("语音命令识别系统初始化完成")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_ack_tts_phrases(raw) -> Dict[str, List[str]]:
|
||||||
|
"""YAML: ack_tts_phrases: { takeoff: [\"...\", ...], ... }"""
|
||||||
|
result: Dict[str, List[str]] = {}
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
return result
|
||||||
|
for k, v in raw.items():
|
||||||
|
key = str(k).strip()
|
||||||
|
if not key:
|
||||||
|
continue
|
||||||
|
if isinstance(v, list):
|
||||||
|
phrases = [str(x).strip() for x in v if str(x).strip()]
|
||||||
|
elif isinstance(v, str) and v.strip():
|
||||||
|
phrases = [v.strip()]
|
||||||
|
else:
|
||||||
|
phrases = []
|
||||||
|
if phrases:
|
||||||
|
result[key] = phrases
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _has_ack_tts_content(self) -> bool:
|
||||||
|
if self._ack_mode_phrases:
|
||||||
|
return any(bool(v) for v in self.ack_tts_phrases.values())
|
||||||
|
return bool(self.ack_tts_text)
|
||||||
|
|
||||||
|
def _pick_ack_phrase(self, command_name: str) -> Optional[str]:
|
||||||
|
if self._ack_mode_phrases:
|
||||||
|
phrases = self.ack_tts_phrases.get(command_name)
|
||||||
|
if not phrases:
|
||||||
|
return None
|
||||||
|
return random.choice(phrases)
|
||||||
|
return self.ack_tts_text or None
|
||||||
|
|
||||||
|
def _get_cached_pcm_for_phrase(self, phrase: str) -> Optional[Tuple[np.ndarray, int]]:
|
||||||
|
"""若启动阶段已预合成该句,则返回缓存,播报时不再跑 ONNX(低延迟)。"""
|
||||||
|
if self._ack_mode_phrases:
|
||||||
|
return self._tts_phrase_pcm_cache.get(phrase)
|
||||||
|
if self._tts_ack_pcm is not None:
|
||||||
|
return self._tts_ack_pcm
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _ensure_tts_engine(self) -> "KokoroOnnxTTS":
|
||||||
|
"""懒加载 Kokoro(双检锁,避免多线程重复加载)。"""
|
||||||
|
from voice_drone.core.tts import KokoroOnnxTTS
|
||||||
|
|
||||||
|
if self._tts_engine is not None:
|
||||||
|
return self._tts_engine
|
||||||
|
with self._tts_lock:
|
||||||
|
if self._tts_engine is None:
|
||||||
|
logger.info("TTS: 正在加载 Kokoro 模型(首次约需十余秒)…")
|
||||||
|
self._tts_engine = KokoroOnnxTTS()
|
||||||
|
logger.info("TTS: Kokoro 模型加载完成")
|
||||||
|
assert self._tts_engine is not None
|
||||||
|
return self._tts_engine
|
||||||
|
|
||||||
|
def _enqueue_ack_playback(self, command_name: str) -> None:
|
||||||
|
"""
|
||||||
|
命令已成功发出后,将待播音频交给主线程队列。
|
||||||
|
|
||||||
|
不在此线程直接调用 sounddevice:Windows 上后台线程常出现播放完全无声。
|
||||||
|
"""
|
||||||
|
if not self.ack_tts_enabled:
|
||||||
|
return
|
||||||
|
phrase = self._pick_ack_phrase(command_name)
|
||||||
|
if not phrase:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
cached = self._get_cached_pcm_for_phrase(phrase)
|
||||||
|
if cached is not None:
|
||||||
|
audio, sr = cached
|
||||||
|
self._ack_playback_queue.put(("pcm", audio.copy(), sr), block=False)
|
||||||
|
logger.info(
|
||||||
|
f"命令已发送,已排队语音应答(主线程播放,预缓存): {phrase!r}"
|
||||||
|
)
|
||||||
|
print(f"[TTS] 已排队语音应答(主线程播放,预缓存): {phrase!r}", flush=True)
|
||||||
|
else:
|
||||||
|
self._ack_playback_queue.put(("synth", phrase), block=False)
|
||||||
|
logger.info(
|
||||||
|
f"命令已发送,已排队语音应答(主线程合成+播放,无缓存,可能有数秒延迟): {phrase!r}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[TTS] 已排队语音应答(主线程合成+播放,无缓存): {phrase!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
except queue.Full:
|
||||||
|
logger.warning("应答语音播放队列已满,跳过本次")
|
||||||
|
|
||||||
|
def _before_audio_iteration(self) -> None:
|
||||||
|
"""主循环每轮开头(主线程):子类可扩展以播放其它排队 TTS。"""
|
||||||
|
self._drain_ack_playback_queue()
|
||||||
|
|
||||||
|
def _drain_ack_playback_queue(self, recover_mic: bool = True) -> None:
|
||||||
|
"""在主线程中播放队列中的应答(与麦克风采集同进程、同主循环线程)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recover_mic: 播完后是否恢复麦克风;退出 shutdown 时应为 False,避免与 stop() 中关流冲突。
|
||||||
|
"""
|
||||||
|
from voice_drone.core.tts import play_tts_audio, speak_text
|
||||||
|
|
||||||
|
items: list = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
items.append(self._ack_playback_queue.get_nowait())
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
if not items:
|
||||||
|
return
|
||||||
|
|
||||||
|
mic_stopped = False
|
||||||
|
if self.ack_pause_mic_for_playback:
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
"TTS: 已暂停麦克风采集以便扬声器播放(避免 Windows 下输入/输出同时开无声)"
|
||||||
|
)
|
||||||
|
self.audio_capture.stop_stream()
|
||||||
|
mic_stopped = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"暂停麦克风失败,将尝试直接播放: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in items:
|
||||||
|
try:
|
||||||
|
kind = item[0]
|
||||||
|
if kind == "pcm":
|
||||||
|
_, audio, sr = item
|
||||||
|
logger.info("TTS: 主线程播放应答(预缓存波形)")
|
||||||
|
play_tts_audio(audio, sr)
|
||||||
|
logger.info("TTS: 播放完成")
|
||||||
|
elif kind == "synth":
|
||||||
|
logger.info("TTS: 主线程合成并播放应答(无预缓存)")
|
||||||
|
tts = self._ensure_tts_engine()
|
||||||
|
text = item[1] if len(item) >= 2 else (self.ack_tts_text or "")
|
||||||
|
speak_text(text, tts=tts)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"应答语音播放失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if mic_stopped and recover_mic:
|
||||||
|
try:
|
||||||
|
self.audio_capture.start_stream()
|
||||||
|
try:
|
||||||
|
self.audio_preprocessor.reset()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
logger.debug("audio_preprocessor.reset: %s", e)
|
||||||
|
# TTS 暂停期间若未凑齐「尾静音」帧,VAD 会一直保持 is_speaking=True;
|
||||||
|
# 恢复后 detect_speech_start 会直接放弃,表现为「能恢复采集但再也不识别」。
|
||||||
|
self.vad.reset()
|
||||||
|
with self.speech_buffer_lock:
|
||||||
|
self.speech_buffer.clear()
|
||||||
|
self.pre_speech_buffer.clear()
|
||||||
|
logger.info("TTS: 麦克风采集已恢复(已重置 VAD 与语音缓冲)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"麦克风采集恢复失败,请重启程序: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def _prewarm_tts_async(self) -> None:
|
||||||
|
"""后台预加载 TTS(仅当未使用阻塞预加载时)。"""
|
||||||
|
if not self.ack_tts_enabled or not self._has_ack_tts_content() or not self.ack_tts_prewarm:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _run() -> None:
|
||||||
|
try:
|
||||||
|
self._ensure_tts_engine()
|
||||||
|
if self._ack_mode_phrases:
|
||||||
|
logger.warning(
|
||||||
|
"TTS: 当前为「按命令随机短语」且未使用阻塞预加载,"
|
||||||
|
"各句首次播报可能仍有数秒延迟;若需低延迟请将 ack_tts_prewarm_blocking 设为 true。"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS 预加载失败(将在首次播报时重试): {e}", exc_info=True)
|
||||||
|
|
||||||
|
threading.Thread(target=_run, daemon=True, name="tts-prewarm").start()
|
||||||
|
|
||||||
|
def _prewarm_tts_blocking(self) -> None:
|
||||||
|
"""启动时准备应答 PCM:优先读磁盘缓存(文案与 TTS 配置未变则跳过合成);必要时加载 Kokoro 并合成。"""
|
||||||
|
if not self.ack_tts_enabled or not self._has_ack_tts_content() or not self.ack_tts_prewarm:
|
||||||
|
return
|
||||||
|
use_disk = self.ack_tts_disk_cache
|
||||||
|
logger.info("TTS: 正在准备语音反馈(磁盘缓存 / 合成)…")
|
||||||
|
print("正在加载语音反馈…")
|
||||||
|
try:
|
||||||
|
if self._ack_mode_phrases:
|
||||||
|
self._tts_phrase_pcm_cache.clear()
|
||||||
|
seen: set = set()
|
||||||
|
unique: List[str] = []
|
||||||
|
for lst in self.ack_tts_phrases.values():
|
||||||
|
for t in lst:
|
||||||
|
p = str(t).strip()
|
||||||
|
if p and p not in seen:
|
||||||
|
seen.add(p)
|
||||||
|
unique.append(p)
|
||||||
|
if not unique:
|
||||||
|
return
|
||||||
|
|
||||||
|
fingerprint = compute_ack_pcm_fingerprint(unique, mode_phrases=True)
|
||||||
|
missing = list(unique)
|
||||||
|
if use_disk:
|
||||||
|
loaded, missing = load_cached_phrases(unique, fingerprint)
|
||||||
|
for ph, pcm in loaded.items():
|
||||||
|
self._tts_phrase_pcm_cache[ph] = pcm
|
||||||
|
|
||||||
|
if not missing:
|
||||||
|
self._tts_ack_pcm = None
|
||||||
|
logger.info(
|
||||||
|
"TTS: 已从磁盘加载全部应答波形(%d 句),跳过 Kokoro 加载与合成",
|
||||||
|
len(unique),
|
||||||
|
)
|
||||||
|
print("语音反馈已就绪(本地缓存),可以开始说话下指令。")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._ensure_tts_engine()
|
||||||
|
assert self._tts_engine is not None
|
||||||
|
need = [p for p in unique if p not in self._tts_phrase_pcm_cache]
|
||||||
|
for j, phrase in enumerate(need, start=1):
|
||||||
|
logger.info(
|
||||||
|
f"TTS: 合成应答句 {j}/{len(need)}: {phrase!r}"
|
||||||
|
)
|
||||||
|
audio, sr = self._tts_engine.synthesize(phrase)
|
||||||
|
self._tts_phrase_pcm_cache[phrase] = (audio, sr)
|
||||||
|
self._tts_ack_pcm = None
|
||||||
|
if use_disk:
|
||||||
|
persist_phrases(fingerprint, dict(self._tts_phrase_pcm_cache))
|
||||||
|
logger.info(
|
||||||
|
"TTS: 语音反馈已就绪(随机应答已缓存,播报低延迟)"
|
||||||
|
)
|
||||||
|
print("语音反馈引擎已就绪,可以开始说话下指令。")
|
||||||
|
else:
|
||||||
|
text = (self.ack_tts_text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
fingerprint = compute_ack_pcm_fingerprint(
|
||||||
|
[], global_text=text, mode_phrases=False
|
||||||
|
)
|
||||||
|
missing = [text]
|
||||||
|
if use_disk:
|
||||||
|
loaded, missing = load_cached_phrases([text], fingerprint)
|
||||||
|
if text in loaded:
|
||||||
|
self._tts_ack_pcm = loaded[text]
|
||||||
|
|
||||||
|
if not missing:
|
||||||
|
logger.info(
|
||||||
|
"TTS: 已从磁盘加载全局应答波形,跳过 Kokoro 加载与合成"
|
||||||
|
)
|
||||||
|
print("语音反馈已就绪(本地缓存),可以开始说话下指令。")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._ensure_tts_engine()
|
||||||
|
assert self._tts_engine is not None
|
||||||
|
audio, sr = self._tts_engine.synthesize(text)
|
||||||
|
self._tts_ack_pcm = (audio, sr)
|
||||||
|
if use_disk:
|
||||||
|
persist_phrases(fingerprint, {text: self._tts_ack_pcm})
|
||||||
|
logger.info(
|
||||||
|
"TTS: 语音反馈引擎已就绪;已缓存应答语音,命令成功后将快速播报"
|
||||||
|
)
|
||||||
|
print("语音反馈引擎已就绪,可以开始说话下指令。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"TTS: 启动阶段预加载失败,命令成功后可能延迟或无语音反馈: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _init_sounddevice_output_probe() -> None:
|
||||||
|
"""在主线程探测默认输出设备;应答播报必须在主线程调用 sd.play。"""
|
||||||
|
try:
|
||||||
|
from voice_drone.core.tts import log_sounddevice_output_devices
|
||||||
|
|
||||||
|
log_sounddevice_output_devices()
|
||||||
|
import sounddevice as sd # type: ignore
|
||||||
|
|
||||||
|
from voice_drone.core.tts import _sounddevice_default_output_index
|
||||||
|
|
||||||
|
out_idx = _sounddevice_default_output_index()
|
||||||
|
if out_idx is not None and int(out_idx) >= 0:
|
||||||
|
info = sd.query_devices(int(out_idx))
|
||||||
|
logger.info(
|
||||||
|
f"sounddevice 默认输出设备: {info.get('name', '?')} (index={out_idx})"
|
||||||
|
)
|
||||||
|
sd.check_output_settings(samplerate=24000, channels=1, dtype="float32")
|
||||||
|
# 预解析 tts.output_device,启动日志中可见实际用于播放的设备
|
||||||
|
from voice_drone.core.tts import get_playback_output_device_id
|
||||||
|
|
||||||
|
get_playback_output_device_id()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"sounddevice 输出设备探测失败,可能导致无法播音: {e}")
|
||||||
|
|
||||||
|
def _get_next_sequence_id(self) -> int:
|
||||||
|
"""获取下一个命令序列号"""
|
||||||
|
with self.sequence_lock:
|
||||||
|
self.sequence_id += 1
|
||||||
|
return self.sequence_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _int16_chunk_rms(chunk: np.ndarray) -> float:
|
||||||
|
if chunk.size == 0:
|
||||||
|
return 0.0
|
||||||
|
return float(np.sqrt(np.mean(chunk.astype(np.float64) ** 2)))
|
||||||
|
|
||||||
|
def _submit_concatenated_speech_to_stt(self) -> None:
|
||||||
|
"""在持有 speech_buffer_lock 时调用:合并 speech_buffer 并送 STT,然后清空。"""
|
||||||
|
if len(self.speech_buffer) == 0:
|
||||||
|
return
|
||||||
|
speech_audio = np.concatenate(self.speech_buffer)
|
||||||
|
self.speech_buffer.clear()
|
||||||
|
min_samples = int(self.audio_capture.sample_rate * 0.5)
|
||||||
|
if len(speech_audio) >= min_samples:
|
||||||
|
try:
|
||||||
|
self.stt_queue.put(speech_audio.copy(), block=False)
|
||||||
|
logger.debug(
|
||||||
|
f"提交语音段到STT队列,长度: {len(speech_audio)} 采样点"
|
||||||
|
)
|
||||||
|
if os.environ.get("ROCKET_PRINT_VAD", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"[VAD] 已送 STT,{len(speech_audio)} 采样点(≈{len(speech_audio) / float(self.audio_capture.sample_rate):.2f}s)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
except queue.Full:
|
||||||
|
logger.warning("STT队列已满,跳过本次识别")
|
||||||
|
elif os.environ.get("ROCKET_PRINT_VAD", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"[VAD] 语音段太短已丢弃({len(speech_audio)} < {min_samples} 采样)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _energy_vad_on_chunk(self, processed_chunk: np.ndarray) -> None:
|
||||||
|
rms = self._int16_chunk_rms(processed_chunk)
|
||||||
|
_vad_diag = os.environ.get("ROCKET_PRINT_VAD", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
)
|
||||||
|
if _vad_diag:
|
||||||
|
self._ev_rms_peak = max(self._ev_rms_peak, rms)
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - self._ev_last_diag_time >= 3.0:
|
||||||
|
print(
|
||||||
|
f"[VAD] energy 诊断:近 3s 块 RMS 峰值≈{self._ev_rms_peak:.0f} "
|
||||||
|
f"(high={self._energy_rms_high} low={self._energy_rms_low})",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._ev_rms_peak = 0.0
|
||||||
|
self._ev_last_diag_time = now
|
||||||
|
|
||||||
|
if not self._ev_speaking:
|
||||||
|
if rms >= self._energy_rms_high:
|
||||||
|
self._ev_high_run += 1
|
||||||
|
else:
|
||||||
|
self._ev_high_run = 0
|
||||||
|
if self._ev_high_run >= self._energy_start_chunks:
|
||||||
|
self._ev_speaking = True
|
||||||
|
self._ev_high_run = 0
|
||||||
|
self._ev_low_run = 0
|
||||||
|
self._ev_utt_peak = rms
|
||||||
|
hook = self._vad_speech_start_hook
|
||||||
|
if hook is not None:
|
||||||
|
try:
|
||||||
|
hook()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
logger.debug("vad_speech_start_hook: %s", e, exc_info=True)
|
||||||
|
with self.speech_buffer_lock:
|
||||||
|
if self.pre_speech_buffer:
|
||||||
|
self.speech_buffer = list(self.pre_speech_buffer)
|
||||||
|
else:
|
||||||
|
self.speech_buffer.clear()
|
||||||
|
self.speech_buffer.append(processed_chunk)
|
||||||
|
logger.debug(
|
||||||
|
"energy VAD: 开始收集语音段(含预缓冲约 %.2f s)",
|
||||||
|
self.pre_speech_max_seconds,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.speech_buffer_lock:
|
||||||
|
self.speech_buffer.append(processed_chunk)
|
||||||
|
|
||||||
|
self._ev_utt_peak = max(rms, self._ev_utt_peak * self._energy_utt_peak_decay)
|
||||||
|
below_abs = rms <= self._energy_rms_low
|
||||||
|
below_rel = (
|
||||||
|
self._energy_end_peak_ratio > 0
|
||||||
|
and self._ev_utt_peak >= self._energy_rms_high
|
||||||
|
and rms <= self._ev_utt_peak * self._energy_end_peak_ratio
|
||||||
|
)
|
||||||
|
if below_abs or below_rel:
|
||||||
|
self._ev_low_run += 1
|
||||||
|
else:
|
||||||
|
self._ev_low_run = 0
|
||||||
|
|
||||||
|
if self._ev_low_run >= self._energy_end_chunks:
|
||||||
|
self._ev_speaking = False
|
||||||
|
self._ev_low_run = 0
|
||||||
|
self._ev_utt_peak = 0.0
|
||||||
|
with self.speech_buffer_lock:
|
||||||
|
self._submit_concatenated_speech_to_stt()
|
||||||
|
self._reset_agc_after_utterance_end()
|
||||||
|
logger.debug("energy VAD: 语音段结束,已提交")
|
||||||
|
|
||||||
|
def _reset_agc_after_utterance_end(self) -> None:
|
||||||
|
"""VAD 句尾:清 AGC 滑窗,避免巨响后 RMS 卡死。"""
|
||||||
|
try:
|
||||||
|
self.audio_preprocessor.reset_agc_state()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def discard_pending_stt_segments(self) -> int:
|
||||||
|
"""丢弃尚未被 STT 线程取走的整句,避免唤醒/播 TTS 关麦后仍识别旧段。"""
|
||||||
|
n = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.stt_queue.get_nowait()
|
||||||
|
self.stt_queue.task_done()
|
||||||
|
n += 1
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
if n:
|
||||||
|
logger.info(
|
||||||
|
"已丢弃 %s 条待 STT 的语音段(流程切换,避免与播 TTS 重叠)",
|
||||||
|
n,
|
||||||
|
)
|
||||||
|
return n
|
||||||
|
|
||||||
|
def _stt_worker_thread(self):
|
||||||
|
"""STT识别工作线程(异步处理,不阻塞主流程)"""
|
||||||
|
logger.info("STT识别线程已启动")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
audio_data = self.stt_queue.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"STT工作线程错误: {e}", exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if audio_data is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = self.stt.invoke_numpy(audio_data)
|
||||||
|
|
||||||
|
if os.environ.get("ROCKET_PRINT_STT", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"[STT] {text!r}"
|
||||||
|
if (text and text.strip())
|
||||||
|
else "[STT] <空或不识别>",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if text and text.strip():
|
||||||
|
logger.info(f"🎤 STT识别结果: {text}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.command_queue.put(text, block=False)
|
||||||
|
logger.debug(f"文本已提交到命令处理队列: {text}")
|
||||||
|
except queue.Full:
|
||||||
|
logger.warning("命令处理队列已满,跳过本次识别结果")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"STT识别失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
self.stt_queue.task_done()
|
||||||
|
|
||||||
|
logger.info("STT识别线程已停止")
|
||||||
|
|
||||||
|
def _command_worker_thread(self):
|
||||||
|
"""命令处理工作线程(文本预处理+命令生成+Socket发送)"""
|
||||||
|
logger.info("命令处理线程已启动")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
text = self.command_queue.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"命令处理线程错误: {e}", exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if text is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 检测唤醒词
|
||||||
|
is_wake, matched_wake_word = self.wake_word_detector.detect(text)
|
||||||
|
|
||||||
|
if not is_wake:
|
||||||
|
logger.debug(f"未检测到唤醒词,忽略文本: {text}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"🔔 检测到唤醒词: {matched_wake_word}")
|
||||||
|
|
||||||
|
# 2. 提取命令文本(移除唤醒词)
|
||||||
|
command_text = self.wake_word_detector.extract_command_text(text)
|
||||||
|
if not command_text or not command_text.strip():
|
||||||
|
logger.warning(f"唤醒词后无命令内容: {text}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"提取的命令文本: {command_text}")
|
||||||
|
|
||||||
|
# 3. 文本预处理(快速模式,不进行分词)
|
||||||
|
normalized_text, params = self.text_preprocessor.preprocess_fast(command_text)
|
||||||
|
|
||||||
|
logger.debug(f"文本预处理结果:")
|
||||||
|
logger.debug(f" 规范化文本: {normalized_text}")
|
||||||
|
logger.debug(f" 命令关键词: {params.command_keyword}")
|
||||||
|
logger.debug(f" 距离: {params.distance} 米")
|
||||||
|
logger.debug(f" 速度: {params.speed} 米/秒")
|
||||||
|
logger.debug(f" 时间: {params.duration} 秒")
|
||||||
|
|
||||||
|
# 4. 检查是否识别到命令关键词
|
||||||
|
if not params.command_keyword:
|
||||||
|
logger.warning(f"未识别到有效命令关键词: {normalized_text}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5. 生成命令
|
||||||
|
sequence_id = self._get_next_sequence_id()
|
||||||
|
command = Command.create(
|
||||||
|
command=params.command_keyword,
|
||||||
|
sequence_id=sequence_id,
|
||||||
|
distance=params.distance,
|
||||||
|
speed=params.speed,
|
||||||
|
duration=params.duration
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"📝 生成命令: {command.command}")
|
||||||
|
logger.debug(f"命令详情: {command.to_dict()}")
|
||||||
|
|
||||||
|
# 6. 发送命令到Socket服务器
|
||||||
|
if self.socket_client.send_command_with_retry(command):
|
||||||
|
logger.info(f"✅ 命令已发送: {command.command} (序列号: {sequence_id})")
|
||||||
|
self._enqueue_ack_playback(command.command)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"命令未送达(已达 max_retries): %s (序列号: %s)",
|
||||||
|
command.command,
|
||||||
|
sequence_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"命令处理失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.command_queue.task_done()
|
||||||
|
|
||||||
|
logger.info("命令处理线程已停止")
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""启动语音命令识别系统"""
|
||||||
|
if self.running:
|
||||||
|
logger.warning("语音命令识别系统已在运行")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 先完成阻塞式 TTS 预加载,再开麦与识别线程,避免用户在预加载期间下指令导致无波形缓存、播报延迟
|
||||||
|
print("[TTS] 探测扬声器并预加载应答语音(可能需十余秒,请勿说话)…", flush=True)
|
||||||
|
self._init_sounddevice_output_probe()
|
||||||
|
if self.ack_tts_enabled and self._has_ack_tts_content() and self.ack_tts_prewarm:
|
||||||
|
if self.ack_tts_prewarm_blocking:
|
||||||
|
self._prewarm_tts_blocking()
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"[TTS] 已跳过启动预加载(ack_tts_enabled/应答文案/ack_tts_prewarm)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# 启动STT识别线程
|
||||||
|
self.stt_thread = threading.Thread(target=self._stt_worker_thread, daemon=True)
|
||||||
|
self.stt_thread.start()
|
||||||
|
|
||||||
|
# 启动命令处理线程
|
||||||
|
self.command_thread = threading.Thread(target=self._command_worker_thread, daemon=True)
|
||||||
|
self.command_thread.start()
|
||||||
|
|
||||||
|
# 启动音频采集
|
||||||
|
self.audio_capture.start_stream()
|
||||||
|
|
||||||
|
if self.ack_tts_enabled and self._has_ack_tts_content() and self.ack_tts_prewarm:
|
||||||
|
if not self.ack_tts_prewarm_blocking:
|
||||||
|
self._prewarm_tts_async()
|
||||||
|
|
||||||
|
logger.info("语音命令识别系统已启动")
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("🎙️ 高性能实时语音命令识别系统已启动")
|
||||||
|
print("=" * 70)
|
||||||
|
print("💡 功能说明:")
|
||||||
|
print(" - 系统会自动检测语音并识别")
|
||||||
|
print(f" - 🔔 唤醒词: {self.wake_word_detector.primary}")
|
||||||
|
print(" - 只有包含唤醒词的语音才会被处理")
|
||||||
|
print(" - 识别结果会自动转换为无人机控制命令")
|
||||||
|
print(" - 命令会自动发送到Socket服务器")
|
||||||
|
print(" - 按 Ctrl+C 退出")
|
||||||
|
print("=" * 70 + "\n")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止语音命令识别系统"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# 先通知工作线程结束,再播放尚未 drain 的应答(避免 Ctrl+C 时主循环未跑下一轮导致无声)
|
||||||
|
if self.stt_thread is not None:
|
||||||
|
self.stt_queue.put(None)
|
||||||
|
if self.command_thread is not None:
|
||||||
|
self.command_queue.put(None)
|
||||||
|
if self.stt_thread is not None:
|
||||||
|
self.stt_thread.join(timeout=2.0)
|
||||||
|
if self.command_thread is not None:
|
||||||
|
self.command_thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
if self.ack_tts_enabled:
|
||||||
|
try:
|
||||||
|
self._drain_ack_playback_queue(recover_mic=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"退出前播放应答失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
self.audio_capture.stop_stream()
|
||||||
|
|
||||||
|
# 断开Socket连接
|
||||||
|
if self.socket_client.connected:
|
||||||
|
self.socket_client.disconnect()
|
||||||
|
|
||||||
|
logger.info("语音命令识别系统已停止")
|
||||||
|
print("\n语音命令识别系统已停止")
|
||||||
|
|
||||||
|
def process_audio_stream(self):
|
||||||
|
"""
|
||||||
|
处理音频流(主循环)
|
||||||
|
|
||||||
|
高性能实时处理流程:
|
||||||
|
1. 采集音频块(非阻塞)
|
||||||
|
2. 预处理(降噪+AGC)
|
||||||
|
3. VAD检测语音开始/结束
|
||||||
|
4. 收集语音段
|
||||||
|
5. 异步STT识别(不阻塞主流程)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
# 0. 主线程播放命令应答(必须在采集循环线程中执行 sd.play,见 tts.play_tts_audio 说明)
|
||||||
|
self._before_audio_iteration()
|
||||||
|
|
||||||
|
# 1. 采集音频块(非阻塞,高性能模式)
|
||||||
|
chunk = self.audio_capture.read_chunk_numpy(timeout=0.1)
|
||||||
|
if chunk is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. 音频预处理(降噪+AGC)
|
||||||
|
processed_chunk = self.audio_preprocessor.process(chunk)
|
||||||
|
|
||||||
|
# 初始化预缓冲区的最大块数(只需计算一次)
|
||||||
|
if self.pre_speech_max_chunks is None:
|
||||||
|
# 每个chunk包含的采样点数
|
||||||
|
samples_per_chunk = processed_chunk.shape[0]
|
||||||
|
if samples_per_chunk > 0:
|
||||||
|
# 0.8 秒需要的chunk数量 = 预缓冲秒数 * 采样率 / 每块采样数
|
||||||
|
chunks = int(
|
||||||
|
self.pre_speech_max_seconds * self.audio_capture.sample_rate
|
||||||
|
/ samples_per_chunk
|
||||||
|
)
|
||||||
|
# 至少保留 1 块,避免被算成 0
|
||||||
|
self.pre_speech_max_chunks = max(chunks, 1)
|
||||||
|
else:
|
||||||
|
self.pre_speech_max_chunks = 1
|
||||||
|
|
||||||
|
# 将当前块加入预缓冲区(环形缓冲)
|
||||||
|
# 注意:预缓冲区保存的是“最近的一段音频”,无论当下是否在说话
|
||||||
|
self.pre_speech_buffer.append(processed_chunk)
|
||||||
|
if (
|
||||||
|
self.pre_speech_max_chunks is not None
|
||||||
|
and len(self.pre_speech_buffer) > self.pre_speech_max_chunks
|
||||||
|
):
|
||||||
|
# 超出最大长度时,丢弃最早的块
|
||||||
|
self.pre_speech_buffer.pop(0)
|
||||||
|
|
||||||
|
# 3. VAD:Silero 或能量(RMS)分段
|
||||||
|
if self._use_energy_vad:
|
||||||
|
self._energy_vad_on_chunk(processed_chunk)
|
||||||
|
else:
|
||||||
|
chunk_bytes = processed_chunk.tobytes()
|
||||||
|
|
||||||
|
if self.vad.detect_speech_start(chunk_bytes):
|
||||||
|
hook = self._vad_speech_start_hook
|
||||||
|
if hook is not None:
|
||||||
|
try:
|
||||||
|
hook()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
logger.debug(
|
||||||
|
"vad_speech_start_hook: %s", e, exc_info=True
|
||||||
|
)
|
||||||
|
with self.speech_buffer_lock:
|
||||||
|
if self.pre_speech_buffer:
|
||||||
|
self.speech_buffer = list(self.pre_speech_buffer)
|
||||||
|
else:
|
||||||
|
self.speech_buffer.clear()
|
||||||
|
self.speech_buffer.append(processed_chunk)
|
||||||
|
logger.debug(
|
||||||
|
"检测到语音开始,使用预缓冲音频(约 %.2f 秒)作为前缀,开始收集语音段",
|
||||||
|
self.pre_speech_max_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.vad.is_speaking:
|
||||||
|
with self.speech_buffer_lock:
|
||||||
|
self.speech_buffer.append(processed_chunk)
|
||||||
|
|
||||||
|
if self.vad.detect_speech_end(chunk_bytes):
|
||||||
|
with self.speech_buffer_lock:
|
||||||
|
self._submit_concatenated_speech_to_stt()
|
||||||
|
self._reset_agc_after_utterance_end()
|
||||||
|
logger.debug("检测到语音结束,提交识别")
|
||||||
|
|
||||||
|
hook = getattr(self, "_after_processed_audio_chunk", None)
|
||||||
|
if hook is not None:
|
||||||
|
try:
|
||||||
|
hook(processed_chunk)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
logger.debug(
|
||||||
|
"after_processed_audio_chunk: %s", e, exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("用户中断")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理音频流时发生错误: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""运行语音命令识别系统(完整流程)"""
|
||||||
|
try:
|
||||||
|
self.start()
|
||||||
|
self.process_audio_stream()
|
||||||
|
finally:
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试代码
|
||||||
|
recognizer = VoiceCommandRecognizer()
|
||||||
|
recognizer.run()
|
||||||
0
voice_drone/core/rule.py
Normal file
0
voice_drone/core/rule.py
Normal file
239
voice_drone/core/scoket_client.py
Normal file
239
voice_drone/core/scoket_client.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
"""
|
||||||
|
Socket客户端
|
||||||
|
"""
|
||||||
|
import socket
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from voice_drone.core.configuration import SYSTEM_SOCKET_SERVER_CONFIG
|
||||||
|
from voice_drone.core.command import Command
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("socket.client")
|
||||||
|
|
||||||
|
|
||||||
|
class SocketClient:
|
||||||
|
# 初始化Socket客户端
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
self.host = config.get("host")
|
||||||
|
self.port = config.get("port")
|
||||||
|
self.connect_timeout = config.get("connect_timeout")
|
||||||
|
self.send_timeout = config.get("send_timeout")
|
||||||
|
self.reconnect_interval = float(config.get("reconnect_interval") or 3.0)
|
||||||
|
# max_retries:-1 表示断线后持续重连并发送,直到成功(不视为致命错误)
|
||||||
|
_mr = config.get("max_retries", -1)
|
||||||
|
try:
|
||||||
|
self.max_reconnect_attempts = int(_mr)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
self.max_reconnect_attempts = -1
|
||||||
|
|
||||||
|
self.sock = None
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
# 连接到socket服务器
|
||||||
|
def connect(self) -> bool:
|
||||||
|
if self.connected and self.sock is not None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.sock.settimeout(self.connect_timeout)
|
||||||
|
self.sock.connect((self.host, self.port))
|
||||||
|
self.sock.settimeout(self.send_timeout)
|
||||||
|
self.connected = True
|
||||||
|
print(
|
||||||
|
f"[SocketClient] 连接成功: {self.host}:{self.port}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
logger.info("Socket 已连接 %s:%s", self.host, self.port)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except socket.timeout:
|
||||||
|
logger.warning(
|
||||||
|
"Socket 连接超时 host=%s port=%s timeout=%s",
|
||||||
|
self.host,
|
||||||
|
self.port,
|
||||||
|
self.connect_timeout,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[SocketClient] connect: 连接超时 "
|
||||||
|
f"(host={self.host!r}, port={self.port!r}, timeout={self.connect_timeout!r})",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
except ConnectionRefusedError as e:
|
||||||
|
logger.warning("Socket 连接被拒绝: %s", e)
|
||||||
|
print(
|
||||||
|
f"[SocketClient] connect: 连接被拒绝: {e!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("Socket connect OSError (%s): %s", type(e).__name__, e)
|
||||||
|
print(
|
||||||
|
f"[SocketClient] connect: OSError ({type(e).__name__}): {e!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"[SocketClient] connect: 未预期异常 ({type(e).__name__}): {e!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 断开与socket服务器的连接
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
self._cleanup()
|
||||||
|
|
||||||
|
# 清理资源
|
||||||
|
def _cleanup(self) -> None:
|
||||||
|
if self.sock is not None:
|
||||||
|
try:
|
||||||
|
self.sock.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.sock = None
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
# 确保连接已建立
|
||||||
|
def _ensure_connected(self) -> bool:
|
||||||
|
if self.connected and self.sock is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return self.connect()
|
||||||
|
|
||||||
|
# 发送命令
|
||||||
|
def send_command(self, command) -> bool:
|
||||||
|
print("[SocketClient] 正在发送命令…", flush=True)
|
||||||
|
|
||||||
|
if not self._ensure_connected():
|
||||||
|
logger.warning(
|
||||||
|
"Socket 未连接且 connect 失败,跳过本次发送 host=%s port=%s",
|
||||||
|
self.host,
|
||||||
|
self.port,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[SocketClient] 未连接或 connect 失败,跳过发送",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
command_dict = command.to_dict()
|
||||||
|
json_str = json.dumps(command_dict, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 添加换行符(根据 JSON格式说明.md,命令以换行符分隔)
|
||||||
|
message = json_str + "\n"
|
||||||
|
|
||||||
|
# 发送数据
|
||||||
|
self.sock.sendall(message.encode("utf-8"))
|
||||||
|
print("[SocketClient] sendall 成功", flush=True)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except socket.timeout:
|
||||||
|
logger.warning("Socket send 超时,将断开以便重连")
|
||||||
|
print("[SocketClient] send_command: socket 超时", flush=True)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
except ConnectionResetError as e:
|
||||||
|
logger.warning("Socket 连接被重置(将重连): %s", e)
|
||||||
|
print(
|
||||||
|
f"[SocketClient] send_command: 连接被重置 ({type(e).__name__}): {e!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
except BrokenPipeError as e:
|
||||||
|
logger.warning("Socket 管道破裂(将重连): %s", e)
|
||||||
|
print(
|
||||||
|
f"[SocketClient] send_command: 管道破裂 ({type(e).__name__}): {e!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
except OSError as e:
|
||||||
|
# 断网、对端关闭等:可恢复,不当作未捕获致命错误
|
||||||
|
logger.warning("Socket send OSError (%s): %s(将重连)", type(e).__name__, e)
|
||||||
|
print(
|
||||||
|
f"[SocketClient] send_command: OSError ({type(e).__name__}): {e!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Socket send 异常 (%s): %s(将重连)", type(e).__name__, e
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[SocketClient] send_command: 异常 ({type(e).__name__}): {e!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
self._cleanup()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 发送命令并重试
|
||||||
|
def send_command_with_retry(self, command) -> bool:
|
||||||
|
"""失败后清理连接并按 reconnect_interval 重试;max_retries=-1 时直到发送成功。"""
|
||||||
|
unlimited = self.max_reconnect_attempts < 0
|
||||||
|
cap = max(1, self.max_reconnect_attempts) if not unlimited else None
|
||||||
|
attempt = 0
|
||||||
|
while True:
|
||||||
|
attempt += 1
|
||||||
|
self._cleanup()
|
||||||
|
if self.send_command(command):
|
||||||
|
if attempt > 1:
|
||||||
|
print(
|
||||||
|
f"[SocketClient] 重试后发送成功(第 {attempt} 次)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
logger.info("Socket 重连后命令已发送(第 %s 次尝试)", attempt)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not unlimited and cap is not None and attempt >= cap:
|
||||||
|
logger.warning(
|
||||||
|
"Socket 已达 max_retries=%s,本次命令未送达,稍后可再试",
|
||||||
|
self.max_reconnect_attempts,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[SocketClient] 已达最大重试次数,本次命令未送达(可稍后重试)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 无限重试时每 10 次打一条日志,避免刷屏
|
||||||
|
if unlimited and attempt % 10 == 1:
|
||||||
|
logger.warning(
|
||||||
|
"Socket 发送失败,%ss 后第 %s 次重连重试…",
|
||||||
|
self.reconnect_interval,
|
||||||
|
attempt,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[SocketClient] 发送失败,{self.reconnect_interval}s 后重试 "
|
||||||
|
f"(第 {attempt} 次)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
time.sleep(self.reconnect_interval)
|
||||||
|
|
||||||
|
# 上下文管理器入口
|
||||||
|
def __enter__(self):
|
||||||
|
self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
# 上下文管理器出口
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from voice_drone.core.configuration import SYSTEM_SOCKET_SERVER_CONFIG
|
||||||
|
from voice_drone.core.command import Command
|
||||||
|
|
||||||
|
config = SYSTEM_SOCKET_SERVER_CONFIG
|
||||||
|
client = SocketClient(config)
|
||||||
|
client.connect()
|
||||||
|
command = Command.create("takeoff", 1)
|
||||||
|
client.send_command(command)
|
||||||
|
client.disconnect()
|
||||||
46
voice_drone/core/streaming_llm_tts.py
Normal file
46
voice_drone/core/streaming_llm_tts.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""流式闲聊:按句切分文本入队 TTS;飞控 JSON 路径由调用方整块推理后再播。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# 句末:切段送合成
|
||||||
|
_SENTENCE_END = frozenset("。!?;\n")
|
||||||
|
# 过长且无句末时,优先在以下标点处断开
|
||||||
|
_SOFT_BREAK = ",、,"
|
||||||
|
|
||||||
|
|
||||||
|
def take_completed_sentences(buffer: str) -> tuple[list[str], str]:
|
||||||
|
"""从 buffer 开头取出所有「以句末标点结尾」的完整小段。"""
|
||||||
|
segments: list[str] = []
|
||||||
|
i = 0
|
||||||
|
n = len(buffer)
|
||||||
|
while i < n:
|
||||||
|
j = i
|
||||||
|
while j < n and buffer[j] not in _SENTENCE_END:
|
||||||
|
j += 1
|
||||||
|
if j >= n:
|
||||||
|
break
|
||||||
|
raw = buffer[i : j + 1].strip()
|
||||||
|
if raw:
|
||||||
|
segments.append(raw)
|
||||||
|
i = j + 1
|
||||||
|
return segments, buffer[i:]
|
||||||
|
|
||||||
|
|
||||||
|
def force_soft_split(remainder: str, max_chars: int) -> tuple[list[str], str]:
|
||||||
|
"""remainder 长度 >= max_chars 且无句末时,强制切下第一段。"""
|
||||||
|
if max_chars <= 0 or len(remainder) < max_chars:
|
||||||
|
return [], remainder
|
||||||
|
window = remainder[:max_chars]
|
||||||
|
cut = -1
|
||||||
|
for sep in _SOFT_BREAK:
|
||||||
|
p = window.rfind(sep)
|
||||||
|
if p > cut:
|
||||||
|
cut = p
|
||||||
|
if cut <= 0:
|
||||||
|
cut = max_chars
|
||||||
|
first = remainder[: cut + 1].strip()
|
||||||
|
rest = remainder[cut + 1 :]
|
||||||
|
out: list[str] = []
|
||||||
|
if first:
|
||||||
|
out.append(first)
|
||||||
|
return out, rest
|
||||||
494
voice_drone/core/stt.py
Normal file
494
voice_drone/core/stt.py
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
"""
|
||||||
|
语音识别(Speech-to-Text)类 - 纯 ONNX Runtime 极致性能推理
|
||||||
|
针对 RK3588 等 ARM 设备进行了深度优化,完全移除 FunASR 依赖。
|
||||||
|
前处理(fbank + CMVN + LFR)与解码均手写实现。
|
||||||
|
"""
|
||||||
|
import platform
|
||||||
|
import os
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
from voice_drone.tools.wrapper import time_cost
|
||||||
|
import scipy.special
|
||||||
|
from voice_drone.core.configuration import SYSTEM_STT_CONFIG, SYSTEM_AUDIO_CONFIG
|
||||||
|
|
||||||
|
# voice_drone/core/stt.py -> 工程根(含 voice_drone_assistant 与本仓库根两种布局)
|
||||||
|
_STT_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
def _stt_path_candidates(path: Path) -> List[Path]:
|
||||||
|
"""相对配置路径的候选绝对路径:优先工程目录,其次嵌套在上一级仓库时的 src/models/。"""
|
||||||
|
if path.is_absolute():
|
||||||
|
return [path]
|
||||||
|
out: List[Path] = [_STT_PROJECT_ROOT / path]
|
||||||
|
if path.parts and path.parts[0] == "models":
|
||||||
|
out.append(_STT_PROJECT_ROOT.parent / "src" / path)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class STT:
|
||||||
|
"""
|
||||||
|
语音识别(Speech-to-Text)类
|
||||||
|
使用 ONNX Runtime 进行最优性能推理
|
||||||
|
针对 RK3588 等 ARM 设备进行了深度优化
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
初始化 STT 模型
|
||||||
|
"""
|
||||||
|
stt_conf = SYSTEM_STT_CONFIG
|
||||||
|
self.logger = get_logger("stt.onnx")
|
||||||
|
|
||||||
|
# 从配置读取参数
|
||||||
|
self.model_dir = stt_conf.get("model_dir")
|
||||||
|
self.model_path = stt_conf.get("model_path")
|
||||||
|
self.prefer_int8 = stt_conf.get("prefer_int8", True)
|
||||||
|
_wf = stt_conf.get("warmup_file")
|
||||||
|
self.warmup_file: Optional[str] = None
|
||||||
|
if _wf:
|
||||||
|
wf_path = Path(_wf)
|
||||||
|
if wf_path.is_absolute() and wf_path.is_file():
|
||||||
|
self.warmup_file = str(wf_path)
|
||||||
|
else:
|
||||||
|
for c in _stt_path_candidates(wf_path):
|
||||||
|
if c.is_file():
|
||||||
|
self.warmup_file = str(c)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 音频预处理参数(确保数值类型正确)
|
||||||
|
self.sample_rate = int(stt_conf.get("sample_rate", SYSTEM_AUDIO_CONFIG.get("sample_rate", 16000)))
|
||||||
|
self.n_mels = int(stt_conf.get("n_mels", 80))
|
||||||
|
self.frame_length_ms = float(stt_conf.get("frame_length_ms", 25))
|
||||||
|
self.frame_shift_ms = float(stt_conf.get("frame_shift_ms", 10))
|
||||||
|
self.log_eps = float(stt_conf.get("log_eps", 1e-10))
|
||||||
|
|
||||||
|
# ARM 优化配置
|
||||||
|
arm_conf = stt_conf.get("arm_optimization", {})
|
||||||
|
self.arm_enabled = arm_conf.get("enabled", True)
|
||||||
|
self.arm_max_threads = arm_conf.get("max_threads", 4)
|
||||||
|
|
||||||
|
# CTC 解码配置
|
||||||
|
ctc_conf = stt_conf.get("ctc_decode", {})
|
||||||
|
self.blank_id = ctc_conf.get("blank_id", 0)
|
||||||
|
|
||||||
|
# 语言和文本规范化配置(默认值)
|
||||||
|
lang_conf = stt_conf.get("language", {})
|
||||||
|
text_norm_conf = stt_conf.get("text_norm", {})
|
||||||
|
self.lang_zh_default = lang_conf.get("zh_id", 3)
|
||||||
|
self.with_itn_default = text_norm_conf.get("with_itn_id", 14)
|
||||||
|
self.without_itn_default = text_norm_conf.get("without_itn_id", 15)
|
||||||
|
|
||||||
|
# 后处理配置
|
||||||
|
postprocess_conf = stt_conf.get("postprocess", {})
|
||||||
|
self.special_tokens = postprocess_conf.get("special_tokens", [
|
||||||
|
"<|zh|>", "<|NEUTRAL|>", "<|Speech|>", "<|woitn|>", "<|withitn|>"
|
||||||
|
])
|
||||||
|
|
||||||
|
# 检测是否为 RK3588 或 ARM 设备
|
||||||
|
ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64')
|
||||||
|
RK3588 = 'rk3588' in platform.platform().lower() or os.path.exists('/proc/device-tree/compatible')
|
||||||
|
|
||||||
|
# ARM 设备性能优化配置
|
||||||
|
if self.arm_enabled and (ARM or RK3588):
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
optimal_threads = min(self.arm_max_threads, cpu_count)
|
||||||
|
|
||||||
|
# 设置 OpenMP 线程数
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(optimal_threads)
|
||||||
|
os.environ['MKL_NUM_THREADS'] = str(optimal_threads)
|
||||||
|
os.environ['KMP_AFFINITY'] = 'granularity=fine,compact,1,0'
|
||||||
|
os.environ['OMP_DYNAMIC'] = 'FALSE'
|
||||||
|
os.environ['MKL_DYNAMIC'] = 'FALSE'
|
||||||
|
|
||||||
|
self.logger.info("ARM/RK3588 优化已启用")
|
||||||
|
self.logger.info(f" CPU 核心数: {cpu_count}")
|
||||||
|
self.logger.info(f" 优化线程数: {optimal_threads}")
|
||||||
|
|
||||||
|
# 确定模型路径
|
||||||
|
onnx_model_path = self._resolve_model_path()
|
||||||
|
|
||||||
|
# 保存模型目录路径(用于加载 tokens.txt)
|
||||||
|
self.onnx_model_dir = onnx_model_path.parent
|
||||||
|
|
||||||
|
self.logger.info(f"加载 ONNX 模型: {onnx_model_path}")
|
||||||
|
self._load_onnx_model(str(onnx_model_path))
|
||||||
|
|
||||||
|
# 模型预热
|
||||||
|
if self.warmup_file and os.path.exists(self.warmup_file):
|
||||||
|
try:
|
||||||
|
self.logger.info(f"正在预热模型(使用: {self.warmup_file})...")
|
||||||
|
_ = self.invoke(self.warmup_file)
|
||||||
|
self.logger.info("模型预热完成")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"预热失败(可忽略): {e}")
|
||||||
|
elif self.warmup_file:
|
||||||
|
self.logger.warning(f"预热文件不存在: {self.warmup_file},跳过预热步骤")
|
||||||
|
|
||||||
|
def _resolve_existing_model_file(self, raw: Optional[str]) -> Optional[Path]:
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
p = Path(raw)
|
||||||
|
for c in _stt_path_candidates(p):
|
||||||
|
if c.is_file():
|
||||||
|
return c
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_existing_model_dir(self, raw: Optional[str]) -> Optional[Path]:
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
p = Path(raw)
|
||||||
|
for c in _stt_path_candidates(p):
|
||||||
|
if c.is_dir():
|
||||||
|
return c
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_model_path(self) -> Path:
|
||||||
|
"""
|
||||||
|
解析模型路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型文件路径
|
||||||
|
"""
|
||||||
|
if self.model_path:
|
||||||
|
hit = self._resolve_existing_model_file(self.model_path)
|
||||||
|
if hit is not None:
|
||||||
|
return hit
|
||||||
|
|
||||||
|
if not self.model_dir:
|
||||||
|
raise ValueError("配置中必须指定 model_path 或 model_dir")
|
||||||
|
|
||||||
|
model_dir = self._resolve_existing_model_dir(self.model_dir)
|
||||||
|
if model_dir is None:
|
||||||
|
tried = ", ".join(str(x) for x in _stt_path_candidates(Path(self.model_dir)))
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"ONNX 模型目录不存在。config model_dir={self.model_dir!r},已尝试: {tried}。"
|
||||||
|
f"请将 SenseVoice 放入 {_STT_PROJECT_ROOT / 'models'},或 ln -s ../src/models "
|
||||||
|
f"{_STT_PROJECT_ROOT / 'models'}(见 models/README.txt)。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 优先使用 INT8 量化模型(如果启用)
|
||||||
|
if self.prefer_int8:
|
||||||
|
int8_path = model_dir / "model.int8.onnx"
|
||||||
|
if int8_path.exists():
|
||||||
|
return int8_path
|
||||||
|
|
||||||
|
# 回退到普通模型
|
||||||
|
onnx_path = model_dir / "model.onnx"
|
||||||
|
if onnx_path.exists():
|
||||||
|
return onnx_path
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"ONNX 模型文件不存在: 在 {model_dir} 中未找到 model.int8.onnx 或 model.onnx")
|
||||||
|
|
||||||
|
def _load_onnx_model(self, onnx_model_path: str):
|
||||||
|
"""加载 ONNX 模型"""
|
||||||
|
# 创建 ONNX Runtime 会话选项
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
|
||||||
|
# ARM 设备优化
|
||||||
|
ARM = platform.machine().startswith('arm') or platform.machine().startswith('aarch64')
|
||||||
|
if self.arm_enabled and ARM:
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
optimal_threads = min(self.arm_max_threads, cpu_count)
|
||||||
|
sess_options.intra_op_num_threads = optimal_threads
|
||||||
|
sess_options.inter_op_num_threads = optimal_threads
|
||||||
|
|
||||||
|
# 启用所有图优化(最优性能)
|
||||||
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
|
||||||
|
# 创建推理会话
|
||||||
|
self.onnx_session = ort.InferenceSession(
|
||||||
|
onnx_model_path,
|
||||||
|
sess_options=sess_options,
|
||||||
|
providers=['CPUExecutionProvider']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取模型元数据
|
||||||
|
onnx_model = onnx.load(onnx_model_path)
|
||||||
|
self.model_metadata = {prop.key: prop.value for prop in onnx_model.metadata_props}
|
||||||
|
|
||||||
|
# 解析元数据
|
||||||
|
self.lfr_window_size = int(self.model_metadata.get('lfr_window_size', 7))
|
||||||
|
self.lfr_window_shift = int(self.model_metadata.get('lfr_window_shift', 6))
|
||||||
|
self.vocab_size = int(self.model_metadata.get('vocab_size', 25055))
|
||||||
|
|
||||||
|
# 解析 CMVN 参数
|
||||||
|
neg_mean_str = self.model_metadata.get('neg_mean', '')
|
||||||
|
inv_stddev_str = self.model_metadata.get('inv_stddev', '')
|
||||||
|
self.neg_mean = np.array([float(x) for x in neg_mean_str.split(',')]) if neg_mean_str else None
|
||||||
|
self.inv_stddev = np.array([float(x) for x in inv_stddev_str.split(',')]) if inv_stddev_str else None
|
||||||
|
|
||||||
|
# 语言和文本规范化 ID(从元数据获取,如果没有则使用配置默认值)
|
||||||
|
self.lang_zh = int(self.model_metadata.get('lang_zh', self.lang_zh_default))
|
||||||
|
self.with_itn = int(self.model_metadata.get('with_itn', self.with_itn_default))
|
||||||
|
self.without_itn = int(self.model_metadata.get('without_itn', self.without_itn_default))
|
||||||
|
|
||||||
|
self.logger.info("ONNX 模型加载完成")
|
||||||
|
self.logger.info(f" LFR窗口大小: {self.lfr_window_size}")
|
||||||
|
self.logger.info(f" LFR窗口偏移: {self.lfr_window_shift}")
|
||||||
|
self.logger.info(f" 词汇表大小: {self.vocab_size}")
|
||||||
|
|
||||||
|
def _load_tokens(self):
|
||||||
|
"""加载 tokens 映射"""
|
||||||
|
tokens_file = self.onnx_model_dir / "tokens.txt"
|
||||||
|
if tokens_file.exists():
|
||||||
|
self.tokens = {}
|
||||||
|
with open(tokens_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
token = parts[0]
|
||||||
|
token_id = int(parts[-1])
|
||||||
|
self.tokens[token_id] = token
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _preprocess_audio_array(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> tuple:
|
||||||
|
"""
|
||||||
|
预处理音频数组:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现)
|
||||||
|
|
||||||
|
支持实时音频流处理(numpy数组输入)
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 输入 16k 单声道 numpy 数组(int16 或 float32)
|
||||||
|
2. 计算 80 维 log-mel fbank
|
||||||
|
3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev)
|
||||||
|
4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_array: 音频数据(numpy array,int16 或 float32)
|
||||||
|
sample_rate: 采样率,None时使用配置值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(features, lengths): 特征和长度
|
||||||
|
"""
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
sr = sample_rate if sample_rate is not None else self.sample_rate
|
||||||
|
|
||||||
|
# 1. 转换为float32格式(如果输入是int16)
|
||||||
|
if audio_array.dtype == np.int16:
|
||||||
|
audio = audio_array.astype(np.float32) / 32768.0
|
||||||
|
else:
|
||||||
|
audio = audio_array.astype(np.float32)
|
||||||
|
|
||||||
|
# 确保是单声道
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
audio = np.mean(audio, axis=1)
|
||||||
|
|
||||||
|
if audio.size == 0:
|
||||||
|
raise ValueError("音频数组为空")
|
||||||
|
|
||||||
|
# 2. 计算 fbank 特征
|
||||||
|
n_fft = int(self.frame_length_ms / 1000.0 * sr)
|
||||||
|
hop_length = int(self.frame_shift_ms / 1000.0 * sr)
|
||||||
|
|
||||||
|
mel_spec = librosa.feature.melspectrogram(
|
||||||
|
y=audio,
|
||||||
|
sr=sr,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
window="hann",
|
||||||
|
center=True,
|
||||||
|
power=1.0, # 线性能量
|
||||||
|
)
|
||||||
|
|
||||||
|
# log-mel
|
||||||
|
log_mel = np.log(np.maximum(mel_spec, self.log_eps)).T # (T, n_mels)
|
||||||
|
|
||||||
|
# 3. CMVN:使用 ONNX 元数据中的 neg_mean / inv_stddev
|
||||||
|
if self.neg_mean is not None and self.inv_stddev is not None:
|
||||||
|
if self.neg_mean.shape[0] == log_mel.shape[1]:
|
||||||
|
log_mel = (log_mel + self.neg_mean) * self.inv_stddev
|
||||||
|
|
||||||
|
# 4. LFR:按窗口 lfr_window_size 堆叠,步长 lfr_window_shift
|
||||||
|
T, D = log_mel.shape
|
||||||
|
m = self.lfr_window_size
|
||||||
|
n = self.lfr_window_shift
|
||||||
|
|
||||||
|
if T < m:
|
||||||
|
# 帧数不够,补到 m 帧
|
||||||
|
pad = np.tile(log_mel[-1], (m - T, 1))
|
||||||
|
log_mel = np.vstack([log_mel, pad])
|
||||||
|
T = m
|
||||||
|
|
||||||
|
# 计算 LFR 后的帧数
|
||||||
|
T_lfr = 1 + (T - m) // n
|
||||||
|
lfr_feats = []
|
||||||
|
for i in range(T_lfr):
|
||||||
|
start = i * n
|
||||||
|
end = start + m
|
||||||
|
chunk = log_mel[start:end, :] # (m, D)
|
||||||
|
lfr_feats.append(chunk.reshape(-1)) # 展平为 560 维
|
||||||
|
|
||||||
|
lfr_feats = np.stack(lfr_feats, axis=0) # (T_lfr, m*D=560)
|
||||||
|
|
||||||
|
# 增加 batch 维度: (1, T_lfr, 560)
|
||||||
|
lfr_feats = lfr_feats[np.newaxis, :, :].astype(np.float32)
|
||||||
|
lengths = np.array([lfr_feats.shape[1]], dtype=np.int32)
|
||||||
|
|
||||||
|
return lfr_feats, lengths
|
||||||
|
|
||||||
|
def _preprocess_audio(self, audio_path: str) -> tuple:
|
||||||
|
"""
|
||||||
|
预处理音频:提取特征并转换为 ONNX 模型输入格式(纯 numpy 实现)
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 读入 16k 单声道 wav
|
||||||
|
2. 计算 80 维 log-mel fbank
|
||||||
|
3. 应用 CMVN(使用 ONNX 元数据中的 neg_mean / inv_stddev)
|
||||||
|
4. 应用 LFR(lfr_m, lfr_n)堆叠,得到 560 维特征
|
||||||
|
"""
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
# 1. 读入音频
|
||||||
|
audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
|
||||||
|
if audio.size == 0:
|
||||||
|
raise ValueError(f"音频为空: {audio_path}")
|
||||||
|
|
||||||
|
# 使用_preprocess_audio_array处理
|
||||||
|
return self._preprocess_audio_array(audio, sr)
|
||||||
|
|
||||||
|
def _ctc_decode(self, logits: np.ndarray, length: np.ndarray) -> str:
|
||||||
|
"""
|
||||||
|
CTC 解码:将 logits 转换为文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: CTC logits,形状为 (N, T, vocab_size)
|
||||||
|
length: 序列长度,形状为 (N,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解码后的文本
|
||||||
|
"""
|
||||||
|
# 加载 tokens(如果还没加载)
|
||||||
|
if not hasattr(self, 'tokens') or len(self.tokens) == 0:
|
||||||
|
self._load_tokens()
|
||||||
|
|
||||||
|
# Greedy CTC 解码
|
||||||
|
# 应用 softmax 获取概率
|
||||||
|
probs = scipy.special.softmax(logits[0][:length[0]], axis=-1)
|
||||||
|
|
||||||
|
# 获取每个时间步的最大概率 token
|
||||||
|
token_ids = np.argmax(probs, axis=-1)
|
||||||
|
|
||||||
|
# CTC 解码:移除空白和重复
|
||||||
|
prev_token = -1
|
||||||
|
decoded_tokens = []
|
||||||
|
|
||||||
|
for token_id in token_ids:
|
||||||
|
if token_id != self.blank_id and token_id != prev_token:
|
||||||
|
decoded_tokens.append(token_id)
|
||||||
|
prev_token = token_id
|
||||||
|
|
||||||
|
# Token ID 转文本
|
||||||
|
text_parts = []
|
||||||
|
for token_id in decoded_tokens:
|
||||||
|
if token_id in self.tokens:
|
||||||
|
token = self.tokens[token_id]
|
||||||
|
# 处理 SentencePiece 标记
|
||||||
|
if token.startswith('▁'):
|
||||||
|
if text_parts: # 如果不是第一个token,添加空格
|
||||||
|
text_parts.append(' ')
|
||||||
|
text_parts.append(token[1:])
|
||||||
|
elif not token.startswith('<|'): # 忽略特殊标记
|
||||||
|
text_parts.append(token)
|
||||||
|
|
||||||
|
text = ''.join(text_parts)
|
||||||
|
|
||||||
|
# 后处理:移除残留的特殊标记
|
||||||
|
for special in self.special_tokens:
|
||||||
|
text = text.replace(special, '')
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
@time_cost("STT-语音识别推理耗时")
|
||||||
|
def invoke(self, audio_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
执行语音识别推理(从文件)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: 音频文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
识别结果列表,格式: [{"text": "识别文本"}]
|
||||||
|
"""
|
||||||
|
# 预处理音频
|
||||||
|
features, features_length = self._preprocess_audio(audio_path)
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
text = self._inference(features, features_length)
|
||||||
|
|
||||||
|
return [{"text": text}]
|
||||||
|
|
||||||
|
def invoke_numpy(self, audio_array: np.ndarray, sample_rate: Optional[int] = None) -> str:
|
||||||
|
"""
|
||||||
|
执行语音识别推理(从numpy数组,实时处理)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_array: 音频数据(numpy array,int16 或 float32)
|
||||||
|
sample_rate: 采样率,None时使用配置值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
识别文本
|
||||||
|
"""
|
||||||
|
# 预处理音频数组
|
||||||
|
features, features_length = self._preprocess_audio_array(audio_array, sample_rate)
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
text = self._inference(features, features_length)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _inference(self, features: np.ndarray, features_length: np.ndarray) -> str:
|
||||||
|
"""
|
||||||
|
执行ONNX推理(内部方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: 特征数组
|
||||||
|
features_length: 特征长度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
识别文本
|
||||||
|
"""
|
||||||
|
# 准备 ONNX 模型输入
|
||||||
|
N, T, C = features.shape
|
||||||
|
|
||||||
|
# 语言ID
|
||||||
|
language = np.array([self.lang_zh], dtype=np.int32)
|
||||||
|
|
||||||
|
# 文本规范化
|
||||||
|
text_norm = np.array([self.with_itn], dtype=np.int32)
|
||||||
|
|
||||||
|
# ONNX 推理
|
||||||
|
inputs = {
|
||||||
|
'x': features.astype(np.float32),
|
||||||
|
'x_length': features_length.astype(np.int32),
|
||||||
|
'language': language,
|
||||||
|
'text_norm': text_norm
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs = self.onnx_session.run(None, inputs)
|
||||||
|
logits = outputs[0] # 形状: (N, T, vocab_size)
|
||||||
|
|
||||||
|
# CTC 解码
|
||||||
|
text = self._ctc_decode(logits, features_length)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 使用 ONNX 模型进行推理
|
||||||
|
import os
|
||||||
|
stt = STT()
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
result = stt.invoke("/home/lktx/projects/audio_controll_drone_without_llm/test/测试音频.wav")
|
||||||
|
# result = stt.invoke_numpy(np.random.rand(16000), 16000)
|
||||||
|
print(f"第{i+1}次识别结果: {result}")
|
||||||
716
voice_drone/core/text_preprocessor.py
Normal file
716
voice_drone/core/text_preprocessor.py
Normal file
@ -0,0 +1,716 @@
|
|||||||
|
"""
|
||||||
|
文本预处理模块 - 高性能实时语音转命令文本处理
|
||||||
|
|
||||||
|
本模块主要用于对语音识别输出的文本进行清洗、纠错、简繁转换、分词和参数提取,
|
||||||
|
便于后续命令意图分析和参数解析。
|
||||||
|
|
||||||
|
主要功能:
|
||||||
|
1. 文本清理:去除杂音、特殊字符、多余空格
|
||||||
|
2. 纠错:同音字纠正、常见错误修正
|
||||||
|
3. 简繁转换:统一文本格式(繁体转简体)
|
||||||
|
4. 分词:使用jieba分词,便于关键词匹配
|
||||||
|
5. 数字提取:提取距离(米)、速度(米/秒)、时间(秒)
|
||||||
|
6. 关键词识别:识别命令关键词(起飞、降落、前进等)
|
||||||
|
|
||||||
|
性能优化:
|
||||||
|
- LRU缓存常用处理结果(分词、中文数字解析、完整预处理)
|
||||||
|
- 预编译正则表达式
|
||||||
|
- 优化字符串操作(使用正则表达式批量替换)
|
||||||
|
- 延迟加载可选依赖
|
||||||
|
- 缓存关键词排序结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Dict, Optional, List, Tuple, Set
|
||||||
|
from functools import lru_cache
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
from voice_drone.core.configuration import KEYWORDS_CONFIG, SYSTEM_TEXT_PREPROCESSOR_CONFIG
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
logger = get_logger("text_preprocessor")
|
||||||
|
|
||||||
|
# 延迟加载可选依赖
|
||||||
|
try:
|
||||||
|
from opencc import OpenCC
|
||||||
|
OPENCC_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
OPENCC_AVAILABLE = False
|
||||||
|
logger.warning("opencc 未安装,将跳过简繁转换功能")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jieba
|
||||||
|
JIEBA_AVAILABLE = True
|
||||||
|
# 初始化jieba,加载词典
|
||||||
|
jieba.initialize()
|
||||||
|
except ImportError:
|
||||||
|
JIEBA_AVAILABLE = False
|
||||||
|
logger.warning("jieba 未安装,将跳过分词功能")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
PYPINYIN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PYPINYIN_AVAILABLE = False
|
||||||
|
logger.warning("pypinyin 未安装,将跳过拼音相关功能")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedParams:
|
||||||
|
"""提取的参数信息"""
|
||||||
|
distance: Optional[float] = None # 距离(米)
|
||||||
|
speed: Optional[float] = None # 速度(米/秒)
|
||||||
|
duration: Optional[float] = None # 时间(秒)
|
||||||
|
command_keyword: Optional[str] = None # 识别的命令关键词
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessedText:
|
||||||
|
"""预处理后的文本结果"""
|
||||||
|
cleaned_text: str # 清理后的文本
|
||||||
|
normalized_text: str # 规范化后的文本(简繁转换后)
|
||||||
|
words: List[str] # 分词结果
|
||||||
|
params: ExtractedParams # 提取的参数
|
||||||
|
original_text: str # 原始文本
|
||||||
|
|
||||||
|
|
||||||
|
class TextPreprocessor:
|
||||||
|
"""
|
||||||
|
高性能文本预处理器
|
||||||
|
|
||||||
|
针对实时语音转命令场景优化,支持:
|
||||||
|
- 文本清理和规范化
|
||||||
|
- 同音字纠错
|
||||||
|
- 简繁转换
|
||||||
|
- 分词
|
||||||
|
- 数字和单位提取
|
||||||
|
- 命令关键词识别
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
enable_traditional_to_simplified: Optional[bool] = None,
|
||||||
|
enable_segmentation: Optional[bool] = None,
|
||||||
|
enable_correction: Optional[bool] = None,
|
||||||
|
enable_number_extraction: Optional[bool] = None,
|
||||||
|
enable_keyword_detection: Optional[bool] = None,
|
||||||
|
lru_cache_size: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
初始化文本预处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_traditional_to_simplified: 是否启用繁简转换(None时从配置读取)
|
||||||
|
enable_segmentation: 是否启用分词(None时从配置读取)
|
||||||
|
enable_correction: 是否启用纠错(None时从配置读取)
|
||||||
|
enable_number_extraction: 是否启用数字提取(None时从配置读取)
|
||||||
|
enable_keyword_detection: 是否启用关键词检测(None时从配置读取)
|
||||||
|
lru_cache_size: LRU缓存大小(None时从配置读取)
|
||||||
|
"""
|
||||||
|
# 从配置读取参数(如果未提供)
|
||||||
|
config = SYSTEM_TEXT_PREPROCESSOR_CONFIG or {}
|
||||||
|
|
||||||
|
self.enable_traditional_to_simplified = (
|
||||||
|
enable_traditional_to_simplified
|
||||||
|
if enable_traditional_to_simplified is not None
|
||||||
|
else config.get("enable_traditional_to_simplified", True)
|
||||||
|
) and OPENCC_AVAILABLE
|
||||||
|
|
||||||
|
self.enable_segmentation = (
|
||||||
|
enable_segmentation
|
||||||
|
if enable_segmentation is not None
|
||||||
|
else config.get("enable_segmentation", True)
|
||||||
|
) and JIEBA_AVAILABLE
|
||||||
|
|
||||||
|
self.enable_correction = (
|
||||||
|
enable_correction
|
||||||
|
if enable_correction is not None
|
||||||
|
else config.get("enable_correction", True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.enable_number_extraction = (
|
||||||
|
enable_number_extraction
|
||||||
|
if enable_number_extraction is not None
|
||||||
|
else config.get("enable_number_extraction", True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.enable_keyword_detection = (
|
||||||
|
enable_keyword_detection
|
||||||
|
if enable_keyword_detection is not None
|
||||||
|
else config.get("enable_keyword_detection", True)
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_size = (
|
||||||
|
lru_cache_size
|
||||||
|
if lru_cache_size is not None
|
||||||
|
else config.get("lru_cache_size", 512)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化OpenCC(如果可用)
|
||||||
|
if self.enable_traditional_to_simplified:
|
||||||
|
self.opencc = OpenCC('t2s') # 繁体转简体
|
||||||
|
else:
|
||||||
|
self.opencc = None
|
||||||
|
|
||||||
|
# 加载关键词映射(命令关键词 -> 命令类型)
|
||||||
|
self._load_keyword_mapping()
|
||||||
|
|
||||||
|
# 预编译正则表达式(性能优化)
|
||||||
|
self._compile_regex_patterns()
|
||||||
|
|
||||||
|
# 加载纠错字典
|
||||||
|
self._load_correction_dict()
|
||||||
|
|
||||||
|
# 设置LRU缓存大小
|
||||||
|
self._cache_size = cache_size
|
||||||
|
|
||||||
|
# 创建缓存装饰器(用于分词、中文数字解析、完整预处理)
|
||||||
|
self._segment_text_cached = lru_cache(maxsize=cache_size)(self._segment_text_impl)
|
||||||
|
self._parse_chinese_number_cached = lru_cache(maxsize=128)(self._parse_chinese_number_impl)
|
||||||
|
self._preprocess_cached = lru_cache(maxsize=cache_size)(self._preprocess_impl)
|
||||||
|
self._preprocess_fast_cached = lru_cache(maxsize=cache_size)(self._preprocess_fast_impl)
|
||||||
|
|
||||||
|
logger.info(f"文本预处理器初始化完成")
|
||||||
|
logger.info(f" 繁简转换: {'启用' if self.enable_traditional_to_simplified else '禁用'}")
|
||||||
|
logger.info(f" 分词: {'启用' if self.enable_segmentation else '禁用'}")
|
||||||
|
logger.info(f" 纠错: {'启用' if self.enable_correction else '禁用'}")
|
||||||
|
logger.info(f" 数字提取: {'启用' if self.enable_number_extraction else '禁用'}")
|
||||||
|
logger.info(f" 关键词检测: {'启用' if self.enable_keyword_detection else '禁用'}")
|
||||||
|
|
||||||
|
def _load_keyword_mapping(self):
|
||||||
|
"""加载关键词映射表(命令关键词 -> 命令类型)"""
|
||||||
|
self.keyword_to_command: Dict[str, str] = {}
|
||||||
|
|
||||||
|
if KEYWORDS_CONFIG:
|
||||||
|
for command_type, keywords in KEYWORDS_CONFIG.items():
|
||||||
|
if isinstance(keywords, list):
|
||||||
|
for keyword in keywords:
|
||||||
|
self.keyword_to_command[keyword] = command_type
|
||||||
|
elif isinstance(keywords, str):
|
||||||
|
self.keyword_to_command[keywords] = command_type
|
||||||
|
|
||||||
|
# 预计算排序结果(按长度降序,优先匹配长关键词)
|
||||||
|
self.sorted_keywords = sorted(
|
||||||
|
self.keyword_to_command.keys(),
|
||||||
|
key=len,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"加载了 {len(self.keyword_to_command)} 个关键词映射")
|
||||||
|
|
||||||
|
def _compile_regex_patterns(self):
|
||||||
|
"""预编译正则表达式(性能优化)"""
|
||||||
|
# 清理文本:去除特殊字符、多余空格
|
||||||
|
self.pattern_clean_special = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9\s米每秒秒分小时\.]')
|
||||||
|
self.pattern_clean_spaces = re.compile(r'\s+')
|
||||||
|
|
||||||
|
# 数字提取模式
|
||||||
|
# 距离:数字 + (米|m|M|公尺)(排除速度单位)
|
||||||
|
self.pattern_distance = re.compile(
|
||||||
|
r'(\d+\.?\d*)\s*(?:米|m|M|公尺|meter|meters)(?!\s*[/每]?\s*秒)',
|
||||||
|
re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
# 速度:数字 + (米每秒|m/s|米/秒|米秒|mps|MPS)(优先匹配)
|
||||||
|
# 支持"三米每秒"、"5米/秒"、"2.5米每秒"等格式
|
||||||
|
self.pattern_speed = re.compile(
|
||||||
|
r'(?:速度\s*[::]?\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:米\s*[/每]?\s*秒|m\s*/\s*s|mps|MPS)',
|
||||||
|
re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
# 时间:数字 + (秒|s|S|分钟|分|min|小时|时|h|H)
|
||||||
|
# 支持"持续10秒"、"5分钟"等格式
|
||||||
|
self.pattern_duration = re.compile(
|
||||||
|
r'(?:持续\s*|持续\s*)?(\d+\.?\d*|[零一二三四五六七八九十]+)\s*(?:秒|s|S|分钟|分|min|小时|时|h|H)',
|
||||||
|
re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
# 中文数字映射(用于识别"十米"、"五秒"等)
|
||||||
|
self.chinese_numbers = {
|
||||||
|
'零': 0, '一': 1, '二': 2, '三': 3, '四': 4, '五': 5,
|
||||||
|
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10,
|
||||||
|
'壹': 1, '贰': 2, '叁': 3, '肆': 4, '伍': 5,
|
||||||
|
'陆': 6, '柒': 7, '捌': 8, '玖': 9, '拾': 10,
|
||||||
|
'百': 100, '千': 1000, '万': 10000
|
||||||
|
}
|
||||||
|
|
||||||
|
# 中文数字模式(如"十米"、"五秒"、"二十米"、"三米每秒")
|
||||||
|
# 支持"二十"、"三十"等复合数字
|
||||||
|
self.pattern_chinese_number = re.compile(
|
||||||
|
r'([零一二三四五六七八九十壹贰叁肆伍陆柒捌玖拾百千万]+)\s*(?:米|秒|分|小时|米\s*[/每]?\s*秒)'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_correction_dict(self):
|
||||||
|
"""加载纠错字典(同音字、常见错误)并编译正则表达式"""
|
||||||
|
# 无人机控制相关的常见同音字/错误字映射(只保留实际需要纠错的)
|
||||||
|
correction_pairs = [
|
||||||
|
# 动作相关(同音字纠错)
|
||||||
|
('起非', '起飞'),
|
||||||
|
('降洛', '降落'),
|
||||||
|
('悬廷', '悬停'),
|
||||||
|
('停只', '停止'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 构建正则表达式模式(按长度降序,优先匹配长模式)
|
||||||
|
if correction_pairs:
|
||||||
|
# 按长度降序排序,优先匹配长模式
|
||||||
|
sorted_pairs = sorted(correction_pairs, key=lambda x: len(x[0]), reverse=True)
|
||||||
|
|
||||||
|
# 构建替换映射字典(用于快速查找)
|
||||||
|
self.correction_replacements = {wrong: correct for wrong, correct in sorted_pairs}
|
||||||
|
|
||||||
|
# 编译单一正则表达式
|
||||||
|
patterns = [re.escape(wrong) for wrong, _ in sorted_pairs]
|
||||||
|
self.correction_pattern = re.compile('|'.join(patterns))
|
||||||
|
else:
|
||||||
|
self.correction_pattern = None
|
||||||
|
self.correction_replacements = {}
|
||||||
|
|
||||||
|
logger.debug(f"加载了 {len(correction_pairs)} 个纠错规则")
|
||||||
|
|
||||||
|
def clean_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
清理文本:去除特殊字符、多余空格
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的文本
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 去除特殊字符(保留中文、英文、数字、空格、常用标点)
|
||||||
|
text = self.pattern_clean_special.sub('', text)
|
||||||
|
|
||||||
|
# 统一空格(多个空格合并为一个)
|
||||||
|
text = self.pattern_clean_spaces.sub(' ', text)
|
||||||
|
|
||||||
|
# 去除首尾空格
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def correct_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
纠错:同音字、常见错误修正(使用正则表达式优化)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待纠错文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
纠错后的文本
|
||||||
|
"""
|
||||||
|
if not self.enable_correction or not text or not self.correction_pattern:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# 使用正则表达式一次性替换所有模式(性能优化)
|
||||||
|
def replacer(match):
|
||||||
|
matched = match.group(0)
|
||||||
|
# 直接从字典中查找替换(O(1)查找)
|
||||||
|
return self.correction_replacements.get(matched, matched)
|
||||||
|
|
||||||
|
return self.correction_pattern.sub(replacer, text)
|
||||||
|
|
||||||
|
def traditional_to_simplified(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
繁体转简体
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待转换文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转换后的文本
|
||||||
|
"""
|
||||||
|
if not self.enable_traditional_to_simplified or not self.opencc or not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.opencc.convert(text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"繁简转换失败: {e}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _segment_text_impl(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
分词实现(内部方法,不带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待分词文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分词结果列表
|
||||||
|
"""
|
||||||
|
if not self.enable_segmentation or not text:
|
||||||
|
return [text] if text else []
|
||||||
|
|
||||||
|
try:
|
||||||
|
words = list(jieba.cut(text, cut_all=False))
|
||||||
|
# 过滤空字符串
|
||||||
|
words = [w.strip() for w in words if w.strip()]
|
||||||
|
return words
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"分词失败: {e}")
|
||||||
|
return [text] if text else []
|
||||||
|
|
||||||
|
def segment_text(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
分词(带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待分词文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分词结果列表
|
||||||
|
"""
|
||||||
|
return self._segment_text_cached(text)
|
||||||
|
|
||||||
|
def extract_numbers(self, text: str) -> ExtractedParams:
|
||||||
|
"""
|
||||||
|
提取数字和单位(距离、速度、时间)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待提取文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedParams对象,包含提取的参数
|
||||||
|
"""
|
||||||
|
params = ExtractedParams()
|
||||||
|
|
||||||
|
if not self.enable_number_extraction or not text:
|
||||||
|
return params
|
||||||
|
|
||||||
|
# 优先提取速度(避免被误识别为距离)
|
||||||
|
speed_match = self.pattern_speed.search(text)
|
||||||
|
if speed_match:
|
||||||
|
try:
|
||||||
|
speed_str = speed_match.group(1)
|
||||||
|
# 尝试解析中文数字
|
||||||
|
if speed_str.isdigit() or '.' in speed_str:
|
||||||
|
params.speed = float(speed_str)
|
||||||
|
else:
|
||||||
|
# 中文数字(使用缓存)
|
||||||
|
chinese_speed = self._parse_chinese_number(speed_str)
|
||||||
|
if chinese_speed is not None:
|
||||||
|
params.speed = float(chinese_speed)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 提取距离(米,排除速度单位)
|
||||||
|
distance_match = self.pattern_distance.search(text)
|
||||||
|
if distance_match:
|
||||||
|
try:
|
||||||
|
params.distance = float(distance_match.group(1))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 提取时间(秒)
|
||||||
|
duration_matches = self.pattern_duration.finditer(text) # 查找所有匹配
|
||||||
|
for duration_match in duration_matches:
|
||||||
|
try:
|
||||||
|
duration_str = duration_match.group(1)
|
||||||
|
duration_unit = duration_match.group(2).lower() if len(duration_match.groups()) > 1 else '秒'
|
||||||
|
|
||||||
|
# 解析数字(支持中文数字)
|
||||||
|
if duration_str.isdigit() or '.' in duration_str:
|
||||||
|
duration_value = float(duration_str)
|
||||||
|
else:
|
||||||
|
# 中文数字
|
||||||
|
chinese_duration = self._parse_chinese_number(duration_str)
|
||||||
|
if chinese_duration is None:
|
||||||
|
continue
|
||||||
|
duration_value = float(chinese_duration)
|
||||||
|
|
||||||
|
# 转换为秒
|
||||||
|
if '分' in duration_unit or 'min' in duration_unit:
|
||||||
|
params.duration = duration_value * 60
|
||||||
|
break # 取第一个匹配
|
||||||
|
elif '小时' in duration_unit or 'h' in duration_unit:
|
||||||
|
params.duration = duration_value * 3600
|
||||||
|
break
|
||||||
|
else: # 秒
|
||||||
|
params.duration = duration_value
|
||||||
|
break
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试提取中文数字(如"十米"、"五秒"、"二十米"、"三米每秒")
|
||||||
|
chinese_matches = self.pattern_chinese_number.finditer(text)
|
||||||
|
for chinese_match in chinese_matches:
|
||||||
|
try:
|
||||||
|
chinese_num_str = chinese_match.group(1)
|
||||||
|
full_match = chinese_match.group(0)
|
||||||
|
|
||||||
|
# 解析中文数字(使用缓存)
|
||||||
|
num_value = self._parse_chinese_number(chinese_num_str)
|
||||||
|
|
||||||
|
if num_value is not None:
|
||||||
|
# 判断单位类型
|
||||||
|
if '米每秒' in full_match or '米/秒' in full_match or '米每' in full_match:
|
||||||
|
# 速度单位
|
||||||
|
if params.speed is None:
|
||||||
|
params.speed = float(num_value)
|
||||||
|
elif '米' in full_match and '秒' not in full_match:
|
||||||
|
# 距离单位(不包含"秒")
|
||||||
|
if params.distance is None:
|
||||||
|
params.distance = float(num_value)
|
||||||
|
elif '秒' in full_match and '米' not in full_match:
|
||||||
|
# 时间单位(不包含"米")
|
||||||
|
if params.duration is None:
|
||||||
|
params.duration = float(num_value)
|
||||||
|
elif '分' in full_match and '米' not in full_match:
|
||||||
|
# 时间单位(分钟)
|
||||||
|
if params.duration is None:
|
||||||
|
params.duration = float(num_value) * 60
|
||||||
|
except (ValueError, IndexError, AttributeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _parse_chinese_number_impl(self, chinese_num: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
解析中文数字实现(内部方法,不带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chinese_num: 中文数字字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
对应的阿拉伯数字,解析失败返回None
|
||||||
|
"""
|
||||||
|
if not chinese_num:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 单个数字
|
||||||
|
if chinese_num in self.chinese_numbers:
|
||||||
|
return self.chinese_numbers[chinese_num]
|
||||||
|
|
||||||
|
# "十" -> 10
|
||||||
|
if chinese_num == '十' or chinese_num == '拾':
|
||||||
|
return 10
|
||||||
|
|
||||||
|
# "十一" -> 11, "十二" -> 12, ...
|
||||||
|
if chinese_num.startswith('十') or chinese_num.startswith('拾'):
|
||||||
|
rest = chinese_num[1:]
|
||||||
|
if rest in self.chinese_numbers:
|
||||||
|
return 10 + self.chinese_numbers[rest]
|
||||||
|
|
||||||
|
# "二十" -> 20, "三十" -> 30, ...
|
||||||
|
if chinese_num.endswith('十') or chinese_num.endswith('拾'):
|
||||||
|
prefix = chinese_num[:-1]
|
||||||
|
if prefix in self.chinese_numbers:
|
||||||
|
return self.chinese_numbers[prefix] * 10
|
||||||
|
|
||||||
|
# "二十五" -> 25, "三十五" -> 35, ...
|
||||||
|
if '十' in chinese_num or '拾' in chinese_num:
|
||||||
|
parts = chinese_num.replace('拾', '十').split('十')
|
||||||
|
if len(parts) == 2:
|
||||||
|
tens_part = parts[0] if parts[0] else '一' # "十五" -> parts[0]为空
|
||||||
|
ones_part = parts[1] if parts[1] else ''
|
||||||
|
|
||||||
|
tens = self.chinese_numbers.get(tens_part, 1) if tens_part else 1
|
||||||
|
ones = self.chinese_numbers.get(ones_part, 0) if ones_part else 0
|
||||||
|
|
||||||
|
return tens * 10 + ones
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_chinese_number(self, chinese_num: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
解析中文数字(支持"十"、"二十"、"三十"、"五"等,带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chinese_num: 中文数字字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
对应的阿拉伯数字,解析失败返回None
|
||||||
|
"""
|
||||||
|
return self._parse_chinese_number_cached(chinese_num)
|
||||||
|
|
||||||
|
def detect_keyword(self, text: str, words: Optional[List[str]] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
检测命令关键词(使用缓存的排序结果)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待检测文本
|
||||||
|
words: 分词结果(如果已分词,可传入以提高性能)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检测到的命令类型(如"takeoff"、"forward"等),未检测到返回None
|
||||||
|
"""
|
||||||
|
if not self.enable_keyword_detection or not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果已分词,优先使用分词结果匹配
|
||||||
|
if words:
|
||||||
|
for word in words:
|
||||||
|
if word in self.keyword_to_command:
|
||||||
|
return self.keyword_to_command[word]
|
||||||
|
|
||||||
|
# 使用缓存的排序结果(按长度降序,优先匹配长关键词)
|
||||||
|
for keyword in self.sorted_keywords:
|
||||||
|
if keyword in text:
|
||||||
|
return self.keyword_to_command[keyword]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def preprocess(self, text: str) -> PreprocessedText:
|
||||||
|
"""
|
||||||
|
完整的文本预处理流程(带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PreprocessedText对象,包含所有预处理结果
|
||||||
|
"""
|
||||||
|
return self._preprocess_cached(text)
|
||||||
|
|
||||||
|
def _preprocess_impl(self, text: str) -> PreprocessedText:
|
||||||
|
"""
|
||||||
|
完整的文本预处理流程实现(内部方法,不带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PreprocessedText对象,包含所有预处理结果
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return PreprocessedText(
|
||||||
|
cleaned_text="",
|
||||||
|
normalized_text="",
|
||||||
|
words=[],
|
||||||
|
params=ExtractedParams(),
|
||||||
|
original_text=text
|
||||||
|
)
|
||||||
|
|
||||||
|
original_text = text
|
||||||
|
|
||||||
|
# 1. 清理文本
|
||||||
|
cleaned_text = self.clean_text(text)
|
||||||
|
|
||||||
|
# 2. 纠错
|
||||||
|
corrected_text = self.correct_text(cleaned_text)
|
||||||
|
|
||||||
|
# 3. 繁简转换
|
||||||
|
normalized_text = self.traditional_to_simplified(corrected_text)
|
||||||
|
|
||||||
|
# 4. 分词
|
||||||
|
words = self.segment_text(normalized_text)
|
||||||
|
|
||||||
|
# 5. 提取数字和单位
|
||||||
|
params = self.extract_numbers(normalized_text)
|
||||||
|
|
||||||
|
# 6. 检测关键词
|
||||||
|
command_keyword = self.detect_keyword(normalized_text, words)
|
||||||
|
params.command_keyword = command_keyword
|
||||||
|
|
||||||
|
return PreprocessedText(
|
||||||
|
cleaned_text=cleaned_text,
|
||||||
|
normalized_text=normalized_text,
|
||||||
|
words=words,
|
||||||
|
params=params,
|
||||||
|
original_text=original_text
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess_fast(self, text: str) -> Tuple[str, ExtractedParams]:
|
||||||
|
"""
|
||||||
|
快速预处理(仅返回规范化文本和参数,不进行分词,带缓存)
|
||||||
|
|
||||||
|
适用于实时场景,性能优先
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(规范化文本, 提取的参数)
|
||||||
|
"""
|
||||||
|
return self._preprocess_fast_cached(text)
|
||||||
|
|
||||||
|
def _preprocess_fast_impl(self, text: str) -> Tuple[str, ExtractedParams]:
|
||||||
|
"""
|
||||||
|
快速预处理实现(内部方法,不带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(规范化文本, 提取的参数)
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return "", ExtractedParams()
|
||||||
|
|
||||||
|
# 1. 清理
|
||||||
|
cleaned = self.clean_text(text)
|
||||||
|
|
||||||
|
# 2. 纠错
|
||||||
|
corrected = self.correct_text(cleaned)
|
||||||
|
|
||||||
|
# 3. 繁简转换
|
||||||
|
normalized = self.traditional_to_simplified(corrected)
|
||||||
|
|
||||||
|
# 4. 提取参数(不进行分词,提高性能)
|
||||||
|
params = self.extract_numbers(normalized)
|
||||||
|
|
||||||
|
# 5. 检测关键词(在完整文本中搜索)
|
||||||
|
params.command_keyword = self.detect_keyword(normalized, words=None)
|
||||||
|
|
||||||
|
return normalized, params
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例(可选,用于提高性能)
|
||||||
|
_global_preprocessor: Optional[TextPreprocessor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_preprocessor() -> TextPreprocessor:
|
||||||
|
"""获取全局预处理器实例(单例模式)"""
|
||||||
|
global _global_preprocessor
|
||||||
|
if _global_preprocessor is None:
|
||||||
|
_global_preprocessor = TextPreprocessor()
|
||||||
|
return _global_preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from command import Command
|
||||||
|
|
||||||
|
# 测试代码
|
||||||
|
preprocessor = TextPreprocessor()
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
"现在起飞往前飞,飞10米,速度为5米每秒",
|
||||||
|
"向前飞二十米,速度三米每秒",
|
||||||
|
"立刻降落",
|
||||||
|
"悬停五秒",
|
||||||
|
"向右飛十米", # 繁体测试
|
||||||
|
"往左飛,速度2.5米/秒,持續10秒",
|
||||||
|
]
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("文本预处理器测试")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for i, test_text in enumerate(test_cases, 1):
|
||||||
|
print(f"\n测试 {i}: {test_text}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
# 完整预处理
|
||||||
|
result = preprocessor.preprocess(test_text)
|
||||||
|
print(f"原始文本: {result.original_text}")
|
||||||
|
print(f"清理后: {result.cleaned_text}")
|
||||||
|
print(f"规范化: {result.normalized_text}")
|
||||||
|
print(f"分词: {result.words}")
|
||||||
|
print(f"提取参数:")
|
||||||
|
print(f" 距离: {result.params.distance} 米")
|
||||||
|
print(f" 速度: {result.params.speed} 米/秒")
|
||||||
|
print(f" 时间: {result.params.duration} 秒")
|
||||||
|
print(f" 命令关键词: {result.params.command_keyword}")
|
||||||
|
|
||||||
|
|
||||||
|
command = Command.create(result.params.command_keyword, 1, result.params.distance, result.params.speed, result.params.duration)
|
||||||
|
print(f"命令: {command.to_dict()}")
|
||||||
|
|
||||||
|
# 快速预处理
|
||||||
|
fast_text, fast_params = preprocessor.preprocess_fast(test_text)
|
||||||
|
print(f"\n快速预处理结果:")
|
||||||
|
print(f" 规范化文本: {fast_text}")
|
||||||
|
print(f" 命令关键词: {fast_params.command_keyword}")
|
||||||
695
voice_drone/core/tts.py
Normal file
695
voice_drone/core/tts.py
Normal file
@ -0,0 +1,695 @@
|
|||||||
|
"""
|
||||||
|
TTS(Text-to-Speech)模块 - 基于 Kokoro ONNX 的中文实时合成
|
||||||
|
|
||||||
|
使用 Kokoro-82M-v1.1-zh-ONNX 模型进行文本转语音合成:
|
||||||
|
1. 文本 -> (可选)使用 misaki[zh] 做 G2P,得到音素串
|
||||||
|
2. 音素字符 -> 根据 tokenizer vocab 映射为 token id 序列
|
||||||
|
3. 通过 ONNX Runtime 推理生成 24kHz 单声道语音
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- 主要依赖: onnxruntime + numpy
|
||||||
|
- 如果已安装 misaki[zh] (推荐),效果更好:
|
||||||
|
pip install "misaki[zh]" cn2an pypinyin jieba
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
# 仅保留 ERROR,避免加载 Kokoro 时大量 ConstantFolding/Reciprocal 警告刷屏(不影响推理结果)
|
||||||
|
try:
|
||||||
|
ort.set_default_logger_severity(3)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from voice_drone.core.configuration import SYSTEM_TTS_CONFIG
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
# voice_drone/core/tts.py -> voice_drone_assistant 根目录
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
logger = get_logger("tts.kokoro_onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def _tts_model_dir_candidates(rel: Path) -> List[Path]:
|
||||||
|
if rel.is_absolute():
|
||||||
|
return [rel]
|
||||||
|
out: List[Path] = [_PROJECT_ROOT / rel]
|
||||||
|
if rel.parts and rel.parts[0] == "models":
|
||||||
|
out.append(_PROJECT_ROOT.parent / "src" / rel)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_kokoro_model_dir(raw: str | Path) -> Path:
|
||||||
|
"""含 tokenizer.json 的目录;支持子工程 models/ 缺失时回退到上级仓库 src/models/。"""
|
||||||
|
p = Path(raw)
|
||||||
|
for c in _tts_model_dir_candidates(p):
|
||||||
|
if (c / "tokenizer.json").is_file():
|
||||||
|
return c.resolve()
|
||||||
|
for c in _tts_model_dir_candidates(p):
|
||||||
|
if c.is_dir():
|
||||||
|
logger.warning(
|
||||||
|
"Kokoro 目录存在但未找到 tokenizer.json: %s(将仍使用该路径,后续可能报错)",
|
||||||
|
c,
|
||||||
|
)
|
||||||
|
return c.resolve()
|
||||||
|
return (_PROJECT_ROOT / p).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
class KokoroOnnxTTS:
|
||||||
|
"""
|
||||||
|
Kokoro 中文 ONNX 文本转语音封装
|
||||||
|
|
||||||
|
基本用法:
|
||||||
|
tts = KokoroOnnxTTS()
|
||||||
|
audio, sr = tts.synthesize("你好,世界")
|
||||||
|
|
||||||
|
返回:
|
||||||
|
audio: np.ndarray[float32] 形状为 (N,), 范围约 [-1, 1]
|
||||||
|
sr: int 采样率(默认 24000)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[dict] = None) -> None:
|
||||||
|
# 读取系统 TTS 配置
|
||||||
|
self.config = config or SYSTEM_TTS_CONFIG or {}
|
||||||
|
|
||||||
|
# 模型根目录(包含 onnx/、tokenizer.json、voices/)
|
||||||
|
_raw_dir = self.config.get(
|
||||||
|
"model_dir", "models/Kokoro-82M-v1.1-zh-ONNX"
|
||||||
|
)
|
||||||
|
model_dir = _resolve_kokoro_model_dir(_raw_dir)
|
||||||
|
self.model_dir = model_dir
|
||||||
|
|
||||||
|
# ONNX 模型文件名(位于 model_dir/onnx 下;若 onnx/ 下没有可改配置为根目录文件名)
|
||||||
|
self.model_name = self.config.get("model_name", "model_q4.onnx")
|
||||||
|
self.onnx_path = model_dir / "onnx" / self.model_name
|
||||||
|
if not self.onnx_path.is_file():
|
||||||
|
alt = model_dir / self.model_name
|
||||||
|
if alt.is_file():
|
||||||
|
self.onnx_path = alt
|
||||||
|
|
||||||
|
# 语音风格(voices 子目录下的 *.bin, 这里不含扩展名)
|
||||||
|
self.voice_name = self.config.get("voice", "zf_001")
|
||||||
|
self.voice_path = model_dir / "voices" / f"{self.voice_name}.bin"
|
||||||
|
|
||||||
|
# 语速与输出采样率
|
||||||
|
self.speed = float(self.config.get("speed", 1.0))
|
||||||
|
self.sample_rate = int(self.config.get("sample_rate", 24000))
|
||||||
|
|
||||||
|
# tokenizer.json 路径(本地随 ONNX 模型一起提供)
|
||||||
|
self.tokenizer_path = model_dir / "tokenizer.json"
|
||||||
|
|
||||||
|
# 初始化组件
|
||||||
|
self._session: Optional[ort.InferenceSession] = None
|
||||||
|
self._vocab: Optional[dict] = None
|
||||||
|
self._voices: Optional[np.ndarray] = None
|
||||||
|
self._g2p = None # misaki[zh] G2P, 如不可用则退化为直接使用原始文本
|
||||||
|
|
||||||
|
self._load_all()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# 对外主接口
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
def synthesize(self, text: str) -> Tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
文本转语音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本(推荐为简体中文)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(audio, sample_rate)
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
raise ValueError("TTS 输入文本不能为空")
|
||||||
|
|
||||||
|
phonemes = self._text_to_phonemes(text)
|
||||||
|
token_ids = self._phonemes_to_token_ids(phonemes)
|
||||||
|
|
||||||
|
if len(token_ids) == 0:
|
||||||
|
raise ValueError(f"TTS: 文本在当前 vocab 下无法映射到任何 token, text={text!r}")
|
||||||
|
|
||||||
|
# 按 Kokoro-ONNX 官方示例约定:
|
||||||
|
# - token 序列长度 <= 510
|
||||||
|
# - 前后各添加 pad token 0
|
||||||
|
if len(token_ids) > 510:
|
||||||
|
logger.warning(f"TTS: token 长度 {len(token_ids)} > 510, 将被截断为 510")
|
||||||
|
token_ids = token_ids[:510]
|
||||||
|
|
||||||
|
tokens = np.array([[0, *token_ids, 0]], dtype=np.int64) # shape: (1, <=512)
|
||||||
|
|
||||||
|
# 根据 token 数量选择 style 向量
|
||||||
|
assert self._voices is not None, "TTS: voices 尚未初始化"
|
||||||
|
voices = self._voices # shape: (N, 1, 256)
|
||||||
|
idx = min(len(token_ids), voices.shape[0] - 1)
|
||||||
|
style = voices[idx] # shape: (1, 256)
|
||||||
|
|
||||||
|
speed = np.array([self.speed], dtype=np.float32)
|
||||||
|
|
||||||
|
# ONNX 输入名约定: input_ids, style, speed
|
||||||
|
assert self._session is not None, "TTS: ONNX Session 尚未初始化"
|
||||||
|
session = self._session
|
||||||
|
audio = session.run(
|
||||||
|
None,
|
||||||
|
{
|
||||||
|
"input_ids": tokens,
|
||||||
|
"style": style,
|
||||||
|
"speed": speed,
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# 兼容不同导出形状:
|
||||||
|
# - 标准 Kokoro ONNX: (1, N)
|
||||||
|
# - 也有可能是 (N,)
|
||||||
|
audio = audio.astype(np.float32)
|
||||||
|
if audio.ndim == 2 and audio.shape[0] == 1:
|
||||||
|
audio = audio[0]
|
||||||
|
elif audio.ndim > 2:
|
||||||
|
# 极端情况: 压缩多余维度
|
||||||
|
audio = np.squeeze(audio)
|
||||||
|
|
||||||
|
return audio, self.sample_rate
|
||||||
|
|
||||||
|
def synthesize_to_file(self, text: str, wav_path: str) -> str:
|
||||||
|
"""
|
||||||
|
文本合成并保存为 wav 文件(16-bit PCM)
|
||||||
|
需要依赖 scipy, 可选:
|
||||||
|
pip install scipy
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import scipy.io.wavfile as wavfile # type: ignore
|
||||||
|
except Exception as e: # pragma: no cover - 仅在未安装时触发
|
||||||
|
raise RuntimeError("保存到 wav 需要安装 scipy, 请先执行: pip install scipy") from e
|
||||||
|
|
||||||
|
audio, sr = self.synthesize(text)
|
||||||
|
# 简单归一化并转为 int16
|
||||||
|
max_val = float(np.max(np.abs(audio)) or 1.0)
|
||||||
|
audio_int16 = np.clip(audio / max_val, -1.0, 1.0)
|
||||||
|
audio_int16 = (audio_int16 * 32767.0).astype(np.int16)
|
||||||
|
|
||||||
|
# 某些 SciPy 版本对一维/零维数组支持不统一,这里显式加上通道维度
|
||||||
|
if audio_int16.ndim == 0:
|
||||||
|
audio_to_save = audio_int16.reshape(-1, 1) # 标量 -> (1,1)
|
||||||
|
elif audio_int16.ndim == 1:
|
||||||
|
audio_to_save = audio_int16.reshape(-1, 1) # (N,) -> (N,1) 单声道
|
||||||
|
else:
|
||||||
|
audio_to_save = audio_int16
|
||||||
|
|
||||||
|
wavfile.write(wav_path, sr, audio_to_save)
|
||||||
|
return wav_path
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# 内部初始化
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
def _load_all(self) -> None:
|
||||||
|
self._load_tokenizer_vocab()
|
||||||
|
self._load_voices()
|
||||||
|
self._load_onnx_session()
|
||||||
|
self._init_g2p()
|
||||||
|
|
||||||
|
def _load_tokenizer_vocab(self) -> None:
|
||||||
|
"""
|
||||||
|
从本地 tokenizer.json 载入 vocab 映射:
|
||||||
|
token(str) -> id(int)
|
||||||
|
"""
|
||||||
|
if not self.tokenizer_path.exists():
|
||||||
|
raise FileNotFoundError(f"TTS: 未找到 tokenizer.json: {self.tokenizer_path}")
|
||||||
|
|
||||||
|
with open(self.tokenizer_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
model = data.get("model") or {}
|
||||||
|
vocab = model.get("vocab")
|
||||||
|
if not isinstance(vocab, dict):
|
||||||
|
raise ValueError("TTS: tokenizer.json 格式不正确, 未找到 model.vocab 字段")
|
||||||
|
|
||||||
|
# 保存为: 字符 -> id
|
||||||
|
self._vocab = {k: int(v) for k, v in vocab.items()}
|
||||||
|
logger.info(f"TTS: tokenizer vocab 加载完成, 词表大小: {len(self._vocab)}")
|
||||||
|
|
||||||
|
def _load_voices(self) -> None:
|
||||||
|
"""
|
||||||
|
载入语音风格向量(voices/*.bin)
|
||||||
|
"""
|
||||||
|
if not self.voice_path.exists():
|
||||||
|
raise FileNotFoundError(f"TTS: 未找到语音风格文件: {self.voice_path}")
|
||||||
|
|
||||||
|
voices = np.fromfile(self.voice_path, dtype=np.float32)
|
||||||
|
try:
|
||||||
|
voices = voices.reshape(-1, 1, 256)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"TTS: 语音风格文件形状不符合预期, 无法 reshape 为 (-1,1,256): {self.voice_path}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
self._voices = voices
|
||||||
|
logger.info(
|
||||||
|
f"TTS: 语音风格文件加载完成: {self.voice_name}, 可用 style 数量: {voices.shape[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_onnx_session(self) -> None:
|
||||||
|
"""
|
||||||
|
创建 ONNX Runtime 推理会话
|
||||||
|
"""
|
||||||
|
if not self.onnx_path.exists():
|
||||||
|
raise FileNotFoundError(f"TTS: 未找到 ONNX 模型文件: {self.onnx_path}")
|
||||||
|
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
# 启用所有图优化
|
||||||
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
# RK3588 等多核 CPU:可用环境变量固定 ORT 线程,避免过小/过大(默认 0 表示交给 ORT 自动)
|
||||||
|
_ti = os.environ.get("ROCKET_TTS_ORT_INTRA_OP_THREADS", "").strip()
|
||||||
|
if _ti.isdigit() and int(_ti) > 0:
|
||||||
|
sess_options.intra_op_num_threads = int(_ti)
|
||||||
|
_te = os.environ.get("ROCKET_TTS_ORT_INTER_OP_THREADS", "").strip()
|
||||||
|
if _te.isdigit() and int(_te) > 0:
|
||||||
|
sess_options.inter_op_num_threads = int(_te)
|
||||||
|
|
||||||
|
# 简单的 CPU 推理(如需 GPU, 可在此扩展 providers)
|
||||||
|
self._session = ort.InferenceSession(
|
||||||
|
str(self.onnx_path),
|
||||||
|
sess_options=sess_options,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"TTS: Kokoro ONNX 模型加载完成: {self.onnx_path}")
|
||||||
|
|
||||||
|
def _init_g2p(self) -> None:
|
||||||
|
"""
|
||||||
|
初始化中文 G2P (基于 misaki[zh])。
|
||||||
|
如果环境中未安装 misaki, 则退化为直接使用原始文本字符做映射。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from misaki import zh # type: ignore
|
||||||
|
|
||||||
|
# 兼容不同版本的 misaki:
|
||||||
|
# - 新版: ZHG2P(version=...) 可用
|
||||||
|
# - 旧版: ZHG2P() 不接受参数
|
||||||
|
try:
|
||||||
|
self._g2p = zh.ZHG2P(version="1.1") # type: ignore[call-arg]
|
||||||
|
except TypeError:
|
||||||
|
self._g2p = zh.ZHG2P() # type: ignore[call-arg]
|
||||||
|
|
||||||
|
logger.info("TTS: 已启用 misaki[zh] G2P, 将使用音素级别映射")
|
||||||
|
except Exception as e:
|
||||||
|
self._g2p = None
|
||||||
|
logger.warning(
|
||||||
|
"TTS: 未安装或无法导入 misaki[zh], 将直接基于原始文本字符做 token 映射, "
|
||||||
|
"合成效果可能较差。建议执行: pip install \"misaki[zh]\" cn2an pypinyin jieba"
|
||||||
|
)
|
||||||
|
logger.debug(f"TTS: G2P 初始化失败原因: {e!r}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# 文本 -> 音素 / token
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
def _text_to_phonemes(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
文本 -> 音素串
|
||||||
|
|
||||||
|
- 若 misaki[zh] 可用, 则使用 ZHG2P(version='1.1') 得到音素串
|
||||||
|
- 否则, 直接返回原始文本(后续按字符映射)
|
||||||
|
"""
|
||||||
|
if self._g2p is None:
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
# 兼容不同版本的 misaki:
|
||||||
|
# - 有的返回 (phonemes, tokens)
|
||||||
|
# - 有的只返回 phonemes 字符串
|
||||||
|
result = self._g2p(text)
|
||||||
|
if isinstance(result, tuple) or isinstance(result, list):
|
||||||
|
ps = result[0]
|
||||||
|
else:
|
||||||
|
ps = result
|
||||||
|
|
||||||
|
if not ps:
|
||||||
|
# 回退: 如果 G2P 返回空, 使用原始文本
|
||||||
|
logger.warning("TTS: G2P 结果为空, 回退为原始文本")
|
||||||
|
return text.strip()
|
||||||
|
return ps
|
||||||
|
|
||||||
|
def _phonemes_to_token_ids(self, phonemes: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
将音素串映射为 token id 序列
|
||||||
|
|
||||||
|
直接按字符级别查表:
|
||||||
|
- 每个字符在 vocab 中有唯一 id
|
||||||
|
- 空格本身也是一个 token (id=16)
|
||||||
|
"""
|
||||||
|
assert self._vocab is not None, "TTS: vocab 尚未初始化"
|
||||||
|
vocab = self._vocab
|
||||||
|
|
||||||
|
token_ids: List[int] = []
|
||||||
|
unknown_chars = set()
|
||||||
|
|
||||||
|
for ch in phonemes:
|
||||||
|
if ch == "\n":
|
||||||
|
continue
|
||||||
|
tid = vocab.get(ch)
|
||||||
|
if tid is None:
|
||||||
|
unknown_chars.add(ch)
|
||||||
|
continue
|
||||||
|
token_ids.append(int(tid))
|
||||||
|
|
||||||
|
if unknown_chars:
|
||||||
|
logger.debug(f"TTS: 存在无法映射到 vocab 的字符: {unknown_chars}")
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_output_device_id(raw: object) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
将配置中的 output_device 解析为 sounddevice 设备索引。
|
||||||
|
None / 空:返回 None,表示使用 sd 默认输出。
|
||||||
|
"""
|
||||||
|
import sounddevice as sd # type: ignore
|
||||||
|
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
if isinstance(raw, bool):
|
||||||
|
return None
|
||||||
|
if isinstance(raw, int):
|
||||||
|
return raw if raw >= 0 else None
|
||||||
|
s = str(raw).strip()
|
||||||
|
if not s or s.lower() in ("null", "none", "default", ""):
|
||||||
|
return None
|
||||||
|
if s.isdigit():
|
||||||
|
return int(s)
|
||||||
|
needle = s.lower()
|
||||||
|
devices = sd.query_devices()
|
||||||
|
matches: List[int] = []
|
||||||
|
for i, d in enumerate(devices):
|
||||||
|
if int(d.get("max_output_channels", 0) or 0) <= 0:
|
||||||
|
continue
|
||||||
|
name = (d.get("name") or "").lower()
|
||||||
|
if needle in name:
|
||||||
|
matches.append(i)
|
||||||
|
if not matches:
|
||||||
|
logger.warning(
|
||||||
|
f"TTS: 未找到名称包含 {s!r} 的输出设备,将使用系统默认输出。"
|
||||||
|
"请检查 system.yaml 中 tts.output_device 或查看启动日志中的设备列表。"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if len(matches) > 1:
|
||||||
|
logger.info(
|
||||||
|
f"TTS: 名称 {s!r} 匹配到多个输出设备索引 {matches},使用第一个 {matches[0]}"
|
||||||
|
)
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
|
|
||||||
|
_playback_dev_cache: Optional[int] = None
|
||||||
|
_playback_dev_cache_key: Optional[str] = None
|
||||||
|
_playback_dev_cache_ready: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_playback_output_device_id() -> Optional[int]:
|
||||||
|
"""从 SYSTEM_TTS_CONFIG 解析并缓存播放设备索引(None=默认输出)。"""
|
||||||
|
global _playback_dev_cache, _playback_dev_cache_key, _playback_dev_cache_ready
|
||||||
|
|
||||||
|
cfg = SYSTEM_TTS_CONFIG or {}
|
||||||
|
raw = cfg.get("output_device")
|
||||||
|
key = repr(raw)
|
||||||
|
if _playback_dev_cache_ready and _playback_dev_cache_key == key:
|
||||||
|
return _playback_dev_cache
|
||||||
|
dev_id = _resolve_output_device_id(raw)
|
||||||
|
_playback_dev_cache = dev_id
|
||||||
|
_playback_dev_cache_key = key
|
||||||
|
_playback_dev_cache_ready = True
|
||||||
|
if dev_id is not None:
|
||||||
|
import sounddevice as sd # type: ignore
|
||||||
|
|
||||||
|
info = sd.query_devices(dev_id)
|
||||||
|
logger.info(
|
||||||
|
f"TTS: 播放将使用输出设备 index={dev_id} name={info.get('name', '?')!r}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("TTS: 播放使用系统默认输出设备(未指定或无法匹配 tts.output_device)")
|
||||||
|
return dev_id
|
||||||
|
|
||||||
|
|
||||||
|
def _sounddevice_default_output_index():
|
||||||
|
"""sounddevice 0.5+ 的 default.device 可能是 _InputOutputPair,需取 [1] 为输出索引。"""
|
||||||
|
import sounddevice as sd # type: ignore
|
||||||
|
|
||||||
|
default = sd.default.device
|
||||||
|
if isinstance(default, (list, tuple)):
|
||||||
|
return int(default[1])
|
||||||
|
if hasattr(default, "__getitem__"):
|
||||||
|
try:
|
||||||
|
return int(default[1])
|
||||||
|
except (IndexError, TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return int(default)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def log_sounddevice_output_devices() -> None:
|
||||||
|
"""列出所有可用输出设备及当前默认输出,便于配置 tts.output_device。"""
|
||||||
|
try:
|
||||||
|
import sounddevice as sd # type: ignore
|
||||||
|
|
||||||
|
out_idx = _sounddevice_default_output_index()
|
||||||
|
logger.info("sounddevice 输出设备列表(用于配置 tts.output_device 索引或名称子串):")
|
||||||
|
for i, d in enumerate(sd.query_devices()):
|
||||||
|
if int(d.get("max_output_channels", 0) or 0) <= 0:
|
||||||
|
continue
|
||||||
|
mark = " <- 当前默认输出" if out_idx is not None and int(out_idx) == i else ""
|
||||||
|
logger.info(f" [{i}] {d.get('name', '?')}{mark}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法枚举 sounddevice 输出设备: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _resample_playback_audio(audio: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
|
||||||
|
"""将波形重采样到设备采样率;优先 librosa(kaiser_best),失败则回退 scipy 多相。"""
|
||||||
|
from math import gcd
|
||||||
|
|
||||||
|
from scipy.signal import resample # type: ignore
|
||||||
|
from scipy.signal import resample_poly # type: ignore
|
||||||
|
|
||||||
|
if abs(sr_in - sr_out) < 1:
|
||||||
|
return np.asarray(audio, dtype=np.float32)
|
||||||
|
try:
|
||||||
|
import librosa # type: ignore
|
||||||
|
|
||||||
|
return librosa.resample(
|
||||||
|
np.asarray(audio, dtype=np.float32),
|
||||||
|
orig_sr=sr_in,
|
||||||
|
target_sr=sr_out,
|
||||||
|
res_type="kaiser_best",
|
||||||
|
).astype(np.float32)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"TTS: librosa 重采样不可用,使用多相重采样: {e!r}")
|
||||||
|
try:
|
||||||
|
g = gcd(int(sr_in), int(sr_out))
|
||||||
|
if g > 0:
|
||||||
|
up = int(sr_out) // g
|
||||||
|
down = int(sr_in) // g
|
||||||
|
return resample_poly(audio, up, down).astype(np.float32)
|
||||||
|
except Exception as e2:
|
||||||
|
logger.debug(f"TTS: resample_poly 失败,回退 FFT resample: {e2!r}")
|
||||||
|
num = max(1, int(len(audio) * float(sr_out) / float(sr_in)))
|
||||||
|
return resample(audio, num).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _fade_playback_edges(audio: np.ndarray, sample_rate: int, fade_ms: float) -> np.ndarray:
|
||||||
|
"""极短线性淡入淡出,减轻扬声器/驱动在段首段尾的爆音与杂音感。"""
|
||||||
|
if fade_ms <= 0 or audio.size < 16:
|
||||||
|
return audio
|
||||||
|
n = int(float(sample_rate) * fade_ms / 1000.0)
|
||||||
|
n = min(n, len(audio) // 4)
|
||||||
|
if n <= 0:
|
||||||
|
return audio
|
||||||
|
out = np.asarray(audio, dtype=np.float32, order="C").copy()
|
||||||
|
fade = np.linspace(0.0, 1.0, n, dtype=np.float32)
|
||||||
|
out[:n] *= fade
|
||||||
|
out[-n:] *= fade[::-1]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def play_tts_audio(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
*,
|
||||||
|
output_device: Optional[object] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
使用 sounddevice 播放单声道 float32 音频(阻塞至播放结束)。
|
||||||
|
|
||||||
|
在 Windows 上 PortAudio/sounddevice 从非主线程调用时经常出现「无声音、无报错」,
|
||||||
|
因此本项目中应答播报应在主线程(采集循环所在线程)调用本函数。
|
||||||
|
|
||||||
|
另:多数 Realtek/WASAPI 设备对 24000Hz 播放会「完全无声」且不报错,需重采样到设备
|
||||||
|
default_samplerate(常见 48000/44100),并用 OutputStream 写出。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_device: 若指定,覆盖 system.yaml 的 tts.output_device(设备索引或名称子串)。
|
||||||
|
"""
|
||||||
|
import sounddevice as sd # type: ignore
|
||||||
|
|
||||||
|
cfg = SYSTEM_TTS_CONFIG or {}
|
||||||
|
force_native = bool(cfg.get("playback_resample_to_device_native", True))
|
||||||
|
do_normalize = bool(cfg.get("playback_peak_normalize", True))
|
||||||
|
gain = float(cfg.get("playback_gain", 1.0))
|
||||||
|
if gain <= 0:
|
||||||
|
gain = 1.0
|
||||||
|
fade_ms = float(cfg.get("playback_edge_fade_ms", 8.0))
|
||||||
|
latency = cfg.get("playback_output_latency", "low")
|
||||||
|
if latency not in ("low", "medium", "high"):
|
||||||
|
latency = "low"
|
||||||
|
|
||||||
|
audio = np.asarray(audio, dtype=np.float32).squeeze()
|
||||||
|
if audio.ndim > 1:
|
||||||
|
audio = np.squeeze(audio)
|
||||||
|
if audio.size == 0:
|
||||||
|
logger.warning("TTS: 播放跳过,音频长度为 0")
|
||||||
|
return
|
||||||
|
|
||||||
|
if output_device is not None:
|
||||||
|
dev = _resolve_output_device_id(output_device)
|
||||||
|
else:
|
||||||
|
dev = get_playback_output_device_id()
|
||||||
|
if dev is None:
|
||||||
|
dev = _sounddevice_default_output_index()
|
||||||
|
if dev is None:
|
||||||
|
logger.warning("TTS: 无法解析输出设备索引,使用 sounddevice 默认输出")
|
||||||
|
else:
|
||||||
|
dev = int(dev)
|
||||||
|
|
||||||
|
info = sd.query_devices(dev) if dev is not None else sd.query_devices(kind="output")
|
||||||
|
native_sr = int(float(info.get("default_samplerate", 48000)))
|
||||||
|
sr_out = int(sample_rate)
|
||||||
|
if force_native and native_sr > 0 and abs(native_sr - sr_out) > 1:
|
||||||
|
audio = _resample_playback_audio(audio, sr_out, native_sr)
|
||||||
|
sr_out = native_sr
|
||||||
|
logger.info(
|
||||||
|
f"TTS: 播放重采样 {sample_rate}Hz -> {sr_out}Hz(匹配设备 default_samplerate,避免 Windows 无声)"
|
||||||
|
)
|
||||||
|
|
||||||
|
peak_before = float(np.max(np.abs(audio)))
|
||||||
|
if do_normalize and peak_before > 1e-8 and peak_before > 0.95:
|
||||||
|
audio = (audio / peak_before * 0.92).astype(np.float32, copy=False)
|
||||||
|
|
||||||
|
if gain != 1.0:
|
||||||
|
audio = (audio * np.float32(gain)).astype(np.float32, copy=False)
|
||||||
|
|
||||||
|
audio = _fade_playback_edges(audio, sr_out, fade_ms)
|
||||||
|
|
||||||
|
peak = float(np.max(np.abs(audio)))
|
||||||
|
rms = float(np.sqrt(np.mean(np.square(audio))))
|
||||||
|
dname = info.get("name", "?") if isinstance(info, dict) else "?"
|
||||||
|
logger.info(
|
||||||
|
f"TTS: 播放 峰值={peak:.5f} RMS={rms:.5f} sr={sr_out}Hz 设备={dev!r} ({dname!r})"
|
||||||
|
)
|
||||||
|
if peak < 1e-8:
|
||||||
|
logger.warning("TTS: 波形接近静音,请检查合成是否异常")
|
||||||
|
|
||||||
|
audio = np.clip(audio, -1.0, 1.0).astype(np.float32, copy=False)
|
||||||
|
block = audio.reshape(-1, 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with sd.OutputStream(
|
||||||
|
device=dev,
|
||||||
|
channels=1,
|
||||||
|
samplerate=sr_out,
|
||||||
|
dtype="float32",
|
||||||
|
latency=latency,
|
||||||
|
) as stream:
|
||||||
|
stream.write(block)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS: OutputStream 失败,回退 sd.play: {e}", exc_info=True)
|
||||||
|
sd.play(block, samplerate=sr_out, device=dev)
|
||||||
|
sd.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def play_wav_path(
|
||||||
|
path: str | Path,
|
||||||
|
*,
|
||||||
|
output_device: Optional[object] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
播放 16-bit PCM WAV(单声道或立体声下混为单声道),走与 synthesize + play_tts_audio
|
||||||
|
相同的 sounddevice 输出路径(含 ROCKET_TTS_DEVICE / yaml 设备解析)。
|
||||||
|
"""
|
||||||
|
import wave
|
||||||
|
|
||||||
|
p = Path(path)
|
||||||
|
with wave.open(str(p), "rb") as wf:
|
||||||
|
ch = int(wf.getnchannels())
|
||||||
|
sw = int(wf.getsampwidth())
|
||||||
|
sr = int(wf.getframerate())
|
||||||
|
nframes = int(wf.getnframes())
|
||||||
|
if sw != 2:
|
||||||
|
raise ValueError(f"仅支持 16-bit PCM: {p}")
|
||||||
|
raw = wf.readframes(nframes)
|
||||||
|
mono = np.frombuffer(raw, dtype="<i2").astype(np.float32, copy=False)
|
||||||
|
if ch == 2:
|
||||||
|
mono = mono.reshape(-1, 2).mean(axis=1).astype(np.float32, copy=False)
|
||||||
|
elif ch != 1:
|
||||||
|
raise ValueError(f"仅支持 1 或 2 通道: {p} (ch={ch})")
|
||||||
|
mono = mono * np.float32(1.0 / 32768.0)
|
||||||
|
logger.info("TTS: 播放预生成 WAV %s (%sHz, %s 采样)", p.name, sr, mono.size)
|
||||||
|
play_tts_audio(mono, sr, output_device=output_device)
|
||||||
|
|
||||||
|
|
||||||
|
def speak_text(
|
||||||
|
text: str,
|
||||||
|
tts: Optional["KokoroOnnxTTS"] = None,
|
||||||
|
*,
|
||||||
|
output_device: Optional[object] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
合成并播放一段语音;失败时仅打日志,不向外抛异常(适合命令成功后的反馈)。
|
||||||
|
"""
|
||||||
|
if not text or not str(text).strip():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
engine = tts or KokoroOnnxTTS()
|
||||||
|
t = str(text).strip()
|
||||||
|
logger.info(f"TTS: 开始合成并播放: {t!r}")
|
||||||
|
audio, sr = engine.synthesize(t)
|
||||||
|
play_tts_audio(audio, sr, output_device=output_device)
|
||||||
|
logger.info("TTS: 播放完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS 播放失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"KokoroOnnxTTS",
|
||||||
|
"play_tts_audio",
|
||||||
|
"play_wav_path",
|
||||||
|
"speak_text",
|
||||||
|
"get_playback_output_device_id",
|
||||||
|
"log_sounddevice_output_devices",
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 与主程序一致:使用 play_tts_audio(含重采样到设备 native 采样率)
|
||||||
|
tts = KokoroOnnxTTS()
|
||||||
|
|
||||||
|
text = "任务执行完成,开始返航降落"
|
||||||
|
|
||||||
|
print(f"正在合成语音: {text}")
|
||||||
|
audio, sr = tts.synthesize(text)
|
||||||
|
|
||||||
|
print("正在播放(与主程序相同 play_tts_audio 路径)...")
|
||||||
|
try:
|
||||||
|
play_tts_audio(audio, sr)
|
||||||
|
print("播放结束。")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"播放失败: {e}")
|
||||||
|
|
||||||
|
# === 保存为 WAV 文件(可选)===
|
||||||
|
try:
|
||||||
|
output_path = "任务执行完成,开始返航降落.wav"
|
||||||
|
tts.synthesize_to_file(text, output_path)
|
||||||
|
print(f"音频已保存至: {output_path}")
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"保存失败(可能缺少 scipy): {e}")
|
||||||
|
|
||||||
152
voice_drone/core/tts_ack_cache.py
Normal file
152
voice_drone/core/tts_ack_cache.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
"""
|
||||||
|
应答 TTS 波形磁盘缓存:文案与 TTS 配置未变时跳过逐条合成,加快启动。
|
||||||
|
|
||||||
|
缓存目录:项目根下 cache/ack_tts_pcm/
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from voice_drone.core.configuration import SYSTEM_TTS_CONFIG
|
||||||
|
|
||||||
|
# 与 src/core/configuration.py 一致:src/core/tts_ack_cache.py -> parents[2]
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
ACK_PCM_CACHE_DIR = _PROJECT_ROOT / "cache" / "ack_tts_pcm"
|
||||||
|
MANIFEST_NAME = "manifest.json"
|
||||||
|
CACHE_FORMAT = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _tts_signature() -> dict:
|
||||||
|
tts = SYSTEM_TTS_CONFIG or {}
|
||||||
|
return {
|
||||||
|
"model_dir": str(tts.get("model_dir", "")),
|
||||||
|
"model_name": str(tts.get("model_name", "")),
|
||||||
|
"voice": str(tts.get("voice", "")),
|
||||||
|
"speed": round(float(tts.get("speed", 1.0)), 6),
|
||||||
|
"sample_rate": int(tts.get("sample_rate", 24000)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ack_pcm_fingerprint(
|
||||||
|
unique_phrases: List[str],
|
||||||
|
*,
|
||||||
|
global_text: Optional[str] = None,
|
||||||
|
mode_phrases: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""文案 + TTS 签名变化则指纹变,磁盘缓存失效。"""
|
||||||
|
payload = {
|
||||||
|
"cache_format": CACHE_FORMAT,
|
||||||
|
"tts": _tts_signature(),
|
||||||
|
"mode_phrases": mode_phrases,
|
||||||
|
}
|
||||||
|
if mode_phrases:
|
||||||
|
payload["phrases"] = sorted(unique_phrases)
|
||||||
|
else:
|
||||||
|
payload["global_text"] = (global_text or "").strip()
|
||||||
|
raw = json.dumps(payload, sort_keys=True, ensure_ascii=False)
|
||||||
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _phrase_file_stem(fingerprint: str, phrase: str) -> str:
|
||||||
|
h = hashlib.sha256(fingerprint.encode("utf-8"))
|
||||||
|
h.update(b"\0")
|
||||||
|
h.update(phrase.encode("utf-8"))
|
||||||
|
return h.hexdigest()[:40]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_one_npz(path: Path) -> Optional[Tuple[np.ndarray, int]]:
|
||||||
|
try:
|
||||||
|
z = np.load(path, allow_pickle=False)
|
||||||
|
audio = np.asarray(z["audio"], dtype=np.float32).squeeze()
|
||||||
|
sr = int(np.asarray(z["sr"]).reshape(-1)[0])
|
||||||
|
if audio.size == 0 or sr <= 0:
|
||||||
|
return None
|
||||||
|
return (audio, sr)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_cached_phrases(
|
||||||
|
unique_phrases: List[str],
|
||||||
|
fingerprint: str,
|
||||||
|
) -> Tuple[Dict[str, Tuple[np.ndarray, int]], List[str]]:
|
||||||
|
"""
|
||||||
|
从磁盘加载与 fingerprint 匹配的缓存。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(已加载的 phrase -> (audio, sr), 仍需合成的 phrase 列表)
|
||||||
|
"""
|
||||||
|
out: Dict[str, Tuple[np.ndarray, int]] = {}
|
||||||
|
if not unique_phrases:
|
||||||
|
return {}, []
|
||||||
|
|
||||||
|
cache_dir = ACK_PCM_CACHE_DIR
|
||||||
|
manifest_path = cache_dir / MANIFEST_NAME
|
||||||
|
if not manifest_path.is_file():
|
||||||
|
return {}, list(unique_phrases)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
|
manifest = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
return {}, list(unique_phrases)
|
||||||
|
|
||||||
|
if int(manifest.get("format", 0)) != CACHE_FORMAT:
|
||||||
|
return {}, list(unique_phrases)
|
||||||
|
if manifest.get("fingerprint") != fingerprint:
|
||||||
|
return {}, list(unique_phrases)
|
||||||
|
|
||||||
|
files: Dict[str, str] = manifest.get("files") or {}
|
||||||
|
missing: List[str] = []
|
||||||
|
|
||||||
|
for phrase in unique_phrases:
|
||||||
|
fname = files.get(phrase)
|
||||||
|
if not fname:
|
||||||
|
missing.append(phrase)
|
||||||
|
continue
|
||||||
|
path = cache_dir / fname
|
||||||
|
if not path.is_file():
|
||||||
|
missing.append(phrase)
|
||||||
|
continue
|
||||||
|
loaded = _load_one_npz(path)
|
||||||
|
if loaded is None:
|
||||||
|
missing.append(phrase)
|
||||||
|
continue
|
||||||
|
out[phrase] = loaded
|
||||||
|
|
||||||
|
return out, missing
|
||||||
|
|
||||||
|
|
||||||
|
def persist_phrases(fingerprint: str, phrase_pcm: Dict[str, Tuple[np.ndarray, int]]) -> None:
|
||||||
|
"""写入/更新整包 manifest 与各句 npz(覆盖同名 manifest)。"""
|
||||||
|
if not phrase_pcm:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache_dir = ACK_PCM_CACHE_DIR
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
files: Dict[str, str] = {}
|
||||||
|
for phrase, (audio, sr) in phrase_pcm.items():
|
||||||
|
stem = _phrase_file_stem(fingerprint, phrase)
|
||||||
|
fname = f"{stem}.npz"
|
||||||
|
path = cache_dir / fname
|
||||||
|
audio = np.asarray(audio, dtype=np.float32).squeeze()
|
||||||
|
np.savez_compressed(path, audio=audio, sr=np.array([int(sr)], dtype=np.int32))
|
||||||
|
files[phrase] = fname
|
||||||
|
|
||||||
|
manifest = {
|
||||||
|
"format": CACHE_FORMAT,
|
||||||
|
"fingerprint": fingerprint,
|
||||||
|
"files": files,
|
||||||
|
}
|
||||||
|
tmp = cache_dir / (MANIFEST_NAME + ".tmp")
|
||||||
|
with open(tmp, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(manifest, f, ensure_ascii=False, indent=0)
|
||||||
|
tmp.replace(cache_dir / MANIFEST_NAME)
|
||||||
429
voice_drone/core/vad.py
Normal file
429
voice_drone/core/vad.py
Normal file
@ -0,0 +1,429 @@
|
|||||||
|
"""
|
||||||
|
语音活动检测(VAD)模块 - 纯 ONNX Runtime 版 Silero VAD
|
||||||
|
|
||||||
|
使用 Silero VAD 的 ONNX 模型检测语音活动,识别语音的开始和结束。
|
||||||
|
不依赖 PyTorch/silero_vad 包,只依赖 onnxruntime + numpy。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import multiprocessing
|
||||||
|
from voice_drone.core.configuration import (
|
||||||
|
SYSTEM_AUDIO_CONFIG,
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG,
|
||||||
|
SYSTEM_VAD_CONFIG,
|
||||||
|
)
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
from voice_drone.tools.wrapper import time_cost
|
||||||
|
logger = get_logger("vad.silero_onnx")
|
||||||
|
|
||||||
|
|
||||||
|
class VAD:
|
||||||
|
"""
|
||||||
|
语音活动检测器
|
||||||
|
|
||||||
|
使用 ONNX Runtime
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
初始化 VAD 检测器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 可选的配置字典,用于覆盖默认配置
|
||||||
|
"""
|
||||||
|
# ---- 从 system.yaml 的 vad 部分读取默认配置 ----
|
||||||
|
# system.yaml:
|
||||||
|
# vad:
|
||||||
|
# threshold: 0.65
|
||||||
|
# start_frame: 3
|
||||||
|
# end_frame: 10
|
||||||
|
# min_silence_duration_s: 0.5
|
||||||
|
# max_silence_duration_s: 30
|
||||||
|
# model_path: "src/models/silero_vad.onnx"
|
||||||
|
vad_conf = SYSTEM_VAD_CONFIG
|
||||||
|
|
||||||
|
# 语音概率阈值(YAML 可能是字符串)
|
||||||
|
self.speech_threshold = float(vad_conf.get("threshold", 0.5))
|
||||||
|
|
||||||
|
# 连续多少帧检测到语音才认为“开始说话”
|
||||||
|
self.speech_start_frames = int(vad_conf.get("start_frame", 3))
|
||||||
|
|
||||||
|
# 连续多少帧检测到静音才认为“结束说话”
|
||||||
|
self.silence_end_frames = int(vad_conf.get("end_frame", 10))
|
||||||
|
|
||||||
|
# 可选: 最短/最长语音段时长(秒),可以在上层按需使用
|
||||||
|
self.min_speech_duration = vad_conf.get("min_silence_duration_s")
|
||||||
|
self.max_speech_duration = vad_conf.get("max_silence_duration_s")
|
||||||
|
|
||||||
|
# 采样率来自 audio 配置
|
||||||
|
self.sample_rate = SYSTEM_AUDIO_CONFIG.get("sample_rate")
|
||||||
|
|
||||||
|
# 与 recognizer 一致:能量 VAD 时不加载 Silero(避免无模型文件仍强加载)
|
||||||
|
_ev_env = os.environ.get("ROCKET_ENERGY_VAD", "").lower() in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"yes",
|
||||||
|
)
|
||||||
|
_yaml_backend = str(
|
||||||
|
SYSTEM_RECOGNIZER_CONFIG.get("vad_backend", "silero")
|
||||||
|
).lower()
|
||||||
|
if _ev_env or _yaml_backend == "energy":
|
||||||
|
self.onnx_session = None
|
||||||
|
self.vad_model_path = None
|
||||||
|
self.window_size = 512 if int(self.sample_rate or 16000) == 16000 else 256
|
||||||
|
self.input_name = None
|
||||||
|
self.sr_input_name = None
|
||||||
|
self.state_input_name = None
|
||||||
|
self.output_name = None
|
||||||
|
self.state = None
|
||||||
|
self.speech_frame_count = 0
|
||||||
|
self.silence_frame_count = 0
|
||||||
|
self.is_speaking = False
|
||||||
|
logger.info(
|
||||||
|
"VAD:能量(RMS)分段模式,跳过 Silero ONNX(与 ROCKET_ENERGY_VAD / vad_backend 一致)"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ---- 加载 Silero VAD ONNX 模型 ----
|
||||||
|
raw_mp = SYSTEM_VAD_CONFIG.get("model_path")
|
||||||
|
if not raw_mp:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"vad.model_path 未配置。若只用能量 VAD,请在 system.yaml 中设 "
|
||||||
|
"recognizer.vad_backend: energy 并设置 ROCKET_ENERGY_VAD=1"
|
||||||
|
)
|
||||||
|
mp = Path(raw_mp)
|
||||||
|
if not mp.is_absolute():
|
||||||
|
mp = Path(__file__).resolve().parents[2] / mp
|
||||||
|
self.vad_model_path = str(mp)
|
||||||
|
if not mp.is_file():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Silero VAD 模型不存在: {self.vad_model_path}。请下载 silero_vad.onnx 到该路径,"
|
||||||
|
"或改用能量 VAD:recognizer.vad_backend: energy 且 ROCKET_ENERGY_VAD=1"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"正在加载 Silero VAD ONNX 模型: {self.vad_model_path}")
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
optimal_threads = min(4, cpu_count)
|
||||||
|
sess_options.intra_op_num_threads = optimal_threads
|
||||||
|
sess_options.inter_op_num_threads = optimal_threads
|
||||||
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
|
||||||
|
self.onnx_session = ort.InferenceSession(
|
||||||
|
str(mp),
|
||||||
|
sess_options=sess_options,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = self.onnx_session.get_inputs()
|
||||||
|
outputs = self.onnx_session.get_outputs()
|
||||||
|
if not inputs:
|
||||||
|
raise RuntimeError("VAD ONNX 模型没有输入节点")
|
||||||
|
|
||||||
|
# ---- 解析输入 / 输出 ----
|
||||||
|
# 典型的 Silero VAD ONNX 会有:
|
||||||
|
# - 输入: audio(input), 采样率(sr), 状态(state 或 h/c)
|
||||||
|
# - 输出: 语音概率 + 可选的新状态
|
||||||
|
self.input_name = None
|
||||||
|
self.sr_input_name: Optional[str] = None
|
||||||
|
self.state_input_name: Optional[str] = None
|
||||||
|
|
||||||
|
for inp in inputs:
|
||||||
|
name = inp.name
|
||||||
|
if self.input_name is None:
|
||||||
|
# 优先匹配常见名称,否则退回第一个
|
||||||
|
if name in ("input", "audio", "waveform"):
|
||||||
|
self.input_name = name
|
||||||
|
else:
|
||||||
|
self.input_name = name
|
||||||
|
if name == "sr":
|
||||||
|
self.sr_input_name = name
|
||||||
|
if name in ("state", "h", "c", "hidden"):
|
||||||
|
self.state_input_name = name
|
||||||
|
|
||||||
|
# 如果依然没有确定 input_name,兜底使用第一个
|
||||||
|
if self.input_name is None:
|
||||||
|
self.input_name = inputs[0].name
|
||||||
|
|
||||||
|
self.output_name = outputs[0].name if outputs else None
|
||||||
|
|
||||||
|
# 预分配状态向量(如果模型需要)
|
||||||
|
self.state: Optional[np.ndarray] = None
|
||||||
|
if self.state_input_name is not None:
|
||||||
|
state_inp = next(i for i in inputs if i.name == self.state_input_name)
|
||||||
|
# state 的 shape 通常是 [1, N] 或 [N], 这里用 0 初始化
|
||||||
|
state_shape = [
|
||||||
|
int(d) if isinstance(d, int) and d > 0 else 1
|
||||||
|
for d in (state_inp.shape or [1])
|
||||||
|
]
|
||||||
|
self.state = np.zeros(state_shape, dtype=np.float32)
|
||||||
|
|
||||||
|
# 从输入 shape 推断窗口大小: (batch, samples) 或 (samples,)
|
||||||
|
input_shape = inputs[0].shape
|
||||||
|
win_size = None
|
||||||
|
if isinstance(input_shape, (list, tuple)) and len(input_shape) >= 1:
|
||||||
|
last_dim = input_shape[-1]
|
||||||
|
if isinstance(last_dim, int):
|
||||||
|
win_size = last_dim
|
||||||
|
if win_size is None:
|
||||||
|
win_size = 512 if self.sample_rate == 16000 else 256
|
||||||
|
self.window_size = int(win_size)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Silero VAD ONNX 模型加载完成: 输入={self.input_name}, 输出={self.output_name}, "
|
||||||
|
f"window_size={self.window_size}, sample_rate={self.sample_rate}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Silero VAD ONNX 模型加载失败: {e}")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"无法加载 Silero VAD: {e}。若无需 Silero,请设 ROCKET_ENERGY_VAD=1 且 "
|
||||||
|
"recognizer.vad_backend: energy"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# State tracking
|
||||||
|
self.speech_frame_count = 0 # Consecutive speech frame count
|
||||||
|
self.silence_frame_count = 0 # Consecutive silence frame count
|
||||||
|
self.is_speaking = False # Currently speaking
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"VADDetector 初始化完成: "
|
||||||
|
f"speech_threshold={self.speech_threshold}, "
|
||||||
|
f"speech_start_frames={self.speech_start_frames}, "
|
||||||
|
f"silence_end_frames={self.silence_end_frames}, "
|
||||||
|
f"sample_rate={self.sample_rate}Hz"
|
||||||
|
)
|
||||||
|
|
||||||
|
# @time_cost("VAD-语音检测耗时")
|
||||||
|
def is_speech(self, audio_chunk: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
检测音频块是否包含语音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_chunk: 音频数据(bytes),必须是 16kHz, 16-bit, 单声道 PCM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示检测到语音,False 表示静音
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.onnx_session is None:
|
||||||
|
return False
|
||||||
|
# 将 bytes 转换为 numpy array(int16),确保 little-endian 字节序
|
||||||
|
audio_array = np.frombuffer(audio_chunk, dtype="<i2")
|
||||||
|
|
||||||
|
# 转换为 float32 并归一化到 [-1, 1]
|
||||||
|
audio_float = audio_array.astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
required_samples = getattr(self, "window_size", 512 if self.sample_rate == 16000 else 256)
|
||||||
|
|
||||||
|
# 如果音频块小于要求的大小,填充零
|
||||||
|
if len(audio_float) < required_samples:
|
||||||
|
audio_float = np.pad(
|
||||||
|
audio_float, (0, required_samples - len(audio_float)), mode="constant"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果音频块大于要求的大小,分割成多个小块并取平均值
|
||||||
|
if len(audio_float) > required_samples:
|
||||||
|
num_chunks = len(audio_float) // required_samples
|
||||||
|
speech_probs = []
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
start_idx = i * required_samples
|
||||||
|
end_idx = start_idx + required_samples
|
||||||
|
chunk = audio_float[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 模型通常期望输入形状为 (1, samples)
|
||||||
|
input_data = chunk[np.newaxis, :].astype(np.float32)
|
||||||
|
ort_inputs = {self.input_name: input_data}
|
||||||
|
|
||||||
|
# 如果模型需要 sr,state 等附加输入,一并提供
|
||||||
|
if getattr(self, "sr_input_name", None) is not None:
|
||||||
|
# Silero VAD 一般期望 int64 采样率
|
||||||
|
ort_inputs[self.sr_input_name] = np.array(
|
||||||
|
[self.sample_rate], dtype=np.int64
|
||||||
|
)
|
||||||
|
if getattr(self, "state_input_name", None) is not None and self.state is not None:
|
||||||
|
ort_inputs[self.state_input_name] = self.state
|
||||||
|
|
||||||
|
outputs = self.onnx_session.run(None, ort_inputs)
|
||||||
|
|
||||||
|
# 如果模型返回新的 state,更新内部状态
|
||||||
|
if (
|
||||||
|
getattr(self, "state_input_name", None) is not None
|
||||||
|
and len(outputs) > 1
|
||||||
|
):
|
||||||
|
self.state = outputs[1]
|
||||||
|
|
||||||
|
prob = float(outputs[0].reshape(-1)[0])
|
||||||
|
speech_probs.append(prob)
|
||||||
|
|
||||||
|
speech_prob = float(np.mean(speech_probs))
|
||||||
|
else:
|
||||||
|
input_data = audio_float[:required_samples][np.newaxis, :].astype(np.float32)
|
||||||
|
ort_inputs = {self.input_name: input_data}
|
||||||
|
|
||||||
|
if getattr(self, "sr_input_name", None) is not None:
|
||||||
|
ort_inputs[self.sr_input_name] = np.array(
|
||||||
|
[self.sample_rate], dtype=np.int64
|
||||||
|
)
|
||||||
|
if getattr(self, "state_input_name", None) is not None and self.state is not None:
|
||||||
|
ort_inputs[self.state_input_name] = self.state
|
||||||
|
|
||||||
|
outputs = self.onnx_session.run(None, ort_inputs)
|
||||||
|
|
||||||
|
if (
|
||||||
|
getattr(self, "state_input_name", None) is not None
|
||||||
|
and len(outputs) > 1
|
||||||
|
):
|
||||||
|
self.state = outputs[1]
|
||||||
|
|
||||||
|
speech_prob = float(outputs[0].reshape(-1)[0])
|
||||||
|
|
||||||
|
return speech_prob >= self.speech_threshold
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VAD detection failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_speech_numpy(self, audio_array: np.ndarray) -> bool:
|
||||||
|
"""
|
||||||
|
检测音频数组是否包含语音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_array: 音频数据(numpy array,dtype=int16)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示检测到语音,False 表示静音
|
||||||
|
"""
|
||||||
|
# 转换为 bytes
|
||||||
|
audio_bytes = audio_array.tobytes()
|
||||||
|
return self.is_speech(audio_bytes)
|
||||||
|
|
||||||
|
def detect_speech_start(self, audio_chunk: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
检测语音开始
|
||||||
|
|
||||||
|
需要连续检测到多帧语音才认为语音开始
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_chunk: 音频数据块
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示检测到语音开始
|
||||||
|
"""
|
||||||
|
if self.is_speaking:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.is_speech(audio_chunk):
|
||||||
|
self.speech_frame_count += 1
|
||||||
|
self.silence_frame_count = 0
|
||||||
|
|
||||||
|
if self.speech_frame_count >= self.speech_start_frames:
|
||||||
|
self.is_speaking = True
|
||||||
|
self.speech_frame_count = 0
|
||||||
|
logger.info("Speech start detected")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.speech_frame_count = 0
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def detect_speech_end(self, audio_chunk: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
检测语音结束
|
||||||
|
|
||||||
|
需要连续检测到多帧静音才认为语音结束
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_chunk: 音频数据块
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示检测到语音结束
|
||||||
|
"""
|
||||||
|
if not self.is_speaking:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self.is_speech(audio_chunk):
|
||||||
|
self.silence_frame_count += 1
|
||||||
|
self.speech_frame_count = 0
|
||||||
|
|
||||||
|
if self.silence_frame_count >= self.silence_end_frames:
|
||||||
|
self.is_speaking = False
|
||||||
|
self.silence_frame_count = 0
|
||||||
|
logger.info("Speech end detected")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.silence_frame_count = 0
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
重置检测器状态
|
||||||
|
|
||||||
|
清除帧计数、是否在说话标记,以及 Silero 的 RNN 状态(长间隔后应清零,避免与后续音频错位)。
|
||||||
|
"""
|
||||||
|
self.speech_frame_count = 0
|
||||||
|
self.silence_frame_count = 0
|
||||||
|
self.is_speaking = False
|
||||||
|
if self.state is not None:
|
||||||
|
self.state.fill(0)
|
||||||
|
logger.debug("VAD detector state reset")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
使用测试音频按帧扫描,统计语音帧比例,更直观地验证 VAD 是否工作正常。
|
||||||
|
"""
|
||||||
|
import wave
|
||||||
|
|
||||||
|
vad = VAD()
|
||||||
|
audio_file = "test/测试音频.wav"
|
||||||
|
|
||||||
|
# 1. 读取 wav
|
||||||
|
with wave.open(audio_file, "rb") as wf:
|
||||||
|
n_channels = wf.getnchannels()
|
||||||
|
sampwidth = wf.getsampwidth()
|
||||||
|
framerate = wf.getframerate()
|
||||||
|
n_frames = wf.getnframes()
|
||||||
|
raw = wf.readframes(n_frames)
|
||||||
|
|
||||||
|
# 2. 转成 int16 数组
|
||||||
|
audio = np.frombuffer(raw, dtype="<i2")
|
||||||
|
|
||||||
|
# 3. 双声道 -> 单声道
|
||||||
|
if n_channels == 2:
|
||||||
|
audio = audio.reshape(-1, 2)
|
||||||
|
audio = audio.mean(axis=1).astype(np.int16)
|
||||||
|
|
||||||
|
# 4. 重采样到 VAD 所需采样率(通常 16k)
|
||||||
|
target_sr = vad.sample_rate
|
||||||
|
if framerate != target_sr:
|
||||||
|
x_old = np.linspace(0, 1, num=len(audio), endpoint=False)
|
||||||
|
x_new = np.linspace(0, 1, num=int(len(audio) * target_sr / framerate), endpoint=False)
|
||||||
|
audio = np.interp(x_new, x_old, audio).astype(np.int16)
|
||||||
|
|
||||||
|
print("wav info:", n_channels, "ch,", framerate, "Hz")
|
||||||
|
print("audio len (samples):", len(audio), " target_sr:", target_sr)
|
||||||
|
|
||||||
|
# 5. 按 VAD 窗口大小逐帧扫描
|
||||||
|
frame_samples = vad.window_size
|
||||||
|
frame_bytes = frame_samples * 2 # int16 -> 2 字节
|
||||||
|
audio_bytes = audio.tobytes()
|
||||||
|
num_frames = len(audio_bytes) // frame_bytes
|
||||||
|
|
||||||
|
speech_frames = 0
|
||||||
|
for i in range(num_frames):
|
||||||
|
chunk = audio_bytes[i * frame_bytes : (i + 1) * frame_bytes]
|
||||||
|
if vad.is_speech(chunk):
|
||||||
|
speech_frames += 1
|
||||||
|
|
||||||
|
speech_ratio = speech_frames / num_frames if num_frames > 0 else 0.0
|
||||||
|
print("total frames:", num_frames)
|
||||||
|
print("speech frames:", speech_frames)
|
||||||
|
print("speech ratio:", speech_ratio)
|
||||||
|
print("has_any_speech:", speech_frames > 0)
|
||||||
375
voice_drone/core/wake_word.py
Normal file
375
voice_drone/core/wake_word.py
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
"""
|
||||||
|
唤醒词检测模块 - 高性能实时唤醒词识别
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- 精确匹配
|
||||||
|
- 模糊匹配(同音字、拼音)
|
||||||
|
- 部分匹配
|
||||||
|
- 配置化变体映射
|
||||||
|
|
||||||
|
性能优化:
|
||||||
|
- 预编译正则表达式
|
||||||
|
- LRU缓存匹配结果
|
||||||
|
- 优化字符串操作
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from functools import lru_cache
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
from voice_drone.core.configuration import (
|
||||||
|
WAKE_WORD_PRIMARY,
|
||||||
|
WAKE_WORD_VARIANTS,
|
||||||
|
WAKE_WORD_MATCHING_CONFIG
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger("wake_word")
|
||||||
|
|
||||||
|
# 延迟加载可选依赖
|
||||||
|
try:
|
||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
PYPINYIN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PYPINYIN_AVAILABLE = False
|
||||||
|
logger.warning("pypinyin 未安装,拼音匹配功能将受限")
|
||||||
|
|
||||||
|
|
||||||
|
class WakeWordDetector:
|
||||||
|
"""
|
||||||
|
唤醒词检测器
|
||||||
|
|
||||||
|
支持多种匹配模式:
|
||||||
|
- 精确匹配:完全匹配唤醒词
|
||||||
|
- 模糊匹配:同音字、拼音变体
|
||||||
|
- 部分匹配:只匹配部分唤醒词
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化唤醒词检测器"""
|
||||||
|
logger.info("初始化唤醒词检测器...")
|
||||||
|
|
||||||
|
# 从配置加载
|
||||||
|
self.primary = WAKE_WORD_PRIMARY
|
||||||
|
self.variants = WAKE_WORD_VARIANTS or []
|
||||||
|
self.matching_config = WAKE_WORD_MATCHING_CONFIG or {}
|
||||||
|
|
||||||
|
# 匹配配置
|
||||||
|
self.enable_fuzzy = self.matching_config.get("enable_fuzzy", True)
|
||||||
|
self.enable_partial = self.matching_config.get("enable_partial", True)
|
||||||
|
self.ignore_case = self.matching_config.get("ignore_case", True)
|
||||||
|
self.ignore_spaces = self.matching_config.get("ignore_spaces", True)
|
||||||
|
self.min_match_length = self.matching_config.get("min_match_length", 2)
|
||||||
|
self.similarity_threshold = self.matching_config.get("similarity_threshold", 0.7)
|
||||||
|
|
||||||
|
# 构建匹配模式
|
||||||
|
self._build_patterns()
|
||||||
|
|
||||||
|
logger.info(f"唤醒词检测器初始化完成")
|
||||||
|
logger.info(f" 主唤醒词: {self.primary}")
|
||||||
|
logger.info(f" 变体数量: {len(self.variants)}")
|
||||||
|
logger.info(f" 模糊匹配: {'启用' if self.enable_fuzzy else '禁用'}")
|
||||||
|
logger.info(f" 部分匹配: {'启用' if self.enable_partial else '禁用'}")
|
||||||
|
|
||||||
|
def _build_patterns(self):
|
||||||
|
"""构建匹配模式(预编译正则表达式)"""
|
||||||
|
# 标准化所有变体(去除空格、转小写等)
|
||||||
|
self.normalized_variants = []
|
||||||
|
|
||||||
|
for variant in self.variants:
|
||||||
|
normalized = self._normalize_text(variant)
|
||||||
|
if normalized:
|
||||||
|
self.normalized_variants.append(normalized)
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
self.normalized_variants = list(set(self.normalized_variants))
|
||||||
|
|
||||||
|
# 按长度降序排序(优先匹配长模式)
|
||||||
|
self.normalized_variants.sort(key=len, reverse=True)
|
||||||
|
|
||||||
|
# 构建正则表达式模式
|
||||||
|
patterns = []
|
||||||
|
for variant in self.normalized_variants:
|
||||||
|
# 转义特殊字符
|
||||||
|
escaped = re.escape(variant)
|
||||||
|
patterns.append(escaped)
|
||||||
|
|
||||||
|
# 编译单一正则表达式
|
||||||
|
if patterns:
|
||||||
|
self.pattern = re.compile('|'.join(patterns), re.IGNORECASE if self.ignore_case else 0)
|
||||||
|
else:
|
||||||
|
self.pattern = None
|
||||||
|
|
||||||
|
logger.debug(f"构建了 {len(self.normalized_variants)} 个匹配模式")
|
||||||
|
|
||||||
|
def _normalize_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
标准化文本(用于匹配)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的文本
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
normalized = text
|
||||||
|
|
||||||
|
# 忽略大小写
|
||||||
|
if self.ignore_case:
|
||||||
|
normalized = normalized.lower()
|
||||||
|
|
||||||
|
# 忽略空格
|
||||||
|
if self.ignore_spaces:
|
||||||
|
normalized = normalized.replace(' ', '').replace('\t', '').replace('\n', '')
|
||||||
|
|
||||||
|
return normalized.strip()
|
||||||
|
|
||||||
|
def _get_pinyin(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
获取文本的拼音(用于拼音匹配)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 中文文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
拼音字符串(小写,无空格)
|
||||||
|
"""
|
||||||
|
if not PYPINYIN_AVAILABLE:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
|
||||||
|
return ''.join(pinyin_list).lower()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"拼音转换失败: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _fuzzy_match(self, text: str, variant: str) -> bool:
|
||||||
|
"""
|
||||||
|
模糊匹配(同音字、拼音)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待匹配文本
|
||||||
|
variant: 变体文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否匹配
|
||||||
|
"""
|
||||||
|
# 1. 精确匹配(已标准化)
|
||||||
|
normalized_text = self._normalize_text(text)
|
||||||
|
normalized_variant = self._normalize_text(variant)
|
||||||
|
|
||||||
|
if normalized_text == normalized_variant:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 2. 拼音匹配
|
||||||
|
if PYPINYIN_AVAILABLE:
|
||||||
|
text_pinyin = self._get_pinyin(text)
|
||||||
|
variant_pinyin = self._get_pinyin(variant)
|
||||||
|
|
||||||
|
if text_pinyin and variant_pinyin:
|
||||||
|
# 完全匹配拼音
|
||||||
|
if text_pinyin == variant_pinyin:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 部分匹配拼音(至少匹配一半)
|
||||||
|
if len(variant_pinyin) >= 2:
|
||||||
|
# 检查是否包含变体的拼音
|
||||||
|
if variant_pinyin in text_pinyin or text_pinyin in variant_pinyin:
|
||||||
|
# 计算相似度
|
||||||
|
similarity = min(len(variant_pinyin), len(text_pinyin)) / max(len(variant_pinyin), len(text_pinyin))
|
||||||
|
if similarity >= self.similarity_threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 3. 字符级相似度匹配(简单实现)
|
||||||
|
if len(normalized_text) >= self.min_match_length and len(normalized_variant) >= self.min_match_length:
|
||||||
|
# 检查是否包含变体
|
||||||
|
if normalized_variant in normalized_text or normalized_text in normalized_variant:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _partial_match(self, text: str) -> bool:
|
||||||
|
"""
|
||||||
|
部分匹配(只匹配部分唤醒词,如主词较长时取前半段;短词请在配置 variants 中列出)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待匹配文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否匹配
|
||||||
|
"""
|
||||||
|
if not self.enable_partial:
|
||||||
|
return False
|
||||||
|
|
||||||
|
normalized_text = self._normalize_text(text)
|
||||||
|
|
||||||
|
# 检查是否包含主唤醒词的一部分
|
||||||
|
if self.primary:
|
||||||
|
# 提取主唤醒词的前半部分(如四字词可拆成前两字)
|
||||||
|
primary_normalized = self._normalize_text(self.primary)
|
||||||
|
if len(primary_normalized) >= self.min_match_length * 2:
|
||||||
|
half_length = len(primary_normalized) // 2
|
||||||
|
half_wake_word = primary_normalized[:half_length]
|
||||||
|
|
||||||
|
if len(half_wake_word) >= self.min_match_length:
|
||||||
|
if half_wake_word in normalized_text:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@lru_cache(maxsize=256)
|
||||||
|
def detect(self, text: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
检测文本中是否包含唤醒词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待检测文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否匹配, 匹配到的唤醒词)
|
||||||
|
"""
|
||||||
|
if not text or not self.pattern:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
normalized_text = self._normalize_text(text)
|
||||||
|
|
||||||
|
# 1. 精确匹配(使用正则表达式)
|
||||||
|
if self.pattern:
|
||||||
|
match = self.pattern.search(normalized_text)
|
||||||
|
if match:
|
||||||
|
matched_text = match.group(0)
|
||||||
|
logger.debug(f"精确匹配到唤醒词: {matched_text}")
|
||||||
|
return True, matched_text
|
||||||
|
|
||||||
|
# 2. 模糊匹配(同音字、拼音)
|
||||||
|
if self.enable_fuzzy:
|
||||||
|
for variant in self.normalized_variants:
|
||||||
|
if self._fuzzy_match(normalized_text, variant):
|
||||||
|
logger.debug(f"模糊匹配到唤醒词变体: {variant}")
|
||||||
|
return True, variant
|
||||||
|
|
||||||
|
# 3. 部分匹配
|
||||||
|
if self._partial_match(normalized_text):
|
||||||
|
logger.debug(f"部分匹配到唤醒词")
|
||||||
|
return True, self.primary[:len(self.primary)//2] if self.primary else None
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def extract_command_text(self, text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
从文本中提取命令部分(移除唤醒词)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 包含唤醒词的完整文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取的命令文本,如果未检测到唤醒词返回None
|
||||||
|
"""
|
||||||
|
is_wake, matched_wake_word = self.detect(text)
|
||||||
|
|
||||||
|
if not is_wake:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 标准化文本用于查找
|
||||||
|
normalized_text = self._normalize_text(text)
|
||||||
|
normalized_wake = self._normalize_text(matched_wake_word) if matched_wake_word else ""
|
||||||
|
|
||||||
|
if not normalized_wake or normalized_wake not in normalized_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 找到唤醒词在标准化文本中的位置
|
||||||
|
idx = normalized_text.find(normalized_wake)
|
||||||
|
if idx < 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 方法1:尝试在原始文本中精确查找匹配的变体
|
||||||
|
original_text = text
|
||||||
|
text_lower = original_text.lower()
|
||||||
|
|
||||||
|
# 查找所有可能的变体在原始文本中的位置
|
||||||
|
best_match_idx = -1
|
||||||
|
best_match_length = 0
|
||||||
|
|
||||||
|
# 检查配置中的所有变体
|
||||||
|
for variant in self.variants:
|
||||||
|
variant_normalized = self._normalize_text(variant)
|
||||||
|
if variant_normalized == normalized_wake:
|
||||||
|
# 这个变体匹配到了,尝试在原始文本中找到它
|
||||||
|
variant_lower = variant.lower()
|
||||||
|
if variant_lower in text_lower:
|
||||||
|
variant_idx = text_lower.find(variant_lower)
|
||||||
|
if variant_idx >= 0:
|
||||||
|
# 选择最长的匹配(更准确)
|
||||||
|
if len(variant) > best_match_length:
|
||||||
|
best_match_idx = variant_idx
|
||||||
|
best_match_length = len(variant)
|
||||||
|
|
||||||
|
# 如果找到了匹配的变体
|
||||||
|
if best_match_idx >= 0:
|
||||||
|
command_start = best_match_idx + best_match_length
|
||||||
|
command_text = original_text[command_start:].strip()
|
||||||
|
# 移除开头的标点符号
|
||||||
|
command_text = command_text.lstrip(',。、,.').strip()
|
||||||
|
return command_text if command_text else None
|
||||||
|
|
||||||
|
# 方法2:回退方案 - 使用字符计数近似定位
|
||||||
|
# 计算标准化文本中唤醒词结束位置对应的原始文本位置
|
||||||
|
wake_end_in_normalized = idx + len(normalized_wake)
|
||||||
|
|
||||||
|
# 计算原始文本中对应的字符位置
|
||||||
|
char_count = 0
|
||||||
|
for i, char in enumerate(original_text):
|
||||||
|
normalized_char = self._normalize_text(char)
|
||||||
|
if normalized_char:
|
||||||
|
if char_count >= wake_end_in_normalized:
|
||||||
|
command_text = original_text[i:].strip()
|
||||||
|
command_text = command_text.lstrip(',。、,.').strip()
|
||||||
|
return command_text if command_text else None
|
||||||
|
char_count += 1
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_global_detector: Optional[WakeWordDetector] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_wake_word_detector() -> WakeWordDetector:
|
||||||
|
"""获取全局唤醒词检测器实例(单例模式)"""
|
||||||
|
global _global_detector
|
||||||
|
if _global_detector is None:
|
||||||
|
_global_detector = WakeWordDetector()
|
||||||
|
return _global_detector
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试代码
|
||||||
|
detector = WakeWordDetector()
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
("无人机,现在起飞", True),
|
||||||
|
("wu ren ji 现在起飞", True),
|
||||||
|
("Wu Ren Ji 现在起飞", True),
|
||||||
|
("五人机,现在起飞", True),
|
||||||
|
("现在起飞", False),
|
||||||
|
("无人,现在起飞", True), # 变体列表中的短说
|
||||||
|
("人机 前进", False), # 已移除单独「人机」变体,避免子串误唤醒
|
||||||
|
("无人机 前进", True),
|
||||||
|
]
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("唤醒词检测测试")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for text, expected in test_cases:
|
||||||
|
is_wake, matched = detector.detect(text)
|
||||||
|
command_text = detector.extract_command_text(text)
|
||||||
|
status = "OK" if is_wake == expected else "FAIL"
|
||||||
|
print(f"{status} 文本: {text}")
|
||||||
|
print(f" 匹配: {is_wake} (期望: {expected})")
|
||||||
|
print(f" 匹配词: {matched}")
|
||||||
|
print(f" 提取命令: {command_text}")
|
||||||
|
print()
|
||||||
BIN
voice_drone/core/任务执行完成,开始返航降落.wav
Normal file
BIN
voice_drone/core/任务执行完成,开始返航降落.wav
Normal file
Binary file not shown.
BIN
voice_drone/core/好的收到,开始起飞.wav
Normal file
BIN
voice_drone/core/好的收到,开始起飞.wav
Normal file
Binary file not shown.
1
voice_drone/flight_bridge/__init__.py
Normal file
1
voice_drone/flight_bridge/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""伴飞桥:将 flight_intent v1 译为 MAVROS/PX4 行为(ROS 1 Noetic 首版)。"""
|
||||||
330
voice_drone/flight_bridge/ros1_mavros_executor.py
Normal file
330
voice_drone/flight_bridge/ros1_mavros_executor.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
"""
|
||||||
|
ROS1 + MAVROS:按 ValidatedFlightIntent 顺序执行(Offboard 位姿 / AUTO.LAND / AUTO.RTL)。
|
||||||
|
|
||||||
|
需在已连接飞控的 MAVROS 环境中运行(与 src/px4_ctrl_offboard_demo.py 相同前提)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
from geometry_msgs.msg import PoseStamped
|
||||||
|
from mavros_msgs.msg import PositionTarget, State
|
||||||
|
from mavros_msgs.srv import CommandBool, SetMode
|
||||||
|
|
||||||
|
from voice_drone.core.flight_intent import (
|
||||||
|
ActionGoto,
|
||||||
|
ActionHold,
|
||||||
|
ActionHover,
|
||||||
|
ActionLand,
|
||||||
|
ActionReturnHome,
|
||||||
|
ActionTakeoff,
|
||||||
|
ActionWait,
|
||||||
|
FlightAction,
|
||||||
|
ValidatedFlightIntent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _yaw_from_quaternion(x: float, y: float, z: float, w: float) -> float:
|
||||||
|
siny_cosp = 2.0 * (w * z + x * y)
|
||||||
|
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
||||||
|
return math.atan2(siny_cosp, cosy_cosp)
|
||||||
|
|
||||||
|
|
||||||
|
class MavrosFlightExecutor:
|
||||||
|
"""单次连接 MAVROS;对外提供 execute(intent)。"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
rospy.init_node("flight_intent_mavros_bridge", anonymous=True)
|
||||||
|
|
||||||
|
self.state = State()
|
||||||
|
self.pose = PoseStamped()
|
||||||
|
self.has_pose = False
|
||||||
|
|
||||||
|
rospy.Subscriber("/mavros/state", State, self._state_cb, queue_size=10)
|
||||||
|
rospy.Subscriber(
|
||||||
|
"/mavros/local_position/pose",
|
||||||
|
PoseStamped,
|
||||||
|
self._pose_cb,
|
||||||
|
queue_size=10,
|
||||||
|
)
|
||||||
|
self.sp_pub = rospy.Publisher(
|
||||||
|
"/mavros/setpoint_raw/local",
|
||||||
|
PositionTarget,
|
||||||
|
queue_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
rospy.wait_for_service("/mavros/cmd/arming", timeout=60.0)
|
||||||
|
rospy.wait_for_service("/mavros/set_mode", timeout=60.0)
|
||||||
|
self._arming_name = "/mavros/cmd/arming"
|
||||||
|
self._set_mode_name = "/mavros/set_mode"
|
||||||
|
self.arm_srv = rospy.ServiceProxy(self._arming_name, CommandBool)
|
||||||
|
self.mode_srv = rospy.ServiceProxy(self._set_mode_name, SetMode)
|
||||||
|
|
||||||
|
self.rate = rospy.Rate(20)
|
||||||
|
|
||||||
|
self.default_takeoff_relative_m = rospy.get_param(
|
||||||
|
"~default_takeoff_relative_m",
|
||||||
|
0.5,
|
||||||
|
)
|
||||||
|
self.takeoff_timeout_sec = rospy.get_param("~takeoff_timeout_sec", 15.0)
|
||||||
|
self.goto_tol = rospy.get_param("~goto_position_tolerance", 0.15)
|
||||||
|
self.goto_timeout_sec = rospy.get_param("~goto_timeout_sec", 60.0)
|
||||||
|
self.land_timeout_sec = rospy.get_param("~land_timeout_sec", 45.0)
|
||||||
|
self.pre_stream_count = int(rospy.get_param("~offboard_pre_stream_count", 80))
|
||||||
|
|
||||||
|
def _state_cb(self, msg: State) -> None:
|
||||||
|
self.state = msg
|
||||||
|
|
||||||
|
def _pose_cb(self, msg: PoseStamped) -> None:
|
||||||
|
self.pose = msg
|
||||||
|
self.has_pose = True
|
||||||
|
|
||||||
|
def _wait_connected_and_pose(self) -> None:
|
||||||
|
rospy.loginfo("等待 FCU 与本地位置 …")
|
||||||
|
while not rospy.is_shutdown() and not self.state.connected:
|
||||||
|
self.rate.sleep()
|
||||||
|
while not rospy.is_shutdown() and not self.has_pose:
|
||||||
|
self.rate.sleep()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _position_sp(x: float, y: float, z: float) -> PositionTarget:
|
||||||
|
sp = PositionTarget()
|
||||||
|
sp.header.stamp = rospy.Time.now()
|
||||||
|
sp.coordinate_frame = PositionTarget.FRAME_LOCAL_NED
|
||||||
|
sp.type_mask = (
|
||||||
|
PositionTarget.IGNORE_VX
|
||||||
|
| PositionTarget.IGNORE_VY
|
||||||
|
| PositionTarget.IGNORE_VZ
|
||||||
|
| PositionTarget.IGNORE_AFX
|
||||||
|
| PositionTarget.IGNORE_AFY
|
||||||
|
| PositionTarget.IGNORE_AFZ
|
||||||
|
| PositionTarget.IGNORE_YAW
|
||||||
|
| PositionTarget.IGNORE_YAW_RATE
|
||||||
|
)
|
||||||
|
sp.position.x = x
|
||||||
|
sp.position.y = y
|
||||||
|
sp.position.z = z
|
||||||
|
return sp
|
||||||
|
|
||||||
|
def _publish_sp(self, sp: PositionTarget) -> None:
|
||||||
|
sp.header.stamp = rospy.Time.now()
|
||||||
|
self.sp_pub.publish(sp)
|
||||||
|
|
||||||
|
def _stream_init_setpoint(self, x: float, y: float, z: float) -> None:
|
||||||
|
sp = self._position_sp(x, y, z)
|
||||||
|
for _ in range(self.pre_stream_count):
|
||||||
|
if rospy.is_shutdown():
|
||||||
|
return
|
||||||
|
self._publish_sp(sp)
|
||||||
|
self.rate.sleep()
|
||||||
|
|
||||||
|
def _refresh_mode_arm_proxies(self) -> None:
|
||||||
|
"""MAVROS 重启或链路抖事后,旧 ServiceProxy 可能一直报 unavailable,需重建。"""
|
||||||
|
try:
|
||||||
|
rospy.wait_for_service(self._set_mode_name, timeout=2.0)
|
||||||
|
rospy.wait_for_service(self._arming_name, timeout=2.0)
|
||||||
|
except rospy.ROSException:
|
||||||
|
return
|
||||||
|
self.mode_srv = rospy.ServiceProxy(self._set_mode_name, SetMode)
|
||||||
|
self.arm_srv = rospy.ServiceProxy(self._arming_name, CommandBool)
|
||||||
|
|
||||||
|
def _try_set_mode_arm_offboard(self) -> None:
|
||||||
|
try:
|
||||||
|
if self.state.mode != "OFFBOARD":
|
||||||
|
self.mode_srv(base_mode=0, custom_mode="OFFBOARD")
|
||||||
|
if not self.state.armed:
|
||||||
|
self.arm_srv(True)
|
||||||
|
except (rospy.ServiceException, rospy.ROSException) as exc:
|
||||||
|
rospy.logwarn_throttle(2.0, "set_mode/arm: %s", exc)
|
||||||
|
|
||||||
|
def _current_xyz(self) -> Tuple[float, float, float]:
|
||||||
|
p = self.pose.pose.position
|
||||||
|
return (p.x, p.y, p.z)
|
||||||
|
|
||||||
|
def _do_takeoff(self, step: ActionTakeoff) -> None:
|
||||||
|
alt = step.args.relative_altitude_m
|
||||||
|
dz = float(alt) if alt is not None else float(self.default_takeoff_relative_m)
|
||||||
|
self._wait_connected_and_pose()
|
||||||
|
x0, y0, z0 = self._current_xyz()
|
||||||
|
# 与 px4_ctrl_offboard_demo.py 一致:z_tgt = z0 + dz
|
||||||
|
z_tgt = z0 + dz
|
||||||
|
rospy.loginfo(
|
||||||
|
"takeoff: z0=%.2f Δ=%.2f m -> z_target=%.2f(与 demo 相同约定)",
|
||||||
|
z0,
|
||||||
|
dz,
|
||||||
|
z_tgt,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._stream_init_setpoint(x0, y0, z0)
|
||||||
|
t0 = rospy.Time.now()
|
||||||
|
deadline = t0 + rospy.Duration(self.takeoff_timeout_sec)
|
||||||
|
while not rospy.is_shutdown() and rospy.Time.now() < deadline:
|
||||||
|
self._try_set_mode_arm_offboard()
|
||||||
|
sp = self._position_sp(x0, y0, z_tgt)
|
||||||
|
self._publish_sp(sp)
|
||||||
|
z_now = self.pose.pose.position.z
|
||||||
|
if (
|
||||||
|
abs(z_now - z_tgt) < 0.08
|
||||||
|
and self.state.mode == "OFFBOARD"
|
||||||
|
and self.state.armed
|
||||||
|
):
|
||||||
|
rospy.loginfo("takeoff reached")
|
||||||
|
break
|
||||||
|
self.rate.sleep()
|
||||||
|
else:
|
||||||
|
rospy.logwarn("takeoff timeout,继续后续步骤")
|
||||||
|
|
||||||
|
def _do_hold_position(self, _label: str = "hover") -> None:
|
||||||
|
x, y, z = self._current_xyz()
|
||||||
|
rospy.loginfo("%s: 保持当前点 (%.2f, %.2f, %.2f) NED", _label, x, y, z)
|
||||||
|
sp = self._position_sp(x, y, z)
|
||||||
|
t_end = rospy.Time.now() + rospy.Duration(1.0)
|
||||||
|
while not rospy.is_shutdown() and rospy.Time.now() < t_end:
|
||||||
|
self._try_set_mode_arm_offboard()
|
||||||
|
self._publish_sp(sp)
|
||||||
|
self.rate.sleep()
|
||||||
|
|
||||||
|
def _do_wait(self, step: ActionWait) -> None:
|
||||||
|
sec = float(step.args.seconds)
|
||||||
|
rospy.loginfo("wait %.2f s", sec)
|
||||||
|
t_end = rospy.Time.now() + rospy.Duration(sec)
|
||||||
|
x, y, z = self._current_xyz()
|
||||||
|
sp = self._position_sp(x, y, z)
|
||||||
|
while not rospy.is_shutdown() and rospy.Time.now() < t_end:
|
||||||
|
# Offboard 下建议保持 stream
|
||||||
|
if self.state.mode == "OFFBOARD" and self.state.armed:
|
||||||
|
self._publish_sp(sp)
|
||||||
|
self.rate.sleep()
|
||||||
|
|
||||||
|
def _ned_delta_from_goto(self, step: ActionGoto) -> Optional[Tuple[float, float, float]]:
|
||||||
|
a = step.args
|
||||||
|
dx = 0.0 if a.x is None else float(a.x)
|
||||||
|
dy = 0.0 if a.y is None else float(a.y)
|
||||||
|
dz = 0.0 if a.z is None else float(a.z)
|
||||||
|
if a.frame == "local_ned":
|
||||||
|
return (dx, dy, dz)
|
||||||
|
# body_ned: x前 y右 z下 → 转到 NED 水平增量
|
||||||
|
q = self.pose.pose.orientation
|
||||||
|
yaw = _yaw_from_quaternion(q.x, q.y, q.z, q.w)
|
||||||
|
north = math.cos(yaw) * dx - math.sin(yaw) * dy
|
||||||
|
east = math.sin(yaw) * dx + math.cos(yaw) * dy
|
||||||
|
return (north, east, dz)
|
||||||
|
|
||||||
|
def _do_goto(self, step: ActionGoto) -> None:
|
||||||
|
delta = self._ned_delta_from_goto(step)
|
||||||
|
if delta is None:
|
||||||
|
rospy.logwarn("goto: unsupported frame")
|
||||||
|
return
|
||||||
|
dn, de, dd = delta
|
||||||
|
if dn == 0.0 and de == 0.0 and dd == 0.0:
|
||||||
|
rospy.loginfo("goto: 零位移,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
x0, y0, z0 = self._current_xyz()
|
||||||
|
xt, yt, zt = x0 + dn, y0 + de, z0 + dd
|
||||||
|
rospy.loginfo(
|
||||||
|
"goto: (%.2f,%.2f,%.2f) -> (%.2f,%.2f,%.2f) NED",
|
||||||
|
x0,
|
||||||
|
y0,
|
||||||
|
z0,
|
||||||
|
xt,
|
||||||
|
yt,
|
||||||
|
zt,
|
||||||
|
)
|
||||||
|
deadline = rospy.Time.now() + rospy.Duration(self.goto_timeout_sec)
|
||||||
|
while not rospy.is_shutdown() and rospy.Time.now() < deadline:
|
||||||
|
self._try_set_mode_arm_offboard()
|
||||||
|
sp = self._position_sp(xt, yt, zt)
|
||||||
|
self._publish_sp(sp)
|
||||||
|
px, py, pz = self._current_xyz()
|
||||||
|
err = math.sqrt(
|
||||||
|
(px - xt) ** 2 + (py - yt) ** 2 + (pz - zt) ** 2
|
||||||
|
)
|
||||||
|
if err < self.goto_tol and self.state.mode == "OFFBOARD":
|
||||||
|
rospy.loginfo("goto: reached (err=%.3f)", err)
|
||||||
|
break
|
||||||
|
self.rate.sleep()
|
||||||
|
else:
|
||||||
|
rospy.logwarn("goto: timeout")
|
||||||
|
|
||||||
|
def _do_land(self) -> None:
|
||||||
|
rospy.loginfo("land: AUTO.LAND")
|
||||||
|
t0 = rospy.Time.now()
|
||||||
|
deadline = t0 + rospy.Duration(self.land_timeout_sec)
|
||||||
|
fails = 0
|
||||||
|
while not rospy.is_shutdown() and rospy.Time.now() < deadline:
|
||||||
|
x, y, z = self._current_xyz()
|
||||||
|
# OFFBOARD 时若停发 setpoint,PX4 会很快退出该模式,MAVROS 侧 set_mode 可能长期不可用
|
||||||
|
if self.state.armed and self.state.mode == "OFFBOARD":
|
||||||
|
self._publish_sp(self._position_sp(x, y, z))
|
||||||
|
try:
|
||||||
|
if self.state.mode != "AUTO.LAND":
|
||||||
|
self.mode_srv(base_mode=0, custom_mode="AUTO.LAND")
|
||||||
|
except (rospy.ServiceException, rospy.ROSException) as exc:
|
||||||
|
fails += 1
|
||||||
|
rospy.logwarn_throttle(2.0, "AUTO.LAND: %s", exc)
|
||||||
|
if fails == 1 or fails % 15 == 0:
|
||||||
|
self._refresh_mode_arm_proxies()
|
||||||
|
if not self.state.armed:
|
||||||
|
rospy.loginfo("land: disarmed")
|
||||||
|
return
|
||||||
|
self.rate.sleep()
|
||||||
|
rospy.logwarn("land: timeout")
|
||||||
|
|
||||||
|
def _do_rtl(self) -> None:
|
||||||
|
rospy.loginfo("return_home: AUTO.RTL")
|
||||||
|
t0 = rospy.Time.now()
|
||||||
|
deadline = t0 + rospy.Duration(self.land_timeout_sec * 2)
|
||||||
|
fails = 0
|
||||||
|
while not rospy.is_shutdown() and rospy.Time.now() < deadline:
|
||||||
|
x, y, z = self._current_xyz()
|
||||||
|
if self.state.armed and self.state.mode == "OFFBOARD":
|
||||||
|
self._publish_sp(self._position_sp(x, y, z))
|
||||||
|
try:
|
||||||
|
if self.state.mode != "AUTO.RTL":
|
||||||
|
self.mode_srv(base_mode=0, custom_mode="AUTO.RTL")
|
||||||
|
except (rospy.ServiceException, rospy.ROSException) as exc:
|
||||||
|
fails += 1
|
||||||
|
rospy.logwarn_throttle(2.0, "RTL: %s", exc)
|
||||||
|
if fails == 1 or fails % 15 == 0:
|
||||||
|
self._refresh_mode_arm_proxies()
|
||||||
|
if not self.state.armed:
|
||||||
|
rospy.loginfo("RTL Finished(已 disarm)")
|
||||||
|
return
|
||||||
|
self.rate.sleep()
|
||||||
|
rospy.logwarn("rtl: timeout")
|
||||||
|
|
||||||
|
def execute(self, intent: ValidatedFlightIntent) -> None:
|
||||||
|
rospy.loginfo(
|
||||||
|
"执行 flight_intent:steps=%d summary=%s",
|
||||||
|
len(intent.actions),
|
||||||
|
intent.summary[:80],
|
||||||
|
)
|
||||||
|
self._wait_connected_and_pose()
|
||||||
|
|
||||||
|
for i, act in enumerate(intent.actions):
|
||||||
|
if rospy.is_shutdown():
|
||||||
|
break
|
||||||
|
rospy.loginfo("--- step %d/%d: %s", i + 1, len(intent.actions), type(act).__name__)
|
||||||
|
self._dispatch(act)
|
||||||
|
|
||||||
|
rospy.loginfo("flight_intent 序列结束")
|
||||||
|
|
||||||
|
def _dispatch(self, act: FlightAction) -> None:
|
||||||
|
if isinstance(act, ActionTakeoff):
|
||||||
|
self._do_takeoff(act)
|
||||||
|
elif isinstance(act, (ActionHover, ActionHold)):
|
||||||
|
self._do_hold_position("hover" if isinstance(act, ActionHover) else "hold")
|
||||||
|
elif isinstance(act, ActionWait):
|
||||||
|
self._do_wait(act)
|
||||||
|
elif isinstance(act, ActionGoto):
|
||||||
|
self._do_goto(act)
|
||||||
|
elif isinstance(act, ActionLand):
|
||||||
|
self._do_land()
|
||||||
|
elif isinstance(act, ActionReturnHome):
|
||||||
|
self._do_rtl()
|
||||||
|
else:
|
||||||
|
rospy.logwarn("未支持的动作: %r", act)
|
||||||
94
voice_drone/flight_bridge/ros1_node.py
Normal file
94
voice_drone/flight_bridge/ros1_node.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
ROS1 伴飞桥节点:订阅 JSON(std_msgs/String),校验 flight_intent v1 后调用 MAVROS 执行。
|
||||||
|
|
||||||
|
话题(默认):
|
||||||
|
全局 /input(std_msgs/String),与 rostopic pub、语音端 ROCKET_FLIGHT_BRIDGE_TOPIC 一致。
|
||||||
|
可通过私有参数 ~input_topic 覆盖(须带前导 / 才是全局名)。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
rostopic pub -1 /input std_msgs/String \\
|
||||||
|
'{data: "{\"is_flight_intent\":true,\"version\":1,\"actions\":[{\"type\":\"land\",\"args\":{}}],\"summary\":\"降\"}"}'
|
||||||
|
|
||||||
|
前提是:已 roslaunch mavros px4.launch …,且 /mavros/state connected。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
from voice_drone.core.flight_intent import parse_flight_intent_dict
|
||||||
|
from voice_drone.flight_bridge.ros1_mavros_executor import MavrosFlightExecutor
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_flight_intent_dict(raw: dict) -> dict:
|
||||||
|
"""允许仅传 {actions, summary?},补全顶层字段。"""
|
||||||
|
if raw.get("is_flight_intent") is True and raw.get("version") == 1:
|
||||||
|
return raw
|
||||||
|
actions = raw.get("actions")
|
||||||
|
if isinstance(actions, list) and actions:
|
||||||
|
summary = str(raw.get("summary") or "bridge").strip() or "bridge"
|
||||||
|
return {
|
||||||
|
"is_flight_intent": True,
|
||||||
|
"version": 1,
|
||||||
|
"actions": actions,
|
||||||
|
"summary": summary,
|
||||||
|
}
|
||||||
|
raise ValueError("JSON 须为完整 flight_intent 或含 actions 数组")
|
||||||
|
|
||||||
|
|
||||||
|
class FlightIntentBridgeNode:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._exec = MavrosFlightExecutor()
|
||||||
|
self._busy = threading.Lock()
|
||||||
|
# 默认用绝对名 /input:若用相对名 "input",anonymous 节点下会变成 /flight_intent_mavros_bridge_*/input,与 rostopic pub /input 不一致。
|
||||||
|
topic = rospy.get_param("~input_topic", "/input")
|
||||||
|
self._sub = rospy.Subscriber(
|
||||||
|
topic,
|
||||||
|
String,
|
||||||
|
self._on_input,
|
||||||
|
queue_size=1,
|
||||||
|
)
|
||||||
|
rospy.loginfo("flight_intent_bridge 就绪:订阅 %s", topic)
|
||||||
|
|
||||||
|
def _on_input(self, msg: String) -> None:
|
||||||
|
data = (msg.data or "").strip()
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
if not self._busy.acquire(blocking=False):
|
||||||
|
rospy.logwarn("上一段 flight_intent 仍在执行,忽略本条")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _run() -> None:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
raw = json.loads(data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
rospy.logerr("JSON 解析失败: %s", e)
|
||||||
|
return
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
rospy.logerr("顶层须为 JSON object")
|
||||||
|
return
|
||||||
|
raw = _coerce_flight_intent_dict(raw)
|
||||||
|
parsed, errors = parse_flight_intent_dict(raw)
|
||||||
|
if errors or parsed is None:
|
||||||
|
rospy.logerr("flight_intent 校验失败: %s", errors)
|
||||||
|
return
|
||||||
|
self._exec.execute(parsed)
|
||||||
|
finally:
|
||||||
|
self._busy.release()
|
||||||
|
|
||||||
|
threading.Thread(target=_run, daemon=True, name="flight-intent-exec").start()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
FlightIntentBridgeNode()
|
||||||
|
rospy.spin()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
14
voice_drone/logging_/__init__.py
Normal file
14
voice_drone/logging_/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
日志系统入口
|
||||||
|
|
||||||
|
提供 get_logger 接口,返回带颜色控制台输出的 logger。
|
||||||
|
|
||||||
|
注意:文件夹名已改为 logging_ 以避免与标准库的 logging 模块冲突。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 由于文件夹名已改为 logging_,不再与标准库的 logging 冲突
|
||||||
|
# 直接导入标准库的 logging 模块
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 现在可以安全导入 color_logger
|
||||||
|
from .color_logger import get_logger
|
||||||
107
voice_drone/logging_/color_logger.py
Normal file
107
voice_drone/logging_/color_logger.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# 直接导入标准库的 logging(不再有命名冲突)
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取系统日志配置: level / debug 等
|
||||||
|
from voice_drone.core.configuration import SYSTEM_LOGGING_CONFIG
|
||||||
|
except Exception:
|
||||||
|
SYSTEM_LOGGING_CONFIG = {"level": "INFO", "debug": False}
|
||||||
|
|
||||||
|
|
||||||
|
class ColorFormatter(logging.Formatter):
|
||||||
|
"""简单彩色日志格式化器."""
|
||||||
|
|
||||||
|
COLORS = {
|
||||||
|
"DEBUG": "\033[36m", # 青色
|
||||||
|
"INFO": "\033[32m", # 绿色
|
||||||
|
"WARNING": "\033[33m", # 黄色
|
||||||
|
"ERROR": "\033[31m", # 红色
|
||||||
|
"CRITICAL": "\033[41m", # 红底
|
||||||
|
}
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
level = record.levelname
|
||||||
|
color = self.COLORS.get(level, "")
|
||||||
|
msg = super().format(record)
|
||||||
|
return f"{color}{msg}{self.RESET}" if color else msg
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_level(config_level: Optional[str], debug_flag: bool) -> int:
|
||||||
|
"""根据配置字符串和 debug 标志,解析出 logging 等级."""
|
||||||
|
if config_level is None:
|
||||||
|
config_level = "INFO"
|
||||||
|
level_str = str(config_level).upper()
|
||||||
|
|
||||||
|
# 特殊值: 关闭日志
|
||||||
|
if level_str in ("OFF", "NONE", "DISABLE", "DISABLED"):
|
||||||
|
return logging.CRITICAL + 1 # 实际等同于全关
|
||||||
|
|
||||||
|
if debug_flag:
|
||||||
|
return logging.DEBUG
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"DEBUG": logging.DEBUG,
|
||||||
|
"INFO": logging.INFO,
|
||||||
|
"WARNING": logging.WARNING,
|
||||||
|
"WARN": logging.WARNING,
|
||||||
|
"ERROR": logging.ERROR,
|
||||||
|
"CRITICAL": logging.CRITICAL,
|
||||||
|
}
|
||||||
|
return mapping.get(level_str, logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str = "app") -> logging.Logger:
|
||||||
|
"""
|
||||||
|
获取一个带颜色控制台输出的 logger。
|
||||||
|
|
||||||
|
- name: logger 名称,按模块/子系统区分即可,例如 "stt.onnx"
|
||||||
|
日志级别与启用/禁用由配置文件 `system.yaml` 决定:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
logging:
|
||||||
|
level: "INFO" # DEBUG/INFO/WARNING/ERROR/CRITICAL/OFF
|
||||||
|
debug: false # true 时强制 DEBUG
|
||||||
|
```
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from voice_drone.logging_ import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("stt.onnx")
|
||||||
|
logger.info("模型加载完成")
|
||||||
|
logger.warning("预热失败,将继续执行")
|
||||||
|
logger.error("推理出错")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
if logger.handlers:
|
||||||
|
# 已经配置过,直接复用,但更新 level
|
||||||
|
level = _resolve_level(
|
||||||
|
SYSTEM_LOGGING_CONFIG.get("level") if isinstance(SYSTEM_LOGGING_CONFIG, dict) else None,
|
||||||
|
SYSTEM_LOGGING_CONFIG.get("debug", False) if isinstance(SYSTEM_LOGGING_CONFIG, dict) else False,
|
||||||
|
)
|
||||||
|
logger.setLevel(level)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# 解析配置的日志级别
|
||||||
|
if isinstance(SYSTEM_LOGGING_CONFIG, dict):
|
||||||
|
level = _resolve_level(SYSTEM_LOGGING_CONFIG.get("level"), SYSTEM_LOGGING_CONFIG.get("debug", False))
|
||||||
|
else:
|
||||||
|
level = logging.INFO
|
||||||
|
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
fmt = "[%(asctime)s] [%(levelname)s] %(message)s"
|
||||||
|
datefmt = "%H:%M:%S"
|
||||||
|
handler.setFormatter(ColorFormatter(fmt=fmt, datefmt=datefmt))
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
# 如果 level 被解析为 "关闭",则仍然返回 logger,但不会输出普通日志
|
||||||
|
return logger
|
||||||
|
|
||||||
2271
voice_drone/main_app.py
Normal file
2271
voice_drone/main_app.py
Normal file
File diff suppressed because it is too large
Load Diff
1
voice_drone/tools/__init__.py
Normal file
1
voice_drone/tools/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""One-off helpers (CLI tools)."""
|
||||||
15
voice_drone/tools/config_loader.py
Normal file
15
voice_drone/tools/config_loader.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import yaml
|
||||||
|
|
||||||
|
def load_config(config_path):
|
||||||
|
"""
|
||||||
|
加载配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path (str): 配置文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 返回解析后的配置字典
|
||||||
|
"""
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
return config
|
||||||
66
voice_drone/tools/publish_flight_intent_ros_once.py
Normal file
66
voice_drone/tools/publish_flight_intent_ros_once.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Publish one flight_intent JSON to a ROS1 std_msgs/String topic (for flight bridge).
|
||||||
|
|
||||||
|
Run after sourcing ROS, e.g.:
|
||||||
|
source /opt/ros/noetic/setup.bash && python3 -m voice_drone.tools.publish_flight_intent_ros_once /tmp/intent.json
|
||||||
|
|
||||||
|
Or from repo root(须在已 source ROS 的 shell 中,且勿覆盖 ROS 的 PYTHONPATH;应 prepend):
|
||||||
|
export PYTHONPATH="$PWD:$PYTHONPATH" && python3 -m voice_drone.tools.publish_flight_intent_ros_once ...
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
|
||||||
|
def _load_payload(path: str | None) -> dict[str, Any]:
|
||||||
|
if path and path != "-":
|
||||||
|
raw = Path(path).read_text(encoding="utf-8")
|
||||||
|
else:
|
||||||
|
raw = sys.stdin.read()
|
||||||
|
data = json.loads(raw)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("root must be a JSON object")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(description="Publish flight_intent JSON once to ROS String topic")
|
||||||
|
ap.add_argument(
|
||||||
|
"json_file",
|
||||||
|
nargs="?",
|
||||||
|
default="-",
|
||||||
|
help="Path to JSON file, or - for stdin (default: stdin)",
|
||||||
|
)
|
||||||
|
ap.add_argument("--topic", default="/input", help="ROS topic (default: /input)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--wait-subscribers",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Seconds to wait for at least one subscriber (default: 0)",
|
||||||
|
)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
payload = _load_payload(args.json_file if args.json_file != "-" else None)
|
||||||
|
json_str = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||||||
|
|
||||||
|
rospy.init_node("flight_intent_ros_publisher_once", anonymous=True)
|
||||||
|
pub = rospy.Publisher(args.topic, String, queue_size=1, latch=False)
|
||||||
|
deadline = rospy.Time.now() + rospy.Duration(args.wait_subscribers)
|
||||||
|
while args.wait_subscribers > 0 and pub.get_num_connections() < 1 and not rospy.is_shutdown():
|
||||||
|
if rospy.Time.now() > deadline:
|
||||||
|
break
|
||||||
|
rospy.sleep(0.05)
|
||||||
|
rospy.sleep(0.15)
|
||||||
|
pub.publish(String(data=json_str))
|
||||||
|
rospy.sleep(0.25)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
17
voice_drone/tools/wrapper.py
Normal file
17
voice_drone/tools/wrapper.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
def time_cost(tag=None):
|
||||||
|
"""
|
||||||
|
装饰器,自定义函数耗时打印信息
|
||||||
|
:param tag: 可选,自定义标志(说明当前执行的函数)
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
start_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
end_time = time.time()
|
||||||
|
label = tag if tag is not None else func.__name__
|
||||||
|
print(f"耗时统计[{label}]: {(end_time - start_time)*1000:.2f} 毫秒")
|
||||||
|
return result
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
20
with_system_alsa.sh
Normal file
20
with_system_alsa.sh
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Conda 里 PyAudio 自带的 libportaudio 会加载同目录 libasound,从而去 conda/.../lib/alsa-lib/
|
||||||
|
# 找插件(多数环境不完整,终端刷屏且麦克风电平异常)。
|
||||||
|
# 在启动 Python **之前** 预加载系统的 libasound.so.2 可恢复正常 ALSA 行为。
|
||||||
|
#
|
||||||
|
# 用法(项目根目录):
|
||||||
|
# bash with_system_alsa.sh python src/mic_level_check.py
|
||||||
|
# bash with_system_alsa.sh python src/rocket_drone_audio.py
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
ARCH="$(uname -m)"
|
||||||
|
case "${ARCH}" in
|
||||||
|
aarch64) ASO="/usr/lib/aarch64-linux-gnu/libasound.so.2" ;;
|
||||||
|
x86_64) ASO="/usr/lib/x86_64-linux-gnu/libasound.so.2" ;;
|
||||||
|
*) ASO="" ;;
|
||||||
|
esac
|
||||||
|
if [[ -n "${ASO}" && -f "${ASO}" ]]; then
|
||||||
|
export LD_PRELOAD="${ASO}${LD_PRELOAD:+:${LD_PRELOAD}}"
|
||||||
|
fi
|
||||||
|
exec "$@"
|
||||||
Loading…
x
Reference in New Issue
Block a user