diff --git a/apps/common/lance.py b/apps/common/lance.py deleted file mode 100644 index 7e08692d3661c3e5fab2e01c7f1c2c83a3203070..0000000000000000000000000000000000000000 --- a/apps/common/lance.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""向LanceDB中存储向量化数据""" - -import lancedb -from lancedb.db import AsyncConnection -from lancedb.index import HnswSq - -from apps.common.config import Config -from apps.common.singleton import SingletonMeta -from apps.models.vector import ( - CallPoolVector, - FlowPoolVector, - NodePoolVector, - ServicePoolVector, -) -from apps.schemas.mcp import MCPToolVector, MCPVector - - -class LanceDB(metaclass=SingletonMeta): - """LanceDB向量化存储""" - _engine: AsyncConnection | None = None - - @staticmethod - async def init() -> None: - """ - 初始化LanceDB - - 此步骤包含创建LanceDB引擎、建表等操作 - - :return: 无 - """ - LanceDB._engine = await lancedb.connect_async( - Config().get_config().deploy.data_dir.rstrip("/") + "/vectors", - ) - - # 创建表 - await LanceDB._engine.create_table( - "flow", - schema=FlowPoolVector, - exist_ok=True, - ) - await LanceDB.create_index("flow") - await LanceDB._engine.create_table( - "service", - schema=ServicePoolVector, - exist_ok=True, - ) - await LanceDB.create_index("service") - await LanceDB._engine.create_table( - "call", - schema=CallPoolVector, - exist_ok=True, - ) - await LanceDB.create_index("call") - await LanceDB._engine.create_table( - "node", - schema=NodePoolVector, - exist_ok=True, - ) - await LanceDB.create_index("node") - await LanceDB._engine.create_table( - "mcp", - schema=MCPVector, - exist_ok=True, - ) - await LanceDB.create_index("mcp") - await LanceDB._engine.create_table( - "mcp_tool", - schema=MCPToolVector, - exist_ok=True, - ) - await LanceDB.create_index("mcp_tool") - - @staticmethod - async def get_table(table_name: str) -> lancedb.AsyncTable: - """ - 获取LanceDB中的表 - - :param str table_name: 表名 - :return: 表 - :rtype: lancedb.AsyncTable - """ - return await LanceDB._engine.open_table(table_name) - - @staticmethod - async def create_index(table_name: str) -> None: - """ - 创建LanceDB中表的索引;使用HNSW算法 - - :param str table_name: 表名 - :return: 无 - """ - table = await LanceDB.get_table(table_name) - await table.create_index( - "embedding", - config=HnswSq(), - ) diff --git a/apps/common/postgres.py b/apps/common/postgres.py index 7669220694dcf78e4a736ead5dfdabfd58b9a23a..b10e1db79c56dfdb2a0de4642fce97b588c11e59 100644 --- a/apps/common/postgres.py +++ b/apps/common/postgres.py @@ -1,11 +1,180 @@ -"""Postgres连接器""" +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker +from sqlalchemy import Index +import urllib.parse +from pgvector.sqlalchemy import Vector +from sqlalchemy import Column, String, func +from sqlalchemy.orm import declarative_base +from apps.common.config import Config +Base = declarative_base() -from sqlalchemy import Column -from sqlalchemy.dialects.postgresql import JSONB +class FlowPoolVector(Base): + """App向量信息""" -class PGSQL: - """Postgres连接器""" + __tablename__ = "flow_pool_vector" - def __init__(self, db_url: str): - pass + id = Column(String, primary_key=True, index=True) + app_id = Column(String, index=True) + embedding = Column(Vector(dim=1024)) # type: ignore[call-arg] + + __table_args__ = ( + Index( + 'ix_flow_pool_vector_embedding', + embedding, + postgresql_using='hnsw', + postgresql_with={'m': 32, 'ef_construction': 200}, + postgresql_ops={'embedding': 'vector_cosine_ops'} + ), + ) + + +class ServicePoolVector(Base): + """Service向量信息""" + + __tablename__ = "service_pool_vector" + + id = Column(String, primary_key=True, index=True) + embedding = Column(Vector(dim=1024)) # type: ignore[call-arg] + + __table_args__ = ( + Index( + 'ix_service_pool_vector_embedding', + embedding, + postgresql_using='hnsw', + postgresql_with={'m': 32, 'ef_construction': 200}, + postgresql_ops={'embedding': 'vector_cosine_ops'} + ), + ) + + +class CallPoolVector(Base): + """Call向量信息""" + + __tablename__ = "call_pool_vector" + + id = Column(String, primary_key=True, index=True) + service_id = Column(String, index=True) + embedding = Column(Vector(dim=1024)) # type: ignore[call-arg] + + __table_args__ = ( + Index( + 'ix_call_pool_vector_embedding', + embedding, + postgresql_using='hnsw', + postgresql_with={'m': 32, 'ef_construction': 200}, + postgresql_ops={'embedding': 'vector_cosine_ops'} + ), + ) + + +class NodePoolVector(Base): + """Node向量信息""" + + __tablename__ = "node_pool_vector" + + id = Column(String, primary_key=True, index=True) + service_id = Column(String, index=True) + embedding = Column(Vector(dim=1024)) # type: ignore[call-arg] + + __table_args__ = ( + Index( + 'ix_node_pool_vector_embedding', + embedding, + postgresql_using='hnsw', + postgresql_with={'m': 32, 'ef_construction': 200}, + postgresql_ops={'embedding': 'vector_cosine_ops'} + ), + ) + + +class McpVector(Base): + """MCP向量化数据,存储在LanceDB的 ``mcp`` 表中""" + + __tablename__ = "mcp_vector" + + id = Column(String, primary_key=True, index=True) + embedding = Column(Vector(dim=1024)) # type: ignore[call-arg] + __table_args__ = ( + Index( + 'ix_mcp_vector_embedding', + embedding, + postgresql_using='hnsw', + postgresql_with={'m': 32, 'ef_construction': 200}, + postgresql_ops={'embedding': 'vector_cosine_ops'} + ), + ) + + +class McpToolVector(Base): + """MCP工具向量化数据,存储在LanceDB的 ``mcp_tool`` 表中""" + + __tablename__ = "mcp_tool_vector" + + id = Column(String, primary_key=True, index=True) + mcp_id = Column(String, index=True) + embedding = Column(Vector(dim=1024)) # type: ignore[call-arg] + __table_args__ = ( + Index( + 'ix_mcp_tool_vector_embedding', + embedding, + postgresql_using='hnsw', + postgresql_with={'m': 32, 'ef_construction': 200}, + postgresql_ops={'embedding': 'vector_cosine_ops'} + ), + ) + + +class DataBase: + + # 对密码进行 URL 编码 + user = Config().get_config().vectordb.user + host = Config().get_config().vectordb.host + port = Config().get_config().vectordb.port + password = Config().get_config().vectordb.password + db = Config().get_config().vectordb.db + encoded_password = urllib.parse.quote_plus(password) + + if Config().get_config().vectordb.type == 'opengauss': + database_url = f"opengauss+asyncpg://{user}:{encoded_password}@{host}:{port}/{db}" + else: + database_url = f"postgresql+asyncpg://{user}:{encoded_password}@{host}:{port}/{db}" + engine = create_async_engine( + database_url, + echo=False, + pool_recycle=300, + pool_pre_ping=True + ) + init_all_table_flag = False + + @classmethod + async def init(cls) -> None: + """初始化数据库连接""" + if Config().get_config().vectordb.type == 'opengauss': + from sqlalchemy import event + from opengauss_sqlalchemy.register_async import register_vector + + @event.listens_for(DataBase.engine.sync_engine, "connect") + def connect(dbapi_connection, connection_record): + dbapi_connection.run_async(register_vector) + async with DataBase.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + @classmethod + async def get_session(cls): + if DataBase.init_all_table_flag is False: + await DataBase.init() + DataBase.init_all_table_flag = True + connection = async_sessionmaker( + DataBase.engine, expire_on_commit=False)() + return cls._ConnectionManager(connection) + + class _ConnectionManager: + def __init__(self, connection): + self.connection = connection + + async def __aenter__(self): + return self.connection + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.connection.close() diff --git a/apps/main.py b/apps/main.py index 950789f98779cfde6ea896e1783915d25624b6c4..3f9f143838c4b4ae87345d90c4424b3a81340ef8 100644 --- a/apps/main.py +++ b/apps/main.py @@ -20,7 +20,7 @@ from rich.console import Console from rich.logging import RichHandler from apps.common.config import Config -from apps.common.lance import LanceDB +from apps.common.postgres import DataBase from apps.common.wordscheck import WordsCheck from apps.llm.token import TokenCalculator from apps.routers import ( @@ -50,11 +50,12 @@ from apps.middleware.error_handler import ErrorHandlerMiddleware, create_error_h # 全局变量用于跟踪后台任务 _cleanup_task = None + async def cleanup_on_shutdown(): """应用关闭时的清理函数""" logger = logging.getLogger(__name__) logger.info("开始清理应用资源...") - + try: # 取消定期清理任务 global _cleanup_task @@ -64,30 +65,31 @@ async def cleanup_on_shutdown(): await _cleanup_task except asyncio.CancelledError: logger.info("定期清理任务已取消") - + # 清理后台任务 await cleanup_background_tasks() - + # 关闭Redis连接 from apps.common.redis_cache import RedisCache redis_cache = RedisCache() if redis_cache.is_connected(): await redis_cache.close() logger.info("Redis连接已关闭") - + except Exception as e: logger.error(f"清理应用资源时出错: {e}") - + logger.info("应用资源清理完成") + @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" # 启动时的初始化 await init_resources() - + yield - + # 关闭时的清理 await cleanup_on_shutdown() @@ -148,6 +150,7 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) + async def add_no_auth_user() -> None: """ 添加无认证用户 @@ -155,21 +158,22 @@ async def add_no_auth_user() -> None: from apps.common.mongo import MongoDB from apps.schemas.collection import User from apps.common.config import Config - + config = Config().get_config() mongo = MongoDB() user_collection = mongo.get_collection("user") - + # 使用配置文件中的no_auth设置 user_sub = config.no_auth.user_sub user_name = config.no_auth.user_name - + # 检查是否已存在管理员用户,避免重复添加 existing_admin = await user_collection.find_one({"is_admin": True}) if existing_admin: - logger.info(f"[add_no_auth_user] 管理员用户已存在: {existing_admin.get('_id')}") + logger.info( + f"[add_no_auth_user] 管理员用户已存在: {existing_admin.get('_id')}") return - + try: await user_collection.insert_one(User( _id=user_sub, @@ -181,6 +185,7 @@ async def add_no_auth_user() -> None: except Exception as e: logger.error(f"[add_no_auth_user] 添加默认用户失败: {e}") + async def set_administrator() -> None: """ 设置管理员用户 @@ -189,19 +194,19 @@ async def set_administrator() -> None: from apps.common.mongo import MongoDB from apps.schemas.collection import User from apps.common.config import Config - + config = Config().get_config() mongo = MongoDB() user_collection = mongo.get_collection("user") - + # 获取管理员配置 admin_user_sub = config.admin.user_sub admin_user_name = config.admin.user_name - + try: # 检查用户是否已存在 existing_user = await user_collection.find_one({"_id": admin_user_sub}) - + if existing_user: # 用户存在,更新 is_admin 字段为 true await user_collection.update_one( @@ -218,10 +223,11 @@ async def set_administrator() -> None: auto_execute=False ).model_dump(by_alias=True)) logger.info(f"[set_administrator] 成功添加新管理员用户: {admin_user_sub}") - + except Exception as e: logger.error(f"[set_administrator] 设置管理员用户失败: {e}") + async def clear_user_activity() -> None: """清除所有用户的活跃状态""" from apps.services.activity import Activity @@ -231,17 +237,18 @@ async def clear_user_activity() -> None: await activity_collection.delete_many({}) logging.info("清除所有用户活跃状态完成") + async def init_model_registry() -> None: """初始化模型注册表""" import os from pathlib import Path from apps.llm.model_registry import model_registry - + models_config_file = os.getenv("MODELS_CONFIG") if models_config_file is None: logger.info(f"[init_model_registry] 配置文件不存在,使用默认模型配置") return - + config_path = Path(models_config_file) if config_path.exists(): try: @@ -252,6 +259,7 @@ async def init_model_registry() -> None: else: logger.info(f"[init_model_registry] 配置文件 {config_path} 不存在,使用默认模型配置") + async def init_system_models() -> None: """初始化系统模型""" try: @@ -261,18 +269,19 @@ async def init_system_models() -> None: except Exception as e: logger.error(f"[init_system_models] 系统模型初始化失败: {e}") + async def init_resources() -> None: """初始化必要资源""" - + WordsCheck() - await LanceDB.init() + await DataBase.init() await Pool.init() TokenCalculator() - + # 初始化模型注册表和系统模型 await init_model_registry() await init_system_models() - + if Config().get_config().no_auth.enable: await add_no_auth_user() if Config().get_config().admin.enable: @@ -282,7 +291,7 @@ async def init_resources() -> None: # 初始化变量池管理器 from apps.scheduler.variable.pool_manager import initialize_pool_manager await initialize_pool_manager() - + # 🔑 新增:启动时清理遗留文件 try: logger.info("开始启动时文件清理...") @@ -290,23 +299,24 @@ async def init_resources() -> None: logger.info("启动时文件清理完成") except Exception as e: logger.error(f"启动时文件清理失败: {e}") - + # 初始化前置节点变量缓存服务 try: from apps.services.predecessor_cache_service import PredecessorCacheService, periodic_cleanup_background_tasks await PredecessorCacheService.initialize_redis() - + # 项目启动时清空所有前置节点缓存,确保使用最新的算法逻辑 await PredecessorCacheService.clear_all_predecessor_cache() - + # 启动定期清理任务 global _cleanup_task _cleanup_task = asyncio.create_task(start_periodic_cleanup()) - + logging.info("前置节点变量缓存服务初始化成功") except Exception as e: logging.warning(f"前置节点变量缓存服务初始化失败(将降级使用实时解析): {e}") + async def startup_file_cleanup(): """启动时清理遗留文件(除了已绑定历史记录的文件)""" logger = logging.getLogger(__name__) @@ -314,18 +324,18 @@ async def startup_file_cleanup(): from apps.services.document import DocumentManager from apps.scheduler.variable.type import VariableType from apps.common.mongo import MongoDB - + mongo = MongoDB() doc_collection = mongo.get_collection("document") variables_collection = mongo.get_collection("variables") record_group_collection = mongo.get_collection("record_group") conversation_collection = mongo.get_collection("conversation") - + # 获取所有文档ID all_file_ids = set() async for doc in doc_collection.find({}, {"_id": 1}): all_file_ids.add(doc["_id"]) - + # 获取已绑定历史记录的文件ID protected_file_ids = set() async for record_group in record_group_collection.find({}, {"docs": 1}): @@ -334,13 +344,13 @@ async def startup_file_cleanup(): doc_id = doc.get("id") or doc.get("_id") if doc_id: protected_file_ids.add(doc_id) - + # 获取conversation中unused_docs的文件ID(这些是暂时的,不应清理) unused_file_ids = set() async for conv in conversation_collection.find({}, {"unused_docs": 1}): unused_docs = conv.get("unused_docs", []) unused_file_ids.update(unused_docs) - + # 获取变量中引用的文件ID variable_file_ids = set() async for var_doc in variables_collection.find({ @@ -359,14 +369,14 @@ async def startup_file_cleanup(): variable_file_ids.update(file_ids) except Exception as e: logger.warning(f"解析变量文件引用失败: {e}") - + # 计算可以清理的文件:不在历史记录、不在unused_docs、不在变量中引用的 protected_ids = protected_file_ids | unused_file_ids | variable_file_ids orphaned_file_ids = all_file_ids - protected_ids - + if orphaned_file_ids: logger.info(f"启动时发现 {len(orphaned_file_ids)} 个遗留文件,开始清理") - + # 批量删除遗留文件 cleaned_count = 0 for file_id in orphaned_file_ids: @@ -382,14 +392,15 @@ async def startup_file_cleanup(): logger.warning(f"遗留文件 {file_id} 缺少user_sub信息") except Exception as e: logger.error(f"删除遗留文件 {file_id} 失败: {e}") - + logger.info(f"启动时清理了 {cleaned_count} 个遗留文件") else: logger.info("启动时未发现遗留文件") - + except Exception as e: logger.error(f"启动时文件清理失败: {e}") + async def cleanup_orphaned_files(): """清理孤儿文件(不被任何变量引用且未绑定历史记录的文件)""" logger = logging.getLogger(__name__) @@ -397,18 +408,18 @@ async def cleanup_orphaned_files(): from apps.services.document import DocumentManager from apps.scheduler.variable.type import VariableType from apps.common.mongo import MongoDB - + mongo = MongoDB() doc_collection = mongo.get_collection("document") variables_collection = mongo.get_collection("variables") record_group_collection = mongo.get_collection("record_group") conversation_collection = mongo.get_collection("conversation") - + # 获取所有文档ID all_file_ids = set() async for doc in doc_collection.find({}, {"_id": 1}): all_file_ids.add(doc["_id"]) - + # 🔑 修正:获取已绑定历史记录的文件ID(这些是受保护的) protected_file_ids = set() async for record_group in record_group_collection.find({}, {"docs": 1}): @@ -417,13 +428,13 @@ async def cleanup_orphaned_files(): doc_id = doc.get("id") or doc.get("_id") if doc_id: protected_file_ids.add(doc_id) - + # 获取conversation中unused_docs的文件ID(这些也要保护) unused_file_ids = set() async for conv in conversation_collection.find({}, {"unused_docs": 1}): unused_docs = conv.get("unused_docs", []) unused_file_ids.update(unused_docs) - + # 获取所有变量中引用的文件ID referenced_file_ids = set() async for var_doc in variables_collection.find({ @@ -442,14 +453,14 @@ async def cleanup_orphaned_files(): referenced_file_ids.update(file_ids) except Exception as e: logger.warning(f"解析变量文件引用失败: {e}") - + # 找出孤儿文件:既不在历史记录中,也不在unused_docs中,也不被变量引用 protected_ids = protected_file_ids | unused_file_ids | referenced_file_ids orphaned_file_ids = all_file_ids - protected_ids - + if orphaned_file_ids: logger.info(f"发现 {len(orphaned_file_ids)} 个孤儿文件,开始清理") - + # 批量删除孤儿文件 cleaned_count = 0 for file_id in orphaned_file_ids: @@ -465,58 +476,59 @@ async def cleanup_orphaned_files(): logger.warning(f"孤儿文件 {file_id} 缺少user_sub信息") except Exception as e: logger.error(f"删除孤儿文件 {file_id} 失败: {e}") - + logger.info(f"孤儿文件清理完成,共清理 {cleaned_count} 个文件") else: logger.debug("未发现孤儿文件") - + except Exception as e: logger.error(f"孤儿文件清理失败: {e}") + async def start_periodic_cleanup(): """启动定期清理任务""" import asyncio from apps.scheduler.variable.pool_manager import get_pool_manager from apps.common.mongo import MongoDB from datetime import datetime, timedelta - + logger = logging.getLogger(__name__) - + async def cleanup_task(): """定期清理任务""" while True: try: # 每30分钟执行一次清理 await asyncio.sleep(30 * 60) - + logger.info("开始定期清理任务") - + # 获取活跃的对话列表 mongo = MongoDB() conv_collection = mongo.get_collection("conversation") - + # 获取最近24小时内有活动的对话ID cutoff_time = datetime.utcnow() - timedelta(hours=24) active_conversations = set() - + async for conv in conv_collection.find( - {"updated_at": {"$gte": cutoff_time}}, + {"updated_at": {"$gte": cutoff_time}}, {"_id": 1} ): active_conversations.add(conv["_id"]) - + # 清理未使用的对话变量池(现在包含文件清理) pool_manager = await get_pool_manager() await pool_manager.cleanup_unused_pools(active_conversations) - + # 🔑 新增:清理孤儿文件(在document表中存在但不被任何变量引用的文件) await cleanup_orphaned_files() - + logger.info("定期清理任务完成") - + except Exception as e: logger.error(f"定期清理任务失败: {e}") - + # 启动后台任务 asyncio.create_task(cleanup_task()) @@ -527,10 +539,11 @@ if __name__ == "__main__": logger = logging.getLogger(__name__) logger.info(f"收到信号 {signum},准备关闭应用...") sys.exit(0) - + # 注册信号处理器 signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - + # 启动FastAPI - uvicorn.run(app, host="0.0.0.0", port=8002, log_level="info", log_config=None) + uvicorn.run(app, host="0.0.0.0", port=8002, + log_level="info", log_config=None) diff --git a/apps/models/vector.py b/apps/models/vector.py deleted file mode 100644 index 1cdc85c9b9fbe5aaa902e1b93238001e5bbe44c5..0000000000000000000000000000000000000000 --- a/apps/models/vector.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""向量数据库数据结构;数据将存储在LanceDB中""" - -from lancedb.pydantic import LanceModel, Vector - - -class FlowPoolVector(LanceModel): - """App向量信息""" - - id: str - app_id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] - - -class ServicePoolVector(LanceModel): - """Service向量信息""" - - id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] - - -class CallPoolVector(LanceModel): - """Call向量信息""" - - id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] - - -class NodePoolVector(LanceModel): - """Node向量信息""" - - id: str - service_id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] diff --git a/apps/routers/record.py b/apps/routers/record.py index 663708b86b2dc57a0973c8b73a3d4cb02298b08b..c50a0b482d59dba11d0c0776965a37e31cc0f05c 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -26,6 +26,8 @@ from apps.services.conversation import ConversationManager from apps.services.document import DocumentManager from apps.services.record import RecordManager from apps.services.task import TaskManager +import logging +logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/record", @@ -43,7 +45,11 @@ router = APIRouter( ) async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: """获取某个对话的所有问答对""" + import time + st = time.time() cur_conv = await ConversationManager.get_conversation_by_conversation_id(user_sub, conversation_id) + en = time.time() + logger.error("get conversation time cost: %.4f", en - st) # 判断conversation是否合法 if not cur_conv: return JSONResponse( @@ -54,9 +60,12 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ result={}, ).model_dump(exclude_none=True), ) - + st = time.time() record_group_list = await RecordManager.query_record_group_by_conversation_id(conversation_id) + en = time.time() + logger.error("query record group time cost: %.4f", en - st) result = [] + st = time.time() for record_group in record_group_list: for record in record_group.records: record_data = Security.decrypt(record.content, record.key) @@ -85,7 +94,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) if flow_step_list: tmp_record.flow = RecordFlow( - id=record.flow.flow_id, # TODO: 此处前端应该用name + id=record.flow.flow_id, recordId=record.id, flowId=record.flow.flow_id, flowName=record.flow.flow_name, @@ -106,6 +115,8 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ ) result.append(tmp_record) + en = time.time() + logger.error("process record time cost: %.4f", en - st) return JSONResponse( status_code=status.HTTP_200_OK, content=RecordListRsp( diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 5b846efe41ed2a1374eb6ffabaf993b8f5fd9b96..7cbb719990cafc7dfe698815e4d3559017c929e8 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -6,11 +6,12 @@ import logging from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from apps.common.lance import LanceDB +from apps.common.postgres import DataBase from apps.common.mongo import MongoDB from apps.llm.embedding import Embedding from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM +from apps.services.vector import VectorManager from apps.scheduler.mcp.prompt import ( MCP_SELECT, ) @@ -32,136 +33,17 @@ class MCPSelector: self.input_tokens = 0 self.output_tokens = 0 - @staticmethod - def _assemble_sql(mcp_list: list[str]) -> str: - """组装SQL""" - sql = "(" - for mcp_id in mcp_list: - sql += f"'{mcp_id}', " - return sql.rstrip(", ") + ")" - - async def _get_top_mcp_by_embedding( - self, - query: str, - mcp_list: list[str], - ) -> list[dict[str, str]]: - """通过向量检索获取Top5 MCP Server""" - logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list) - mcp_table = await LanceDB.get_table("mcp") - query_embedding = await Embedding.get_embedding([query]) - mcp_vecs = ( - await ( - await mcp_table.search( - query=query_embedding, - vector_column_name="embedding", - ) - ) - .where(f"id IN {MCPSelector._assemble_sql(mcp_list)}") - .limit(5) - .to_list() - ) - - # 拿到名称和description - logger.info("[MCPHelper] 查询MCP Server名称和描述: %s", mcp_vecs) - mcp_collection = MongoDB().get_collection("mcp") - llm_mcp_list: list[dict[str, str]] = [] - for mcp_vec in mcp_vecs: - mcp_id = mcp_vec["id"] - mcp_data = await mcp_collection.find_one({"_id": mcp_id}) - if not mcp_data: - logger.warning("[MCPHelper] 查询MCP Server名称和描述失败: %s", mcp_id) - continue - mcp_data = MCPCollection.model_validate(mcp_data) - llm_mcp_list.extend([{ - "id": mcp_id, - "name": mcp_data.name, - "description": mcp_data.description, - }]) - return llm_mcp_list - - async def _get_mcp_by_llm( - self, query: str, mcp_list: list[dict[str, str]], mcp_ids: list[str], language: LanguageType = LanguageType.CHINESE - ) -> MCPSelectResult: - """通过LLM选择最合适的MCP Server""" - # 初始化jinja2环境 - env = SandboxedEnvironment( - loader=BaseLoader, - autoescape=True, - trim_blocks=True, - lstrip_blocks=True, - ) - template = env.from_string(MCP_SELECT[language]) - # 渲染模板 - mcp_prompt = template.render( - mcp_list=mcp_list, - goal=query, - ) - - # 调用大模型进行推理 - result = await self._call_reasoning(mcp_prompt) - - # 使用小模型提取JSON - return await self._call_function_mcp(result, mcp_ids) - - async def _call_reasoning(self, prompt: str) -> str: - """调用大模型进行推理""" - logger.info("[MCPHelper] 调用推理大模型") - llm = ReasoningLLM() - message = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ] - result = "" - async for chunk in llm.call(message): - result += chunk - self.input_tokens += llm.input_tokens - self.output_tokens += llm.output_tokens - return result - - async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: - """调用结构化输出小模型提取JSON""" - logger.info("[MCPHelper] 调用结构化输出小模型") - llm = FunctionLLM() - message = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": reasoning_result}, - ] - schema = MCPSelectResult.model_json_schema() - # schema中加入选项 - schema["properties"]["mcp_id"]["enum"] = mcp_ids - result = await llm.call(messages=message, schema=schema) - try: - result = MCPSelectResult.model_validate(result) - except Exception: - logger.exception("[MCPHelper] 解析MCP Select Result失败") - raise - return result - - async def select_top_mcp( - self, query: str, mcp_list: list[str], language: LanguageType = LanguageType.CHINESE - ) -> MCPSelectResult: - """ - 选择最合适的MCP Server - - 先通过Embedding选择Top5,然后通过LLM选择Top 1 - """ - # 通过向量检索获取Top5 - llm_mcp_list = await self._get_top_mcp_by_embedding(query, mcp_list) - - # 通过LLM选择最合适的 - return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list, language) - @staticmethod async def select_top_tool( query: str, mcp_list: list[str], top_n: int = 10 ) -> list[MCPTool]: """选择最合适的工具""" - tool_vector = await LanceDB.get_table("mcp_tool") query_embedding = await Embedding.get_embedding([query]) - tool_vecs = await (await tool_vector.search( - query=query_embedding, - vector_column_name="embedding", - )).where(f"mcp_id IN {MCPSelector._assemble_sql(mcp_list)}").limit(top_n).to_list() + tool_vecs = await VectorManager.select_topk_mcp_tool_by_mcp_ids( + vector=query_embedding[0], + mcp_ids=mcp_list, + top_k=top_n, + ) # 拿到工具 tool_collection = MongoDB().get_collection("mcp") @@ -169,11 +51,11 @@ class MCPSelector: for tool_vec in tool_vecs: # 到MongoDB里找对应的工具 - logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"]) + logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec.mcp_id) tool_data = await tool_collection.aggregate([ - {"$match": {"_id": tool_vec["mcp_id"]}}, + {"$match": {"_id": tool_vec.mcp_id}}, {"$unwind": "$tools"}, - {"$match": {"tools.id": tool_vec["id"]}}, + {"$match": {"tools.id": tool_vec.id}}, {"$project": {"_id": 0, "tools": 1}}, {"$replaceRoot": {"newRoot": "$tools"}}, ]) diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index fd97625271ace53c500f339f5d4f976dce4fcf83..45bee6773378d93e3c48841f802b9e4b44318807 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -7,15 +7,14 @@ import logging import sys from hashlib import shake_128 from pathlib import Path - import apps.scheduler.call as system_call +from apps.common.postgres import DataBase, CallPoolVector +from apps.services.vector import VectorManager from apps.common.config import Config from apps.common.singleton import SingletonMeta -from apps.schemas.enum_var import CallType +from apps.schemas.enum_var import CallType, VectorPoolType from apps.schemas.pool import CallPool, NodePool -from apps.models.vector import CallPoolVector from apps.llm.embedding import Embedding -from apps.common.lance import LanceDB from apps.common.mongo import MongoDB logger = logging.getLogger(__name__) @@ -158,19 +157,11 @@ class CallLoader(metaclass=SingletonMeta): err = f"[CallLoader] 从MongoDB删除Call失败:{e}" logger.exception(err) raise RuntimeError(err) from e - - # 从LanceDB中删除 - while True: - try: - table = await LanceDB.get_table("call") - await table.delete(f"id = '{call_name}'") - break - except RuntimeError as e: - if "Commit conflict" in str(e): - logger.error("[CallLoader] LanceDB删除call冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + # 从PostgreSQL/OpenGauss中删除 + await VectorManager.delete_vectors( + vector_type=VectorPoolType.CALL, + ids=[call_name], + ) # 更新数据库 @@ -201,42 +192,27 @@ class CallLoader(metaclass=SingletonMeta): err = "[CallLoader] 更新MongoDB失败" logger.exception(err) raise RuntimeError(err) from e - - while True: - try: - table = await LanceDB.get_table("call") - # 删除重复的ID - for call in call_metadata: - await table.delete(f"id = '{call.id}'") - break - except RuntimeError as e: - if "Commit conflict" in str(e): - logger.error("[CallLoader] LanceDB插入call冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + call_ids = [call.id for call in call_metadata] + await VectorManager.delete_vectors( + vector_type=VectorPoolType.CALL, + ids=call_ids, + ) # 进行向量化,更新LanceDB call_vecs = await Embedding.get_embedding(call_descriptions) - vector_data = [] + vector_entites = [] for i, vec in enumerate(call_vecs): - vector_data.append( + vector_entites.append( CallPoolVector( id=call_metadata[i].id, embedding=vec, ), ) - while True: - try: - table = await LanceDB.get_table("call") - await table.add(vector_data) - break - except RuntimeError as e: - if "Commit conflict" in str(e): - logger.error("[CallLoader] LanceDB插入call冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + # 分批次插入,防止数据量过大时插入失败 + BATCH_SIZE = 1024 + for i in range(0, len(vector_entites), BATCH_SIZE): + batch = vector_entites[i:i + BATCH_SIZE] + await VectorManager.add_vectors(batch) async def load(self) -> None: """初始化Call信息""" diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 724375b843f39b5e678352a5a01fcd417f5dc976..a332a84ceaf5f5a62b4f6e0e842ef0fcb7dda657 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -11,13 +11,13 @@ import yaml from anyio import Path from apps.common.config import Config -from apps.schemas.enum_var import NodeType, EdgeType +from apps.schemas.enum_var import NodeType, EdgeType, VectorPoolType from apps.schemas.flow import AppFlow, Flow from apps.schemas.pool import AppPool -from apps.models.vector import FlowPoolVector from apps.llm.embedding import Embedding +from apps.services.vector import VectorManager from apps.services.node import NodeManager -from apps.common.lance import LanceDB +from apps.common.postgres import DataBase, FlowPoolVector from apps.common.mongo import MongoDB from apps.scheduler.util import yaml_enum_presenter, yaml_str_presenter from apps.schemas.subflow import AppSubFlow @@ -240,12 +240,11 @@ class FlowLoader: except Exception: logger.exception("[FlowLoader] 删除工作流文件失败:%s", flow_path) return False - - table = await LanceDB.get_table("flow") - try: - await table.delete(f"id = '{flow_id}'") - except Exception: - logger.exception("[FlowLoader] LanceDB删除flow失败") + # 从数据库中删除 + await VectorManager.delete_vectors( + vector_type=VectorPoolType.FLOW, + ids=[flow_id], + ) return True logger.warning("[FlowLoader] 工作流文件不存在或不是文件:%s", flow_path) return True @@ -296,68 +295,27 @@ class FlowLoader: except Exception: logger.exception("[FlowLoader] 更新 MongoDB 失败") - # 删除重复的ID,增加重试次数限制 - max_retries = 10 - retry_count = 0 import time st = time.time() - while retry_count < max_retries: - try: - table = await LanceDB.get_table("flow") - await table.delete(f"id = '{metadata.id}'") - break - except RuntimeError as e: - if "Commit conflict" in str(e): - retry_count += 1 - logger.error(f"[FlowLoader] LanceDB删除flow冲突,重试中... ({retry_count}/{max_retries})") # noqa: TRY400 - # 指数退避,减少冲突概率 - await asyncio.sleep(0.01 * (2 ** min(retry_count, 5))) - else: - raise - except Exception as e: - logger.error(f"[FlowLoader] LanceDB删除操作异常: {e}") - break + await VectorManager.delete_vectors( + vector_type=VectorPoolType.FLOW, + ids=[metadata.id], + ) en = time.time() - logger.error(f"[FlowLoader] LanceDB删除flow耗时: {en-st} 秒") - if retry_count >= max_retries: - logger.warning( - f"[FlowLoader] LanceDB删除flow达到最大重试次数,跳过删除: {metadata.id}") - # 不抛出异常,继续执行后续操作 + logger.error(f"[FlowLoader] PostgreSQL/OpenGauss删除flow耗时: {en-st} 秒") + + # 不抛出异常,继续执行后续操作 # 进行向量化 service_embedding = await Embedding.get_embedding([metadata.description]) - vector_data = [ - FlowPoolVector( - id=metadata.id, - app_id=app_id, - embedding=service_embedding[0], - ), - ] st = time.time() - # 插入向量数据,增加重试次数限制 - max_retries_insert = 10 - retry_count_insert = 0 - while retry_count_insert < max_retries_insert: - try: - table = await LanceDB.get_table("flow") - await table.add(vector_data) - break - except RuntimeError as e: - if "Commit conflict" in str(e): - retry_count_insert += 1 - logger.error(f"[FlowLoader] LanceDB插入flow冲突,重试中... ({retry_count_insert}/{max_retries_insert})") # noqa: TRY400 - # 指数退避,减少冲突概率 - await asyncio.sleep(0.01 * (2 ** min(retry_count_insert, 5))) - else: - raise - except Exception as e: - logger.error(f"[FlowLoader] LanceDB插入操作异常: {e}") - break + flow_pool_vector_entity = FlowPoolVector( + id=metadata.id, + app_id=app_id, + embedding=service_embedding[0], + ) + await VectorManager.add_vector(flow_pool_vector_entity) en = time.time() - logger.error(f"[FlowLoader] LanceDB插入flow耗时: {en-st} 秒") - if retry_count_insert >= max_retries_insert: - logger.error( - f"[FlowLoader] LanceDB插入flow达到最大重试次数,操作失败: {metadata.id}") - raise RuntimeError(f"LanceDB插入flow失败,达到最大重试次数: {metadata.id}") + logger.error(f"[FlowLoader] PostgreSQL/OpenGauss添加flow耗时: {en-st} 秒") async def save_subflow(self, app_id: str, flow_id: str, sub_flow_id: str, flow: Flow) -> None: """保存子工作流到层次化路径""" diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index fb6a24cdd640f0fa65c8fa64c420084a9fca4c48..9b187689c507e42e9fea6668583d9acb540f0788 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -13,14 +13,16 @@ from anyio import Path from sqids.sqids import Sqids from typing import Any -from apps.common.lance import LanceDB +from apps.common.postgres import DataBase, McpVector, McpToolVector from apps.common.mongo import MongoDB from apps.common.process_handler import ProcessHandler from apps.common.singleton import SingletonMeta from apps.constants import MCP_PATH from apps.llm.embedding import Embedding +from apps.services.vector import VectorManager from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.install import install_npx, install_uvx +from apps.schemas.enum_var import VectorPoolType from apps.schemas.mcp import ( MCPCollection, MCPInstallStatus, @@ -28,9 +30,7 @@ from apps.schemas.mcp import ( MCPServerSSEConfig, MCPServerStdioConfig, MCPTool, - MCPToolVector, - MCPType, - MCPVector, + MCPType ) logger = logging.getLogger(__name__) @@ -313,46 +313,28 @@ class MCPLoader(metaclass=SingletonMeta): # 服务本身向量化 embedding = await Embedding.get_embedding([config.description]) - - while True: - try: - mcp_table = await LanceDB.get_table("mcp") - await mcp_table.add( - [MCPVector( - id=mcp_id, - embedding=embedding[0], - )] - ) - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[MCPLoader] LanceDB插入mcp冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + mcp_vector_entity = McpVector( + id=mcp_id, + embedding=embedding[0], + ) + await VectorManager.add_vector(mcp_vector_entity) # 工具向量化 tool_desc_list = [tool.description for tool in tool_list] tool_embedding = await Embedding.get_embedding(tool_desc_list) + mcp_tool_entity_list = [] for tool, embedding in zip(tool_list, tool_embedding, strict=True): - while True: - try: - mcp_tool_table = await LanceDB.get_table("mcp_tool") - await mcp_tool_table.add( - [MCPToolVector( - id=tool.id, - mcp_id=mcp_id, - embedding=embedding, - )] - ) - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[MCPLoader] LanceDB插入mcp_tool冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise - await LanceDB.create_index("mcp_tool") + mcp_tool_entity_list.append( + McpToolVector( + id=tool.id, + mcp_id=mcp_id, + embedding=embedding, + ) + ) + BATCH_SIZE = 1024 + for i in range(0, len(mcp_tool_entity_list), BATCH_SIZE): + batch = mcp_tool_entity_list[i:i + BATCH_SIZE] + await VectorManager.add_vectors(batch) @staticmethod async def save_one(mcp_id: str, config: MCPServerConfig) -> None: @@ -576,20 +558,9 @@ class MCPLoader(metaclass=SingletonMeta): await mcp_collection.delete_many({"_id": {"$in": deleted_mcp_list}}) logger.info("[MCPLoader] 清除数据库中无效的MCP") - # 从LanceDB中移除 - for mcp_id in deleted_mcp_list: - while True: - try: - mcp_table = await LanceDB.get_table("mcp") - await mcp_table.delete(f"id == '{mcp_id}'") - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[MCPLoader] LanceDB删除mcp冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise - logger.info("[MCPLoader] 清除LanceDB中无效的MCP") + # 从PostgreSQL/OpenGauss中移除 + await VectorManager.delete_mcp_tool_vectors_by_mcp_ids(deleted_mcp_list) + logger.info("[MCPLoader] 清除PostgreSQL/OpenGauss中无效的MCP向量数据") @staticmethod async def delete_mcp(mcp_id: str) -> None: @@ -603,6 +574,11 @@ class MCPLoader(metaclass=SingletonMeta): template_path = MCP_PATH / "template" / mcp_id if await template_path.exists(): await asyncer.asyncify(shutil.rmtree)(template_path.as_posix(), ignore_errors=True) + # 从PostgreSQL/OpenGauss中删除 + await VectorManager.delete_vectors( + vector_type=VectorPoolType.MCP, + ids=[mcp_id], + ) @staticmethod async def _load_user_mcp() -> None: diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index cf63aa27edd799d9aa82db86819febcf135b03b8..d0d20318238d1a809111a147c2f77f37497fcf35 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -8,14 +8,14 @@ import shutil from anyio import Path from fastapi.encoders import jsonable_encoder - +from apps.common.postgres import ServicePoolVector, NodePoolVector from apps.common.config import Config +from apps.schemas.enum_var import VectorPoolType from apps.schemas.flow import Permission, ServiceMetadata from apps.schemas.pool import NodePool, ServicePool -from apps.models.vector import NodePoolVector, ServicePoolVector from apps.llm.embedding import Embedding -from apps.common.lance import LanceDB from apps.common.mongo import MongoDB +from apps.services.vector import VectorManager from apps.scheduler.pool.check import FileChecker from apps.scheduler.pool.loader.metadata import MetadataLoader, MetadataType from apps.scheduler.pool.loader.openapi import OpenAPILoader @@ -82,16 +82,13 @@ class ServiceLoader: except Exception: logger.exception("[ServiceLoader] 删除Service失败") - try: - # 获取 LanceDB 表 - service_table = await LanceDB.get_table("service") - node_table = await LanceDB.get_table("node") - - # 删除数据 - await service_table.delete(f"id = '{service_id}'") - await node_table.delete(f"id = '{service_id}'") - except Exception: - logger.exception("[ServiceLoader] 删除数据库失败") + await VectorManager.delete_vectors( + vector_type=VectorPoolType.SERVICE, + ids=[service_id], + ) + await VectorManager.delete_call_vectors_by_service_ids( + service_ids=[service_id], + ) if not is_reload: path = BASE_PATH / service_id @@ -136,62 +133,37 @@ class ServiceLoader: raise RuntimeError(err) from e # 向量化所有数据并保存 - while True: - try: - service_table = await LanceDB.get_table("service") - node_table = await LanceDB.get_table("node") - await service_table.delete(f"id = '{metadata.id}'") - await node_table.delete(f"service_id = '{metadata.id}'") - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[ServiceLoader] LanceDB删除service冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + await VectorManager.delete_vectors( + vector_type=VectorPoolType.SERVICE, + ids=[metadata.id], + ) + await VectorManager.delete_call_vectors_by_service_ids( + service_ids=[metadata.id], + ) # 进行向量化,更新LanceDB service_vecs = await Embedding.get_embedding([metadata.description]) - service_vector_data = [ - ServicePoolVector( - id=metadata.id, - embedding=service_vecs[0], - ), - ] - while True: - try: - service_table = await LanceDB.get_table("service") - await service_table.add(service_vector_data) - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[ServiceLoader] LanceDB插入service冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + service_vector_pool_entity = ServicePoolVector( + id=metadata.id, + embedding=service_vecs[0], + ) + await VectorManager.add_vector(service_vector_pool_entity) node_descriptions = [] for node in nodes: node_descriptions += [node.description] node_vecs = await Embedding.get_embedding(node_descriptions) - node_vector_data = [] + node_vector_pool_entities = [] for i, vec in enumerate(node_vecs): - node_vector_data.append( + node_vector_pool_entities.append( NodePoolVector( id=nodes[i].id, service_id=metadata.id, embedding=vec, - ), + ) ) - while True: - try: - node_table = await LanceDB.get_table("node") - await node_table.add(node_vector_data) - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[ServiceLoader] LanceDB插入node冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + BATCH_SIZE = 1024 + for i in range(0, len(node_vector_pool_entities), BATCH_SIZE): + batch = node_vector_pool_entities[i:i + BATCH_SIZE] + await VectorManager.add_vectors(batch) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 439212aa035557dadd97381040a9a1e148092030..51d6c7d6ee54abcd48ffef1f55fcacbfeb409f19 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -283,9 +283,12 @@ class Scheduler: return # 如果用户选了特定的Flow - if app_info.flow_id: + if app_info.flow_id or len(flow_info) == 1: logger.info("[Scheduler] 获取工作流定义") - flow_id = app_info.flow_id + if app_info.flow_id: + flow_id = app_info.flow_id + else: + flow_id = flow_info[0].id flow_data = await Pool().get_flow(app_info.app_id, flow_id) else: # 如果用户没有选特定的Flow,则根据语义选择一个Flow diff --git a/apps/schemas/config.py b/apps/schemas/config.py index 95c9a72cd02eb8742fe3dece7a548f3ab03854b4..b685fc42454727254ae95da5277a6167dd936e9a 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -140,6 +140,17 @@ class RedisConfig(BaseModel): health_check_interval: int = Field(description="健康检查间隔(秒)", default=30) +class VectorDBConfig(BaseModel): + """向量数据库配置""" + # postgres或者openGuass + type: str = Field(description="向量数据库类型", pattern="^(postgres|opengauss)$") + host: str = Field(description="向量数据库主机名") + port: int = Field(description="向量数据库端口号", default=5432) + user: str = Field(description="向量数据库用户名") + password: str = Field(description="向量数据库密码") + db: str = Field(description="向量数据库名称") + + class LLMConfig(BaseModel): """LLM配置""" @@ -214,6 +225,7 @@ class ConfigModel(BaseModel): minio: MinioConfig mongodb: MongoDBConfig redis: RedisConfig + vectordb: VectorDBConfig llm: LLMConfig function_call: FunctionCallConfig mcp_config: McpConfig = Field(description="MCP配置", default=McpConfig()) diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index cb81f556c7c7e199dcd76dc07a4a03dc8a6490cb..6ece78ba66283e9e589d4f015d5f4f6b01c4a4a3 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -4,6 +4,17 @@ from enum import Enum +class VectorPoolType(str, Enum): + """向量池类型""" + + FLOW = "flow" + SERVICE = "service" + CALL = "call" + NODE = "node" + MCP = "mcp" + MCP_TOOL = "mcp_tool" + + class SlotType(str, Enum): """Slot类型""" @@ -224,4 +235,4 @@ class LanguageType(str, Enum): """语言类型""" CHINESE = "zh" - ENGLISH = "en" \ No newline at end of file + ENGLISH = "en" diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 43277e0dbdbaca57fd6eea1bd240699ccb8e3406..5955fa7a98fa0f350b88ad882ee758a1d2f6d6a7 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -3,8 +3,6 @@ from enum import Enum from typing import Any - -from lancedb.pydantic import LanceModel, Vector from pydantic import BaseModel, Field @@ -37,10 +35,12 @@ class MCPType(str, Enum): class MCPBasicConfig(BaseModel): """MCP 基本配置""" - auto_approve: list[str] = Field(description="自动批准的MCP工具ID列表", default=[], alias="autoApprove") + auto_approve: list[str] = Field( + description="自动批准的MCP工具ID列表", default=[], alias="autoApprove") disabled: bool = Field(description="MCP 服务器是否禁用", default=False) auto_install: bool = Field(description="是否自动安装MCP服务器", default=True) - timeout: int = Field(description="MCP 服务器超时时间(秒)", default=60, alias="timeout") + timeout: int = Field(description="MCP 服务器超时时间(秒)", + default=60, alias="timeout") description: str = Field(description="MCP 服务器自然语言描述", default="") @@ -55,7 +55,8 @@ class MCPServerStdioConfig(MCPBasicConfig): class MCPServerSSEConfig(MCPBasicConfig): """MCP 服务器配置""" headers: dict[str, str] = Field(description="MCP 服务器请求头", default={}) - url: str = Field(description="MCP 服务器地址", default="http://example.com/sse", pattern=r"^https?://.*$") + url: str = Field(description="MCP 服务器地址", + default="http://example.com/sse", pattern=r"^https?://.*$") class MCPServerConfig(BaseModel): @@ -66,8 +67,10 @@ class MCPServerConfig(BaseModel): description: str = Field(description="MCP 服务器自然语言描述", default="") type: MCPType = Field(description="MCP 服务器类型", default=MCPType.STDIO) author: str = Field(description="MCP 服务器上传者", default="") - author_name: str = Field(description="MCP 服务器上传者用户名", default="", alias="authorName") - config: MCPServerStdioConfig | MCPServerSSEConfig = Field(description="MCP 服务器配置") + author_name: str = Field(description="MCP 服务器上传者用户名", + default="", alias="authorName") + config: MCPServerStdioConfig | MCPServerSSEConfig = Field( + description="MCP 服务器配置") class MCPTool(BaseModel): @@ -89,24 +92,11 @@ class MCPCollection(BaseModel): type: MCPType = Field(description="MCP 类型", default=MCPType.SSE) activated: list[str] = Field(description="激活该MCP的用户ID列表", default=[]) tools: list[MCPTool] = Field(description="MCP工具列表", default=[]) - status: MCPInstallStatus = Field(description="MCP服务状态", default=MCPInstallStatus.INIT) + status: MCPInstallStatus = Field( + description="MCP服务状态", default=MCPInstallStatus.INIT) author: str = Field(description="MCP作者", default="") - author_name: str = Field(description="MCP作者用户名", default="", alias="authorName") - - -class MCPVector(LanceModel): - """MCP向量化数据,存储在LanceDB的 ``mcp`` 表中""" - - id: str = Field(description="MCP ID") - embedding: Vector(dim=1024) = Field(description="MCP描述的向量信息") # type: ignore[call-arg] - - -class MCPToolVector(LanceModel): - """MCP工具向量化数据,存储在LanceDB的 ``mcp_tool`` 表中""" - - id: str = Field(description="工具ID") - mcp_id: str = Field(description="MCP ID") - embedding: Vector(dim=1024) = Field(description="MCP工具描述的向量信息") # type: ignore[call-arg] + author_name: str = Field(description="MCP作者用户名", + default="", alias="authorName") class GoalEvaluationResult(BaseModel): @@ -167,7 +157,8 @@ class ErrorType(str, Enum): class ToolExcutionErrorType(BaseModel): """MCP工具执行错误""" - type: ErrorType = Field(description="错误类型", default=ErrorType.MISSING_PARAM) + type: ErrorType = Field( + description="错误类型", default=ErrorType.MISSING_PARAM) reason: str = Field(description="错误原因", default="") @@ -213,4 +204,4 @@ class Step(BaseModel): """MCP步骤""" tool_id: str = Field(description="工具ID") - description: str = Field(description="步骤描述,15个字以下") \ No newline at end of file + description: str = Field(description="步骤描述,15个字以下") diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 8336d7ac38cdeb5fa89e2cbfcf495c0bb4e3bbad..02702f2fe0efeeba5eb632a25a87b8852bdf0c78 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -14,6 +14,7 @@ from apps.services.llm import LLMManager from apps.services.task import TaskManager from apps.templates.generate_llm_operator_config import llm_provider_dict from apps.llm.adapters import get_provider_from_endpoint +from apps.llm.schema import DefaultModelId logger = logging.getLogger(__name__) @@ -47,35 +48,24 @@ class ConversationManager: """通过用户ID新建对话""" if not llm_id: # 获取系统默认模型的UUID - from apps.services.llm import LLMManager - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") - config = Config().get_config() - - # 查找系统默认chat模型 - system_llm = await llm_collection.find_one({ - "user_sub": "", - "type": "chat", - "model_name": config.llm.model - }) - + llm = await LLMManager.get_llm_by_id(DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) llm_item = LLMItem( - llm_id=str(system_llm["_id"]), - model_name=system_llm["model_name"], - icon=system_llm["icon"], + llm_id=llm.id, + model_name=llm.model_name, + icon=llm.icon, ) else: # 首先尝试通过用户ID和LLM ID查找 - try: - llm = await LLMManager.get_llm_by_user_sub_and_id(user_sub, llm_id) - except ValueError: - # 如果用户级别的LLM不存在,尝试查找系统级别的LLM - logger.info(f"[ConversationManager] 用户级别LLM {llm_id} 不存在,尝试查找系统级别LLM") + default_model_ids = [item.value for item in DefaultModelId] + if llm_id in default_model_ids: + llm = await LLMManager.get_llm_by_id(llm_id) + else: try: - llm = await LLMManager.get_llm_by_id(llm_id) + llm = await LLMManager.get_llm_by_user_sub_and_id(user_sub, llm_id) except ValueError: - logger.error(f"[ConversationManager] 系统级别LLM {llm_id} 也不存在") - llm = None + # 如果用户级别的LLM不存在,尝试查找系统级别的LLM + logger.info( + f"[ConversationManager] 用户级别LLM {llm_id} 不存在,尝试查找系统级别LLM") if llm is None: logger.error("[ConversationManager] 获取大模型失败") return None diff --git a/apps/services/document.py b/apps/services/document.py index 5d4238583a03467797ba8a53bda20f0b24afb5ea..980eeac17dfaf64ce71cd0ebdfe2e0ea1ddf7ece 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -57,9 +57,10 @@ class DocumentManager: async def validate_file_upload_for_variable(cls, user_sub: str, documents: list[UploadFile], var_name: str, var_type: str, scope: str, conversation_id: Optional[str] = None, flow_id: Optional[str] = None) -> tuple[bool, str]: """验证文件上传是否符合变量限制""" try: - logger.info(f"开始验证文件上传: scope={scope}, var_name={var_name}, conversation_id={conversation_id}, flow_id={flow_id}") + logger.info( + f"开始验证文件上传: scope={scope}, var_name={var_name}, conversation_id={conversation_id}, flow_id={flow_id}") pool_manager = await get_pool_manager() - + # 根据scope获取变量池 if scope == "system": pool = await pool_manager.get_conversation_pool(conversation_id) @@ -67,7 +68,8 @@ class DocumentManager: # 兜底:如果conversation pool不存在,自动创建 if not flow_id: flow_id = await cls._get_flow_id_for_conversation(conversation_id) - logger.info(f"自动创建system scope的conversation pool: {conversation_id}, flow_id={flow_id}") + logger.info( + f"自动创建system scope的conversation pool: {conversation_id}, flow_id={flow_id}") pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) elif scope == "user": pool = await pool_manager.get_user_pool(user_sub) @@ -79,63 +81,69 @@ class DocumentManager: # 兜底:如果conversation pool不存在,自动创建 if not flow_id: flow_id = await cls._get_flow_id_for_conversation(conversation_id) - logger.info(f"自动创建conversation scope的conversation pool: {conversation_id}, flow_id={flow_id}") + logger.info( + f"自动创建conversation scope的conversation pool: {conversation_id}, flow_id={flow_id}") pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) else: return False, f"不支持的scope类型: {scope}" - + if not pool: return False, f"无法获取变量池,scope: {scope}" - + # 获取变量 variable = await pool.get_variable(var_name) if variable is None: logger.error(f"变量 {var_name} 不存在") return False, f"变量 {var_name} 不存在" - + logger.info(f"找到变量: {var_name}, 类型: {variable.metadata.var_type}") - + # 检查变量类型 if variable.metadata.var_type not in [VariableType.FILE, VariableType.ARRAY_FILE]: return False, f"变量 {var_name} 不是文件类型" - + # 获取变量配置 var_config = variable.value if not isinstance(var_config, dict): - logger.error(f"变量 {var_name} 配置格式不正确: {type(var_config)}, 值: {var_config}") + logger.error( + f"变量 {var_name} 配置格式不正确: {type(var_config)}, 值: {var_config}") return False, f"变量 {var_name} 配置格式不正确" - + logger.info(f"变量配置: {var_config}") - + # 1. 检查文件数量限制 - max_files = var_config.get("max_files", 1 if variable.metadata.var_type == VariableType.FILE else 10) + max_files = var_config.get( + "max_files", 1 if variable.metadata.var_type == VariableType.FILE else 10) if len(documents) > max_files: return False, f"文件数量超过限制:最多允许 {max_files} 个文件,实际上传 {len(documents)} 个" - + # 2. 检查文件大小限制 - max_file_size = var_config.get("max_file_size", 10 * 1024 * 1024) # 默认10MB + max_file_size = var_config.get( + "max_file_size", 10 * 1024 * 1024) # 默认10MB for doc in documents: if doc.size and doc.size > max_file_size: max_size_mb = max_file_size / (1024 * 1024) return False, f"文件大小超过限制:{doc.filename} 大小为 {doc.size / (1024 * 1024):.1f}MB,最大允许 {max_size_mb:.1f}MB" - + # 3. 检查文件类型限制 supported_types = var_config.get("supported_types", []) if supported_types: for doc in documents: if doc.filename: - file_ext = "." + doc.filename.split(".")[-1].lower() if "." in doc.filename else "" + file_ext = "." + \ + doc.filename.split( + ".")[-1].lower() if "." in doc.filename else "" if file_ext not in supported_types: return False, f"文件类型不支持:{doc.filename},支持的类型:{supported_types}" - + # 4. 检查上传方式限制(当前只支持文件流上传) upload_methods = var_config.get("upload_methods", ["manual"]) if "manual" not in upload_methods: return False, f"当前上传方式不支持,支持的上传方式:{upload_methods}" - + logger.info(f"文件上传验证通过: {var_name}") return True, "验证通过" - + except Exception as e: logger.exception(f"文件上传验证失败: {e}") return False, f"验证过程发生错误: {str(e)}" @@ -143,7 +151,7 @@ class DocumentManager: @classmethod async def storage_docs(cls, user_sub: str, conversation_id: str, documents: list[UploadFile], scope: str = "system", var_name: Optional[str] = None, var_type: Optional[str] = None) -> list[Document]: """存储多个文件 - 支持system和conversation scope - + 🔑 优化说明:支持前端两阶段处理逻辑 - 前端第一阶段:已通过updateVariable清空file_id/file_ids - 这里:上传文件到MinIO/MongoDB,然后更新变量池的file_id/file_ids @@ -153,19 +161,22 @@ class DocumentManager: mongo = MongoDB() doc_collection = mongo.get_collection("document") conversation_collection = mongo.get_collection("conversation") - + # 🔑 第一步:上传文件到MinIO和MongoDB for document in documents: if document.filename is None or document.size is None: - logger.warning(f"跳过无效文件: filename={document.filename}, size={document.size}") + logger.warning( + f"跳过无效文件: filename={document.filename}, size={document.size}") continue file_id = str(uuid.uuid4()) try: mime = await asyncer.asyncify(cls._storage_single_doc_minio)(file_id, document) - logger.info(f"文件上传到MinIO成功: {file_id}, filename={document.filename}") + logger.info( + f"文件上传到MinIO成功: {file_id}, filename={document.filename}") except Exception as e: - logger.exception(f"[DocumentManager] 上传文件到MinIO失败: filename={document.filename}, error={e}") + logger.exception( + f"[DocumentManager] 上传文件到MinIO失败: filename={document.filename}, error={e}") continue # 保存到MongoDB @@ -187,7 +198,8 @@ class DocumentManager: ) logger.info(f"文件元数据保存到MongoDB成功: {file_id}") except Exception as e: - logger.exception(f"[DocumentManager] 保存文件元数据到MongoDB失败: file_id={file_id}, error={e}") + logger.exception( + f"[DocumentManager] 保存文件元数据到MongoDB失败: file_id={file_id}, error={e}") # 尝试清理MinIO中的文件 try: from apps.common.minio import MinioClient @@ -203,16 +215,18 @@ class DocumentManager: # 🔑 第二步:更新变量池中的file_id/file_ids if uploaded_files: try: - logger.info(f"开始更新变量池: scope={scope}, var_name={var_name}, uploaded_count={len(uploaded_files)}") + logger.info( + f"开始更新变量池: scope={scope}, var_name={var_name}, uploaded_count={len(uploaded_files)}") await cls._store_files_in_variable_pool(user_sub, conversation_id, uploaded_files, scope, var_name, var_type) logger.info(f"✅ 变量池更新成功: {var_name}") except Exception as e: - logger.exception(f"❌ [DocumentManager] 存储文件变量到变量池失败: var_name={var_name}, error={e}") - + logger.exception( + f"❌ [DocumentManager] 存储文件变量到变量池失败: var_name={var_name}, error={e}") + # 🔑 重要:变量池更新失败时,需要清理已上传的文件 logger.warning(f"开始清理已上传的文件,数量: {len(uploaded_files)}") cleanup_errors = [] - + # 清理MongoDB记录 for doc in uploaded_files: try: @@ -222,26 +236,30 @@ class DocumentManager: {"$pull": {"unused_docs": doc.id}} ) except Exception as cleanup_error: - cleanup_errors.append(f"MongoDB清理失败 {doc.id}: {cleanup_error}") - + cleanup_errors.append( + f"MongoDB清理失败 {doc.id}: {cleanup_error}") + # 清理MinIO文件 for doc in uploaded_files: try: from apps.common.minio import MinioClient MinioClient.delete_file("document", doc.id) except Exception as cleanup_error: - cleanup_errors.append(f"MinIO清理失败 {doc.id}: {cleanup_error}") - + cleanup_errors.append( + f"MinIO清理失败 {doc.id}: {cleanup_error}") + if cleanup_errors: logger.warning(f"文件清理过程中出现错误: {cleanup_errors}") - + # 重新抛出原始异常 raise e elif scope == "conversation": # 🔑 重要:conversation scope但没有上传任何文件时,也需要确保变量存在 - logger.warning(f"conversation scope但没有成功上传任何文件: var_name={var_name}") + logger.warning( + f"conversation scope但没有成功上传任何文件: var_name={var_name}") - logger.info(f"文件存储完成: uploaded_count={len(uploaded_files)}, scope={scope}") + logger.info( + f"文件存储完成: uploaded_count={len(uploaded_files)}, scope={scope}") return uploaded_files @classmethod @@ -333,7 +351,7 @@ class DocumentManager: mongo = MongoDB() doc_collection = mongo.get_collection("document") - + for document in files: if document.filename is None or document.size is None: continue @@ -374,13 +392,13 @@ class DocumentManager: @classmethod async def _store_files_in_variable_pool(cls, user_sub: str, conversation_id: str, uploaded_files: list[Document], scope: str, var_name: Optional[str] = None, var_type: Optional[str] = None): """将文件存储到变量池中 - 支持system和conversation scope - + 🔑 优化说明:支持前端两阶段处理逻辑 - 前端第一阶段已清空file_id,这里只需要安全更新file_id - 避免与前端的变量更新逻辑冲突 """ pool_manager = await get_pool_manager() - + # 根据scope确定变量池和is_system值 if scope == "system": pool = await pool_manager.get_conversation_pool(conversation_id) @@ -419,29 +437,34 @@ class DocumentManager: if existing_variable.metadata.var_type in [VariableType.FILE, VariableType.ARRAY_FILE]: # 🔧 新增:数据验证和自动修复机制 from apps.scheduler.variable.file_utils import FileVariableHelper - + # 获取现有变量的配置 current_value = existing_variable.value if isinstance(current_value, dict): # 验证和标准化现有数据 var_type_str = existing_variable.metadata.var_type.value - is_valid, error_msg = FileVariableHelper.validate_file_variable_consistency(current_value, var_type_str) - + is_valid, error_msg = FileVariableHelper.validate_file_variable_consistency( + current_value, var_type_str) + if not is_valid: - logger.warning(f"检测到变量 {final_var_name} 数据不一致: {error_msg},将自动修复") + logger.warning( + f"检测到变量 {final_var_name} 数据不一致: {error_msg},将自动修复") # 自动标准化数据 - current_value = FileVariableHelper.normalize_file_variable(current_value, var_type_str) + current_value = FileVariableHelper.normalize_file_variable( + current_value, var_type_str) logger.info(f"已自动修复变量 {final_var_name} 的数据格式") - + # 🔑 关键:保持所有现有配置,只更新file_id updated_value = current_value.copy() # 深拷贝保持原有配置 - + if existing_variable.metadata.var_type == VariableType.FILE: # 单文件:设置file_id和filename if uploaded_files: updated_value["file_id"] = uploaded_files[0].id - updated_value["filename"] = uploaded_files[0].name # 新增:保存原始文件名 - logger.info(f"更新单文件变量 {final_var_name}: file_id={updated_value['file_id']}, filename={updated_value['filename']}") + # 新增:保存原始文件名 + updated_value["filename"] = uploaded_files[0].name + logger.info( + f"更新单文件变量 {final_var_name}: file_id={updated_value['file_id']}, filename={updated_value['filename']}") else: updated_value["file_id"] = "" updated_value.pop("filename", None) # 清空文件名 @@ -450,20 +473,22 @@ class DocumentManager: # 文件数组:设置file_ids和files信息,确保数据一致性 file_ids_list = [doc.id for doc in uploaded_files] files_list = [ # 新增:保存文件详细信息 - {"file_id": doc.id, "filename": doc.name} + {"file_id": doc.id, "filename": doc.name} for doc in uploaded_files ] - + # 🔧 重要:确保file_ids和files数组始终保持同步 updated_value["file_ids"] = file_ids_list updated_value["files"] = files_list - + # 🔧 验证数据一致性 if len(file_ids_list) != len(files_list): - logger.warning(f"文件数组变量数据不一致: file_ids={len(file_ids_list)}, files={len(files_list)}") - - logger.info(f"更新文件数组变量 {final_var_name}: file_ids={updated_value['file_ids']}, files={len(updated_value['files'])}个文件") - + logger.warning( + f"文件数组变量数据不一致: file_ids={len(file_ids_list)}, files={len(files_list)}") + + logger.info( + f"更新文件数组变量 {final_var_name}: file_ids={updated_value['file_ids']}, files={len(updated_value['files'])}个文件") + # 🔑 使用原子更新,避免并发冲突 try: await pool.update_variable( @@ -471,18 +496,23 @@ class DocumentManager: value=updated_value, var_type=existing_variable.metadata.var_type, description=existing_variable.metadata.description, - force_system_update=(scope == "system") # system.files 允许更新 + # system.files 允许更新 + force_system_update=(scope == "system") ) - logger.info(f"✅ 成功更新文件变量: {final_var_name}, uploaded_files_count={len(uploaded_files)}") + logger.info( + f"✅ 成功更新文件变量: {final_var_name}, uploaded_files_count={len(uploaded_files)}") except Exception as e: - logger.error(f"❌ 更新文件变量失败: {final_var_name}, error={e}") + logger.error( + f"❌ 更新文件变量失败: {final_var_name}, error={e}") # 抛出异常,让上层处理文件上传失败的清理工作 raise Exception(f"文件变量更新失败: {e}") else: - logger.warning(f"现有变量 {final_var_name} 的值格式不正确,无法更新: type={type(current_value)}, value={current_value}") + logger.warning( + f"现有变量 {final_var_name} 的值格式不正确,无法更新: type={type(current_value)}, value={current_value}") raise Exception(f"变量 {final_var_name} 配置格式错误") else: - logger.warning(f"现有变量 {final_var_name} 不是文件类型,无法更新: type={existing_variable.metadata.var_type}") + logger.warning( + f"现有变量 {final_var_name} 不是文件类型,无法更新: type={existing_variable.metadata.var_type}") raise Exception(f"变量 {final_var_name} 不是文件类型") else: # 只有system scope允许创建新变量 @@ -490,10 +520,10 @@ class DocumentManager: # 创建新的system.files变量 file_ids_list = [doc.id for doc in uploaded_files] files_list = [ # 新增:保存文件详细信息 - {"file_id": doc.id, "filename": doc.name} + {"file_id": doc.id, "filename": doc.name} for doc in uploaded_files ] - + final_var_value = { "file_ids": file_ids_list, # 确保使用相同的数据源 "files": files_list, # 确保使用相同的数据源 @@ -503,13 +533,15 @@ class DocumentManager: "max_file_size": 10 * 1024 * 1024, # 10MB "required": False } - + # 🔧 验证数据一致性 if len(file_ids_list) != len(files_list): - logger.warning(f"新建system变量数据不一致: file_ids={len(file_ids_list)}, files={len(files_list)}") - - logger.info(f"创建新的system.files变量: file_ids={len(file_ids_list)}个文件, files={len(files_list)}个详细信息") - + logger.warning( + f"新建system变量数据不一致: file_ids={len(file_ids_list)}, files={len(files_list)}") + + logger.info( + f"创建新的system.files变量: file_ids={len(file_ids_list)}个文件, files={len(files_list)}个详细信息") + await pool.add_variable( name=final_var_name, var_type=final_var_type, @@ -521,10 +553,12 @@ class DocumentManager: logger.info(f"已创建新的system.files变量") else: # 🔑 重要:conversation变量必须预先存在(前端第一阶段已创建) - logger.error(f"conversation变量 {final_var_name} 不存在,这可能表示前端第一阶段处理失败") + logger.error( + f"conversation变量 {final_var_name} 不存在,这可能表示前端第一阶段处理失败") raise Exception(f"变量 {final_var_name} 不存在,无法上传文件") - logger.info(f"✅ 已将文件存储到变量池,变量名: {final_var_name}, 类型: {final_var_type}, scope: {scope}") + logger.info( + f"✅ 已将文件存储到变量池,变量名: {final_var_name}, 类型: {final_var_type}, scope: {scope}") @classmethod async def _store_user_files_in_variable_pool(cls, user_sub: str, uploaded_files: list[Document], var_name: str, var_type: str): @@ -548,8 +582,9 @@ class DocumentManager: if existing_variable.metadata.var_type == VariableType.FILE: current_value["file_id"] = uploaded_files[0].id if uploaded_files else "" else: # ARRAY_FILE - current_value["file_ids"] = [doc.id for doc in uploaded_files] - + current_value["file_ids"] = [ + doc.id for doc in uploaded_files] + # 更新变量 await pool.update_variable( name=var_name, @@ -567,7 +602,8 @@ class DocumentManager: logger.warning(f"用户变量 {var_name} 不存在,无法创建新变量") return - logger.info(f"已将用户文件存储到变量池,变量名: {var_name}, 类型: {var_type}, user_sub: {user_sub}") + logger.info( + f"已将用户文件存储到变量池,变量名: {var_name}, 类型: {var_type}, user_sub: {user_sub}") @classmethod async def _store_env_files_in_variable_pool(cls, user_sub: str, flow_id: str, uploaded_files: list[Document]): @@ -589,8 +625,9 @@ class DocumentManager: current_value = existing_variable.value if isinstance(current_value, dict): # 更新文件ID列表 - current_value["file_ids"] = [doc.id for doc in uploaded_files] - + current_value["file_ids"] = [ + doc.id for doc in uploaded_files] + # 更新变量 await pool.update_variable( name=var_name, @@ -608,7 +645,8 @@ class DocumentManager: logger.warning(f"环境变量 {var_name} 不存在,无法创建新变量") return - logger.info(f"已将环境文件存储到变量池,变量名: {var_name}, user_sub: {user_sub}, flow_id: {flow_id}") + logger.info( + f"已将环境文件存储到变量池,变量名: {var_name}, user_sub: {user_sub}, flow_id: {flow_id}") @classmethod async def _get_flow_id_for_conversation(cls, conversation_id: str) -> str: @@ -617,7 +655,7 @@ class DocumentManager: mongo = MongoDB() conversation_collection = mongo.get_collection("conversation") conversation = await conversation_collection.find_one({"_id": conversation_id}) - + if conversation and conversation.get("app_id"): # 使用app_id作为flow_id return conversation["app_id"] @@ -661,7 +699,8 @@ class DocumentManager: for doc in docs: if doc.associated == "question": doc_info = await document_collection.find_one({"_id": doc.id, "user_sub": user_sub}) - doc_info = Document.model_validate(doc_info) if doc_info else None + doc_info = Document.model_validate( + doc_info) if doc_info else None if doc_info: doc.name = doc_info.name doc.extension = doc_info.type @@ -677,11 +716,61 @@ class DocumentManager: size=doc.size, conversation_id=record_group.get("conversation_id", ""), associated=doc.associated, - created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3) + created_at=doc.created_at or round( + datetime.now(tz=UTC).timestamp(), 3) ) for doc in docs if type is None or doc.associated == type ] + @classmethod + async def get_used_docs_by_record_groups( + cls, user_sub: str, record_group_ids: list[str], type: str | None = None) -> list[RecordDocument]: + """获取多个RecordGroup关联的文件""" + mongo = MongoDB() + record_group_collection = mongo.get_collection("record_group") + document_collection = mongo.get_collection("document") + if type not in ["question", "answer", None]: + raise ValueError("type must be 'question', 'answer' or None") + docs = [] + record_groups = record_group_collection.find( + {"_id": {"$in": record_group_ids}, "user_sub": user_sub}) + question_doc_ids = [] + async for record_group in record_groups: + rg = RecordGroup.model_validate(record_group) + for doc in rg.docs: + if doc.associated == "question": + question_doc_ids.append(doc.id) + question_docs_info = {} + if question_doc_ids: + async for doc in document_collection.find({"_id": {"$in": question_doc_ids}, "user_sub": user_sub}): + question_docs_info[doc["_id"]] = Document.model_validate(doc) + for record_group in record_groups: + rg = RecordGroup.model_validate(record_group) + for doc in rg.docs: + if doc.associated == "question": + doc_info = question_docs_info.get(doc.id) + if doc_info: + doc.name = doc_info.name + doc.extension = doc_info.type + doc.size = doc_info.size + docs.extend([ + RecordDocument( + _id=doc.id, + order=doc.order, + author=doc.author, + abstract=doc.abstract, + name=doc.name, + type=doc.extension, + size=doc.size, + conversation_id=record_group.get("conversation_id", ""), + associated=doc.associated, + created_at=doc.created_at or round( + datetime.now(tz=UTC).timestamp(), 3) + ) + for doc in rg.docs if type is None or doc.associated == type + ]) + return docs + @classmethod async def get_used_docs( cls, user_sub: str, conversation_id: str, record_num: int | None = 10, type: str | None = None) -> list[Document]: @@ -693,7 +782,8 @@ class DocumentManager: raise ValueError("type must be 'question', 'answer' or None") if record_num: record_groups = ( - record_group_collection.find({"conversation_id": conversation_id, "user_sub": user_sub}) + record_group_collection.find( + {"conversation_id": conversation_id, "user_sub": user_sub}) .sort("created_at", -1) .limit(record_num) ) diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index 786eb03b3ffaa57b7695bbac571a5e774d5a84bb..d8ef97baf1d16f0a5dff492afe8e3c1bf98af0a3 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -47,25 +47,28 @@ class FlowService: # 保存原始的serviceId和callId,以便API插件节点能够正确识别 original_service_id = node.service_id original_call_id = node.call_id - + # 根据是否有serviceId来判断是API插件还是普通的Empty节点 if original_service_id and original_service_id.strip(): # 有serviceId,标记为Plugin类型(API插件节点) node.node_id = SpecialCallType.PLUGIN.value node.call_id = SpecialCallType.PLUGIN.value node.service_id = original_service_id - logger.info(f"[FlowService] 将节点 {original_call_id} 标记为API插件节点,serviceId: {original_service_id}") + logger.info( + f"[FlowService] 将节点 {original_call_id} 标记为API插件节点,serviceId: {original_service_id}") else: # 没有serviceId,标记为Empty类型 node.node_id = SpecialCallType.EMPTY.value node.call_id = SpecialCallType.EMPTY.value - logger.info(f"[FlowService] 将节点 {original_call_id} 标记为Empty节点") - + logger.info( + f"[FlowService] 将节点 {original_call_id} 标记为Empty节点") + # 更新描述信息,保留原有描述 original_description = node.description or "" node.description = f'【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n{original_description}' - - logger.error(f"[FlowService] 获取步骤的call_id失败 {original_call_id},错误: {e}") + + logger.error( + f"[FlowService] 获取步骤的call_id失败 {original_call_id},错误: {e}") node_branch_map[node.step_id] = set() if node.call_id == NodeType.CHOICE.value: input_parameters = node.parameters["input_parameters"] @@ -289,7 +292,7 @@ class FlowService: async def validate_subflow_illegal(cls, flow: FlowItem) -> None: """ 验证子工作流是否违法(子工作流专用验证,不强制要求end节点) - + :param flow: 子工作流 :raises ValidationError: 验证失败 """ @@ -299,45 +302,45 @@ class FlowService: async def validate_subflow_connectivity(cls, flow: FlowItem) -> bool: """ 验证子工作流连通性(子工作流专用,不要求连接到end节点) - + :param flow: 子工作流 :return: 是否连通 """ if not flow.nodes: return True - + # 构建图结构 graph = {} start_nodes = [] - + for node in flow.nodes: graph[node.step_id] = [] if node.call_id == 'start' or not any( edge.target_node == node.step_id for edge in flow.edges ): start_nodes.append(node.step_id) - + for edge in flow.edges: if edge.source_node in graph: graph[edge.source_node].append(edge.target_node) - + # 检查从开始节点是否能到达所有其他节点 if not start_nodes: return len(flow.nodes) <= 1 # 如果没有开始节点且只有一个或零个节点,认为连通 - + visited = set() - + def dfs(node_id): if node_id in visited: return visited.add(node_id) for neighbor in graph.get(node_id, []): dfs(neighbor) - + # 从所有开始节点开始遍历 for start_node in start_nodes: dfs(start_node) - + # 检查是否所有节点都被访问到 all_node_ids = {node.step_id for node in flow.nodes} return len(visited) == len(all_node_ids) @@ -346,7 +349,7 @@ class FlowService: async def _validate_flow_nodes(cls, flow: FlowItem, is_subflow: bool = False) -> None: """ 验证工作流节点(支持子工作流模式) - + :param flow: 工作流 :param is_subflow: 是否为子工作流 :raises ValidationError: 验证失败 @@ -356,13 +359,13 @@ class FlowService: start_count = 0 end_count = 0 - + for node in flow.nodes: if node.call_id == "start": start_count += 1 elif node.call_id == "end": end_count += 1 - + # 主工作流必须有start和end节点 if not is_subflow: if start_count != 1: @@ -383,9 +386,11 @@ class FlowService: # 验证边引用的节点存在 for edge in flow.edges: - source_exists = any(node.step_id == edge.source_node for node in flow.nodes) - target_exists = any(node.step_id == edge.target_node for node in flow.nodes) - + source_exists = any( + node.step_id == edge.source_node for node in flow.nodes) + target_exists = any( + node.step_id == edge.target_node for node in flow.nodes) + if not source_exists: raise ValidationError(f"边引用的源节点不存在: {edge.source_node}") if not target_exists: diff --git a/apps/services/task.py b/apps/services/task.py index bb10eafd20c056c76d4c7ef9ca6d226c018b11a1..0f9438ba0babdba48c5dc0fa5c35accc1b610d57 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -86,13 +86,40 @@ class TaskManager: for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context_list.append(FlowStepHistory.model_validate(flow_context)) + flow_context_list.append( + FlowStepHistory.model_validate(flow_context)) except Exception: logger.exception("[TaskManager] 获取record_id的flow信息失败") return [] else: return flow_context_list + @staticmethod + async def get_context_by_record_ids(record_group_ids: List[str], record_ids: List[str]) -> List[FlowStepHistory]: + """根据record_group_ids获取flow信息""" + record_group_collection = MongoDB().get_collection("record_group") + flow_context_collection = MongoDB().get_collection("flow_context") + flow_context_list = [] + # 查询所有符合条件的记录 + try: + cursor = record_group_collection.aggregate([ + {"$match": {"_id": {"$in": record_group_ids}}}, + {"$unwind": "$records"}, + {"$match": {"records.id": {"$in": record_ids}}}, + ]) + records = await cursor.to_list(length=None) + + for record in records: + for flow_context_id in record["records"]["flow"]["history_ids"]: + flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) + if flow_context: + flow_context_list.append( + FlowStepHistory.model_validate(flow_context)) + return flow_context_list + except Exception: + logger.exception("[TaskManager] 获取record_ids的flow信息失败") + return [] + @staticmethod async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[FlowStepHistory]: """根据task_id获取flow信息""" @@ -102,10 +129,12 @@ class TaskManager: try: if length is None: async for context in flow_context_collection.find({"task_id": task_id}): - flow_context.append(FlowStepHistory.model_validate(context)) + flow_context.append( + FlowStepHistory.model_validate(context)) else: async for context in flow_context_collection.find({"task_id": task_id}).limit(length): - flow_context.append(FlowStepHistory.model_validate(context)) + flow_context.append( + FlowStepHistory.model_validate(context)) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return [] @@ -138,12 +167,13 @@ class TaskManager: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: - # 删除旧的flow_context + # 删除旧的flow_context await flow_context_collection.delete_many({"task_id": task_id}) if not flow_context: return await flow_context_collection.insert_many( - [history.model_dump(exclude_none=True, by_alias=True) for history in flow_context], + [history.model_dump(exclude_none=True, by_alias=True) + for history in flow_context], ordered=False, ) except Exception: diff --git a/apps/services/vector.py b/apps/services/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..dad67d3a511e69d7d3aa2035f96a1eb2c95ca178 --- /dev/null +++ b/apps/services/vector.py @@ -0,0 +1,142 @@ +import logging +from sqlalchemy import select, delete, update, text +from apps.common.postgres import DataBase, FlowPoolVector, ServicePoolVector, CallPoolVector, NodePoolVector, McpVector, McpToolVector +from apps.schemas.enum_var import VectorPoolType + +logger = logging.getLogger(__name__) + + +class VectorManager: + """向量管理器""" + + @staticmethod + async def add_vector( + data: FlowPoolVector | ServicePoolVector | CallPoolVector | NodePoolVector + ) -> None: + """添加向量数据""" + try: + async with await DataBase.get_session() as session: + session.add(data) + await session.commit() + except Exception as e: + # 这里可以添加日志记录或其他错误处理逻辑 + logger.error(f"[VectorManager] 添加向量数据失败: {e}") + + @staticmethod + async def add_vectors( + data_list: list[FlowPoolVector | ServicePoolVector | + CallPoolVector | NodePoolVector] + ) -> None: + """批量添加向量数据""" + try: + async with await DataBase.get_session() as session: + session.add_all(data_list) + await session.commit() + except Exception as e: + # 这里可以添加日志记录或其他错误处理逻辑 + logger.error(f"[VectorManager] 批量添加向量数据失败: {e}") + + @staticmethod + async def delete_vectors( + vector_type: VectorPoolType, + ids: list[str], + ) -> None: + """删除向量数据""" + table_map = { + VectorPoolType.FLOW: FlowPoolVector, + VectorPoolType.SERVICE: ServicePoolVector, + VectorPoolType.CALL: CallPoolVector, + VectorPoolType.NODE: NodePoolVector, + VectorPoolType.MCP: McpVector, + VectorPoolType.MCP_TOOL: McpToolVector, + } + table = table_map.get(vector_type) + if not table: + err = f"[VectorManager] 不支持的向量类型: {vector_type}" + logger.error(err) + raise ValueError(err) + + stmt = ( + delete(table) + .where(table.id.in_(ids)) + ) + + try: + async with await DataBase.get_session() as session: + await session.execute(stmt) + await session.commit() + except Exception as e: + logger.error(f"[VectorManager] 删除向量数据失败: {e}") + + @staticmethod + async def delete_mcp_tool_vectors_by_mcp_ids(mcp_ids: list[str]) -> None: + """根据MCP ID删除MCP工具向量数据""" + stmt = ( + delete(McpToolVector) + .where(McpToolVector.mcp_id.in_(mcp_ids)) + ) + + try: + async with await DataBase.get_session() as session: + await session.execute(stmt) + await session.commit() + except Exception as e: + logger.error(f"[VectorManager] 根据MCP ID删除MCP工具向量数据失败: {e}") + + @staticmethod + async def delete_call_vectors_by_service_ids(service_ids: list[str]) -> None: + """根据Service ID删除Call向量数据""" + stmt = ( + delete(CallPoolVector) + .where(CallPoolVector.service_id.in_(service_ids)) + ) + + try: + async with await DataBase.get_session() as session: + await session.execute(stmt) + await session.commit() + except Exception as e: + logger.error(f"[VectorManager] 根据Service ID删除Call向量数据失败: {e}") + + @staticmethod + async def select_topk_mcp_tool_by_mcp_ids( + vector: list[float], + mcp_ids: list[str], + top_k: int = 10 + ) -> list[McpToolVector]: + """根据MCP ID选择TopK的MCP工具向量数据""" + base_sql = """ + SELECT + id, mcp_id, embedding + FROM mcp_tool_vector + WHERE mcp_id = ANY(:mcp_ids) + ORDER BY embedding <=> :vector ASC + LIMIT :top_k + """ + try: + async with await DataBase.get_session() as session: + result = await session.execute( + text(base_sql), + { + "vector": vector, + "mcp_ids": mcp_ids, + "top_k": top_k, + }, + ) + rows = result.fetchall() + + mcp_tool_vectors = [] + for row in rows: + mcp_tool_vector = McpToolVector( + id=row.id, + mcp_id=row.mcp_id, + embedding=row.embedding, + ) + mcp_tool_vectors.append(mcp_tool_vector) + + return mcp_tool_vectors + + except Exception as e: + err = f"根据MCP ID选择TopK的MCP工具向量数据失败: {str(e)}" + logger.exception("[VectorManager] %s", err) + return [] diff --git a/pyproject.toml b/pyproject.toml index bb829574e8ae83773239a8b5c554c8f93baa9cb5..0e460e9c6bb08788619a87807747ac5f232488d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "tiktoken==0.9.0", "toml==0.10.2", "uvicorn==0.34.0", + "opengauss-sqlalchemy==2.4.0", ] [[tool.uv.index]]