diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_server/cli/handle.py b/mcp_center/servers/oe-cli-mcp-server/mcp_server/cli/handle.py index 2a52caa1d6a63e81db61a0da80176197b96c7fad..cc5d972ee1f0a7e59c7a0aad038fcb9761b3e1d0 100644 --- a/mcp_center/servers/oe-cli-mcp-server/mcp_server/cli/handle.py +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_server/cli/handle.py @@ -43,7 +43,7 @@ def send_http_request(action: str, params: dict = None): def handle_add(pkg_input): """处理 -add 命令""" type_map = {"智能运维": ToolType.BASE.value, "智算调优": ToolType.AI.value, - "通算调优": ToolType.CAL.value, "镜像运维": ToolType.MIRROR.value, "个性化": ToolType.PERSONAL.value} + "通算调优": ToolType.CAL.value, "镜像运维": ToolType.MIRROR.value, "个性化": ToolType.PERSONAL.value, "知识库": ToolType.RAG.value} if pkg_input in type_map: params = {"type": "system", "value": type_map[pkg_input]} @@ -61,7 +61,7 @@ def handle_add(pkg_input): def handle_remove(pkg_input): """处理 -remove 命令""" type_map = {"智能运维": ToolType.BASE.value, "智算调优": ToolType.AI.value, - "通算调优": ToolType.CAL.value, "镜像运维": ToolType.MIRROR.value, "个性化": ToolType.PERSONAL.value} + "通算调优": ToolType.CAL.value, "镜像运维": ToolType.MIRROR.value, "个性化": ToolType.PERSONAL.value, "知识库": ToolType.RAG.value} params = {"type": "system" if pkg_input in type_map else "custom", "value": type_map.get(pkg_input, pkg_input)} diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/config.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7c889abb3e9fa99927c33021bb2956138f536186 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/config.py @@ -0,0 +1,174 @@ +import os +import json +import logging +from typing import Optional, Dict, Any + +logger = logging.getLogger(__name__) + +# 配置缓存 +_config: Optional[Dict[str, Any]] = None +_config_file_path = "rag_config.json" + + +def _load_config() -> Dict[str, Any]: + """ + 加载配置文件 + :return: 配置字典 + """ + global _config + + if _config is not None: + return _config + + default_config = { + "embedding": { + "type": "openai", + "api_key": "", + "endpoint": "", + "model_name": "text-embedding-ada-002", + "timeout": 30, + "vector_dimension": 1024 + }, + "token": { + "model": "gpt-4", + "max_tokens": 8192, + "default_chunk_size": 1024 + }, + "search": { + "default_top_k": 5, + "max_top_k": 100 + } + } + + config_file = _get_config_file_path() + if os.path.exists(config_file): + try: + with open(config_file, 'r', encoding='utf-8') as f: + file_config = json.load(f) + # 合并配置(文件配置优先) + _config = _merge_config(default_config, file_config) + except Exception as e: + logger.warning(f"[Config] 加载配置文件失败: {e}") + _config = default_config + else: + _config = default_config + + _apply_env_overrides(_config) + + return _config + + +def _cfg() -> Dict[str, Any]: + return _load_config() + + +def _get_config_file_path() -> str: + """ + 获取配置文件路径 + 优先使用项目根目录下的 rag_config.json + :return: 配置文件路径 + """ + # 尝试从项目根目录查找 + current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + config_file = os.path.join(current_dir, _config_file_path) + if os.path.exists(config_file): + return config_file + + # 尝试从当前工作目录查找 + cwd_config = os.path.join(os.getcwd(), _config_file_path) + if os.path.exists(cwd_config): + return cwd_config + + # 返回项目根目录路径(即使文件不存在) + return config_file + + +def _merge_config(default: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """ + 合并配置字典(递归合并) + :param default: 默认配置 + :param override: 覆盖配置 + :return: 合并后的配置 + """ + result = default.copy() + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _merge_config(result[key], value) + else: + result[key] = value + return result + + +def _apply_env_overrides(config: Dict[str, Any]): + """ + 应用环境变量覆盖(优先级最高) + :param config: 配置字典 + """ + if os.getenv("EMBEDDING_TYPE"): + config["embedding"]["type"] = os.getenv("EMBEDDING_TYPE") + if os.getenv("EMBEDDING_API_KEY"): + config["embedding"]["api_key"] = os.getenv("EMBEDDING_API_KEY") + if os.getenv("EMBEDDING_ENDPOINT"): + config["embedding"]["endpoint"] = os.getenv("EMBEDDING_ENDPOINT") + if os.getenv("EMBEDDING_MODEL_NAME"): + config["embedding"]["model_name"] = os.getenv("EMBEDDING_MODEL_NAME") + + if os.getenv("TOKEN_MODEL"): + config["token"]["model"] = os.getenv("TOKEN_MODEL") + if os.getenv("MAX_TOKENS"): + try: + config["token"]["max_tokens"] = int(os.getenv("MAX_TOKENS")) + except ValueError: + pass + if os.getenv("DEFAULT_CHUNK_SIZE"): + try: + config["token"]["default_chunk_size"] = int(os.getenv("DEFAULT_CHUNK_SIZE")) + except ValueError: + pass + + +def get_embedding_type() -> str: + return _cfg()["embedding"]["type"] + + +def get_embedding_api_key() -> str: + return _cfg()["embedding"]["api_key"] + + +def get_embedding_endpoint() -> str: + return _cfg()["embedding"]["endpoint"] + + +def get_embedding_model_name() -> str: + return _cfg()["embedding"]["model_name"] + + +def get_embedding_timeout() -> int: + return _cfg()["embedding"]["timeout"] + + +def get_embedding_vector_dimension() -> int: + return _cfg()["embedding"]["vector_dimension"] + + +def get_token_model() -> str: + return _cfg()["token"]["model"] + + +def get_max_tokens() -> int: + return _cfg()["token"]["max_tokens"] + + +def get_default_chunk_size() -> int: + return _cfg()["token"]["default_chunk_size"] + + +def get_default_top_k() -> int: + return _cfg()["search"]["default_top_k"] + + +def reload_config(): + """当 rag_config.json 更新后重新加载缓存""" + global _config + _config = None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/embedding.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0a23bd198a63682cb374c55d90730ccf76907d19 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/embedding.py @@ -0,0 +1,145 @@ +import json +import logging +import asyncio +import aiohttp +from typing import Optional, List +from base.config import ( + get_embedding_type, + get_embedding_api_key, + get_embedding_endpoint, + get_embedding_model_name, + get_embedding_timeout, + get_embedding_vector_dimension +) + +logger = logging.getLogger(__name__) + + +class Embedding: + """Embedding 服务类""" + + @staticmethod + def _get_config(): + """获取配置(延迟加载)""" + return { + "type": get_embedding_type(), + "api_key": get_embedding_api_key(), + "endpoint": get_embedding_endpoint(), + "model_name": get_embedding_model_name(), + "timeout": get_embedding_timeout(), + "vector_dimension": get_embedding_vector_dimension() + } + + @staticmethod + def is_configured() -> bool: + config = Embedding._get_config() + return bool(config["api_key"] and config["endpoint"]) + + @staticmethod + async def vectorize_embedding(text: str, session: Optional[aiohttp.ClientSession] = None) -> Optional[List[float]]: + """ + 将文本向量化(异步实现) + :param text: 文本内容 + :param session: 可选的 aiohttp 会话 + :return: 向量列表 + """ + config = Embedding._get_config() + vector = None + should_close_session = False + + # 如果没有提供会话,创建一个新的 + if session is None: + timeout = aiohttp.ClientTimeout(total=config["timeout"]) + connector = aiohttp.TCPConnector(ssl=False) + session = aiohttp.ClientSession(timeout=timeout, connector=connector) + should_close_session = True + + try: + if config["type"] == "openai": + headers = { + "Authorization": f"Bearer {config['api_key']}" + } + data = { + "input": text, + "model": config["model_name"], + "encoding_format": "float" + } + try: + async with session.post( + url=config["endpoint"], + headers=headers, + json=data + ) as res: + if res.status != 200: + return None + result = await res.json() + vector = result['data'][0]['embedding'] + except Exception: + return None + elif config["type"] == "mindie": + try: + data = { + "inputs": text, + } + async with session.post( + url=config["endpoint"], + json=data + ) as res: + if res.status != 200: + return None + text_result = await res.text() + vector = json.loads(text_result)[0] + except Exception: + return None + else: + return None + + # 确保向量长度为配置的维度(不足补0,超过截断) + if vector: + vector_dim = config["vector_dimension"] + while len(vector) < vector_dim: + vector.append(0.0) + return vector[:vector_dim] + return None + finally: + if should_close_session: + await session.close() + + @staticmethod + async def vectorize_embeddings_batch(texts: List[str], max_concurrent: int = 5) -> List[Optional[List[float]]]: + """ + 批量向量化(并发处理) + :param texts: 文本列表 + :param max_concurrent: 最大并发数 + :return: 向量列表(与输入文本顺序对应) + """ + config = Embedding._get_config() + if not config["api_key"] or not config["endpoint"]: + return [None] * len(texts) + + # 创建共享的 aiohttp 会话(复用连接) + timeout = aiohttp.ClientTimeout(total=config["timeout"]) + connector = aiohttp.TCPConnector(ssl=False) + async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: + # 使用信号量控制并发数 + semaphore = asyncio.Semaphore(max_concurrent) + + async def vectorize_with_semaphore(text: str, index: int) -> tuple: + async with semaphore: + vector = await Embedding.vectorize_embedding(text, session=session) + return index, vector + + tasks = [vectorize_with_semaphore(text, i) for i, text in enumerate(texts)] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + vectors = [None] * len(texts) + for result in results: + if isinstance(result, Exception): + continue + index, vector = result + vectors[index] = vector + + return vectors + + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/database_manager.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/database_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e77808c4df807b3e0b0c6e9520bc1a4069129439 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/database_manager.py @@ -0,0 +1,257 @@ +""" +数据库操作类 - 使用 SQLAlchemy ORM +""" +import os +import struct +import uuid +from typing import List, Optional, Dict, Any +from datetime import datetime +import logging +from sqlalchemy import create_engine, text, inspect +from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.exc import SQLAlchemyError + +from base.models import Base, KnowledgeBase, Document, Chunk +from base.config import get_embedding_vector_dimension +from base.manager.document_manager import DocumentManager +import sqlite_vec + +logger = logging.getLogger(__name__) + + +class Database: + """SQLite 数据库操作类 - 使用 SQLAlchemy ORM""" + + def __init__(self, db_path: str = "knowledge_base.db"): + """ + 初始化数据库连接 + :param db_path: 数据库文件路径 + """ + db_dir = os.path.dirname(os.path.abspath(db_path)) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + + self.db_path = os.path.abspath(db_path) + self.engine = create_engine( + f'sqlite:///{self.db_path}', + echo=False, + connect_args={'check_same_thread': False} + ) + self.SessionLocal = sessionmaker(bind=self.engine, autocommit=False, autoflush=False) + self._init_database() + + def _init_database(self): + """初始化数据库表结构""" + try: + # 创建所有表 + Base.metadata.create_all(self.engine) + + # 加载 sqlite-vec 扩展并创建 FTS5 和 vec_index 表 + with self.engine.begin() as conn: + # 创建 FTS5 虚拟表(需要使用原生 SQL) + conn.execute(text(""" + CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + id UNINDEXED, + content, + content_rowid=id + ) + """)) + + # 加载 sqlite-vec 扩展 + try: + raw_conn = conn.connection.dbapi_connection + raw_conn.enable_load_extension(True) + sqlite_vec.load(raw_conn) + raw_conn.enable_load_extension(False) + except Exception as e: + logger.warning(f"加载 sqlite-vec 扩展失败: {e}") + + # 创建 vec_index 虚拟表 + try: + vector_dim = get_embedding_vector_dimension() + conn.execute(text(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_index USING vec0( + embedding float[{vector_dim}] + ) + """)) + except Exception as e: + logger.warning(f"创建 vec_index 表失败: {e}") + except Exception as e: + logger.exception(f"[Database] 初始化数据库失败: {e}") + raise e + + def get_session(self) -> Session: + """获取数据库会话""" + return self.SessionLocal() + + def get_connection(self): + """ + 获取原始数据库连接(用于特殊操作,如 FTS5 和 vec_index) + 注意:此方法保留以兼容现有代码,但推荐使用 get_session() + 返回一个上下文管理器,使用后会自动关闭 + """ + return self.engine.connect() + + def add_knowledge_base(self, kb_id: str, name: str, chunk_size: int, + embedding_model: Optional[str] = None, + embedding_endpoint: Optional[str] = None, + embedding_api_key: Optional[str] = None) -> bool: + """添加知识库""" + session = self.get_session() + try: + kb = KnowledgeBase( + id=kb_id, + name=name, + chunk_size=chunk_size, + embedding_model=embedding_model, + embedding_endpoint=embedding_endpoint, + embedding_api_key=embedding_api_key + ) + session.add(kb) + session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[Database] 添加知识库失败: {e}") + session.rollback() + return False + finally: + session.close() + + def get_knowledge_base(self, kb_name: str) -> Optional[KnowledgeBase]: + """获取知识库""" + session = self.get_session() + try: + return session.query(KnowledgeBase).filter_by(name=kb_name).first() + finally: + session.close() + + def delete_knowledge_base(self, kb_id: str) -> bool: + """删除知识库(级联删除相关文档和chunks)""" + session = self.get_session() + try: + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if kb: + session.delete(kb) + session.commit() + return True + return False + except SQLAlchemyError as e: + logger.exception(f"[Database] 删除知识库失败: {e}") + session.rollback() + return False + finally: + session.close() + + def list_knowledge_bases(self) -> List[KnowledgeBase]: + """列出所有知识库""" + session = self.get_session() + try: + return session.query(KnowledgeBase).order_by(KnowledgeBase.created_at.desc()).all() + finally: + session.close() + + def import_database(self, source_db_path: str) -> tuple[int, int]: + """ + 导入数据库,将其中的内容合并到当前数据库 + + :param source_db_path: 源数据库文件路径 + :return: (imported_kb_count, imported_doc_count) + """ + source_db = Database(source_db_path) + source_session = source_db.get_session() + + try: + # 读取源数据库的知识库 + source_kbs = source_session.query(KnowledgeBase).all() + if not source_kbs: + return 0, 0 + + # 读取源数据库的文档 + source_docs = source_session.query(Document).all() + + # 合并到当前数据库 + target_session = self.get_session() + + try: + imported_kb_count = 0 + imported_doc_count = 0 + + for source_kb in source_kbs: + # 检查知识库是否已存在,如果存在则生成唯一名称 + kb_name = source_kb.name + existing_kb = self.get_knowledge_base(kb_name) + if existing_kb: + # 生成唯一名称 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + counter = 1 + unique_kb_name = f"{kb_name}_{timestamp}" + while self.get_knowledge_base(unique_kb_name): + unique_kb_name = f"{kb_name}_{timestamp}_{counter}" + counter += 1 + kb_name = unique_kb_name + + # 导入知识库 + new_kb_id = str(uuid.uuid4()) + if self.add_knowledge_base(new_kb_id, kb_name, source_kb.chunk_size, + source_kb.embedding_model, source_kb.embedding_endpoint, + source_kb.embedding_api_key): + imported_kb_count += 1 + + # 导入该知识库下的文档 + kb_docs = [doc for doc in source_docs if doc.kb_id == source_kb.id] + manager = DocumentManager(target_session) + + for source_doc in kb_docs: + # 检查文档是否已存在,如果存在则生成唯一名称 + doc_name = source_doc.name + existing_doc = manager.get_document(new_kb_id, doc_name) + if existing_doc: + # 生成唯一名称 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # 分离文件名和扩展名 + if '.' in doc_name: + name_part, ext_part = doc_name.rsplit('.', 1) + unique_doc_name = f"{name_part}_{timestamp}.{ext_part}" + else: + unique_doc_name = f"{doc_name}_{timestamp}" + + # 如果新名称仍然存在,继续添加后缀 + counter = 1 + final_doc_name = unique_doc_name + while manager.get_document(new_kb_id, final_doc_name): + if '.' in doc_name: + name_part, ext_part = doc_name.rsplit('.', 1) + final_doc_name = f"{name_part}_{timestamp}_{counter}.{ext_part}" + else: + final_doc_name = f"{doc_name}_{timestamp}_{counter}" + counter += 1 + doc_name = final_doc_name + + # 导入文档 + new_doc_id = str(uuid.uuid4()) + if manager.add_document(new_doc_id, new_kb_id, doc_name, + source_doc.file_path, source_doc.file_type, + source_doc.content, source_doc.chunk_size): + imported_doc_count += 1 + + # 导入chunks(包含向量) + source_chunks = source_session.query(Chunk).filter_by(doc_id=source_doc.id).all() + for source_chunk in source_chunks: + new_chunk_id = str(uuid.uuid4()) + # 提取向量(如果存在) + embedding = None + if source_chunk.embedding: + embedding_bytes = source_chunk.embedding + if len(embedding_bytes) > 0 and len(embedding_bytes) % 4 == 0: + embedding = list(struct.unpack(f'{len(embedding_bytes)//4}f', embedding_bytes)) + + manager.add_chunk(new_chunk_id, new_doc_id, source_chunk.content, + source_chunk.tokens, source_chunk.chunk_index, embedding) + return imported_kb_count, imported_doc_count + finally: + target_session.close() + finally: + source_session.close() + source_db = None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/document_manager.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/document_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5186ee241584f7dc8b8d5cc6a6fc2d4c40f02238 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/document_manager.py @@ -0,0 +1,394 @@ +""" +文档操作模块 - 使用 SQLAlchemy ORM +""" +import os +import struct +import uuid +import asyncio +from typing import List, Optional, Tuple +from datetime import datetime +import logging +from sqlalchemy import text +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from base.models import Document, Chunk +from base.embedding import Embedding +from base.parser.parser import Parser +from base.token_tool import TokenTool +import jieba + +logger = logging.getLogger(__name__) + + +class DocumentManager: + """文档操作管理器""" + + def __init__(self, session: Session): + """ + 初始化文档管理器 + :param session: 数据库会话 + """ + self.session = session + + def add_document(self, doc_id: str, kb_id: str, name: str, file_path: str, + file_type: str, content: Optional[str] = None, chunk_size: Optional[int] = None) -> bool: + """添加文档""" + try: + document = Document( + id=doc_id, + kb_id=kb_id, + name=name, + file_path=file_path, + file_type=file_type, + content=content, + chunk_size=chunk_size, + updated_at=datetime.now() + ) + self.session.add(document) + self.session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 添加文档失败: {e}") + self.session.rollback() + return False + + def delete_document(self, kb_id: str, doc_name: str) -> bool: + """删除文档(级联删除相关chunks)""" + try: + doc = self.session.query(Document).filter_by(kb_id=kb_id, name=doc_name).first() + if doc: + self.session.delete(doc) + self.session.commit() + return True + return False + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 删除文档失败: {e}") + self.session.rollback() + return False + + def get_document(self, kb_id: str, doc_name: str) -> Optional[Document]: + """获取文档""" + return self.session.query(Document).filter_by(kb_id=kb_id, name=doc_name).first() + + def list_documents_by_kb(self, kb_id: str) -> List[Document]: + """列出知识库下的所有文档""" + return self.session.query(Document).filter_by(kb_id=kb_id).order_by(Document.created_at.desc()).all() + + def add_chunk(self, chunk_id: str, doc_id: str, content: str, tokens: int, chunk_index: int, + embedding: Optional[List[float]] = None) -> bool: + """添加 chunk(可包含向量)""" + try: + embedding_bytes = None + if embedding: + embedding_bytes = struct.pack(f'{len(embedding)}f', *embedding) + + chunk = Chunk( + id=chunk_id, + doc_id=doc_id, + content=content, + tokens=tokens, + chunk_index=chunk_index, + embedding=embedding_bytes + ) + self.session.add(chunk) + self.session.flush() + + # 添加 FTS5 索引(需要使用原生 SQL) + fts_content = self._prepare_fts_content(content) + self.session.execute(text(""" + INSERT INTO chunks_fts (id, content) + VALUES (:chunk_id, :content) + """), {"chunk_id": chunk_id, "content": fts_content}) + + # 检查并更新 vec_index(需要使用原生 SQL) + if embedding_bytes: + conn = self.session.connection() + result = conn.execute(text(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='vec_index' + """)) + if result.fetchone(): + result = conn.execute(text(""" + SELECT rowid FROM chunks WHERE id = :chunk_id + """), {"chunk_id": chunk_id}) + row = result.fetchone() + if row: + vec_rowid = row[0] + # 先删除可能存在的旧记录,避免 UNIQUE constraint 冲突 + conn.execute(text(""" + DELETE FROM vec_index WHERE rowid = :rowid + """), {"rowid": vec_rowid}) + # 然后插入新记录 + conn.execute(text(""" + INSERT INTO vec_index(rowid, embedding) + VALUES (:rowid, :embedding) + """), {"rowid": vec_rowid, "embedding": embedding_bytes}) + + self.session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 添加chunk失败: {e}") + self.session.rollback() + return False + + def _prepare_fts_content(self, content: str) -> str: + """ + 准备 FTS5 内容(对中文进行 jieba 分词) + :param content: 原始内容 + :return: 分词后的内容(用空格连接) + """ + try: + words = jieba.cut(content) + words = [word.strip() for word in words if word.strip()] + return ' '.join(words) + except Exception: + return content + + def update_chunk_embedding(self, chunk_id: str, embedding: List[float]) -> bool: + """更新 chunk 的向量""" + try: + embedding_bytes = struct.pack(f'{len(embedding)}f', *embedding) + + chunk = self.session.query(Chunk).filter_by(id=chunk_id).first() + if not chunk: + return False + + chunk.embedding = embedding_bytes + self.session.flush() + + # 检查并更新 vec_index(需要使用原生 SQL) + conn = self.session.connection() + result = conn.execute(text(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='vec_index' + """)) + if result.fetchone(): + result = conn.execute(text(""" + SELECT rowid FROM chunks WHERE id = :chunk_id + """), {"chunk_id": chunk_id}) + row = result.fetchone() + if row: + vec_rowid = row[0] + # 先删除可能存在的旧记录,避免 UNIQUE constraint 冲突 + conn.execute(text(""" + DELETE FROM vec_index WHERE rowid = :rowid + """), {"rowid": vec_rowid}) + # 然后插入新记录 + conn.execute(text(""" + INSERT INTO vec_index(rowid, embedding) + VALUES (:rowid, :embedding) + """), {"rowid": vec_rowid, "embedding": embedding_bytes}) + + self.session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 更新chunk向量失败: {e}") + self.session.rollback() + return False + + def delete_document_chunks(self, doc_id: str) -> None: + """删除文档的所有chunks""" + chunks = self.session.query(Chunk).filter_by(doc_id=doc_id).all() + conn = self.session.connection() + for chunk in chunks: + # 删除FTS5索引 + conn.execute(text(""" + DELETE FROM chunks_fts WHERE id = :chunk_id + """), {"chunk_id": chunk.id}) + # 删除向量索引(如果chunk有向量) + if chunk.embedding: + result = conn.execute(text(""" + SELECT rowid FROM chunks WHERE id = :chunk_id + """), {"chunk_id": chunk.id}) + row = result.fetchone() + if row: + conn.execute(text(""" + DELETE FROM vec_index WHERE rowid = :rowid + """), {"rowid": row[0]}) + # 删除chunk + self.session.delete(chunk) + self.session.commit() + + def update_document_content(self, doc_id: str, content: str, chunk_size: int) -> None: + """更新文档的content和chunk_size""" + doc = self.session.query(Document).filter_by(id=doc_id).first() + if doc: + doc.chunk_size = chunk_size + doc.content = content + doc.updated_at = datetime.now() + self.session.commit() + + +def _generate_unique_name(base_name: str, check_exists_func) -> str: + """ + 生成唯一名称,如果已存在则添加时间戳 + + :param base_name: 基础名称 + :param check_exists_func: 检查是否存在的函数,接受名称参数,返回是否存在 + :return: 唯一名称 + """ + if not check_exists_func(base_name): + return base_name + + # 如果已存在,添加时间戳 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # 分离文件名和扩展名 + if '.' in base_name: + name_part, ext_part = base_name.rsplit('.', 1) + new_name = f"{name_part}_{timestamp}.{ext_part}" + else: + new_name = f"{base_name}_{timestamp}" + + # 如果新名称仍然存在,继续添加后缀 + counter = 1 + final_name = new_name + while check_exists_func(final_name): + if '.' in base_name: + name_part, ext_part = base_name.rsplit('.', 1) + final_name = f"{name_part}_{timestamp}_{counter}.{ext_part}" + else: + final_name = f"{base_name}_{timestamp}_{counter}" + counter += 1 + + return final_name + + +async def import_document(session: Session, kb_id: str, file_path: str, + chunk_size: int) -> Tuple[bool, str, Optional[dict]]: + """ + 导入文档(异步) + + :param session: 数据库会话 + :param kb_id: 知识库ID + :param file_path: 文件路径 + :param chunk_size: chunk大小 + :return: (success, message, data) + """ + try: + doc_name = os.path.basename(file_path) + content = Parser.parse(file_path) + if not content: + return False, "文档解析失败", None + + chunks = TokenTool.split_content_to_chunks(content, chunk_size) + if not chunks: + return False, "文档内容为空", None + + manager = DocumentManager(session) + + # 检查文档是否已存在,如果存在则生成唯一名称 + def check_doc_exists(name: str) -> bool: + return manager.get_document(kb_id, name) is not None + + unique_doc_name = _generate_unique_name(doc_name, check_doc_exists) + + doc_id = str(uuid.uuid4()) + file_type = file_path.lower().split('.')[-1] + + if not manager.add_document(doc_id, kb_id, unique_doc_name, file_path, file_type, content, chunk_size): + return False, "添加文档失败", None + + chunk_ids = [] + chunk_data = [] + + # 先收集所有chunk数据 + for idx, chunk_content in enumerate(chunks): + chunk_id = str(uuid.uuid4()) + tokens = TokenTool.get_tokens(chunk_content) + chunk_data.append((chunk_id, chunk_content, tokens, idx)) + + # 批量生成向量(异步) + embeddings_list = [None] * len(chunk_data) + if Embedding.is_configured() and chunk_data: + try: + chunk_contents = [content for _, content, _, _ in chunk_data] + embeddings_list = await Embedding.vectorize_embeddings_batch(chunk_contents, max_concurrent=5) + except Exception as e: + logger.warning(f"批量生成向量失败: {e}") + + # 添加chunks(包含向量) + for (chunk_id, chunk_content, tokens, idx), embedding in zip(chunk_data, embeddings_list): + if manager.add_chunk(chunk_id, doc_id, chunk_content, tokens, idx, embedding): + chunk_ids.append(chunk_id) + + return True, f"成功导入文档,共 {len(chunk_ids)} 个 chunks", { + "doc_id": doc_id, + "doc_name": unique_doc_name, + "original_name": doc_name if unique_doc_name != doc_name else None, + "chunk_count": len(chunk_ids), + "file_path": file_path + } + except Exception as e: + logger.exception(f"[import_document] 导入文档失败: {e}") + return False, "导入文档失败", None + + +async def update_document(session: Session, kb_id: str, doc_name: str, chunk_size: int) -> Tuple[bool, str, Optional[dict]]: + """ + 更新文档的chunk_size并重新解析(异步) + + :param session: 数据库会话 + :param kb_id: 知识库ID + :param doc_name: 文档名称 + :param chunk_size: 新的chunk大小 + :return: (success, message, data) + """ + try: + manager = DocumentManager(session) + doc = manager.get_document(kb_id, doc_name) + if not doc: + return False, f"文档 '{doc_name}' 不存在", None + + # 删除旧文档的所有chunks + manager.delete_document_chunks(doc.id) + + # 重新解析文档 + if not doc.file_path or not os.path.exists(doc.file_path): + return False, "文档文件不存在", None + + content = Parser.parse(doc.file_path) + if not content: + return False, "文档解析失败", None + + chunks = TokenTool.split_content_to_chunks(content, chunk_size) + if not chunks: + return False, "文档内容为空", None + + # 收集所有chunk数据 + chunk_ids = [] + chunk_data = [] + + for idx, chunk_content in enumerate(chunks): + chunk_id = str(uuid.uuid4()) + tokens = TokenTool.get_tokens(chunk_content) + chunk_data.append((chunk_id, chunk_content, tokens, idx)) + + # 批量生成向量(异步) + embeddings_list = [None] * len(chunk_data) + if Embedding.is_configured() and chunk_data: + try: + chunk_contents = [content for _, content, _, _ in chunk_data] + embeddings_list = await Embedding.vectorize_embeddings_batch(chunk_contents, max_concurrent=5) + except Exception as e: + logger.warning(f"批量生成向量失败: {e}") + + # 添加chunks(包含向量) + for (chunk_id, chunk_content, tokens, idx), embedding in zip(chunk_data, embeddings_list): + if manager.add_chunk(chunk_id, doc.id, chunk_content, tokens, idx, embedding): + chunk_ids.append(chunk_id) + + # 更新文档的chunk_size和content + manager.update_document_content(doc.id, content, chunk_size) + + return True, f"成功修改文档,共 {len(chunk_ids)} 个 chunks", { + "doc_id": doc.id, + "doc_name": doc_name, + "chunk_count": len(chunk_ids), + "chunk_size": chunk_size + } + except Exception as e: + logger.exception(f"[update_document] 修改文档失败: {e}") + return False, "修改文档失败", None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/models.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/models.py new file mode 100644 index 0000000000000000000000000000000000000000..4b197d4418476acfdb69e996a0ae05f08a31b32d --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/models.py @@ -0,0 +1,79 @@ +from sqlalchemy import ( + Column, String, Integer, Text, DateTime, ForeignKey, + LargeBinary, Index, func +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship +from datetime import datetime + +Base = declarative_base() + + +class KnowledgeBase(Base): + """知识库表""" + __tablename__ = 'knowledge_bases' + + id = Column(String, primary_key=True) + name = Column(String, nullable=False, unique=True) + chunk_size = Column(Integer, nullable=False) + embedding_model = Column(Text) + embedding_endpoint = Column(Text) + embedding_api_key = Column(Text) + created_at = Column(DateTime, default=datetime.now, server_default=func.current_timestamp()) + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, server_default=func.current_timestamp()) + + # 关系 + documents = relationship("Document", back_populates="knowledge_base", cascade="all, delete-orphan") + + # 索引 + __table_args__ = ( + Index('idx_kb_name', 'name'), + ) + + +class Document(Base): + """文档表""" + __tablename__ = 'documents' + + id = Column(String, primary_key=True) + kb_id = Column(String, ForeignKey('knowledge_bases.id', ondelete='CASCADE'), nullable=False) + name = Column(String, nullable=False) + file_path = Column(Text) + file_type = Column(String) + content = Column(Text) # 文档完整内容 + chunk_size = Column(Integer) + created_at = Column(DateTime, default=datetime.now, server_default=func.current_timestamp()) + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, server_default=func.current_timestamp()) + + # 关系 + knowledge_base = relationship("KnowledgeBase", back_populates="documents") + chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan") + + # 索引 + __table_args__ = ( + Index('idx_doc_kb_id', 'kb_id'), + Index('idx_doc_name', 'name'), + ) + + +class Chunk(Base): + """文档分块表""" + __tablename__ = 'chunks' + + id = Column(String, primary_key=True) + doc_id = Column(String, ForeignKey('documents.id', ondelete='CASCADE'), nullable=False) + content = Column(Text, nullable=False) + tokens = Column(Integer) + chunk_index = Column(Integer) + embedding = Column(LargeBinary) # 向量嵌入 + created_at = Column(DateTime, default=datetime.now, server_default=func.current_timestamp()) + + # 关系 + document = relationship("Document", back_populates="chunks") + + # 索引 + __table_args__ = ( + Index('idx_chunk_doc_id', 'doc_id'), + Index('idx_chunk_index', 'chunk_index'), + ) + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/doc.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/doc.py new file mode 100644 index 0000000000000000000000000000000000000000..7f5cf90f9e52fc0f548bf11b5883eb406741a83e --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/doc.py @@ -0,0 +1,61 @@ +import logging +from typing import Optional +from docx import Document as DocxDocument + +logger = logging.getLogger(__name__) + + +def parse_docx(file_path: str) -> Optional[str]: + """ + 解析 DOCX 文件 + :param file_path: 文件路径 + :return: 文件内容 + """ + try: + doc = DocxDocument(file_path) + if not doc: + logger.error("[DocParser] 无法打开docx文件") + return None + + paragraphs = [] + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + paragraphs.append(paragraph.text) + + for table in doc.tables: + for row in table.rows: + for cell in row.cells: + if cell.text.strip(): + paragraphs.append(cell.text) + + content = '\n'.join(paragraphs) + return content + except Exception as e: + logger.exception(f"[DocParser] 解析DOCX文件失败: {e}") + return None + + +def parse_doc(file_path: str) -> Optional[str]: + """ + 解析 DOC 文件(旧版 Word 格式) + :param file_path: 文件路径 + :return: 文件内容 + """ + try: + doc = DocxDocument(file_path) + paragraphs = [] + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + paragraphs.append(paragraph.text) + for table in doc.tables: + for row in table.rows: + for cell in row.cells: + if cell.text.strip(): + paragraphs.append(cell.text) + content = '\n'.join(paragraphs) + return content + except Exception: + logger.warning("[DocParser] python-docx 不支持 DOC 格式,尝试其他方法") + logger.warning("[DocParser] DOC 格式解析需要额外工具,当前仅支持 DOCX") + return None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/parser.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5fba0955e081cdd68546ef8d7b4c7e5425d4afc7 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/parser.py @@ -0,0 +1,59 @@ +""" +文档解析器模块 +""" +import logging +from typing import Optional, Dict + +logger = logging.getLogger(__name__) + +from base.parser.txt import parse_txt +from base.parser.doc import parse_docx, parse_doc +from base.parser.pdf import parse_pdf + +_parsers: Dict[str, callable] = {} + + +def register_parser(file_ext: str, parser_func: callable): + """ + 注册解析器 + :param file_ext: 文件扩展名(如 'txt', 'docx') + :param parser_func: 解析函数,接收 file_path 参数,返回 Optional[str] + """ + _parsers[file_ext.lower()] = parser_func + logger.debug(f"[Parser] 注册解析器: {file_ext}") + + +def parse(file_path: str) -> Optional[str]: + """ + 根据文件类型自动选择解析器 + :param file_path: 文件路径 + :return: 文件内容 + """ + file_ext = file_path.lower().split('.')[-1] + + if file_ext not in _parsers: + logger.error(f"[Parser] 不支持的文件类型: {file_ext}") + return None + + try: + parser_func = _parsers[file_ext] + return parser_func(file_path) + except Exception as e: + logger.exception(f"[Parser] 解析文件失败: {file_path}, {e}") + return None + + +# 注册解析器 +register_parser('txt', parse_txt) +register_parser('docx', parse_docx) +register_parser('doc', parse_doc) +register_parser('pdf', parse_pdf) + + +class Parser: + """文档解析器类""" + + @staticmethod + def parse(file_path: str) -> Optional[str]: + return parse(file_path) + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/pdf.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/pdf.py new file mode 100644 index 0000000000000000000000000000000000000000..a3549c4e9fad101992d4c69a8c064933c1e8bbe3 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/pdf.py @@ -0,0 +1,80 @@ +""" +PDF 文件解析器 +使用 PyMuPDF (fitz) 提取 PDF 中的文本内容 +""" +import logging +from typing import Optional +import fitz + +logger = logging.getLogger(__name__) + + +def parse_pdf(file_path: str) -> Optional[str]: + """ + 解析 PDF 文件,提取文本内容 + + :param file_path: PDF 文件路径 + :return: 提取的文本内容,如果失败则返回 None + """ + try: + # 打开 PDF 文件 + pdf_doc = fitz.open(file_path) + + if not pdf_doc: + logger.error("[PdfParser] 无法打开 PDF 文件") + return None + + text_blocks = [] + + # 遍历每一页 + for page_num in range(len(pdf_doc)): + page = pdf_doc.load_page(page_num) + + # 获取文本块 + blocks = page.get_text("blocks") + + # 提取文本块内容 + for block in blocks: + if block[6] == 0: # 确保是文本块(block[6] == 0 表示文本块) + text = block[4].strip() # block[4] 是文本内容 + if text: + # 保存文本和位置信息用于排序 + bbox = block[:4] # (x0, y0, x1, y1) + text_blocks.append({ + 'text': text, + 'y0': bbox[1], # 上边界,用于排序 + 'x0': bbox[0] # 左边界,用于排序 + }) + + # 关闭 PDF 文档 + pdf_doc.close() + + if not text_blocks: + logger.warning("[PdfParser] PDF 文件中没有找到文本内容") + return None + + # 按位置排序(从上到下,从左到右) + text_blocks.sort(key=lambda x: (x['y0'], x['x0'])) + + # 合并文本块,添加换行 + paragraphs = [] + prev_y0 = None + + for block in text_blocks: + text = block['text'] + y0 = block['y0'] + + # 如果当前块与上一个块在垂直方向上有较大距离,添加换行 + if prev_y0 is not None and y0 - prev_y0 > 10: # 10 像素的阈值,表示新段落 + paragraphs.append('') + + paragraphs.append(text) + prev_y0 = y0 + + content = '\n'.join(paragraphs) + return content + + except Exception as e: + logger.exception(f"[PdfParser] 解析 PDF 文件失败: {e}") + return None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/txt.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/txt.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed639ff4cc5fb0085782a785ea3edb7b46511b7 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/parser/txt.py @@ -0,0 +1,30 @@ +import chardet +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +def detect_encoding(file_path: str) -> str: + try: + with open(file_path, 'rb') as file: + raw_data = file.read() + result = chardet.detect(raw_data) + encoding = result['encoding'] + if encoding is None: + encoding = 'utf-8' + return encoding + except Exception as e: + logger.exception(f"[TxtParser] 检测编码失败: {e}") + return 'utf-8' + + +def parse_txt(file_path: str) -> Optional[str]: + try: + encoding = detect_encoding(file_path) + with open(file_path, 'r', encoding=encoding, errors='ignore') as file: + content = file.read() + return content + except Exception as e: + logger.exception(f"[TxtParser] 解析TXT文件失败: {e}") + return None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/rerank.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..1e004ed03503a59470db9577594bc798a19b0069 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/rerank.py @@ -0,0 +1,64 @@ +import jieba +import logging +from typing import List, Dict, Any + +logger = logging.getLogger(__name__) + +class Rerank: + """Rerank 类(使用 Jaccard 相似度)""" + + stopwords = set(['的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说', '要', '去', '你', '会', '着', '没有', '看', '好', '自己', '这']) + + @staticmethod + def split_words(content: str) -> List[str]: + try: + return list(jieba.cut(str(content))) + except Exception: + return [] + + @staticmethod + def cal_jaccard(str1: str, str2: str) -> float: + try: + if len(str1) == 0 and len(str2) == 0: + return 100.0 + + words1 = Rerank.split_words(str1) + words2 = Rerank.split_words(str2) + + new_words1 = [word for word in words1 if word not in Rerank.stopwords and word.strip()] + new_words2 = [word for word in words2 if word not in Rerank.stopwords and word.strip()] + + if len(new_words1) == 0 or len(new_words2) == 0: + return 0.0 + + set1 = set(new_words1) + set2 = set(new_words2) + intersection = len(set1.intersection(set2)) + union = len(set1.union(set2)) + + if union == 0: + return 0.0 + + score = intersection / union * 100.0 + return score + except Exception: + return 0.0 + + @staticmethod + def rerank_chunks(chunks: List[Dict[str, Any]], query: str) -> List[Dict[str, Any]]: + try: + score_chunks = [] + for chunk in chunks: + content = chunk.get('content', '') + score = Rerank.cal_jaccard(content, query) + chunk['jaccard_score'] = score + score_chunks.append((score, chunk)) + + # 按 Jaccard 分数降序排序 + score_chunks.sort(key=lambda x: x[0], reverse=True) + sorted_chunks = [chunk for _, chunk in score_chunks] + + return sorted_chunks + except Exception: + return chunks + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/keyword.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/keyword.py new file mode 100644 index 0000000000000000000000000000000000000000..d1994d0b3be3ae46aff3d84961d1870891ba3e5c --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/keyword.py @@ -0,0 +1,92 @@ +""" +关键词检索模块 - 使用 SQLAlchemy +""" +import logging +from typing import List, Dict, Any, Optional +from sqlalchemy import text +import jieba + +logger = logging.getLogger(__name__) + + +def _prepare_fts_query(query: str) -> str: + """ + 准备 FTS5 查询 + :param query: 原始查询文本 + :return: FTS5 查询字符串 + """ + def escape_fts_word(word: str) -> str: + # 包含以下任意字符时,整体作为短语用双引号包裹,避免触发 FTS5 语法解析 + # 特别是 '%' 在 FTS5 MATCH 语法中会导致 "syntax error near '%'" + special_chars = [ + '"', "'", '(', ')', '*', ':', '?', '+', '-', '|', '&', + '{', '}', '[', ']', '^', '$', '\\', '/', '!', '~', ';', + ',', '.', ' ', '%' + ] + if any(char in word for char in special_chars): + escaped_word = word.replace('"', '""') + return f'"{escaped_word}"' + return word + + try: + words = jieba.cut(query) + words = [word.strip() for word in words if word.strip()] + if not words: + return escape_fts_word(query) + + escaped_words = [escape_fts_word(word) for word in words] + fts_query = ' OR '.join(escaped_words) + return fts_query + except Exception: + return escape_fts_word(query) + + +def search_by_keyword(conn, query: str, top_k: int = 5, doc_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """ + 关键词检索(FTS5,使用 jieba 对中文进行分词) + :param conn: 数据库连接对象(SQLAlchemy Connection) + :param query: 查询文本 + :param top_k: 返回数量 + :param doc_ids: 可选的文档ID列表,用于过滤 + :return: chunk 列表 + """ + try: + fts_query = _prepare_fts_query(query) + + params = {"fts_query": fts_query, "top_k": top_k} + where_clause = "WHERE chunks_fts MATCH :fts_query" + + if doc_ids: + placeholders = ','.join([f':doc_id_{i}' for i in range(len(doc_ids))]) + for i, doc_id in enumerate(doc_ids): + params[f'doc_id_{i}'] = doc_id + where_clause += f" AND c.doc_id IN ({placeholders})" + + sql = f""" + SELECT c.id, c.doc_id, c.content, c.tokens, c.chunk_index, + d.name as doc_name, + chunks_fts.rank + FROM chunks_fts + JOIN chunks c ON c.id = chunks_fts.id + JOIN documents d ON d.id = c.doc_id + {where_clause} + ORDER BY chunks_fts.rank + LIMIT :top_k + """ + result = conn.execute(text(sql), params) + + results = [] + for row in result: + results.append({ + 'id': row.id, + 'doc_id': row.doc_id, + 'content': row.content, + 'tokens': row.tokens, + 'chunk_index': row.chunk_index, + 'doc_name': row.doc_name, + 'score': row.rank if row.rank is not None else 0.0 + }) + return results + except Exception as e: + logger.exception(f"[KeywordSearch] 关键词检索失败: {e}") + return [] diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/vector.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..179423caabe42a6fca900affe490cdd46fcd5036 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/vector.py @@ -0,0 +1,67 @@ +""" +向量检索模块 - 使用 SQLAlchemy +""" +import logging +import struct +from typing import List, Dict, Any, Optional +from sqlalchemy import text + +logger = logging.getLogger(__name__) + + +def search_by_vector(conn, query_vector: List[float], top_k: int = 5, doc_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """ + 向量检索 + :param conn: 数据库连接对象(SQLAlchemy Connection) + :param query_vector: 查询向量 + :param top_k: 返回数量 + :param doc_ids: 可选的文档ID列表,用于过滤 + :return: chunk 列表 + """ + try: + # 检查 vec_index 表是否存在 + result = conn.execute(text(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='vec_index' + """)) + if not result.fetchone(): + return [] + + query_vector_bytes = struct.pack(f'{len(query_vector)}f', *query_vector) + + params = {"query_vector": query_vector_bytes, "top_k": top_k} + where_clause = "WHERE v.embedding MATCH :query_vector AND k = :top_k" + + if doc_ids: + placeholders = ','.join([f':doc_id_{i}' for i in range(len(doc_ids))]) + for i, doc_id in enumerate(doc_ids): + params[f'doc_id_{i}'] = doc_id + where_clause += f" AND c.doc_id IN ({placeholders})" + + sql = f""" + SELECT c.id, c.doc_id, c.content, c.tokens, c.chunk_index, + d.name as doc_name, + distance + FROM vec_index v + JOIN chunks c ON c.rowid = v.rowid + JOIN documents d ON d.id = c.doc_id + {where_clause} + ORDER BY distance + """ + result = conn.execute(text(sql), params) + + results = [] + for row in result: + results.append({ + 'id': row.id, + 'doc_id': row.doc_id, + 'content': row.content, + 'tokens': row.tokens, + 'chunk_index': row.chunk_index, + 'doc_name': row.doc_name, + 'score': row.distance + }) + return results + except Exception as e: + logger.exception(f"[VectorSearch] 向量检索失败: {e}") + return [] diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/weighted_keyword_and_vector_search.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/weighted_keyword_and_vector_search.py new file mode 100644 index 0000000000000000000000000000000000000000..f824151409345d2fa67aa4c593938e0f5dba1495 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/weighted_keyword_and_vector_search.py @@ -0,0 +1,122 @@ +import logging +import asyncio +from typing import List, Dict, Any, Optional +from base.search.keyword import search_by_keyword as keyword_search +from base.search.vector import search_by_vector as vector_search +from base.embedding import Embedding +from base.rerank import Rerank + +logger = logging.getLogger(__name__) + + +async def weighted_keyword_and_vector_search( + conn, + query: str, + top_k: int = 5, + weight_keyword: float = 0.3, + weight_vector: float = 0.7, + doc_ids: Optional[List[str]] = None +) -> List[Dict[str, Any]]: + """ + 加权关键词和向量混合检索(异步) + + :param conn: 数据库连接对象(SQLAlchemy Connection) + :param query: 查询文本 + :param top_k: 返回数量 + :param weight_keyword: 关键词搜索权重 + :param weight_vector: 向量搜索权重 + :return: 合并后的 chunk 列表 + """ + try: + # 同时进行关键词和向量搜索,每个获取 2*topk 个结果 + keyword_chunks = [] + vector_chunks = [] + + # 关键词搜索 + try: + keyword_chunks = keyword_search(conn, query, 2 * top_k, doc_ids) + except Exception as e: + logger.warning(f"[WeightedSearch] 关键词检索失败: {e}") + + # 向量搜索(需要 embedding 配置) + if Embedding.is_configured(): + try: + query_vector = await Embedding.vectorize_embedding(query) + if query_vector: + vector_chunks = vector_search(conn, query_vector, 2 * top_k, doc_ids) + except Exception as e: + logger.warning(f"[WeightedSearch] 向量检索失败: {e}") + + # 如果没有结果 + if not keyword_chunks and not vector_chunks: + return [] + + # 归一化并合并结果 + merged_chunks = {} + + # 处理关键词搜索结果 + if keyword_chunks: + # 归一化 rank 分数(rank 越小越好,转换为越大越好) + keyword_scores = [chunk.get('score', 0.0) for chunk in keyword_chunks if chunk.get('score') is not None] + if keyword_scores: + min_rank = min(keyword_scores) + max_rank = max(keyword_scores) + rank_range = max_rank - min_rank + + for chunk in keyword_chunks: + chunk_id = chunk['id'] + rank = chunk.get('score', 0.0) + # 转换为越大越好的分数(归一化到 0-1) + if rank_range > 0: + normalized_score = 1.0 - ((rank - min_rank) / rank_range) + else: + normalized_score = 1.0 + weighted_score = normalized_score * weight_keyword + + if chunk_id not in merged_chunks: + merged_chunks[chunk_id] = chunk.copy() + merged_chunks[chunk_id]['score'] = weighted_score + else: + merged_chunks[chunk_id]['score'] += weighted_score + + # 处理向量搜索结果 + if vector_chunks: + # 归一化 distance 分数(distance 越小越好,转换为越大越好) + vector_scores = [chunk.get('score', 0.0) for chunk in vector_chunks if chunk.get('score') is not None] + if vector_scores: + min_distance = min(vector_scores) + max_distance = max(vector_scores) + distance_range = max_distance - min_distance + + for chunk in vector_chunks: + chunk_id = chunk['id'] + distance = chunk.get('score', 0.0) + # 转换为越大越好的分数(归一化到 0-1) + if distance_range > 0: + normalized_score = 1.0 - ((distance - min_distance) / distance_range) + else: + normalized_score = 1.0 + weighted_score = normalized_score * weight_vector + + if chunk_id not in merged_chunks: + merged_chunks[chunk_id] = chunk.copy() + merged_chunks[chunk_id]['score'] = weighted_score + else: + merged_chunks[chunk_id]['score'] += weighted_score + + # 转换为列表并按分数排序 + merged_list = list(merged_chunks.values()) + merged_list.sort(key=lambda x: x.get('score', 0.0), reverse=True) + + # Rerank + reranked_chunks = Rerank.rerank_chunks(merged_list, query) + + # 取前 top_k 个 + final_chunks = reranked_chunks[:top_k] + + return final_chunks + + except Exception as e: + logger.exception(f"[WeightedSearch] 混合检索失败: {e}") + return [] + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/token_tool.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/token_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..66bce6f177975cb03cdde675e9f6a4794f7c352f --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/token_tool.py @@ -0,0 +1,157 @@ +import tiktoken +import logging +import re +import uuid +from typing import List +from base.config import get_token_model, get_max_tokens + +logger = logging.getLogger(__name__) + + +class TokenTool: + """Token 工具类""" + + _encoding_cache = {} + + @staticmethod + def _get_encoding(): + """ + 获取编码器(带缓存) + :return: tiktoken 编码器 + """ + model = get_token_model() + if model not in TokenTool._encoding_cache: + try: + TokenTool._encoding_cache[model] = tiktoken.encoding_for_model(model) + except Exception: + TokenTool._encoding_cache[model] = tiktoken.get_encoding("cl100k_base") + return TokenTool._encoding_cache[model] + + @staticmethod + def get_tokens(content: str) -> int: + """ + 获取文本的 token 数量 + :param content: 文本内容 + :return: token 数量 + """ + try: + enc = TokenTool._get_encoding() + return len(enc.encode(str(content))) + except Exception: + return 0 + + @staticmethod + def get_k_tokens_words_from_content(content: str, k: int = 1024) -> str: + """ + 从内容中获取 k 个 token 的文本 + :param content: 文本内容 + :param k: token 数量 + :return: 截取后的文本 + """ + try: + if TokenTool.get_tokens(content) <= k: + return content + + # 使用二分查找找到合适的截取位置 + l = 0 + r = len(content) + while l + 1 < r: + mid = (l + r) // 2 + if TokenTool.get_tokens(content[:mid]) <= k: + l = mid + else: + r = mid + return content[:l] + except Exception: + return "" + + @staticmethod + def content_to_sentences(content: str) -> List[str]: + """ + 将内容分割为句子 + :param content: 文本内容 + :return: 句子列表 + """ + protected_phrases = [ + 'e.g.', 'i.e.', 'U.S.', 'U.K.', 'A.M.', 'P.M.', 'a.m.', 'p.m.', + 'Inc.', 'Ltd.', 'No.', 'vs.', 'approx.', 'Dr.', 'Mr.', 'Ms.', 'Prof.', + ] + + placeholder_map = {} + for phrase in protected_phrases: + placeholder = f"__PROTECTED_{uuid.uuid4().hex}__" + placeholder_map[placeholder] = phrase + content = content.replace(phrase, placeholder) + + # 分句正则模式 + chinese_punct = r'[。!?!?;;]' + right_quotes = r'["'""'】】》〕〉)\\]]' + pattern = re.compile( + rf'(?<={chinese_punct}{right_quotes})' + rf'|(?<={chinese_punct})(?=[^{right_quotes}])' + r'|(?<=[\.\?!;])(?=\s|$)' + ) + + # 分割并还原 + sentences = [] + for segment in pattern.split(content): + segment = segment.strip() + if not segment: + continue + for placeholder, original in placeholder_map.items(): + segment = segment.replace(placeholder, original) + sentences.append(segment) + + return sentences + + @staticmethod + def split_content_to_chunks(content: str, chunk_size: int = 1024) -> List[str]: + """ + 将内容切分为 chunks + :param content: 文本内容 + :param chunk_size: chunk 大小(token 数) + :return: chunk 列表 + """ + try: + sentences = TokenTool.content_to_sentences(content) + chunks = [] + current_chunk = "" + current_tokens = 0 + + for sentence in sentences: + sentence_tokens = TokenTool.get_tokens(sentence) + + # 如果单个句子超过 chunk_size,需要进一步切分 + if sentence_tokens > chunk_size: + if current_chunk: + chunks.append(current_chunk) + current_chunk = "" + current_tokens = 0 + + # 切分长句子 + sub_content = sentence + while TokenTool.get_tokens(sub_content) > chunk_size: + sub_chunk = TokenTool.get_k_tokens_words_from_content(sub_content, chunk_size) + chunks.append(sub_chunk) + sub_content = sub_content[len(sub_chunk):] + + if sub_content: + current_chunk = sub_content + current_tokens = TokenTool.get_tokens(sub_content) + else: + if current_tokens + sentence_tokens > chunk_size: + if current_chunk: + chunks.append(current_chunk) + current_chunk = sentence + current_tokens = sentence_tokens + else: + current_chunk += sentence + current_tokens += sentence_tokens + + if current_chunk: + chunks.append(current_chunk) + + return chunks + except Exception: + return [] + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/config.json b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/config.json new file mode 100644 index 0000000000000000000000000000000000000000..776f4c252380ca6abf07caa43931c3ebcea5e256 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/config.json @@ -0,0 +1,48 @@ +{ + "tools": { + "create_knowledge_base": { + "zh": "创建一个新的知识库。知识库是文档的容器,每个知识库可以有自己的chunk_size和embedding配置。创建后需要调用select_knowledge_base来选择该知识库才能使用。\n\n参数说明:\n- kb_name:知识库名称(必填,必须唯一)\n- chunk_size:chunk大小,单位token(必填,例如512、1024)\n- embedding_model:向量化模型名称(可选,例如text-embedding-ada-002)\n- embedding_endpoint:向量化服务端点URL(可选)\n- embedding_api_key:向量化服务API Key(可选)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含创建结果\n - kb_id:知识库ID\n - kb_name:知识库名称\n - chunk_size:chunk大小", + "en": "Create a new knowledge base. A knowledge base is a container for documents, and each knowledge base can have its own chunk_size and embedding configuration. After creation, you need to call select_knowledge_base to select this knowledge base before using it.\n\nParameters:\n- kb_name: Knowledge base name (required, must be unique)\n- chunk_size: Chunk size in tokens (required, e.g., 512, 1024)\n- embedding_model: Embedding model name (optional, e.g., text-embedding-ada-002)\n- embedding_endpoint: Embedding service endpoint URL (optional)\n- embedding_api_key: Embedding service API Key (optional)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing creation results\n - kb_id: Knowledge base ID\n - kb_name: Knowledge base name\n - chunk_size: Chunk size" + }, + "delete_knowledge_base": { + "zh": "删除指定的知识库。不能删除当前正在使用的知识库。删除知识库会级联删除该知识库下的所有文档和chunks。\n\n参数说明:\n- kb_name:知识库名称(必填)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含删除结果\n - kb_name:已删除的知识库名称", + "en": "Delete the specified knowledge base. Cannot delete the currently active knowledge base. Deleting a knowledge base will cascade delete all documents and chunks under it.\n\nParameters:\n- kb_name: Knowledge base name (required)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing deletion results\n - kb_name: Deleted knowledge base name" + }, + "list_knowledge_bases": { + "zh": "列出所有可用的知识库。返回所有知识库的详细信息,包括当前选中的知识库。\n\n参数说明:\n无参数\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含知识库列表\n - knowledge_bases:知识库列表,每个知识库包含:\n - id:知识库ID\n - name:知识库名称\n - chunk_size:chunk大小\n - embedding_model:向量化模型\n - created_at:创建时间\n - is_current:是否为当前选中的知识库\n - count:知识库数量\n - current_kb_id:当前选中的知识库ID", + "en": "List all available knowledge bases. Returns detailed information about all knowledge bases, including the currently selected one.\n\nParameters:\nNo parameters\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing knowledge base list\n - knowledge_bases: List of knowledge bases, each containing:\n - id: Knowledge base ID\n - name: Knowledge base name\n - chunk_size: Chunk size\n - embedding_model: Embedding model\n - created_at: Creation time\n - is_current: Whether this is the currently selected knowledge base\n - count: Number of knowledge bases\n - current_kb_id: Currently selected knowledge base ID" + }, + "select_knowledge_base": { + "zh": "选择一个知识库作为当前使用的知识库。选择后,后续的文档导入、查询等操作都会在该知识库中进行。\n\n参数说明:\n- kb_name:知识库名称(必填)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含选择结果\n - kb_id:知识库ID\n - kb_name:知识库名称\n - document_count:该知识库下的文档数量", + "en": "Select a knowledge base as the currently active one. After selection, subsequent operations such as document import and search will be performed in this knowledge base.\n\nParameters:\n- kb_name: Knowledge base name (required)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing selection results\n - kb_id: Knowledge base ID\n - kb_name: Knowledge base name\n - document_count: Number of documents in this knowledge base" + }, + "import_document": { + "zh": "导入文档到当前选中的知识库(支持多文件并发导入)。支持TXT、DOCX、DOC格式。文档会被解析、切分为chunks,并异步批量生成向量存储到数据库中。多个文档会并发处理,提高导入效率。如果文档名称已存在,会自动添加时间戳避免冲突。\n\n参数说明:\n- file_paths:文件路径列表(绝对路径),支持1~n个文件(必填)\n- chunk_size:chunk大小,单位token(可选,默认使用知识库的chunk_size)\n\n返回值:\n- success:布尔值,表示是否成功(只要有文件成功导入即为true)\n- message:字符串,描述操作结果(包含成功和失败的数量)\n- data:字典,包含导入结果\n - total:总文件数\n - success_count:成功导入的文件数\n - failed_count:失败的文件数\n - success_files:成功导入的文件列表,每个包含:\n - file_path:文件路径\n - doc_name:文档名称\n - chunk_count:chunk数量\n - failed_files:失败的文件列表,每个包含:\n - file_path:文件路径\n - error:错误信息", + "en": "Import documents into the currently selected knowledge base (supports concurrent import of multiple files). Supports TXT, DOCX, and DOC formats. Documents will be parsed, split into chunks, and vectors will be generated asynchronously in batch and stored in the database. Multiple documents are processed concurrently to improve import efficiency. If the document name already exists, a timestamp will be automatically added to avoid conflicts.\n\nParameters:\n- file_paths: List of file paths (absolute paths), supports 1~n files (required)\n- chunk_size: Chunk size in tokens (optional, defaults to knowledge base's chunk_size)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful (true if any file was successfully imported)\n- message: String, describing the operation result (includes counts of successful and failed imports)\n- data: Dictionary containing import results\n - total: Total number of files\n - success_count: Number of successfully imported files\n - failed_count: Number of failed files\n - success_files: List of successfully imported files, each containing:\n - file_path: File path\n - doc_name: Document name\n - chunk_count: Number of chunks\n - failed_files: List of failed files, each containing:\n - file_path: File path\n - error: Error message" + }, + "search": { + "zh": "在当前选中的知识库中进行混合检索。结合关键词检索(FTS5)和向量检索(sqlite-vec),使用加权方式合并结果(关键词权重0.3,向量权重0.7),去重后使用Jaccard相似度重排序,返回最相关的top-k个结果。\n\n参数说明:\n- query:查询文本(必填)\n- top_k:返回数量(可选,默认从配置读取,通常为5)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述检索结果\n- data:字典,包含检索结果\n - chunks:chunk列表,每个chunk包含:\n - id:chunk ID\n - doc_id:文档ID\n - content:chunk内容\n - tokens:token数量\n - chunk_index:chunk索引\n - doc_name:文档名称\n - score:综合检索分数\n - count:结果数量", + "en": "Perform hybrid search in the currently selected knowledge base. Combines keyword search (FTS5) and vector search (sqlite-vec), merges results using weighted approach (keyword weight 0.3, vector weight 0.7), deduplicates, reranks using Jaccard similarity, and returns the top-k most relevant results.\n\nParameters:\n- query: Query text (required)\n- top_k: Number of results to return (optional, default from config, usually 5)\n\nReturn value:\n- success: Boolean, indicating whether the search was successful\n- message: String, describing the search result\n- data: Dictionary containing search results\n - chunks: List of chunks, each containing:\n - id: Chunk ID\n - doc_id: Document ID\n - content: Chunk content\n - tokens: Number of tokens\n - chunk_index: Chunk index\n - doc_name: Document name\n - score: Combined search score\n - count: Number of results" + }, + "list_documents": { + "zh": "查看当前选中的知识库下的所有文档列表。返回文档的详细信息。\n\n参数说明:\n无参数\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含文档列表\n - documents:文档列表,每个文档包含:\n - id:文档ID\n - name:文档名称\n - file_path:文件路径\n - file_type:文件类型\n - chunk_size:chunk大小\n - created_at:创建时间\n - updated_at:更新时间\n - count:文档数量", + "en": "List all documents in the currently selected knowledge base. Returns detailed information about the documents.\n\nParameters:\nNo parameters\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing document list\n - documents: List of documents, each containing:\n - id: Document ID\n - name: Document name\n - file_path: File path\n - file_type: File type\n - chunk_size: Chunk size\n - created_at: Creation time\n - updated_at: Update time\n - count: Number of documents" + }, + "delete_document": { + "zh": "删除当前选中的知识库下的指定文档。删除文档会级联删除该文档的所有chunks。\n\n参数说明:\n- doc_name:文档名称(必填)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含删除结果\n - doc_name:已删除的文档名称", + "en": "Delete the specified document from the currently selected knowledge base. Deleting a document will cascade delete all chunks of that document.\n\nParameters:\n- doc_name: Document name (required)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing deletion results\n - doc_name: Deleted document name" + }, + "update_document": { + "zh": "修改文档的chunk_size并重新解析文档。会删除原有的chunks,使用新的chunk_size重新切分文档,并异步批量生成新的向量。\n\n参数说明:\n- doc_name:文档名称(必填)\n- chunk_size:新的chunk大小,单位token(必填)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含修改结果\n - doc_id:文档ID\n - doc_name:文档名称\n - chunk_count:新的chunk数量\n - chunk_size:新的chunk大小", + "en": "Update the document's chunk_size and re-parse the document. Will delete existing chunks, re-split the document using the new chunk_size, and asynchronously generate new vectors in batch.\n\nParameters:\n- doc_name: Document name (required)\n- chunk_size: New chunk size in tokens (required)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing update results\n - doc_id: Document ID\n - doc_name: Document name\n - chunk_count: New number of chunks\n - chunk_size: New chunk size" + }, + "export_database": { + "zh": "导出整个kb.db数据库文件到指定路径。\n\n参数说明:\n- export_path:导出路径(绝对路径,必填)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含导出结果\n - source_path:源数据库路径\n - export_path:导出路径", + "en": "Export the entire kb.db database file to the specified path.\n\nParameters:\n- export_path: Export path (absolute path, required)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing export results\n - source_path: Source database path\n - export_path: Export path" + }, + "import_database": { + "zh": "导入一个.db数据库文件,将其中的内容合并到kb.db中。导入时会自动处理重名冲突,为知识库和文档名称添加时间戳。\n\n参数说明:\n- source_db_path:源数据库文件路径(绝对路径,必填)\n\n返回值:\n- success:布尔值,表示是否成功\n- message:字符串,描述操作结果\n- data:字典,包含导入结果\n - source_path:源数据库路径\n - imported_kb_count:导入的知识库数量\n - imported_doc_count:导入的文档数量", + "en": "Import a .db database file and merge its contents into kb.db. Import will automatically handle name conflicts by adding timestamps to knowledge base and document names.\n\nParameters:\n- source_db_path: Source database file path (absolute path, required)\n\nReturn value:\n- success: Boolean, indicating whether the operation was successful\n- message: String, describing the operation result\n- data: Dictionary containing import results\n - source_path: Source database path\n - imported_kb_count: Number of imported knowledge bases\n - imported_doc_count: Number of imported documents" + } + } +} diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/deps.toml b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/deps.toml new file mode 100644 index 0000000000000000000000000000000000000000..1b44d05199626d5da47504a2e7105b183c657db9 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/deps.toml @@ -0,0 +1,15 @@ +[system] + +[pip] +tiktoken = ">=0.8.0" +python-docx = ">=1.1.0" +chardet = ">=5.2.0" +jieba = ">=0.42.1" +aiohttp = ">=3.9.0" +sqlite-vec = ">=0.1.6" +sqlalchemy = ">=2.0.0" +mcp = ">=0.1.0" +fastapi = ">=0.100.0" +uvicorn = ">=0.23.0" +python-multipart = ">=0.0.6" +PyMuPDF = ">=1.23.0" diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/rag_config.json b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/rag_config.json new file mode 100644 index 0000000000000000000000000000000000000000..4bed2cec5be31e6c8db495827792d17a48335847 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/rag_config.json @@ -0,0 +1,20 @@ +{ + "embedding": { + "type": "openai", + "api_key": "", + "endpoint": "https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings", + "model_name": "text-embedding-v4", + "timeout": 30, + "vector_dimension": 1024 + }, + "token": { + "model": "gpt-4", + "max_tokens": 8192, + "default_chunk_size": 1024 + }, + "search": { + "default_top_k": 5, + "max_top_k": 100 + } +} + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/requirements.txt b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0d6b3aebc80788ba362d45dfd785aa1c7b0d1c6e --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/requirements.txt @@ -0,0 +1,12 @@ +tiktoken>=0.8.0 +python-docx>=1.1.0 +chardet>=5.2.0 +jieba>=0.42.1 +aiohttp>=3.9.0 +sqlite-vec>=0.1.6 +sqlalchemy>=2.0.0 +mcp>=0.1.0 +fastapi>=0.100.0 +uvicorn>=0.23.0 +python-multipart>=0.0.6 +PyMuPDF>=1.23.0 diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/tool.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/tool.py new file mode 100644 index 0000000000000000000000000000000000000000..06f0a20e3de42d5ec333623aef7668bf9376e825 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/tool.py @@ -0,0 +1,601 @@ +import os +import sys +import uuid +import shutil +import logging +import asyncio +from typing import Optional, Dict, Any, List + +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.append(current_dir) + +from base.manager.database_manager import Database +from base.manager.document_manager import DocumentManager, import_document as _import_document, update_document as _update_document +from base.config import get_default_top_k +from base.models import KnowledgeBase +from base.search.weighted_keyword_and_vector_search import weighted_keyword_and_vector_search + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s' +) +logger = logging.getLogger(__name__) + +_db_instance: Optional[Database] = None +_db_path = os.path.join(current_dir, "database", "kb.db") +_current_kb_id: Optional[str] = None + + +def _get_db() -> Database: + """获取数据库实例(固定使用kb.db)""" + global _db_instance + if _db_instance is None: + db_dir = os.path.dirname(_db_path) + if not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + _db_instance = Database(_db_path) + return _db_instance + + +def _ensure_active_kb(result: Dict[str, Any]) -> Optional[str]: + """确保已选择知识库""" + if not _current_kb_id: + result["message"] = "请先选择知识库" + return None + return _current_kb_id + + +def create_knowledge_base( + kb_name: str, + chunk_size: int, + embedding_model: Optional[str] = None, + embedding_endpoint: Optional[str] = None, + embedding_api_key: Optional[str] = None +) -> Dict[str, Any]: + """ + 新增知识库 + + :param kb_name: 知识库名称 + :param chunk_size: chunk 大小(token 数) + :param embedding_model: 向量化模型名称(可选) + :param embedding_endpoint: 向量化服务端点(可选) + :param embedding_api_key: 向量化服务 API Key(可选) + :return: 创建结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + db = _get_db() + session = db.get_session() + try: + # 检查知识库名称是否已存在 + existing_kb = db.get_knowledge_base(kb_name) + if existing_kb: + result["message"] = f"知识库 '{kb_name}' 已存在" + return result + + kb_id = str(uuid.uuid4()) + if db.add_knowledge_base(kb_id, kb_name, chunk_size, + embedding_model, embedding_endpoint, embedding_api_key): + result["success"] = True + result["message"] = f"成功创建知识库: {kb_name}" + result["data"] = { + "kb_id": kb_id, + "kb_name": kb_name, + "chunk_size": chunk_size + } + else: + result["message"] = "创建知识库失败" + finally: + session.close() + except Exception as e: + logger.exception(f"[create_knowledge_base] 创建知识库失败: {e}") + result["message"] = "创建知识库失败" + + return result + + +def delete_knowledge_base(kb_name: str) -> Dict[str, Any]: + """ + 删除知识库 + + :param kb_name: 知识库名称 + :return: 删除结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + db = _get_db() + kb = db.get_knowledge_base(kb_name) + if not kb: + result["message"] = f"知识库 '{kb_name}' 不存在" + return result + + # 检查是否是当前使用的知识库 + global _current_kb_id + if _current_kb_id == kb.id: + result["message"] = "不能删除当前正在使用的知识库" + return result + + if db.delete_knowledge_base(kb.id): + result["success"] = True + result["message"] = f"成功删除知识库: {kb_name}" + result["data"] = {"kb_name": kb_name} + else: + result["message"] = "删除知识库失败" + except Exception as e: + logger.exception(f"[delete_knowledge_base] 删除知识库失败: {e}") + result["message"] = "删除知识库失败" + + return result + + +def list_knowledge_bases() -> Dict[str, Any]: + """ + 查看知识库列表 + + :return: 知识库列表 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + db = _get_db() + kbs = db.list_knowledge_bases() + + knowledge_bases = [] + global _current_kb_id + for kb in kbs: + knowledge_bases.append({ + "id": kb.id, + "name": kb.name, + "chunk_size": kb.chunk_size, + "embedding_model": kb.embedding_model, + "created_at": kb.created_at.isoformat() if kb.created_at else None, + "is_current": _current_kb_id == kb.id + }) + + result["success"] = True + result["message"] = f"找到 {len(knowledge_bases)} 个知识库" + result["data"] = { + "knowledge_bases": knowledge_bases, + "count": len(knowledge_bases), + "current_kb_id": _current_kb_id + } + except Exception as e: + logger.exception(f"[list_knowledge_bases] 获取知识库列表失败: {e}") + result["message"] = "获取知识库列表失败" + + return result + + +def select_knowledge_base(kb_name: str) -> Dict[str, Any]: + """ + 选择知识库 + + :param kb_name: 知识库名称 + :return: 选择结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + db = _get_db() + kb = db.get_knowledge_base(kb_name) + if not kb: + result["message"] = f"知识库 '{kb_name}' 不存在" + return result + + global _current_kb_id + _current_kb_id = kb.id + + session = db.get_session() + try: + manager = DocumentManager(session) + docs = manager.list_documents_by_kb(kb.id) + doc_count = len(docs) + finally: + session.close() + + result["success"] = True + result["message"] = f"成功选择知识库,共 {doc_count} 个文档" + result["data"] = { + "kb_id": kb.id, + "kb_name": kb.name, + "document_count": doc_count + } + except Exception as e: + logger.exception(f"[select_knowledge_base] 选择知识库失败: {e}") + result["message"] = "选择知识库失败" + + return result + + +async def import_document(file_paths: List[str], chunk_size: Optional[int] = None) -> Dict[str, Any]: + """ + 上传文档到当前知识库(异步,支持多文件并发导入) + + :param file_paths: 文件路径列表(绝对路径),支持1~n个文件 + :param chunk_size: chunk 大小(token 数,可选,默认使用知识库的chunk_size) + :return: 导入结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + kb_id = _ensure_active_kb(result) + if not kb_id: + return result + + if not file_paths: + result["message"] = "文件路径列表为空" + return result + + # 验证文件路径是否存在 + invalid_paths = [path for path in file_paths if not os.path.exists(path)] + if invalid_paths: + result["message"] = f"以下文件路径不存在: {', '.join(invalid_paths)}" + return result + + db = _get_db() + # 先获取知识库信息 + session = db.get_session() + try: + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if not kb: + result["message"] = "知识库不存在" + return result + + if chunk_size is None: + chunk_size = kb.chunk_size + finally: + session.close() + + # 并发处理多个文件,每个文件使用独立的 session + async def import_single_file(file_path: str): + """为单个文件创建独立的 session 并导入""" + file_session = db.get_session() + try: + return await _import_document(file_session, kb_id, file_path, chunk_size) + finally: + file_session.close() + + tasks = [ + import_single_file(file_path) + for file_path in file_paths + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 统计结果 + success_count = 0 + failed_count = 0 + success_files = [] + failed_files = [] + + for i, res in enumerate(results): + file_path = file_paths[i] + if isinstance(res, Exception): + failed_count += 1 + failed_files.append({ + "file_path": file_path, + "error": str(res) + }) + logger.exception(f"[import_document] 导入文件失败: {file_path}, 错误: {res}") + else: + success, message, data = res + if success: + success_count += 1 + success_files.append({ + "file_path": file_path, + "doc_name": data.get("doc_name") if data else os.path.basename(file_path), + "chunk_count": data.get("chunk_count", 0) if data else 0 + }) + else: + failed_count += 1 + failed_files.append({ + "file_path": file_path, + "error": message + }) + + result["success"] = success_count > 0 + result["message"] = f"成功导入 {success_count} 个文档,失败 {failed_count} 个" + result["data"] = { + "total": len(file_paths), + "success_count": success_count, + "failed_count": failed_count, + "success_files": success_files, + "failed_files": failed_files + } + except Exception as e: + logger.exception(f"[import_document] 导入文档失败: {e}") + result["message"] = f"导入文档失败: {str(e)}" + + return result + + +async def search(query: str, top_k: Optional[int] = None) -> Dict[str, Any]: + """ + 在当前知识库中查询(异步) + + :param query: 查询文本 + :param top_k: 返回数量(可选,默认从配置读取) + :return: 检索结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + if top_k is None: + top_k = get_default_top_k() + + kb_id = _ensure_active_kb(result) + if not kb_id: + return result + + weight_keyword = 0.3 + weight_vector = 0.7 + + try: + db = _get_db() + session = db.get_session() + try: + # 获取当前知识库的所有文档ID + manager = DocumentManager(session) + docs = manager.list_documents_by_kb(kb_id) + doc_ids = [doc.id for doc in docs] + + if not doc_ids: + result["message"] = "当前知识库中没有文档" + result["data"] = {"chunks": []} + return result + + conn = session.connection() + chunks = await weighted_keyword_and_vector_search( + conn, query, top_k, weight_keyword, weight_vector, doc_ids + ) + finally: + session.close() + + if not chunks: + result["message"] = "未找到相关结果" + result["data"] = {"chunks": []} + return result + + result["success"] = True + result["message"] = f"找到 {len(chunks)} 个相关结果" + result["data"] = { + "chunks": chunks, + "count": len(chunks) + } + except Exception as e: + logger.exception(f"[search] 搜索失败: {e}") + result["message"] = "搜索失败" + + return result + + +def list_documents() -> Dict[str, Any]: + """ + 查看当前知识库下的文档列表 + + :return: 文档列表 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + kb_id = _ensure_active_kb(result) + if not kb_id: + return result + + db = _get_db() + session = db.get_session() + try: + manager = DocumentManager(session) + docs = manager.list_documents_by_kb(kb_id) + finally: + session.close() + + documents = [] + for doc in docs: + documents.append({ + "id": doc.id, + "name": doc.name, + "file_path": doc.file_path, + "file_type": doc.file_type, + "chunk_size": doc.chunk_size, + "created_at": doc.created_at.isoformat() if doc.created_at else None, + "updated_at": doc.updated_at.isoformat() if doc.updated_at else None + }) + + result["success"] = True + result["message"] = f"找到 {len(documents)} 个文档" + result["data"] = { + "documents": documents, + "count": len(documents) + } + except Exception as e: + logger.exception(f"[list_documents] 获取文档列表失败: {e}") + result["message"] = "获取文档列表失败" + + return result + + +def delete_document(doc_name: str) -> Dict[str, Any]: + """ + 删除当前知识库下的文档 + + :param doc_name: 文档名称 + :return: 删除结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + kb_id = _ensure_active_kb(result) + if not kb_id: + return result + + db = _get_db() + session = db.get_session() + try: + manager = DocumentManager(session) + if manager.delete_document(kb_id, doc_name): + result["success"] = True + result["message"] = f"成功删除文档: {doc_name}" + result["data"] = {"doc_name": doc_name} + else: + result["message"] = f"文档 '{doc_name}' 不存在或删除失败" + finally: + session.close() + except Exception as e: + logger.exception(f"[delete_document] 删除文档失败: {e}") + result["message"] = "删除文档失败" + + return result + + +async def update_document(doc_name: str, chunk_size: int) -> Dict[str, Any]: + """ + 修改文档的chunk_size并重新解析(异步) + + :param doc_name: 文档名称 + :param chunk_size: 新的chunk大小 + :return: 修改结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + kb_id = _ensure_active_kb(result) + if not kb_id: + return result + + db = _get_db() + session = db.get_session() + try: + success, message, data = await _update_document(session, kb_id, doc_name, chunk_size) + result["success"] = success + result["message"] = message + result["data"] = data or {} + finally: + session.close() + except Exception as e: + logger.exception(f"[update_document] 修改文档失败: {e}") + result["message"] = "修改文档失败" + + return result + + +def export_database(export_path: str) -> Dict[str, Any]: + """ + 导出整个kb.db数据库文件 + + :param export_path: 导出路径(绝对路径) + :return: 导出结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + if not os.path.exists(_db_path): + result["message"] = "数据库文件不存在" + return result + + if not export_path: + result["message"] = "导出路径不能为空" + return result + + # 确保导出路径以 .db 结尾 + if not export_path.endswith(('.db', '.sqlite', '.sqlite3')): + export_path += '.db' + + # 确保目标目录存在 + export_dir = os.path.dirname(export_path) + if export_dir and not os.path.exists(export_dir): + os.makedirs(export_dir, exist_ok=True) + + shutil.copy2(_db_path, export_path) + + result["success"] = True + result["message"] = f"成功导出数据库到: {export_path}" + result["data"] = { + "source_path": _db_path, + "export_path": export_path + } + except Exception as e: + logger.exception(f"[export_database] 导出数据库失败: {e}") + result["message"] = f"导出数据库失败: {str(e)}" + + return result + + +def import_database(source_db_path: str) -> Dict[str, Any]: + """ + 导入一个.db数据库文件,将其中的内容合并到kb.db中 + + :param source_db_path: 源数据库文件路径(绝对路径) + :return: 导入结果 + """ + result = { + "success": False, + "message": "", + "data": {} + } + + try: + if not source_db_path: + result["message"] = "源数据库路径不能为空" + return result + + if not os.path.exists(source_db_path): + result["message"] = f"源数据库文件不存在: {source_db_path}" + return result + + db = _get_db() + imported_kb_count, imported_doc_count = db.import_database(source_db_path) + + result["success"] = True + result["message"] = f"成功导入,共 {imported_kb_count} 个知识库,{imported_doc_count} 个文档" + result["data"] = { + "source_path": source_db_path, + "imported_kb_count": imported_kb_count, + "imported_doc_count": imported_doc_count + } + except Exception as e: + logger.exception(f"[import_database] 导入数据库失败: {e}") + result["message"] = f"导入数据库失败: {str(e)}" + + return result