372 lines
12 KiB
Python
372 lines
12 KiB
Python
"""
|
||
阿里云百炼 LLM 服务实现
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import copy
|
||
import queue
|
||
import threading
|
||
import time
|
||
from http import HTTPStatus
|
||
from typing import List, Dict, Any, Tuple, AsyncIterator, Optional
|
||
from loguru import logger
|
||
from dashscope import Generation
|
||
import dashscope
|
||
|
||
from app.services.llm_service import LLMServiceInterface, build_system_prompt
|
||
from app.services.tool_definitions import LLM_AGENT_TOOLS
|
||
from app.services.tool_executor import dispatch_tool
|
||
from app.config import settings
|
||
|
||
|
||
def _stream_delta_from_response(response) -> str:
|
||
"""从 DashScope 流式响应块解析增量文本(incremental_output=True)。"""
|
||
try:
|
||
if response.status_code != HTTPStatus.OK:
|
||
return ""
|
||
choices = response.output.choices
|
||
if not choices:
|
||
return ""
|
||
c0 = choices[0]
|
||
if isinstance(c0, dict):
|
||
msg = c0.get("message") or {}
|
||
return msg.get("content") or ""
|
||
msg = getattr(c0, "message", None)
|
||
if msg is None:
|
||
return ""
|
||
if isinstance(msg, dict):
|
||
return msg.get("content") or ""
|
||
return getattr(msg, "content", None) or ""
|
||
except (AttributeError, IndexError, KeyError, TypeError):
|
||
return ""
|
||
|
||
|
||
def _function_payload(fn_obj: Any) -> tuple[str, str]:
|
||
if isinstance(fn_obj, dict):
|
||
return (
|
||
str(fn_obj.get("name") or ""),
|
||
str(fn_obj.get("arguments") if fn_obj.get("arguments") is not None else "{}"),
|
||
)
|
||
name = getattr(fn_obj, "name", None) or ""
|
||
args = getattr(fn_obj, "arguments", None)
|
||
if args is None:
|
||
args = "{}"
|
||
return str(name), str(args)
|
||
|
||
|
||
def _normalize_tool_call(tc: Any) -> dict[str, Any]:
|
||
if isinstance(tc, dict):
|
||
fn_raw = tc.get("function")
|
||
name, arguments = _function_payload(fn_raw)
|
||
return {
|
||
"id": tc.get("id"),
|
||
"type": tc.get("type") or "function",
|
||
"function": {"name": name, "arguments": arguments},
|
||
}
|
||
fn = getattr(tc, "function", None)
|
||
name, arguments = _function_payload(fn)
|
||
return {
|
||
"id": getattr(tc, "id", None),
|
||
"type": getattr(tc, "type", None) or "function",
|
||
"function": {"name": name, "arguments": arguments},
|
||
}
|
||
|
||
|
||
def _assistant_message_to_dict(msg: Any) -> dict[str, Any]:
|
||
"""将本轮 assistant 消息转为 API 可接受的 dict(含 tool_calls)。"""
|
||
if isinstance(msg, dict):
|
||
content = msg.get("content")
|
||
tcs = msg.get("tool_calls")
|
||
else:
|
||
content = getattr(msg, "content", None)
|
||
tcs = getattr(msg, "tool_calls", None)
|
||
out: dict[str, Any] = {"role": "assistant", "content": content}
|
||
if tcs:
|
||
out["tool_calls"] = [_normalize_tool_call(x) for x in tcs]
|
||
return out
|
||
|
||
|
||
def _run_tool_agent_sync(
|
||
messages: List[Dict[str, Any]],
|
||
session_id: str,
|
||
turn_id: str,
|
||
) -> str:
|
||
"""同步:多步 Generation + 执行工具,直至模型输出无 tool_calls。"""
|
||
msgs: List[Dict[str, Any]] = copy.deepcopy(messages)
|
||
max_steps = max(1, int(settings.LLM_AGENT_MAX_STEPS))
|
||
cap = max(500, int(settings.LLM_TOOL_RESULT_MAX_CHARS))
|
||
|
||
for step in range(max_steps):
|
||
logger.debug(
|
||
f"[LLM/tools] step={step} session={session_id[:8]} turn={turn_id[:8]}"
|
||
)
|
||
resp = Generation.call(
|
||
model=settings.LLM_MODEL,
|
||
messages=msgs,
|
||
tools=LLM_AGENT_TOOLS,
|
||
tool_choice="auto",
|
||
result_format="message",
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
timeout=settings.LLM_TIMEOUT,
|
||
)
|
||
if resp.status_code != HTTPStatus.OK:
|
||
err = getattr(resp, "message", None) or str(resp)
|
||
logger.error(f"[LLM/tools] API error: {err}")
|
||
return f"模型调用失败:{err}"
|
||
|
||
msg = resp.output.choices[0].message
|
||
raw_assistant = _assistant_message_to_dict(msg)
|
||
tool_calls = raw_assistant.get("tool_calls")
|
||
|
||
if not tool_calls:
|
||
text = raw_assistant.get("content")
|
||
out = (text or "").strip() if text is not None else ""
|
||
logger.info(
|
||
f"[LLM/tools] done step={step} reply_chars={len(out)} "
|
||
f"session={session_id[:8]} turn={turn_id[:8]}"
|
||
)
|
||
return text or ""
|
||
|
||
msgs.append(raw_assistant)
|
||
for tc in tool_calls:
|
||
fn = tc.get("function") or {}
|
||
name = fn.get("name", "") if isinstance(fn, dict) else ""
|
||
raw_args = fn.get("arguments", "{}") if isinstance(fn, dict) else "{}"
|
||
tid = str(tc.get("id") or "")
|
||
logger.info(
|
||
f"[LLM/tools] call {name!r} tool_call_id={tid[:20]!r} "
|
||
f"session={session_id[:8]}"
|
||
)
|
||
try:
|
||
result = dispatch_tool(name, raw_args)
|
||
except Exception as e:
|
||
logger.exception("[LLM/tools] dispatch error")
|
||
result = f"工具执行异常:{e}"
|
||
if len(result) > cap:
|
||
result = result[:cap] + "\n…(截断)"
|
||
msgs.append(
|
||
{
|
||
"role": "tool",
|
||
"content": result,
|
||
"tool_call_id": tid,
|
||
"name": name,
|
||
}
|
||
)
|
||
|
||
logger.warning("[LLM/tools] max steps exceeded")
|
||
return "步骤过多未能完成,请把问题说简短一些或稍后再试。"
|
||
|
||
|
||
class DashScopeLLMService(LLMServiceInterface):
|
||
"""阿里云百炼 LLM 服务实现"""
|
||
|
||
def __init__(self):
|
||
self._initialized = False
|
||
|
||
async def initialize(self) -> bool:
|
||
"""初始化 DashScope SDK"""
|
||
try:
|
||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||
|
||
logger.info(f"DashScope LLM 初始化成功, 模型: {settings.LLM_MODEL}")
|
||
self._initialized = True
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"DashScope LLM 初始化失败: {e}")
|
||
return False
|
||
|
||
async def chat(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
session_id: str = "",
|
||
turn_id: str = "",
|
||
) -> Tuple[str, float]:
|
||
"""
|
||
调用阿里云百炼 LLM(非流式,兼容旧逻辑与测试)
|
||
"""
|
||
if not self._initialized:
|
||
raise RuntimeError("LLM 服务未初始化")
|
||
|
||
t0 = time.time()
|
||
|
||
try:
|
||
if settings.LLM_TOOLS_ENABLED:
|
||
text = await asyncio.to_thread(
|
||
_run_tool_agent_sync,
|
||
copy.deepcopy(messages),
|
||
session_id,
|
||
turn_id,
|
||
)
|
||
elapsed = time.time() - t0
|
||
logger.debug(
|
||
f"[LLM] session={session_id[:8]} turn={turn_id[:8]} "
|
||
f"推理成功( tools ), 耗时={elapsed:.2f}s, 内容={text[:80]}..."
|
||
)
|
||
return text, elapsed
|
||
|
||
response = Generation.call(
|
||
model=settings.LLM_MODEL,
|
||
messages=messages,
|
||
result_format="message",
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
timeout=settings.LLM_TIMEOUT,
|
||
)
|
||
|
||
elapsed = time.time() - t0
|
||
|
||
if response.status_code == 200:
|
||
msg = response.output.choices[0].message
|
||
if isinstance(msg, dict):
|
||
content = msg.get("content") or ""
|
||
else:
|
||
content = getattr(msg, "content", None) or ""
|
||
logger.debug(
|
||
f"[LLM] session={session_id[:8]} turn={turn_id[:8]} "
|
||
f"推理成功, 耗时={elapsed:.2f}s, 内容={content[:80]}..."
|
||
)
|
||
return content, elapsed
|
||
else:
|
||
logger.error(
|
||
f"[LLM] 调用失败: code={response.status_code}, "
|
||
f"message={response.message}"
|
||
)
|
||
raise Exception(f"LLM API 错误: {response.message}")
|
||
|
||
except Exception as e:
|
||
elapsed = time.time() - t0
|
||
logger.error(
|
||
f"[LLM] 异常: session={session_id[:8]} turn={turn_id[:8]} "
|
||
f"错误={e}, 耗时={elapsed:.2f}s"
|
||
)
|
||
raise
|
||
|
||
async def chat_stream(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
session_id: str = "",
|
||
turn_id: str = "",
|
||
) -> AsyncIterator[str]:
|
||
"""
|
||
DashScope 流式:stream=True + incremental_output=True。
|
||
在单独线程中迭代 SDK 同步生成器,经 asyncio.Queue 推到事件循环。
|
||
"""
|
||
if not self._initialized:
|
||
raise RuntimeError("LLM 服务未初始化")
|
||
|
||
t0 = time.time()
|
||
q: queue.Queue = queue.Queue()
|
||
|
||
def worker():
|
||
try:
|
||
stream = Generation.call(
|
||
model=settings.LLM_MODEL,
|
||
messages=messages,
|
||
result_format="message",
|
||
max_tokens=settings.LLM_MAX_TOKENS,
|
||
temperature=settings.LLM_TEMPERATURE,
|
||
timeout=settings.LLM_TIMEOUT,
|
||
stream=True,
|
||
incremental_output=True,
|
||
)
|
||
for chunk in stream:
|
||
if chunk.status_code != HTTPStatus.OK:
|
||
err = getattr(chunk, "message", None) or str(chunk)
|
||
logger.error(
|
||
f"[LLM] 流式错误: code={chunk.status_code}, message={err}"
|
||
)
|
||
q.put(("err", RuntimeError(f"LLM API 错误: {err}")))
|
||
return
|
||
delta = _stream_delta_from_response(chunk)
|
||
if delta:
|
||
q.put(("delta", delta))
|
||
q.put(("done", None))
|
||
except Exception as e:
|
||
logger.error(f"[LLM] 流式异常: {e}")
|
||
q.put(("ex", e))
|
||
|
||
threading.Thread(target=worker, daemon=True).start()
|
||
|
||
try:
|
||
while True:
|
||
kind, payload = await asyncio.to_thread(q.get)
|
||
if kind == "delta":
|
||
yield payload
|
||
elif kind == "done":
|
||
break
|
||
elif kind == "err":
|
||
raise payload
|
||
elif kind == "ex":
|
||
raise payload
|
||
finally:
|
||
elapsed = time.time() - t0
|
||
logger.debug(
|
||
f"[LLM/stream] session={session_id[:8]} turn={turn_id[:8]} "
|
||
f"结束, 总耗时={elapsed:.2f}s"
|
||
)
|
||
|
||
async def chat_stream_with_tools(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
session_id: str = "",
|
||
turn_id: str = "",
|
||
) -> AsyncIterator[str]:
|
||
"""先跑 function-calling 闭环,再分块 yield,兼容现有 handler 流式 TTS。"""
|
||
if not self._initialized:
|
||
raise RuntimeError("LLM 服务未初始化")
|
||
if not settings.LLM_TOOLS_ENABLED:
|
||
async for chunk in self.chat_stream(
|
||
messages, # type: ignore[arg-type]
|
||
session_id=session_id,
|
||
turn_id=turn_id,
|
||
):
|
||
yield chunk
|
||
return
|
||
|
||
t0 = time.time()
|
||
text = await asyncio.to_thread(
|
||
_run_tool_agent_sync, messages, session_id, turn_id
|
||
)
|
||
if text:
|
||
n = len(text)
|
||
if n <= 72:
|
||
yield text
|
||
else:
|
||
step = max(24, n // 24)
|
||
for i in range(0, n, step):
|
||
yield text[i : i + step]
|
||
logger.debug(
|
||
f"[LLM/tools/stream] session={session_id[:8]} turn={turn_id[:8]} "
|
||
f"总耗时={time.time() - t0:.2f}s"
|
||
)
|
||
|
||
async def build_messages(
|
||
self,
|
||
user_text: str,
|
||
history: List[Dict[str, str]] = None,
|
||
px4: Optional[Dict[str, Any]] = None,
|
||
*,
|
||
enable_tools: bool = False,
|
||
) -> List[Dict[str, Any]]:
|
||
messages: List[Dict[str, Any]] = [
|
||
{
|
||
"role": "system",
|
||
"content": build_system_prompt(px4, enable_tools=enable_tools),
|
||
}
|
||
]
|
||
|
||
if history:
|
||
messages.extend(history[-settings.LLM_CONTEXT_TURNS * 2 :])
|
||
|
||
messages.append({"role": "user", "content": user_text})
|
||
|
||
return messages
|
||
|
||
async def shutdown(self):
|
||
"""关闭服务"""
|
||
self._initialized = False
|
||
logger.info("DashScope LLM 服务已关闭")
|