""" 阿里云百炼 LLM 服务实现 """ from __future__ import annotations import asyncio import copy import queue import threading import time from http import HTTPStatus from typing import List, Dict, Any, Tuple, AsyncIterator, Optional from loguru import logger from dashscope import Generation import dashscope from app.services.llm_service import LLMServiceInterface, build_system_prompt from app.services.tool_definitions import LLM_AGENT_TOOLS from app.services.tool_executor import dispatch_tool from app.config import settings def _stream_delta_from_response(response) -> str: """从 DashScope 流式响应块解析增量文本(incremental_output=True)。""" try: if response.status_code != HTTPStatus.OK: return "" choices = response.output.choices if not choices: return "" c0 = choices[0] if isinstance(c0, dict): msg = c0.get("message") or {} return msg.get("content") or "" msg = getattr(c0, "message", None) if msg is None: return "" if isinstance(msg, dict): return msg.get("content") or "" return getattr(msg, "content", None) or "" except (AttributeError, IndexError, KeyError, TypeError): return "" def _function_payload(fn_obj: Any) -> tuple[str, str]: if isinstance(fn_obj, dict): return ( str(fn_obj.get("name") or ""), str(fn_obj.get("arguments") if fn_obj.get("arguments") is not None else "{}"), ) name = getattr(fn_obj, "name", None) or "" args = getattr(fn_obj, "arguments", None) if args is None: args = "{}" return str(name), str(args) def _normalize_tool_call(tc: Any) -> dict[str, Any]: if isinstance(tc, dict): fn_raw = tc.get("function") name, arguments = _function_payload(fn_raw) return { "id": tc.get("id"), "type": tc.get("type") or "function", "function": {"name": name, "arguments": arguments}, } fn = getattr(tc, "function", None) name, arguments = _function_payload(fn) return { "id": getattr(tc, "id", None), "type": getattr(tc, "type", None) or "function", "function": {"name": name, "arguments": arguments}, } def _assistant_message_to_dict(msg: Any) -> dict[str, Any]: """将本轮 assistant 消息转为 API 可接受的 dict(含 tool_calls)。""" if isinstance(msg, dict): content = msg.get("content") tcs = msg.get("tool_calls") else: content = getattr(msg, "content", None) tcs = getattr(msg, "tool_calls", None) out: dict[str, Any] = {"role": "assistant", "content": content} if tcs: out["tool_calls"] = [_normalize_tool_call(x) for x in tcs] return out def _run_tool_agent_sync( messages: List[Dict[str, Any]], session_id: str, turn_id: str, ) -> str: """同步:多步 Generation + 执行工具,直至模型输出无 tool_calls。""" msgs: List[Dict[str, Any]] = copy.deepcopy(messages) max_steps = max(1, int(settings.LLM_AGENT_MAX_STEPS)) cap = max(500, int(settings.LLM_TOOL_RESULT_MAX_CHARS)) for step in range(max_steps): logger.debug( f"[LLM/tools] step={step} session={session_id[:8]} turn={turn_id[:8]}" ) resp = Generation.call( model=settings.LLM_MODEL, messages=msgs, tools=LLM_AGENT_TOOLS, tool_choice="auto", result_format="message", max_tokens=settings.LLM_MAX_TOKENS, temperature=settings.LLM_TEMPERATURE, timeout=settings.LLM_TIMEOUT, ) if resp.status_code != HTTPStatus.OK: err = getattr(resp, "message", None) or str(resp) logger.error(f"[LLM/tools] API error: {err}") return f"模型调用失败:{err}" msg = resp.output.choices[0].message raw_assistant = _assistant_message_to_dict(msg) tool_calls = raw_assistant.get("tool_calls") if not tool_calls: text = raw_assistant.get("content") out = (text or "").strip() if text is not None else "" logger.info( f"[LLM/tools] done step={step} reply_chars={len(out)} " f"session={session_id[:8]} turn={turn_id[:8]}" ) return text or "" msgs.append(raw_assistant) for tc in tool_calls: fn = tc.get("function") or {} name = fn.get("name", "") if isinstance(fn, dict) else "" raw_args = fn.get("arguments", "{}") if isinstance(fn, dict) else "{}" tid = str(tc.get("id") or "") logger.info( f"[LLM/tools] call {name!r} tool_call_id={tid[:20]!r} " f"session={session_id[:8]}" ) try: result = dispatch_tool(name, raw_args) except Exception as e: logger.exception("[LLM/tools] dispatch error") result = f"工具执行异常:{e}" if len(result) > cap: result = result[:cap] + "\n…(截断)" msgs.append( { "role": "tool", "content": result, "tool_call_id": tid, "name": name, } ) logger.warning("[LLM/tools] max steps exceeded") return "步骤过多未能完成,请把问题说简短一些或稍后再试。" class DashScopeLLMService(LLMServiceInterface): """阿里云百炼 LLM 服务实现""" def __init__(self): self._initialized = False async def initialize(self) -> bool: """初始化 DashScope SDK""" try: dashscope.api_key = settings.DASHSCOPE_API_KEY logger.info(f"DashScope LLM 初始化成功, 模型: {settings.LLM_MODEL}") self._initialized = True return True except Exception as e: logger.error(f"DashScope LLM 初始化失败: {e}") return False async def chat( self, messages: List[Dict[str, str]], session_id: str = "", turn_id: str = "", ) -> Tuple[str, float]: """ 调用阿里云百炼 LLM(非流式,兼容旧逻辑与测试) """ if not self._initialized: raise RuntimeError("LLM 服务未初始化") t0 = time.time() try: if settings.LLM_TOOLS_ENABLED: text = await asyncio.to_thread( _run_tool_agent_sync, copy.deepcopy(messages), session_id, turn_id, ) elapsed = time.time() - t0 logger.debug( f"[LLM] session={session_id[:8]} turn={turn_id[:8]} " f"推理成功( tools ), 耗时={elapsed:.2f}s, 内容={text[:80]}..." ) return text, elapsed response = Generation.call( model=settings.LLM_MODEL, messages=messages, result_format="message", max_tokens=settings.LLM_MAX_TOKENS, temperature=settings.LLM_TEMPERATURE, timeout=settings.LLM_TIMEOUT, ) elapsed = time.time() - t0 if response.status_code == 200: msg = response.output.choices[0].message if isinstance(msg, dict): content = msg.get("content") or "" else: content = getattr(msg, "content", None) or "" logger.debug( f"[LLM] session={session_id[:8]} turn={turn_id[:8]} " f"推理成功, 耗时={elapsed:.2f}s, 内容={content[:80]}..." ) return content, elapsed else: logger.error( f"[LLM] 调用失败: code={response.status_code}, " f"message={response.message}" ) raise Exception(f"LLM API 错误: {response.message}") except Exception as e: elapsed = time.time() - t0 logger.error( f"[LLM] 异常: session={session_id[:8]} turn={turn_id[:8]} " f"错误={e}, 耗时={elapsed:.2f}s" ) raise async def chat_stream( self, messages: List[Dict[str, str]], session_id: str = "", turn_id: str = "", ) -> AsyncIterator[str]: """ DashScope 流式:stream=True + incremental_output=True。 在单独线程中迭代 SDK 同步生成器,经 asyncio.Queue 推到事件循环。 """ if not self._initialized: raise RuntimeError("LLM 服务未初始化") t0 = time.time() q: queue.Queue = queue.Queue() def worker(): try: stream = Generation.call( model=settings.LLM_MODEL, messages=messages, result_format="message", max_tokens=settings.LLM_MAX_TOKENS, temperature=settings.LLM_TEMPERATURE, timeout=settings.LLM_TIMEOUT, stream=True, incremental_output=True, ) for chunk in stream: if chunk.status_code != HTTPStatus.OK: err = getattr(chunk, "message", None) or str(chunk) logger.error( f"[LLM] 流式错误: code={chunk.status_code}, message={err}" ) q.put(("err", RuntimeError(f"LLM API 错误: {err}"))) return delta = _stream_delta_from_response(chunk) if delta: q.put(("delta", delta)) q.put(("done", None)) except Exception as e: logger.error(f"[LLM] 流式异常: {e}") q.put(("ex", e)) threading.Thread(target=worker, daemon=True).start() try: while True: kind, payload = await asyncio.to_thread(q.get) if kind == "delta": yield payload elif kind == "done": break elif kind == "err": raise payload elif kind == "ex": raise payload finally: elapsed = time.time() - t0 logger.debug( f"[LLM/stream] session={session_id[:8]} turn={turn_id[:8]} " f"结束, 总耗时={elapsed:.2f}s" ) async def chat_stream_with_tools( self, messages: List[Dict[str, Any]], session_id: str = "", turn_id: str = "", ) -> AsyncIterator[str]: """先跑 function-calling 闭环,再分块 yield,兼容现有 handler 流式 TTS。""" if not self._initialized: raise RuntimeError("LLM 服务未初始化") if not settings.LLM_TOOLS_ENABLED: async for chunk in self.chat_stream( messages, # type: ignore[arg-type] session_id=session_id, turn_id=turn_id, ): yield chunk return t0 = time.time() text = await asyncio.to_thread( _run_tool_agent_sync, messages, session_id, turn_id ) if text: n = len(text) if n <= 72: yield text else: step = max(24, n // 24) for i in range(0, n, step): yield text[i : i + step] logger.debug( f"[LLM/tools/stream] session={session_id[:8]} turn={turn_id[:8]} " f"总耗时={time.time() - t0:.2f}s" ) async def build_messages( self, user_text: str, history: List[Dict[str, str]] = None, px4: Optional[Dict[str, Any]] = None, *, enable_tools: bool = False, ) -> List[Dict[str, Any]]: messages: List[Dict[str, Any]] = [ { "role": "system", "content": build_system_prompt(px4, enable_tools=enable_tools), } ] if history: messages.extend(history[-settings.LLM_CONTEXT_TURNS * 2 :]) messages.append({"role": "user", "content": user_text}) return messages async def shutdown(self): """关闭服务""" self._initialized = False logger.info("DashScope LLM 服务已关闭")