diff --git a/apps/llm/__init__.py b/apps/llm/__init__.py index fa0ef8e4fc88ab295492772422bfee42e51ba906..30e487a5c4904a8011638b26bc78747edb301c65 100644 --- a/apps/llm/__init__.py +++ b/apps/llm/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """模型调用模块""" -from .embedding import Embedding -from .generator import JsonGenerator +from .embedding import Embedding, embedding +from .generator import json_generator from .llm import LLM from .schema import LLMConfig from .token import token_calculator @@ -10,7 +10,8 @@ from .token import token_calculator __all__ = [ "LLM", "Embedding", - "JsonGenerator", "LLMConfig", + "embedding", + "json_generator", "token_calculator", ] diff --git a/apps/llm/embedding.py b/apps/llm/embedding.py index 932ad983582d8fa547abbf846984a22855b6fb64..1e33fcff568b117ba6ac69f1f748aa2782679035 100644 --- a/apps/llm/embedding.py +++ b/apps/llm/embedding.py @@ -298,6 +298,8 @@ class Embedding: # 使用全局单例的 VectorTableManager _table_manager = VectorTableManager() + _llm_config: LLMData | None = None + _provider: BaseProvider | None = None # 便捷访问属性(指向 VectorTableManager 的表类) @property @@ -330,38 +332,37 @@ class Embedding: """获取MCPToolVector""" return self._table_manager.MCPToolVector - def __init__(self, llm_config: LLMData | None = None) -> None: - """初始化Embedding对象""" - if not llm_config: - err = "[Embedding] 未设置LLM配置" - _logger.error(err) - raise RuntimeError(err) - self._llm_config = llm_config - self._provider = _CLASS_DICT[llm_config.provider](llm_config) - async def _get_embedding_dimension(self) -> int: """获取Embedding的维度""" embedding = await self.get_embedding(["测试文本"]) return len(embedding[0]) - async def init(self) -> None: + async def init(self, llm_config: LLMData | None) -> None: """ - 在使用Embedding前初始化数据库表等资源 + 初始化Embedding配置和资源 + + 设置LLM配置,检测向量维度,并确保数据库向量表存在且维度正确。 + 如果模型或维度发生变化,会自动删除旧表并重建。 - 检测embedding维度并确保向量表存在且维度正确。 - 如果模型或维度发生变化,会自动删除旧表并重建,避免向量空间不匹配。 + :param llm_config: LLM配置 + :raises RuntimeError: 当llm_config为None时抛出 """ - _logger.info( - "[Embedding] 开始初始化向量表,模型=%s/%s", - self._llm_config.provider, - self._llm_config.modelName, - ) + if llm_config is None: + err = "[Embedding] 未设置LLM配置" + _logger.error(err) + raise RuntimeError(err) + + _logger.info("[Embedding] 初始化Embedding,模型=%s/%s", llm_config.provider, llm_config.modelName) + self._llm_config = llm_config + self._provider = _CLASS_DICT[llm_config.provider](llm_config) + # 检测维度 dim = await self._get_embedding_dimension() _logger.info("[Embedding] 检测到向量维度: %d", dim) # 使用 VectorTableManager 确保表存在且维度和模型都正确 await self._table_manager.ensure_tables(dim, self._llm_config) + _logger.info("[Embedding] 向量表检查完成") async def get_embedding(self, text: list[str]) -> list[list[float]]: """ @@ -370,4 +371,12 @@ class Embedding: :param text: 待向量化文本(多条文本组成List) :return: 文本对应的向量(顺序与text一致,也为List) """ + if not self._provider: + err = "[Embedding] Provider未初始化,无法获取embedding" + _logger.error(err) + raise RuntimeError(err) return await self._provider.embedding(text) + + +# 全局Embedding实例 +embedding = Embedding() diff --git a/apps/llm/generator.py b/apps/llm/generator.py index 3ce5a300924912e0cdd9758967cc7569ca64b094..1a32c1bacdac0ac44d01d593abb716b7c32ade21 100644 --- a/apps/llm/generator.py +++ b/apps/llm/generator.py @@ -19,31 +19,14 @@ from .schema import LLMConfig _logger = logging.getLogger(__name__) -JSON_GEN_MAX_TRIAL = 3 # 最大尝试次数 +JSON_GEN_MAX_TRIAL = 3 class JsonGenerator: - """综合Json生成器""" + """综合Json生成器(全局单例)""" - def __init__( - self, llm_config: LLMConfig, query: str, conversation: list[dict[str, str]], function: dict[str, Any], - ) -> None: - """初始化JSON生成器;function使用OpenAI标准Function格式""" - self._query = query - self._function = function - - # 选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型 - self._llm, self._support_function_call = self._select_llm(llm_config) - - self._context = [ - { - "role": "system", - "content": "You are a helpful assistant that can use tools to help answer user queries.", - }, - ] - if conversation: - self._context.extend(conversation) - - self._count = 0 + def __init__(self) -> None: + """创建JsonGenerator实例""" + # Jinja2环境,可以复用 self._env = SandboxedEnvironment( loader=BaseLoader(), autoescape=False, @@ -51,6 +34,14 @@ class JsonGenerator: lstrip_blocks=True, extensions=["jinja2.ext.loopcontrols"], ) + # 初始化时设为None,调用init后设置 + self._llm: LLM | None = None + self._support_function_call: bool = False + + def init(self, llm_config: LLMConfig) -> None: + """初始化JsonGenerator,设置LLM配置""" + # 选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型 + self._llm, self._support_function_call = self._select_llm(llm_config) def _select_llm(self, llm_config: LLMConfig) -> tuple[LLM, bool]: """选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型""" @@ -61,31 +52,37 @@ class JsonGenerator: _logger.info("[JSONGenerator] Function模型不可用或不支持FunctionCall,回退到Reasoning模型") return llm_config.reasoning, False - async def _single_trial(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]: - """单次尝试,包含校验逻辑""" - # 获取schema并创建验证器 - schema = self._function["parameters"] + async def _single_trial( + self, + function: dict[str, Any], + query: str, + context: list[dict[str, str]], + ) -> dict[str, Any]: + """单次尝试,包含校验逻辑;function使用OpenAI标准Function格式""" + if self._llm is None: + err = "[JSONGenerator] 未初始化,请先调用init()方法" + raise RuntimeError(err) + + schema = function["parameters"] validator = Draft7Validator(schema) # 执行生成 if self._support_function_call: - # 如果支持FunctionCall,使用provider的function调用逻辑 - result = await self._call_with_function() + # 如果支持FunctionCall + result = await self._call_with_function(function, query, context) else: - # 如果不支持FunctionCall,使用JSON_GEN_BASIC - result = await self._call_without_function(max_tokens, temperature) + # 如果不支持FunctionCall + result = await self._call_without_function(function, query, context) # 校验结果 try: validator.validate(result) except Exception as err: - # 捕获校验异常信息 err_info = str(err) err_info = err_info.split("\n\n")[0] _logger.info("[JSONGenerator] 验证失败:%s", err_info) - # 将错误信息添加到上下文中 - self._context.append({ + context.append({ "role": "assistant", "content": f"Attempted to use tool but validation failed: {err_info}", }) @@ -93,43 +90,57 @@ class JsonGenerator: else: return result - async def _call_with_function(self) -> dict[str, Any]: + async def _call_with_function( + self, + function: dict[str, Any], + query: str, + context: list[dict[str, str]], + ) -> dict[str, Any]: """使用FunctionCall方式调用""" - # 直接使用传入的function构建工具定义 + if self._llm is None: + err = "[JSONGenerator] 未初始化,请先调用init()方法" + raise RuntimeError(err) + tool = LLMFunctions( - name=self._function["name"], - description=self._function["description"], - param_schema=self._function["parameters"], + name=function["name"], + description=function["description"], + param_schema=function["parameters"], ) - messages = self._context.copy() - messages.append({"role": "user", "content": self._query}) + messages = context.copy() + messages.append({"role": "user", "content": query}) - # 调用LLM的call方法,传入tools tool_call_result = {} async for chunk in self._llm.call(messages, include_thinking=False, streaming=True, tools=[tool]): if chunk.tool_call: tool_call_result.update(chunk.tool_call) - # 从tool_call结果中提取JSON,使用function中的函数名 - function_name = self._function["name"] + function_name = function["name"] if tool_call_result and function_name in tool_call_result: return json.loads(tool_call_result[function_name]) return {} - async def _call_without_function(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]: # noqa: E501 + async def _call_without_function( + self, + function: dict[str, Any], + query: str, + context: list[dict[str, str]], + ) -> dict[str, Any]: """不使用FunctionCall方式调用""" - # 渲染模板 + if self._llm is None: + err = "[JSONGenerator] 未初始化,请先调用init()方法" + raise RuntimeError(err) + template = self._env.from_string(JSON_GEN_BASIC + "\n\n" + JSON_NO_FUNCTION_CALL) prompt = template.render( - query=self._query, - conversation=self._context[1:] if self._context else [], - schema=self._function["parameters"], + query=query, + conversation=context[1:] if context else [], + schema=function["parameters"], ) messages = [ - self._context[0], + context[0], {"role": "user", "content": prompt}, ] @@ -148,25 +159,60 @@ class JsonGenerator: return {} - async def generate(self) -> dict[str, Any]: - """生成JSON""" + async def generate( + self, + query: str, + function: dict[str, Any], + conversation: list[dict[str, str]] | None = None, + ) -> dict[str, Any]: + """ + 生成JSON;function使用OpenAI标准Function格式 + + Args: + query: 用户查询 + function: OpenAI标准Function格式的函数定义 + conversation: 对话历史,默认为空列表 + + Returns: + 生成的JSON对象 + + """ + if self._llm is None: + err = "[JSONGenerator] 未初始化,请先调用init()方法" + raise RuntimeError(err) + # 检查schema格式是否正确 - schema = self._function["parameters"] + schema = function["parameters"] Draft7Validator.check_schema(schema) - while self._count < JSON_GEN_MAX_TRIAL: - self._count += 1 + # 构建上下文 + context = [ + { + "role": "system", + "content": "You are a helpful assistant that can use tools to help answer user queries.", + }, + ] + if conversation: + context.extend(conversation) + + count = 0 + original_context = context.copy() + while count < JSON_GEN_MAX_TRIAL: + count += 1 try: # 如果_single_trial没有抛出异常,直接返回结果,不进行重试 - return await self._single_trial() + return await self._single_trial(function, query, context) except Exception: - # 每次捕获异常时都记录错误 _logger.exception( "[JSONGenerator] 第 %d/%d 次尝试失败", - self._count, + count, JSON_GEN_MAX_TRIAL, ) - # 如果还有重试机会,继续下一次尝试 - if self._count < JSON_GEN_MAX_TRIAL: + if count < JSON_GEN_MAX_TRIAL: continue + context = original_context return {} + + +# 全局单例实例 +json_generator = JsonGenerator() diff --git a/apps/models/__init__.py b/apps/models/__init__.py index 98bd73674641913a678db2b29abd355de3701db1..6a0fcc3c016a92b70ab7cd134e249eef819ad5c2 100644 --- a/apps/models/__init__.py +++ b/apps/models/__init__.py @@ -11,6 +11,7 @@ from .flow import Flow from .llm import LLMData, LLMProvider, LLMType from .mcp import MCPActivated, MCPInfo, MCPInstallStatus, MCPTools, MCPType from .node import NodeInfo +from .settings import GlobalSettings from .record import Record, RecordMetadata from .service import Service, ServiceACL, ServiceHashes from .session import Session, SessionActivity, SessionType @@ -45,6 +46,7 @@ __all__ = [ "ExecutorHistory", "ExecutorStatus", "Flow", + "GlobalSettings", "LLMData", "LLMProvider", "LLMType", diff --git a/apps/models/settings.py b/apps/models/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9bb0ad196098f3b28a64257337e1fd64e69b2c --- /dev/null +++ b/apps/models/settings.py @@ -0,0 +1,41 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""全局设置 数据库表""" + +from datetime import UTC, datetime + +from sqlalchemy import DateTime, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base + + +class GlobalSettings(Base): + """全局设置""" + + __tablename__ = "framework_global_settings" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False) + """设置ID""" + + functionLlmId: Mapped[str | None] = mapped_column( # noqa: N815 + String(255), ForeignKey("framework_llm.id"), nullable=True, default=None, + ) + """Function Call使用的大模型ID""" + + embeddingLlmId: Mapped[str | None] = mapped_column( # noqa: N815 + String(255), ForeignKey("framework_llm.id"), nullable=True, default=None, + ) + """Embedding使用的大模型ID""" + + updatedAt: Mapped[DateTime] = mapped_column( # noqa: N815 + DateTime, + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), + nullable=False, + ) + """设置更新时间""" + + lastEditedBy: Mapped[str | None] = mapped_column( # noqa: N815 + String(50), ForeignKey("framework_user.userSub"), nullable=True, default=None, + ) + """最后一次修改的用户sub""" diff --git a/apps/routers/llm.py b/apps/routers/llm.py index dea8e9cc5c9f6733a927baf513b06a071e1876d7..85420fb8a19fb35ee72c2ab6d39cb5f405f8f9c4 100644 --- a/apps/routers/llm.py +++ b/apps/routers/llm.py @@ -39,7 +39,7 @@ admin_router = APIRouter( ) async def list_llm(llmId: str | None = None) -> JSONResponse: # noqa: N803 """GET /llm: 获取大模型列表""" - llm_list = await LLMManager.list_provider(llmId) + llm_list = await LLMManager.list_llm(llmId, admin_view=False) return JSONResponse( status_code=status.HTTP_200_OK, content=ListLLMRsp( @@ -55,7 +55,7 @@ async def list_llm(llmId: str | None = None) -> JSONResponse: # noqa: N803 ) async def admin_list_llm(llmId: str | None = None) -> JSONResponse: # noqa: N803 """GET /llm/config: 获取大模型配置列表(管理员视图)""" - llm_list = await LLMManager.list_llm(llmId) + llm_list = await LLMManager.list_llm(llmId, admin_view=True) return JSONResponse( status_code=status.HTTP_200_OK, content=ListLLMAdminRsp( diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 9f34f33d3042dad3d2d7fe2ffc1aafa88ff84050..150bba2868d42bf146c26d6b4439dbf7177ee34e 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -9,7 +9,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field -from apps.llm import JsonGenerator +from apps.llm import json_generator from apps.models import LanguageType, NodeInfo from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot as SlotProcessor @@ -54,11 +54,8 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): trim_blocks=True, lstrip_blocks=True, ) - - # 获取当前语言 language = self._sys_vars.language - # 渲染查询模板(不包含历史信息) query_template = env.from_string(SLOT_GEN_PROMPT[language]) query = query_template.render( current_tool={ @@ -68,10 +65,8 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): schema=remaining_schema, ) - # 组装conversation:将summary和历史工具调用组装为对话格式 conversation = [] - - # 使用Jinja2模板渲染任务总结 + # 任务总结 if self.summary or self.facts: summary_template = env.from_string(SLOT_SUMMARY_TEMPLATE[language]) summary_content = summary_template.render( @@ -92,7 +87,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): "content": assistant_response, }) - # 使用Jinja2模板渲染历史工具调用 + # 历史工具调用 if self._flow_history: history_template = env.from_string(SLOT_HISTORY_TEMPLATE[language]) history_content = history_template.render( @@ -104,22 +99,16 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): "content": history_content, }) - # 构建OpenAI标准FunctionCall格式 function = { "name": "fill_parameters", "description": f"Fill the missing parameters for {self.name}. {self.description}", "parameters": remaining_schema, } - - # 使用JsonGenerator进行参数填充 - generator = JsonGenerator( - llm_config=self._llm_obj, + data = await json_generator.generate( query=query, - conversation=conversation, function=function, + conversation=conversation, ) - - data = await generator.generate() answer = json.dumps(data, ensure_ascii=False) return answer, data diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 610fd1b2fa9c9bb54d68c8903af1b51ff20c289e..9793cbfbf6cf261a157dd3ba4127ce7bca92ed21 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -10,7 +10,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from mcp.types import TextContent -from apps.llm import JsonGenerator, LLMConfig +from apps.llm import LLMConfig, json_generator from apps.models import LanguageType, MCPTools from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient @@ -119,16 +119,14 @@ class MCPHost: "parameters": tool.inputSchema, } - # 进行生成 - json_generator = JsonGenerator( - self._llm, - llm_query, - [ + # 使用全局json_generator实例 + return await json_generator.generate( + query=llm_query, + function=function_definition, + conversation=[ {"role": "user", "content": await self.assemble_memory()}, ], - function_definition, ) - return await json_generator.generate() async def call_tool(self, tool: MCPTools, plan_item: MCPPlanItem) -> list[dict[str, Any]]: diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py index f24ae8c1fb64951722a11a2323f616066d0683b3..7dbb9facc240a1bac764d2722d770459aa2f6158 100644 --- a/apps/scheduler/mcp/plan.py +++ b/apps/scheduler/mcp/plan.py @@ -6,7 +6,7 @@ import logging from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from apps.llm import JsonGenerator, LLMConfig +from apps.llm import LLMConfig, json_generator from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPPlan @@ -83,16 +83,14 @@ class MCPPlanner: "parameters": schema, } - # 使用Function模型解析结果 - json_generator = JsonGenerator( - self._llm, - result, - [ + # 使用全局json_generator实例解析结果 + plan = await json_generator.generate( + query=result, + function=function_def, + conversation=[ {"role": "user", "content": result}, ], - function_def, ) - plan = await json_generator.generate() return MCPPlan.model_validate(plan) diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 0f4362de8a605d05219f21ba1ccdfb0770e5b047..dc32a0f5d369698d6630bfc9a9902a3d0474a677 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -6,7 +6,7 @@ import logging from sqlalchemy import select from apps.common.postgres import postgres -from apps.llm import JsonGenerator, LLMConfig +from apps.llm import LLMConfig, json_generator from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPSelectResult from apps.services.mcp_service import MCPServiceManager @@ -63,13 +63,11 @@ class MCPSelector: ) # 使用JsonGenerator生成JSON - generator = JsonGenerator( - llm_config=self._llm, + result = await json_generator.generate( query=user_prompt, - conversation=[], function=function, + conversation=[], ) - result = await generator.generate() try: result = MCPSelectResult.model_validate(result) diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 934e65a6eb3aecdf2306ae4de4ce9a3c5f8d4aa9..42524db9595d17ebaa3122fba43c99dccd8b6d29 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -4,7 +4,7 @@ import logging from typing import Any -from apps.llm import JsonGenerator, LLMConfig +from apps.llm import LLMConfig, json_generator from apps.models import LanguageType from apps.schemas.task import TaskData @@ -44,12 +44,10 @@ class MCPBase: async def get_json_result(self, result: str, function: dict[str, Any]) -> dict[str, Any]: """解析推理结果;function使用OpenAI标准Function格式""" - json_generator = JsonGenerator( - self._llm, - "Please provide a JSON response based on the above information and schema.\n\n", - [ + return await json_generator.generate( + query="Please provide a JSON response based on the above information and schema.\n\n", + function=function, + conversation=[ {"role": "user", "content": result}, ], - function, ) - return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index a6f850c705a28625d2311136957220c92bbfe7c8..1b346e583e4876e43944248069866cff52b3c7a5 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -8,7 +8,7 @@ from typing import Any from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from apps.llm import JsonGenerator +from apps.llm import json_generator from apps.models import ExecutorHistory, LanguageType, MCPTools, TaskRuntime from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.mcp_agent.base import MCPBase @@ -100,12 +100,10 @@ class MCPHost(MCPBase): "parameters": mcp_tool.inputSchema, } - json_generator = JsonGenerator( - self._llm, - llm_query, - [ + return await json_generator.generate( + query=llm_query, + function=function, + conversation=[ {"role": "user", "content": prompt}, ], - function, ) - return await json_generator.generate() diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 4c0ce897e7980b68941f1e199793bdf6aad78f95..c58744e51d89d11f9f7feca1ddf4e3c09a880c4c 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -11,7 +11,7 @@ from jinja2.sandbox import SandboxedEnvironment from apps.common.queue import MessageQueue from apps.common.security import Security -from apps.llm import LLM, Embedding, JsonGenerator, LLMConfig +from apps.llm import LLM, LLMConfig, embedding, json_generator from apps.models import ( AppType, Conversation, @@ -291,23 +291,26 @@ class Scheduler: else: function_llm = LLM(function_llm) - embedding_llm = None + # 获取并设置全局embedding模型 + embedding_obj = None if not self.user.embeddingLLM: _logger.error("[Scheduler] 用户 %s 没有设置向量模型,相关功能将被禁用", self.user.userSub) else: - embedding_llm = await LLMManager.get_llm(self.user.embeddingLLM) - if not embedding_llm: + embedding_llm_config = await LLMManager.get_llm(self.user.embeddingLLM) + if not embedding_llm_config: _logger.error( "[Scheduler] 用户 %s 设置的向量模型ID %s 不存在,相关功能将被禁用", self.user.userSub, self.user.embeddingLLM, ) else: - embedding_llm = Embedding(embedding_llm) + # 设置全局embedding配置 + await embedding.init(embedding_llm_config) + embedding_obj = embedding return LLMConfig( reasoning=reasoning_llm, function=function_llm, - embedding=embedding_llm, + embedding=embedding_obj, ) @@ -346,10 +349,19 @@ class Scheduler: ) schema = TopFlow.model_json_schema() schema["properties"]["choice"]["enum"] = [choice["name"] for choice in choices] - result_str = await JsonGenerator(self.llm.function, self.post_body.question, [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], schema).generate() + function = { + "name": "select_flow", + "description": "Select the appropriate flow", + "parameters": schema, + } + result_str = await json_generator.generate( + query=self.post_body.question, + function=function, + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + ) result = TopFlow.model_validate(result_str) return result.choice diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index c644b62b63fa2c47e1cdfe41a802d00b7beb93a3..5c7ee1ff4f3c27c6e2d76c98957491f5d4ec2f4b 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -62,7 +62,7 @@ class UpdateLLMReq(BaseModel): extra_data: dict[str, Any] | None = Field(default=None, description="额外数据", alias="extraData") -class UpdateUserSelectedLLMReq(BaseModel): +class UpdateSpecialLlmReq(BaseModel): """更新用户特殊LLM请求体""" functionLLM: str = Field(description="Function Call LLM ID") # noqa: N815 diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index c9bf04a27dc2e97f59980db57ded0be1395f0242..f0580aa87d882772f1878bce2485e69cb0fa49d1 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -179,7 +179,7 @@ class GetOperaRsp(ResponseData): result: list[OperateAndBindType] = Field(..., title="Result") -class SelectedSpecialLLMID(BaseModel): +class SelectedSpecialLlmID(BaseModel): """用户选择的LLM数据结构""" functionLLM: str | None = Field(default=None, description="函数模型ID") # noqa: N815 diff --git a/apps/services/llm.py b/apps/services/llm.py index c0b52ffda7a6995037dd20643bc1330e60b73703..416544190d0e93f65de37a797d7e55c2f7e053dd 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -6,17 +6,11 @@ import logging from sqlalchemy import select from apps.common.postgres import postgres -from apps.llm import Embedding from apps.models import LLMData, User -from apps.scheduler.pool.pool import pool -from apps.schemas.request_data import ( - UpdateLLMReq, - UpdateUserSelectedLLMReq, -) +from apps.schemas.request_data import UpdateLLMReq from apps.schemas.response_data import ( LLMAdminInfo, LLMProviderInfo, - SelectedSpecialLLMID, ) logger = logging.getLogger(__name__) @@ -47,11 +41,12 @@ class LLMManager: @staticmethod - async def list_provider(llm_id: str | None) -> list[LLMProviderInfo]: + async def list_llm(llm_id: str | None, *, admin_view: bool = False) -> list[LLMProviderInfo] | list[LLMAdminInfo]: """ 获取大模型列表 :param llm_id: 大模型ID + :param admin_view: 是否返回管理员视图,True返回LLMAdminInfo,False返回LLMProviderInfo :return: 大模型列表 """ async with postgres.session() as session: @@ -69,7 +64,28 @@ class LLMManager: logger.error("[LLMManager] 无法找到大模型 %s", llm_id) return [] - # 默认大模型 + # 根据admin_view参数返回不同的数据结构 + if admin_view: + # 构建管理员视图列表 + admin_list = [] + for llm in llm_list: + llm_item = LLMAdminInfo( + llmId=llm.id, + llmDescription=llm.llmDescription, + llmType=llm.llmType, + baseUrl=llm.baseUrl, + apiKey=llm.apiKey, + modelName=llm.modelName, + maxTokens=llm.maxToken, + ctxLength=llm.ctxLength, + temperature=llm.temperature, + provider=llm.provider.value if llm.provider else None, + extraConfig=llm.extraConfig, + ) + admin_list.append(llm_item) + return admin_list + + # 构建普通用户视图列表 provider_list = [] for llm in llm_list: llm_item = LLMProviderInfo( @@ -83,49 +99,6 @@ class LLMManager: return provider_list - @staticmethod - async def list_llm(llm_id: str | None) -> list[LLMAdminInfo]: - """ - 获取大模型数据列表(管理员视图) - - :param llm_id: 大模型ID - :return: 大模型管理信息列表 - """ - async with postgres.session() as session: - if llm_id: - llm_list = (await session.scalars( - select(LLMData).where( - LLMData.id == llm_id, - ), - )).all() - else: - llm_list = (await session.scalars( - select(LLMData), - )).all() - if not llm_list: - logger.error("[LLMManager] 无法找到大模型 %s", llm_id) - return [] - - # 构建管理员视图列表 - admin_list = [] - for llm in llm_list: - llm_item = LLMAdminInfo( - llmId=llm.id, - llmDescription=llm.llmDescription, - llmType=llm.llmType, - baseUrl=llm.baseUrl, - apiKey=llm.apiKey, - modelName=llm.modelName, - maxTokens=llm.maxToken, - ctxLength=llm.ctxLength, - temperature=llm.temperature, - provider=llm.provider.value if llm.provider else None, - extraConfig=llm.extraConfig, - ) - admin_list.append(llm_item) - return admin_list - - @staticmethod async def update_llm(llm_id: str, req: UpdateLLMReq) -> str: """ @@ -203,46 +176,3 @@ class LLMManager: for item in user: item.embeddingLLM = None await session.commit() - - - @staticmethod - async def update_special_llm( - user_sub: str, - req: UpdateUserSelectedLLMReq, - ) -> None: - """更新用户的默认LLM""" - # 检查embedding模型是否发生变化 - old_embedding_llm = None - new_embedding_llm = req.embeddingLLM - - async with postgres.session() as session: - user = (await session.scalars( - select(User).where(User.userSub == user_sub), - )).one_or_none() - if not user: - err = f"[LLMManager] 用户 {user_sub} 不存在" - raise ValueError(err) - - old_embedding_llm = user.embeddingLLM - user.functionLLM = req.functionLLM - user.embeddingLLM = req.embeddingLLM - await session.commit() - - # 如果embedding模型发生变化,触发向量化过程 - if old_embedding_llm != new_embedding_llm and new_embedding_llm: - try: - # 获取新的embedding模型配置 - embedding_llm_config = await LLMManager.get_llm(new_embedding_llm) - if embedding_llm_config: - # 创建Embedding实例 - embedding_model = Embedding(embedding_llm_config) - await embedding_model.init() - - # 触发向量化 - await pool.set_vector(embedding_model) - - logger.info("[LLMManager] 用户 %s 的embedding模型已更新,向量化过程已完成", user_sub) - else: - logger.error("[LLMManager] 用户 %s 选择的embedding模型 %s 不存在", user_sub, new_embedding_llm) - except Exception: - logger.exception("[LLMManager] 用户 %s 的embedding模型向量化过程失败", user_sub) diff --git a/apps/services/settings.py b/apps/services/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..594890a5cb4cf4d8aa36e8c1e389ec57a63d53a2 --- /dev/null +++ b/apps/services/settings.py @@ -0,0 +1,127 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""全局设置 Manager""" + +import asyncio +import logging + +from sqlalchemy import select + +from apps.common.postgres import postgres +from apps.llm import embedding +from apps.models import GlobalSettings, LLMType +from apps.scheduler.pool.pool import pool +from apps.schemas.request_data import UpdateSpecialLlmReq +from apps.schemas.response_data import SelectedSpecialLlmID +from apps.services.llm import LLMManager + +logger = logging.getLogger(__name__) + + +class SettingsManager: + """全局设置相关操作""" + + @staticmethod + async def get_global_llm_settings() -> SelectedSpecialLlmID: + """ + 获取全局LLM设置 + + :return: 全局设置中的functionLLM和embeddingLLM + """ + async with postgres.session() as session: + # 查询全局设置表,假设只有一条记录 + settings = (await session.scalars( + select(GlobalSettings), + )).first() + + if not settings: + logger.warning("[SettingsManager] 全局设置不存在,返回默认值") + return SelectedSpecialLlmID(functionLLM=None, embeddingLLM=None) + + return SelectedSpecialLlmID( + functionLLM=settings.functionLlmId, + embeddingLLM=settings.embeddingLlmId, + ) + + + @staticmethod + async def _trigger_vectorization(embedding_llm_id: str) -> None: + """ + 触发向量化过程(后台任务) + + :param embedding_llm_id: Embedding模型ID + """ + try: + # 获取新的embedding模型配置 + embedding_llm_config = await LLMManager.get_llm(embedding_llm_id) + if embedding_llm_config: + # 设置全局embedding配置 + await embedding.init(embedding_llm_config) + + # 触发向量化 + await pool.set_vector(embedding) + + logger.info("[SettingsManager] Embedding模型已更新,向量化过程已完成") + else: + logger.error("[SettingsManager] 选择的embedding模型 %s 不存在", embedding_llm_id) + except Exception: + logger.exception("[SettingsManager] Embedding模型向量化过程失败") + + + @staticmethod + async def update_global_llm_settings( + user_sub: str, + req: UpdateSpecialLlmReq, + ) -> None: + """ + 更新全局默认LLM(仅管理员) + + :param user_sub: 操作的管理员user_sub + :param req: 更新请求体 + """ + # 验证functionLLM是否支持Function Call + if req.functionLLM: + function_llm = await LLMManager.get_llm(req.functionLLM) + if not function_llm: + err = f"[SettingsManager] Function LLM {req.functionLLM} 不存在" + raise ValueError(err) + if LLMType.FUNCTION not in function_llm.llmType: + err = f"[SettingsManager] LLM {req.functionLLM} 不支持Function Call" + raise ValueError(err) + + # 验证embeddingLLM是否支持Embedding + if req.embeddingLLM: + embedding_llm = await LLMManager.get_llm(req.embeddingLLM) + if not embedding_llm: + err = f"[SettingsManager] Embedding LLM {req.embeddingLLM} 不存在" + raise ValueError(err) + if LLMType.EMBEDDING not in embedding_llm.llmType: + err = f"[SettingsManager] LLM {req.embeddingLLM} 不支持Embedding" + raise ValueError(err) + + # 读取旧的embedding配置 + old_embedding_llm = None + async with postgres.session() as session: + settings = (await session.scalars(select(GlobalSettings))).first() + if settings: + old_embedding_llm = settings.embeddingLlmId + else: + # 如果不存在设置记录,创建一条新记录 + settings = GlobalSettings( + functionLlmId=None, + embeddingLlmId=None, + lastEditedBy=None, + ) + session.add(settings) + + # 更新全局设置 + settings.functionLlmId = req.functionLLM + settings.embeddingLlmId = req.embeddingLLM + settings.lastEditedBy = user_sub + await session.commit() + + # 如果embedding模型发生变化,在新协程中触发向量化过程 + if old_embedding_llm != req.embeddingLLM and req.embeddingLLM: + task = asyncio.create_task(SettingsManager._trigger_vectorization(req.embeddingLLM)) + # 添加任务完成回调,避免未处理的异常 + task.add_done_callback(lambda _: None) + logger.info("[SettingsManager] 已启动后台向量化任务")