From 003ab2435796c96cda8f783cdc08f45572a33748 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 20 Nov 2025 22:20:11 +0800 Subject: [PATCH] =?UTF-8?q?=E9=97=AE=E9=A2=98=E6=8E=A8=E8=8D=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/facts/facts.py | 2 -- apps/scheduler/call/reply/direct_reply.py | 4 +++- apps/scheduler/call/suggest/suggest.py | 8 ++++++-- apps/scheduler/executor/step.py | 9 ++++----- apps/scheduler/pool/pool.py | 2 +- apps/services/llm.py | 12 ++++++++++++ apps/services/node.py | 21 ++++++++++++--------- 7 files changed, 38 insertions(+), 20 deletions(-) diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index 6f8511825..64d05acdc 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -53,8 +53,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """初始化工具""" # 提取 llm_id 和 enable_thinking,避免重复传递 - llm_id = kwargs.pop("llm_id", None) - enable_thinking = kwargs.pop("enable_thinking", False) obj = cls( answer=executor.task.runtime.answer, diff --git a/apps/scheduler/call/reply/direct_reply.py b/apps/scheduler/call/reply/direct_reply.py index 90a4b7c71..bb85a2d75 100644 --- a/apps/scheduler/call/reply/direct_reply.py +++ b/apps/scheduler/call/reply/direct_reply.py @@ -141,7 +141,9 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep # 首先返回文本内容 yield CallOutputChunk( type=CallOutputType.TEXT, - content=final_answer + content=DirectReplyOutput( + message=final_answer + ).model_dump(exclude_none=True, by_alias=True) ) # 处理附件 diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index dfaf59736..f533a78ec 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -79,8 +79,12 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO "content": executor.task.runtime.answer, }, ] + if "llm_id" in kwargs and kwargs["llm_id"]: + if not (await LLMManager.is_llm_exists(kwargs["llm_id"])): + kwargs["llm_id"] = executor.func_call_llm_id + else: + kwargs["llm_id"] = executor.func_call_llm_id obj = cls( - llm_id=executor.func_call_llm_id, name=executor.step.step.name, description=executor.step.step.description, node=node, @@ -138,7 +142,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO async def get_func_llm_id(self) -> FunctionLLM: """获取用于函数调用的LLM ID""" - func_call_llm = await LLMManager.get_llm_by_id(self.func_llm_id) + func_call_llm = await LLMManager.get_llm_by_id(self.llm_id) func_call_llm_config = FunctionCallConfig( provider=func_call_llm.provider, endpoint=func_call_llm.openai_base_url, diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index b248e80df..bf096b7d5 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -12,6 +12,7 @@ import jsonschema from pydantic import ConfigDict, Field from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.llm.llm import LLM from apps.scheduler.call.empty import Empty from apps.scheduler.call.facts.facts import FactsCall from apps.scheduler.call.reply.direct_reply import DirectReply @@ -98,7 +99,7 @@ class StepExecutor(BaseExecutor): self.task.state.step_id = self.step.step_id # type: ignore[arg-type] # type: ignore[arg-type] self.task.state.step_name = self.step.step.name - + logger.error(f"StepExecutor初始化时,step.step.node: {self.step}") # 获取并验证Call类 node_id = self.step.step.node # 获取node详情并存储 @@ -110,11 +111,9 @@ class StepExecutor(BaseExecutor): if self.node: call_cls = await StepExecutor.get_call_cls(self.node.call_id) - self._call_id = self.node.call_id else: # 可能是特殊的内置Node call_cls = await StepExecutor.get_call_cls(node_id) - self._call_id = node_id # 初始化Call Class,用户参数会覆盖node的参数 params: dict[str, Any] = ( @@ -126,10 +125,10 @@ class StepExecutor(BaseExecutor): params.update(input_params) # 对于LLM调用,注入enable_thinking参数 - if self._call_id == SpecialCallType.LLM.value: + if call_cls == LLM: params["llm_id"] = self.chat_llm_id params['enable_thinking'] = self.background.enable_thinking - + try: self.obj = await call_cls.instance(self, self.node, **params) except Exception: diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 13821f896..2410d4dec 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -163,7 +163,7 @@ class Pool: # 对于Plugin类型的节点,返回API插件Call类 from apps.scheduler.call.plugin import Plugin return Plugin - + # 从MongoDB里拿到数据 call_collection = MongoDB.get_collection("call") call_db_data = await call_collection.find_one({"_id": call_id}) diff --git a/apps/services/llm.py b/apps/services/llm.py index c6792cf61..babbb4e7b 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -128,6 +128,18 @@ class LLMManager: return "" return result.get("llm", {}).get("llm_id", "") + @staticmethod + async def is_llm_exists(llm_id: str) -> bool: + """ + 检查大模型是否存在 + + :param llm_id: 大模型ID + :return: 是否存在 + """ + llm_collection = MongoDB.get_collection("llm") + result = await llm_collection.find_one({"_id": llm_id}) + return result is not None + @staticmethod async def get_llm_by_id(llm_id: str) -> LLM: """ diff --git a/apps/services/node.py b/apps/services/node.py index 9a3505a9b..e864236c1 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -29,7 +29,7 @@ class NodeManager: return SpecialCallType.EMPTY.value elif node_id == SpecialCallType.PLUGIN.value: return SpecialCallType.PLUGIN.value - + # 其他节点类型:从数据库查询 node_collection = MongoDB.get_collection("node") node = await node_collection.find_one({"_id": node_id}, {"call_id": 1}) @@ -72,12 +72,14 @@ class NodeManager: # 如果在known_params中找到匹配的键,更新default值 properties[key]["default"] = known_params[key] # 递归处理嵌套的schema - properties[key] = NodeManager.merge_params_schema(value, known_params) + properties[key] = NodeManager.merge_params_schema( + value, known_params) elif params_schema.get("type") == "array": items = params_schema.get("items", {}) # 递归处理数组项 - params_schema["items"] = NodeManager.merge_params_schema(items, known_params) + params_schema["items"] = NodeManager.merge_params_schema( + items, known_params) return params_schema @@ -89,7 +91,7 @@ class NodeManager: if node_id == SpecialCallType.EMPTY.value: # 如果是空节点,返回空Schema return {}, {} - + if node_id == SpecialCallType.PLUGIN.value: # 如果是Plugin节点,返回API插件的默认Schema return { @@ -123,7 +125,7 @@ class NodeManager: "description": "HTTP状态码" } } - + # 查找Node信息 logger.info("[NodeManager] 获取节点 %s", node_id) node_collection = MongoDB.get_collection("node") @@ -147,7 +149,7 @@ class NodeManager: output_schema = call_class.output_model.model_json_schema( # type: ignore[attr-defined] override=node_data.override_output if node_data.override_output else {}, ) - + # 特殊处理:对于循环节点,直接返回扁平化的输出参数结构 if call_id == "Loop": # 直接使用正确的扁平化格式,避免依赖JSON Schema转换 @@ -157,7 +159,7 @@ class NodeManager: "description": "实际执行的循环次数" }, "stop_reason": { - "type": "string", + "type": "string", "description": "停止原因" }, "variables": { @@ -173,9 +175,10 @@ class NodeManager: "description": "提取的文本内容" } } - + # 返回参数Schema return ( - NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), + NodeManager.merge_params_schema( + call_class.model_json_schema(), node_data.known_params or {}), output_schema, ) -- Gitee