2026-04-14 10:08:41 +08:00

372 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
阿里云百炼 LLM 服务实现
"""
from __future__ import annotations
import asyncio
import copy
import queue
import threading
import time
from http import HTTPStatus
from typing import List, Dict, Any, Tuple, AsyncIterator, Optional
from loguru import logger
from dashscope import Generation
import dashscope
from app.services.llm_service import LLMServiceInterface, build_system_prompt
from app.services.tool_definitions import LLM_AGENT_TOOLS
from app.services.tool_executor import dispatch_tool
from app.config import settings
def _stream_delta_from_response(response) -> str:
"""从 DashScope 流式响应块解析增量文本incremental_output=True"""
try:
if response.status_code != HTTPStatus.OK:
return ""
choices = response.output.choices
if not choices:
return ""
c0 = choices[0]
if isinstance(c0, dict):
msg = c0.get("message") or {}
return msg.get("content") or ""
msg = getattr(c0, "message", None)
if msg is None:
return ""
if isinstance(msg, dict):
return msg.get("content") or ""
return getattr(msg, "content", None) or ""
except (AttributeError, IndexError, KeyError, TypeError):
return ""
def _function_payload(fn_obj: Any) -> tuple[str, str]:
if isinstance(fn_obj, dict):
return (
str(fn_obj.get("name") or ""),
str(fn_obj.get("arguments") if fn_obj.get("arguments") is not None else "{}"),
)
name = getattr(fn_obj, "name", None) or ""
args = getattr(fn_obj, "arguments", None)
if args is None:
args = "{}"
return str(name), str(args)
def _normalize_tool_call(tc: Any) -> dict[str, Any]:
if isinstance(tc, dict):
fn_raw = tc.get("function")
name, arguments = _function_payload(fn_raw)
return {
"id": tc.get("id"),
"type": tc.get("type") or "function",
"function": {"name": name, "arguments": arguments},
}
fn = getattr(tc, "function", None)
name, arguments = _function_payload(fn)
return {
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", None) or "function",
"function": {"name": name, "arguments": arguments},
}
def _assistant_message_to_dict(msg: Any) -> dict[str, Any]:
"""将本轮 assistant 消息转为 API 可接受的 dict含 tool_calls"""
if isinstance(msg, dict):
content = msg.get("content")
tcs = msg.get("tool_calls")
else:
content = getattr(msg, "content", None)
tcs = getattr(msg, "tool_calls", None)
out: dict[str, Any] = {"role": "assistant", "content": content}
if tcs:
out["tool_calls"] = [_normalize_tool_call(x) for x in tcs]
return out
def _run_tool_agent_sync(
messages: List[Dict[str, Any]],
session_id: str,
turn_id: str,
) -> str:
"""同步:多步 Generation + 执行工具,直至模型输出无 tool_calls。"""
msgs: List[Dict[str, Any]] = copy.deepcopy(messages)
max_steps = max(1, int(settings.LLM_AGENT_MAX_STEPS))
cap = max(500, int(settings.LLM_TOOL_RESULT_MAX_CHARS))
for step in range(max_steps):
logger.debug(
f"[LLM/tools] step={step} session={session_id[:8]} turn={turn_id[:8]}"
)
resp = Generation.call(
model=settings.LLM_MODEL,
messages=msgs,
tools=LLM_AGENT_TOOLS,
tool_choice="auto",
result_format="message",
max_tokens=settings.LLM_MAX_TOKENS,
temperature=settings.LLM_TEMPERATURE,
timeout=settings.LLM_TIMEOUT,
)
if resp.status_code != HTTPStatus.OK:
err = getattr(resp, "message", None) or str(resp)
logger.error(f"[LLM/tools] API error: {err}")
return f"模型调用失败:{err}"
msg = resp.output.choices[0].message
raw_assistant = _assistant_message_to_dict(msg)
tool_calls = raw_assistant.get("tool_calls")
if not tool_calls:
text = raw_assistant.get("content")
out = (text or "").strip() if text is not None else ""
logger.info(
f"[LLM/tools] done step={step} reply_chars={len(out)} "
f"session={session_id[:8]} turn={turn_id[:8]}"
)
return text or ""
msgs.append(raw_assistant)
for tc in tool_calls:
fn = tc.get("function") or {}
name = fn.get("name", "") if isinstance(fn, dict) else ""
raw_args = fn.get("arguments", "{}") if isinstance(fn, dict) else "{}"
tid = str(tc.get("id") or "")
logger.info(
f"[LLM/tools] call {name!r} tool_call_id={tid[:20]!r} "
f"session={session_id[:8]}"
)
try:
result = dispatch_tool(name, raw_args)
except Exception as e:
logger.exception("[LLM/tools] dispatch error")
result = f"工具执行异常:{e}"
if len(result) > cap:
result = result[:cap] + "\n…(截断)"
msgs.append(
{
"role": "tool",
"content": result,
"tool_call_id": tid,
"name": name,
}
)
logger.warning("[LLM/tools] max steps exceeded")
return "步骤过多未能完成,请把问题说简短一些或稍后再试。"
class DashScopeLLMService(LLMServiceInterface):
"""阿里云百炼 LLM 服务实现"""
def __init__(self):
self._initialized = False
async def initialize(self) -> bool:
"""初始化 DashScope SDK"""
try:
dashscope.api_key = settings.DASHSCOPE_API_KEY
logger.info(f"DashScope LLM 初始化成功, 模型: {settings.LLM_MODEL}")
self._initialized = True
return True
except Exception as e:
logger.error(f"DashScope LLM 初始化失败: {e}")
return False
async def chat(
self,
messages: List[Dict[str, str]],
session_id: str = "",
turn_id: str = "",
) -> Tuple[str, float]:
"""
调用阿里云百炼 LLM非流式兼容旧逻辑与测试
"""
if not self._initialized:
raise RuntimeError("LLM 服务未初始化")
t0 = time.time()
try:
if settings.LLM_TOOLS_ENABLED:
text = await asyncio.to_thread(
_run_tool_agent_sync,
copy.deepcopy(messages),
session_id,
turn_id,
)
elapsed = time.time() - t0
logger.debug(
f"[LLM] session={session_id[:8]} turn={turn_id[:8]} "
f"推理成功( tools ), 耗时={elapsed:.2f}s, 内容={text[:80]}..."
)
return text, elapsed
response = Generation.call(
model=settings.LLM_MODEL,
messages=messages,
result_format="message",
max_tokens=settings.LLM_MAX_TOKENS,
temperature=settings.LLM_TEMPERATURE,
timeout=settings.LLM_TIMEOUT,
)
elapsed = time.time() - t0
if response.status_code == 200:
msg = response.output.choices[0].message
if isinstance(msg, dict):
content = msg.get("content") or ""
else:
content = getattr(msg, "content", None) or ""
logger.debug(
f"[LLM] session={session_id[:8]} turn={turn_id[:8]} "
f"推理成功, 耗时={elapsed:.2f}s, 内容={content[:80]}..."
)
return content, elapsed
else:
logger.error(
f"[LLM] 调用失败: code={response.status_code}, "
f"message={response.message}"
)
raise Exception(f"LLM API 错误: {response.message}")
except Exception as e:
elapsed = time.time() - t0
logger.error(
f"[LLM] 异常: session={session_id[:8]} turn={turn_id[:8]} "
f"错误={e}, 耗时={elapsed:.2f}s"
)
raise
async def chat_stream(
self,
messages: List[Dict[str, str]],
session_id: str = "",
turn_id: str = "",
) -> AsyncIterator[str]:
"""
DashScope 流式stream=True + incremental_output=True。
在单独线程中迭代 SDK 同步生成器,经 asyncio.Queue 推到事件循环。
"""
if not self._initialized:
raise RuntimeError("LLM 服务未初始化")
t0 = time.time()
q: queue.Queue = queue.Queue()
def worker():
try:
stream = Generation.call(
model=settings.LLM_MODEL,
messages=messages,
result_format="message",
max_tokens=settings.LLM_MAX_TOKENS,
temperature=settings.LLM_TEMPERATURE,
timeout=settings.LLM_TIMEOUT,
stream=True,
incremental_output=True,
)
for chunk in stream:
if chunk.status_code != HTTPStatus.OK:
err = getattr(chunk, "message", None) or str(chunk)
logger.error(
f"[LLM] 流式错误: code={chunk.status_code}, message={err}"
)
q.put(("err", RuntimeError(f"LLM API 错误: {err}")))
return
delta = _stream_delta_from_response(chunk)
if delta:
q.put(("delta", delta))
q.put(("done", None))
except Exception as e:
logger.error(f"[LLM] 流式异常: {e}")
q.put(("ex", e))
threading.Thread(target=worker, daemon=True).start()
try:
while True:
kind, payload = await asyncio.to_thread(q.get)
if kind == "delta":
yield payload
elif kind == "done":
break
elif kind == "err":
raise payload
elif kind == "ex":
raise payload
finally:
elapsed = time.time() - t0
logger.debug(
f"[LLM/stream] session={session_id[:8]} turn={turn_id[:8]} "
f"结束, 总耗时={elapsed:.2f}s"
)
async def chat_stream_with_tools(
self,
messages: List[Dict[str, Any]],
session_id: str = "",
turn_id: str = "",
) -> AsyncIterator[str]:
"""先跑 function-calling 闭环,再分块 yield兼容现有 handler 流式 TTS。"""
if not self._initialized:
raise RuntimeError("LLM 服务未初始化")
if not settings.LLM_TOOLS_ENABLED:
async for chunk in self.chat_stream(
messages, # type: ignore[arg-type]
session_id=session_id,
turn_id=turn_id,
):
yield chunk
return
t0 = time.time()
text = await asyncio.to_thread(
_run_tool_agent_sync, messages, session_id, turn_id
)
if text:
n = len(text)
if n <= 72:
yield text
else:
step = max(24, n // 24)
for i in range(0, n, step):
yield text[i : i + step]
logger.debug(
f"[LLM/tools/stream] session={session_id[:8]} turn={turn_id[:8]} "
f"总耗时={time.time() - t0:.2f}s"
)
async def build_messages(
self,
user_text: str,
history: List[Dict[str, str]] = None,
px4: Optional[Dict[str, Any]] = None,
*,
enable_tools: bool = False,
) -> List[Dict[str, Any]]:
messages: List[Dict[str, Any]] = [
{
"role": "system",
"content": build_system_prompt(px4, enable_tools=enable_tools),
}
]
if history:
messages.extend(history[-settings.LLM_CONTEXT_TURNS * 2 :])
messages.append({"role": "user", "content": user_text})
return messages
async def shutdown(self):
"""关闭服务"""
self._initialized = False
logger.info("DashScope LLM 服务已关闭")