""" 知识库 API 路由模块 处理知识库的 CRUD 操作。 """ import os import shutil from pathlib import Path import asyncpg from fastapi import APIRouter, Depends, HTTPException, status, Query from core.dependencies import get_db, get_current_user from core.exceptions import NotFoundError, BadRequestError from models.user import User from models.knowledge_base import ( KnowledgeBaseCreate, KnowledgeBaseUpdate, KnowledgeBaseResponse, KnowledgeBaseListResponse ) from services.knowledge_base_service import KnowledgeBaseService from services.knowledge_base_file_service import KnowledgeBaseFileService from services.vector_service import get_vector_service from services.oss_service import get_oss_service from utils.helpers import BaseResponse from logger.logging import get_logger logger = get_logger(__name__) # 创建知识库路由 kb_router = APIRouter(prefix="/api/knowledge-base", tags=["知识库"]) # 文件上传目录 UPLOAD_DIR = "./uploads" @kb_router.post("", response_model=BaseResponse, summary="创建知识库") async def create_knowledge_base( kb_data: KnowledgeBaseCreate, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """创建知识库""" try: kb = await KnowledgeBaseService.create_knowledge_base(conn, current_user, kb_data) payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user) return BaseResponse( code=200, msg="创建知识库成功", data=KnowledgeBaseResponse(**payload).model_dump(), ) except ValueError as e: raise BadRequestError(str(e)) @kb_router.get("", response_model=BaseResponse, summary="获取知识库列表") async def get_knowledge_bases( page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """获取当前用户的知识库列表""" knowledge_bases, total = await KnowledgeBaseService.list_visible_knowledge_bases( conn, current_user, page, page_size ) items = [KnowledgeBaseResponse(**dict(r)) for r in knowledge_bases] return BaseResponse( code=200, msg="获取知识库列表成功", data=KnowledgeBaseListResponse(total=total, items=items).model_dump(), ) @kb_router.get("/{kb_id}", response_model=BaseResponse, summary="获取知识库详情") async def get_knowledge_base( kb_id: int, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """获取知识库详情""" kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user) if not kb: raise NotFoundError("知识库") payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user) return BaseResponse( code=200, msg="获取知识库详情成功", data=KnowledgeBaseResponse(**payload).model_dump(), ) @kb_router.put("/{kb_id}", response_model=BaseResponse, summary="更新知识库") async def update_knowledge_base( kb_id: int, kb_data: KnowledgeBaseUpdate, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """更新知识库""" try: kb = await KnowledgeBaseService.update_knowledge_base(conn, kb_id, current_user, kb_data) if not kb: raise NotFoundError("知识库") payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user) return BaseResponse( code=200, msg="更新知识库成功", data=KnowledgeBaseResponse(**payload).model_dump(), ) except ValueError as e: raise BadRequestError(str(e)) @kb_router.delete("/{kb_id}", response_model=BaseResponse, summary="删除知识库") async def delete_knowledge_base( kb_id: int, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """ 删除知识库(软删除) 同时删除知识库的所有文件、向量和物理文件 """ # 检查知识库是否存在 kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user) if not kb: raise NotFoundError("知识库") # 1. 获取知识库的所有文件 all_files = await KnowledgeBaseFileService.get_all_files_by_kb(conn, kb_id) logger.info(f"知识库 {kb_id} 共有 {len(all_files)} 个文件需要删除") # 2. 删除所有物理文件 deleted_files_count = 0 oss_service = get_oss_service() for file in all_files: try: if oss_service.enabled and file.file_path.startswith(('http://', 'https://')): oss_object_name = oss_service.extract_object_name_from_url(file.file_path, kb_id) if oss_object_name and oss_service.delete_file(oss_object_name): deleted_files_count += 1 logger.debug(f"删除 OSS 文件: {oss_object_name}") elif os.path.exists(file.file_path): os.remove(file.file_path) deleted_files_count += 1 logger.debug(f"删除本地文件: {file.file_path}") except Exception as e: logger.warning(f"删除物理文件失败 {file.file_path}: {e}") logger.info(f"已删除 {deleted_files_count} 个物理文件") # 3. 获取所有向量 ID vector_ids = await KnowledgeBaseFileService.get_kb_all_vector_ids(conn, kb_id) # 4. 删除文档块 deleted_chunks = await KnowledgeBaseFileService.delete_kb_all_chunks(conn, kb_id) logger.info(f"已删除知识库 {kb_id} 的 {deleted_chunks} 个文档块") # 5. 删除向量 if vector_ids: try: vector_service = get_vector_service() vector_service.delete_vectors_by_ids(kb_id, vector_ids) logger.info(f"已删除知识库 {kb_id} 的 {len(vector_ids)} 个向量") except Exception as e: logger.warning(f"删除向量库中的向量失败: {e}") # 6. 删除向量库集合 try: vector_service = get_vector_service() vector_service.delete_collection(kb_id) logger.info(f"已删除知识库 {kb_id} 的向量库集合") except Exception as e: logger.warning(f"删除向量库集合失败: {e}") # 7. 删除知识库目录 try: kb_dir = Path(UPLOAD_DIR) / f"kb_{kb_id}" if kb_dir.exists(): shutil.rmtree(kb_dir) logger.info(f"已删除知识库目录: {kb_dir}") except Exception as e: logger.warning(f"删除知识库目录失败: {e}") # 8. 软删除知识库 success = await KnowledgeBaseService.delete_knowledge_base(conn, kb_id, current_user) if not success: raise NotFoundError("知识库") return BaseResponse( code=200, msg=f"删除知识库成功,已删除 {len(all_files)} 个文件、{deleted_chunks} 个文档块和 {len(vector_ids)} 个向量", data={ "id": kb_id, "files_deleted": len(all_files), "chunks_deleted": deleted_chunks, "vectors_deleted": len(vector_ids) } )