""" 知识库文件服务 """ import os import json from typing import Optional, List, Tuple from pathlib import Path import asyncpg from datetime import datetime from models.knowledge_base_file import KnowledgeBaseFile, KnowledgeBaseChunk from logger.logging import get_logger logger = get_logger(__name__) class KnowledgeBaseFileService: """知识库文件服务类""" @staticmethod async def create_file_record( conn: asyncpg.Connection, knowledge_base_id: int, user_id: int, file_name: str, file_path: str, file_size: int, file_type: str = "pdf" ) -> KnowledgeBaseFile: """ 创建文件记录 Args: conn: 数据库连接 knowledge_base_id: 知识库 ID user_id: 用户 ID file_name: 文件名 file_path: 文件路径 file_size: 文件大小 file_type: 文件类型 Returns: KnowledgeBaseFile: 创建的文件记录 """ try: # 检查文件名是否已存在 existing = await conn.fetchrow( """ SELECT id FROM knowledge_base_file WHERE knowledge_base_id = $1 AND file_name = $2 AND is_deleted = FALSE """, knowledge_base_id, file_name ) if existing: raise ValueError(f"文件 '{file_name}' 已存在于该知识库中") # 插入文件记录 row = await conn.fetchrow( """ INSERT INTO knowledge_base_file (knowledge_base_id, user_id, file_name, file_path, file_size, file_type, status) VALUES ($1, $2, $3, $4, $5, $6, 'processing') RETURNING id, knowledge_base_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at """, knowledge_base_id, user_id, file_name, file_path, file_size, file_type ) logger.info(f"创建文件记录: {file_name}, 知识库 ID: {knowledge_base_id}") return KnowledgeBaseFile(**dict(row)) except ValueError: raise except Exception as e: logger.error(f"创建文件记录失败: {e}") raise Exception(f"创建文件记录失败: {str(e)}") @staticmethod async def update_file_status( conn: asyncpg.Connection, file_id: int, status: str, chunk_count: int = 0 ) -> bool: """ 更新文件状态 Args: conn: 数据库连接 file_id: 文件 ID status: 状态(processing/completed/failed) chunk_count: 分块数量 Returns: bool: 是否更新成功 """ try: result = await conn.execute( """ UPDATE knowledge_base_file SET status = $1, chunk_count = $2 WHERE id = $3 """, status, chunk_count, file_id ) return result == "UPDATE 1" except Exception as e: logger.error(f"更新文件状态失败: {e}") return False @staticmethod async def save_chunks( conn: asyncpg.Connection, file_id: int, knowledge_base_id: int, chunks: List[Tuple[int, str, dict, str]], summary: Optional[str] = None ) -> int: """ 批量保存文档块 Args: conn: 数据库连接 file_id: 文件 ID knowledge_base_id: 知识库 ID chunks: 文档块列表 [(chunk_index, content, metadata, vector_id), ...] summary: 文件摘要(可选) Returns: int: 保存的块数量 """ try: # 批量插入(每个chunk都保存summary,便于独立检索) records = [ (file_id, knowledge_base_id, chunk_index, content, json.dumps(metadata), vector_id, summary) for chunk_index, content, metadata, vector_id in chunks ] await conn.executemany( """ INSERT INTO knowledge_base_chunk (file_id, knowledge_base_id, chunk_index, content, metadata, vector_id, summary) VALUES ($1, $2, $3, $4, $5, $6, $7) """, records ) logger.info(f"保存 {len(chunks)} 个文档块,文件 ID: {file_id}, 摘要: {'已保存' if summary else '无'}") return len(chunks) except Exception as e: logger.error(f"保存文档块失败: {e}") raise Exception(f"保存文档块失败: {str(e)}") @staticmethod async def get_file_by_id( conn: asyncpg.Connection, file_id: int, user_id: int ) -> Optional[KnowledgeBaseFile]: """ 根据 ID 获取文件 Args: conn: 数据库连接 file_id: 文件 ID user_id: 用户 ID Returns: Optional[KnowledgeBaseFile]: 文件对象 """ try: row = await conn.fetchrow( """ SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at FROM knowledge_base_file WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE """, file_id, user_id ) if row: return KnowledgeBaseFile(**dict(row)) return None except Exception as e: logger.error(f"获取文件失败: {e}") return None @staticmethod async def get_files_by_kb( conn: asyncpg.Connection, knowledge_base_id: int, user_id: int, page: int = 1, page_size: int = 20 ) -> Tuple[List[KnowledgeBaseFile], int]: """ 获取知识库的文件列表 Args: conn: 数据库连接 knowledge_base_id: 知识库 ID user_id: 用户 ID page: 页码 page_size: 每页数量 Returns: Tuple[List[KnowledgeBaseFile], int]: (文件列表, 总数量) """ try: offset = (page - 1) * page_size # 获取总数 total = await conn.fetchval( """ SELECT COUNT(*) FROM knowledge_base_file WHERE knowledge_base_id = $1 AND user_id = $2 AND is_deleted = FALSE """, knowledge_base_id, user_id ) # 获取列表 rows = await conn.fetch( """ SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at FROM knowledge_base_file WHERE knowledge_base_id = $1 AND user_id = $2 AND is_deleted = FALSE ORDER BY created_at DESC LIMIT $3 OFFSET $4 """, knowledge_base_id, user_id, page_size, offset ) files = [KnowledgeBaseFile(**dict(row)) for row in rows] return files, total except Exception as e: logger.error(f"获取文件列表失败: {e}") raise Exception(f"获取文件列表失败: {str(e)}") @staticmethod async def get_file_vector_ids( conn: asyncpg.Connection, file_id: int ) -> List[str]: """ 获取文件的所有向量 ID Args: conn: 数据库连接 file_id: 文件 ID Returns: List[str]: 向量 ID 列表 """ try: rows = await conn.fetch( """ SELECT vector_id FROM knowledge_base_chunk WHERE file_id = $1 AND vector_id IS NOT NULL """, file_id ) return [row['vector_id'] for row in rows if row['vector_id']] except Exception as e: logger.error(f"获取文件向量 ID 失败: {e}") return [] @staticmethod async def get_all_files_by_kb( conn: asyncpg.Connection, knowledge_base_id: int ) -> List[KnowledgeBaseFile]: """ 获取知识库的所有文件(包括已删除的) Args: conn: 数据库连接 knowledge_base_id: 知识库 ID Returns: List[KnowledgeBaseFile]: 文件列表 """ try: rows = await conn.fetch( """ SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at FROM knowledge_base_file WHERE knowledge_base_id = $1 """, knowledge_base_id ) return [KnowledgeBaseFile(**dict(row)) for row in rows] except Exception as e: logger.error(f"获取知识库所有文件失败: {e}") return [] @staticmethod async def get_kb_all_vector_ids( conn: asyncpg.Connection, knowledge_base_id: int ) -> List[str]: """ 获取知识库的所有向量 ID Args: conn: 数据库连接 knowledge_base_id: 知识库 ID Returns: List[str]: 向量 ID 列表 """ try: rows = await conn.fetch( """ SELECT vector_id FROM knowledge_base_chunk WHERE knowledge_base_id = $1 AND vector_id IS NOT NULL """, knowledge_base_id ) return [row['vector_id'] for row in rows if row['vector_id']] except Exception as e: logger.error(f"获取知识库向量 ID 失败: {e}") return [] @staticmethod async def delete_file_chunks( conn: asyncpg.Connection, file_id: int ) -> int: """ 删除文件的所有文档块 Args: conn: 数据库连接 file_id: 文件 ID Returns: int: 删除的块数量 """ try: result = await conn.execute( """ DELETE FROM knowledge_base_chunk WHERE file_id = $1 """, file_id ) # 解析删除的行数 deleted_count = int(result.split()[-1]) if result.startswith("DELETE") else 0 logger.info(f"删除文件 {file_id} 的 {deleted_count} 个文档块") return deleted_count except Exception as e: logger.error(f"删除文档块失败: {e}") return 0 @staticmethod async def delete_kb_all_chunks( conn: asyncpg.Connection, knowledge_base_id: int ) -> int: """ 删除知识库的所有文档块 Args: conn: 数据库连接 knowledge_base_id: 知识库 ID Returns: int: 删除的块数量 """ try: result = await conn.execute( """ DELETE FROM knowledge_base_chunk WHERE knowledge_base_id = $1 """, knowledge_base_id ) # 解析删除的行数 deleted_count = int(result.split()[-1]) if result.startswith("DELETE") else 0 logger.info(f"删除知识库 {knowledge_base_id} 的 {deleted_count} 个文档块") return deleted_count except Exception as e: logger.error(f"删除知识库文档块失败: {e}") return 0 @staticmethod async def get_recent_files_with_summary( conn: asyncpg.Connection, knowledge_base_id: int, limit: int = 5 ) -> List[dict]: """ 获取知识库中最近上传的文件及其摘要(无时间限制) Args: conn: 数据库连接 knowledge_base_id: 知识库 ID limit: 返回文件数量 Returns: List[dict]: 文件列表 [{"file_name": str, "summary": str}] """ try: rows = await conn.fetch( """ SELECT DISTINCT ON (kbf.id) kbf.id, kbf.file_name, kbc.summary FROM knowledge_base_file kbf LEFT JOIN knowledge_base_chunk kbc ON kbf.id = kbc.file_id WHERE kbf.knowledge_base_id = $1 AND kbf.is_deleted = FALSE AND kbf.status = 'completed' ORDER BY kbf.id, kbf.created_at DESC LIMIT $2 """, knowledge_base_id, limit ) result = [ { "file_id": row['id'], "file_name": row['file_name'], "summary": row['summary'] or "" } for row in rows ] logger.info(f"获取知识库 {knowledge_base_id} 的 {len(result)} 个文件及摘要(无时间限制)") return result except Exception as e: logger.error(f"获取知识库文件摘要失败: {e}") return [] @staticmethod async def get_file_chunks_from_db( conn: asyncpg.Connection, file_id: int ) -> List[dict]: """ 从 PostgreSQL 获取文件的所有 chunks(包括 summary) 用于注入完整内容到 AI 上下文 Args: conn: 数据库连接 file_id: 文件 ID Returns: List[dict]: [{"content": str, "summary": str, "chunk_index": int}] """ try: rows = await conn.fetch( """ SELECT chunk_index, content, summary FROM knowledge_base_chunk WHERE file_id = $1 ORDER BY chunk_index """, file_id ) chunks = [ { "chunk_index": row['chunk_index'], "content": row['content'], "summary": row['summary'] or '' } for row in rows ] logger.info(f"从数据库获取知识库文件chunks: file_id={file_id}, chunks数量={len(chunks)}") return chunks except Exception as e: logger.error(f"从数据库获取知识库文件chunks失败: {e}") return [] @staticmethod async def delete_file( conn: asyncpg.Connection, file_id: int, user_id: int ) -> Tuple[bool, List[str]]: """ 删除文件(软删除) 同时删除文件的所有文档块 Args: conn: 数据库连接 file_id: 文件 ID user_id: 用户 ID Returns: Tuple[bool, List[str]]: (是否删除成功, 向量 ID 列表) """ try: # 先检查文件是否存在且属于该用户 file_record = await conn.fetchrow( """ SELECT id, knowledge_base_id, file_name FROM knowledge_base_file WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE """, file_id, user_id ) if not file_record: return False, [] # 获取文件的向量 ID 列表(在删除 chunk 之前获取) vector_ids = await KnowledgeBaseFileService.get_file_vector_ids(conn, file_id) # 删除文件的所有文档块(物理删除) deleted_chunks = await KnowledgeBaseFileService.delete_file_chunks(conn, file_id) # 执行软删除文件记录 result = await conn.execute( """ UPDATE knowledge_base_file SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE """, file_id, user_id ) if result == "UPDATE 1": logger.info( f"删除文件 ID: {file_id}, 文件名: {file_record['file_name']}, " f"文档块数: {deleted_chunks}, 向量数: {len(vector_ids)}" ) return True, vector_ids return False, [] except Exception as e: logger.error(f"删除文件失败: {e}") return False, []