huoyan-enterprise/backend/api/kb_router.py

205 lines
7.2 KiB
Python

"""
知识库 API 路由模块
处理知识库的 CRUD 操作。
"""
import os
import shutil
from pathlib import Path
import asyncpg
from fastapi import APIRouter, Depends, HTTPException, status, Query
from core.dependencies import get_db, get_current_user
from core.exceptions import NotFoundError, BadRequestError
from models.user import User
from models.knowledge_base import (
KnowledgeBaseCreate,
KnowledgeBaseUpdate,
KnowledgeBaseResponse,
KnowledgeBaseListResponse
)
from services.knowledge_base_service import KnowledgeBaseService
from services.knowledge_base_file_service import KnowledgeBaseFileService
from services.vector_service import get_vector_service
from services.oss_service import get_oss_service
from utils.helpers import BaseResponse
from logger.logging import get_logger
logger = get_logger(__name__)
# 创建知识库路由
kb_router = APIRouter(prefix="/api/knowledge-base", tags=["知识库"])
# 文件上传目录
UPLOAD_DIR = "./uploads"
@kb_router.post("", response_model=BaseResponse, summary="创建知识库")
async def create_knowledge_base(
kb_data: KnowledgeBaseCreate,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""创建知识库"""
try:
kb = await KnowledgeBaseService.create_knowledge_base(conn, current_user, kb_data)
payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user)
return BaseResponse(
code=200,
msg="创建知识库成功",
data=KnowledgeBaseResponse(**payload).model_dump(),
)
except ValueError as e:
raise BadRequestError(str(e))
@kb_router.get("", response_model=BaseResponse, summary="获取知识库列表")
async def get_knowledge_bases(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""获取当前用户的知识库列表"""
knowledge_bases, total = await KnowledgeBaseService.list_visible_knowledge_bases(
conn, current_user, page, page_size
)
items = [KnowledgeBaseResponse(**dict(r)) for r in knowledge_bases]
return BaseResponse(
code=200,
msg="获取知识库列表成功",
data=KnowledgeBaseListResponse(total=total, items=items).model_dump(),
)
@kb_router.get("/{kb_id}", response_model=BaseResponse, summary="获取知识库详情")
async def get_knowledge_base(
kb_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""获取知识库详情"""
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
if not kb:
raise NotFoundError("知识库")
payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user)
return BaseResponse(
code=200,
msg="获取知识库详情成功",
data=KnowledgeBaseResponse(**payload).model_dump(),
)
@kb_router.put("/{kb_id}", response_model=BaseResponse, summary="更新知识库")
async def update_knowledge_base(
kb_id: int,
kb_data: KnowledgeBaseUpdate,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""更新知识库"""
try:
kb = await KnowledgeBaseService.update_knowledge_base(conn, kb_id, current_user, kb_data)
if not kb:
raise NotFoundError("知识库")
payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user)
return BaseResponse(
code=200,
msg="更新知识库成功",
data=KnowledgeBaseResponse(**payload).model_dump(),
)
except ValueError as e:
raise BadRequestError(str(e))
@kb_router.delete("/{kb_id}", response_model=BaseResponse, summary="删除知识库")
async def delete_knowledge_base(
kb_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""
删除知识库(软删除)
同时删除知识库的所有文件、向量和物理文件
"""
# 检查知识库是否存在
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
if not kb:
raise NotFoundError("知识库")
# 1. 获取知识库的所有文件
all_files = await KnowledgeBaseFileService.get_all_files_by_kb(conn, kb_id)
logger.info(f"知识库 {kb_id} 共有 {len(all_files)} 个文件需要删除")
# 2. 删除所有物理文件
deleted_files_count = 0
oss_service = get_oss_service()
for file in all_files:
try:
if oss_service.enabled and file.file_path.startswith(('http://', 'https://')):
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, kb_id)
if oss_object_name and oss_service.delete_file(oss_object_name):
deleted_files_count += 1
logger.debug(f"删除 OSS 文件: {oss_object_name}")
elif os.path.exists(file.file_path):
os.remove(file.file_path)
deleted_files_count += 1
logger.debug(f"删除本地文件: {file.file_path}")
except Exception as e:
logger.warning(f"删除物理文件失败 {file.file_path}: {e}")
logger.info(f"已删除 {deleted_files_count} 个物理文件")
# 3. 获取所有向量 ID
vector_ids = await KnowledgeBaseFileService.get_kb_all_vector_ids(conn, kb_id)
# 4. 删除文档块
deleted_chunks = await KnowledgeBaseFileService.delete_kb_all_chunks(conn, kb_id)
logger.info(f"已删除知识库 {kb_id}{deleted_chunks} 个文档块")
# 5. 删除向量
if vector_ids:
try:
vector_service = get_vector_service()
vector_service.delete_vectors_by_ids(kb_id, vector_ids)
logger.info(f"已删除知识库 {kb_id}{len(vector_ids)} 个向量")
except Exception as e:
logger.warning(f"删除向量库中的向量失败: {e}")
# 6. 删除向量库集合
try:
vector_service = get_vector_service()
vector_service.delete_collection(kb_id)
logger.info(f"已删除知识库 {kb_id} 的向量库集合")
except Exception as e:
logger.warning(f"删除向量库集合失败: {e}")
# 7. 删除知识库目录
try:
kb_dir = Path(UPLOAD_DIR) / f"kb_{kb_id}"
if kb_dir.exists():
shutil.rmtree(kb_dir)
logger.info(f"已删除知识库目录: {kb_dir}")
except Exception as e:
logger.warning(f"删除知识库目录失败: {e}")
# 8. 软删除知识库
success = await KnowledgeBaseService.delete_knowledge_base(conn, kb_id, current_user)
if not success:
raise NotFoundError("知识库")
return BaseResponse(
code=200,
msg=f"删除知识库成功,已删除 {len(all_files)} 个文件、{deleted_chunks} 个文档块和 {len(vector_ids)} 个向量",
data={
"id": kb_id,
"files_deleted": len(all_files),
"chunks_deleted": deleted_chunks,
"vectors_deleted": len(vector_ids)
}
)