96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
"""
|
||
单轮 Fun-ASR 实时识别(DashScope WebSocket),与百炼 LLM 共用 DASHSCOPE_API_KEY。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import threading
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Callable, Coroutine, List, Optional
|
||
|
||
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
|
||
|
||
from app.config import settings
|
||
|
||
|
||
@dataclass
|
||
class FunAsrTurnState:
|
||
"""线程内回调写入;stop() 之后在与 stop 同线程或持锁读取。"""
|
||
|
||
turn_id: str
|
||
final_texts: List[str] = field(default_factory=list)
|
||
last_text: str = ""
|
||
error: Optional[str] = None
|
||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||
|
||
|
||
@dataclass
|
||
class ActiveFunAsrTurn:
|
||
"""会话上一段进行中的 Fun-ASR。"""
|
||
|
||
turn_id: str
|
||
recognition: Recognition
|
||
state: FunAsrTurnState
|
||
|
||
|
||
def build_fun_asr_recognition(
|
||
*,
|
||
loop: asyncio.AbstractEventLoop,
|
||
state: FunAsrTurnState,
|
||
on_partial: Callable[[str, bool], Coroutine[Any, Any, None]],
|
||
on_error_msg: Callable[[str], Coroutine[Any, Any, None]],
|
||
) -> Recognition:
|
||
"""
|
||
创建 Recognition;start()/send_audio_frame()/stop() 由调用方负责(建议 stop 用 asyncio.to_thread)。
|
||
"""
|
||
|
||
def _schedule(coro: Coroutine[Any, Any, None]) -> None:
|
||
try:
|
||
asyncio.run_coroutine_threadsafe(coro, loop)
|
||
except RuntimeError:
|
||
pass
|
||
|
||
class _Cb(RecognitionCallback):
|
||
def on_event(self, result: RecognitionResult) -> None:
|
||
sentence = result.get_sentence()
|
||
if sentence is None:
|
||
return
|
||
if isinstance(sentence, list):
|
||
# 极少数 SDK 版本可能返回句段列表,取末段有 text 的 dict
|
||
picked = None
|
||
for seg in reversed(sentence):
|
||
if isinstance(seg, dict) and (seg.get("text") is not None):
|
||
picked = seg
|
||
break
|
||
sentence = picked
|
||
if sentence is None:
|
||
return
|
||
if not isinstance(sentence, dict):
|
||
return
|
||
if sentence.get("heartbeat") is True:
|
||
return
|
||
text = (sentence.get("text") or "").strip()
|
||
if not text and not RecognitionResult.is_sentence_end(sentence):
|
||
return
|
||
is_end = RecognitionResult.is_sentence_end(sentence)
|
||
with state.lock:
|
||
state.last_text = text
|
||
if is_end and text:
|
||
state.final_texts.append(text)
|
||
_schedule(on_partial(text, is_end))
|
||
|
||
def on_error(self, result: RecognitionResult) -> None:
|
||
msg = result.message or str(result) or "ASR error"
|
||
with state.lock:
|
||
state.error = msg
|
||
_schedule(on_error_msg(msg))
|
||
|
||
rec = Recognition(
|
||
settings.DASHSCOPE_ASR_MODEL,
|
||
_Cb(),
|
||
"pcm",
|
||
settings.ASR_AUDIO_SAMPLE_RATE,
|
||
semantic_punctuation_enabled=settings.ASR_SEMANTIC_PUNCTUATION_ENABLED,
|
||
)
|
||
return rec
|