205 lines
7.2 KiB
Python
205 lines
7.2 KiB
Python
"""
|
|
知识库 API 路由模块
|
|
|
|
处理知识库的 CRUD 操作。
|
|
"""
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import asyncpg
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
|
|
from core.dependencies import get_db, get_current_user
|
|
from core.exceptions import NotFoundError, BadRequestError
|
|
from models.user import User
|
|
from models.knowledge_base import (
|
|
KnowledgeBaseCreate,
|
|
KnowledgeBaseUpdate,
|
|
KnowledgeBaseResponse,
|
|
KnowledgeBaseListResponse
|
|
)
|
|
from services.knowledge_base_service import KnowledgeBaseService
|
|
from services.knowledge_base_file_service import KnowledgeBaseFileService
|
|
from services.vector_service import get_vector_service
|
|
from services.oss_service import get_oss_service
|
|
from utils.helpers import BaseResponse
|
|
from logger.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# 创建知识库路由
|
|
kb_router = APIRouter(prefix="/api/knowledge-base", tags=["知识库"])
|
|
|
|
# 文件上传目录
|
|
UPLOAD_DIR = "./uploads"
|
|
|
|
|
|
@kb_router.post("", response_model=BaseResponse, summary="创建知识库")
|
|
async def create_knowledge_base(
|
|
kb_data: KnowledgeBaseCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
conn: asyncpg.Connection = Depends(get_db)
|
|
):
|
|
"""创建知识库"""
|
|
try:
|
|
kb = await KnowledgeBaseService.create_knowledge_base(conn, current_user, kb_data)
|
|
payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user)
|
|
return BaseResponse(
|
|
code=200,
|
|
msg="创建知识库成功",
|
|
data=KnowledgeBaseResponse(**payload).model_dump(),
|
|
)
|
|
except ValueError as e:
|
|
raise BadRequestError(str(e))
|
|
|
|
|
|
@kb_router.get("", response_model=BaseResponse, summary="获取知识库列表")
|
|
async def get_knowledge_bases(
|
|
page: int = Query(1, ge=1, description="页码"),
|
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
|
current_user: User = Depends(get_current_user),
|
|
conn: asyncpg.Connection = Depends(get_db)
|
|
):
|
|
"""获取当前用户的知识库列表"""
|
|
knowledge_bases, total = await KnowledgeBaseService.list_visible_knowledge_bases(
|
|
conn, current_user, page, page_size
|
|
)
|
|
items = [KnowledgeBaseResponse(**dict(r)) for r in knowledge_bases]
|
|
|
|
return BaseResponse(
|
|
code=200,
|
|
msg="获取知识库列表成功",
|
|
data=KnowledgeBaseListResponse(total=total, items=items).model_dump(),
|
|
)
|
|
|
|
|
|
@kb_router.get("/{kb_id}", response_model=BaseResponse, summary="获取知识库详情")
|
|
async def get_knowledge_base(
|
|
kb_id: int,
|
|
current_user: User = Depends(get_current_user),
|
|
conn: asyncpg.Connection = Depends(get_db)
|
|
):
|
|
"""获取知识库详情"""
|
|
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
|
|
if not kb:
|
|
raise NotFoundError("知识库")
|
|
|
|
payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user)
|
|
return BaseResponse(
|
|
code=200,
|
|
msg="获取知识库详情成功",
|
|
data=KnowledgeBaseResponse(**payload).model_dump(),
|
|
)
|
|
|
|
|
|
@kb_router.put("/{kb_id}", response_model=BaseResponse, summary="更新知识库")
|
|
async def update_knowledge_base(
|
|
kb_id: int,
|
|
kb_data: KnowledgeBaseUpdate,
|
|
current_user: User = Depends(get_current_user),
|
|
conn: asyncpg.Connection = Depends(get_db)
|
|
):
|
|
"""更新知识库"""
|
|
try:
|
|
kb = await KnowledgeBaseService.update_knowledge_base(conn, kb_id, current_user, kb_data)
|
|
if not kb:
|
|
raise NotFoundError("知识库")
|
|
|
|
payload = await KnowledgeBaseService.enrich_kb_for_response(conn, kb, current_user)
|
|
return BaseResponse(
|
|
code=200,
|
|
msg="更新知识库成功",
|
|
data=KnowledgeBaseResponse(**payload).model_dump(),
|
|
)
|
|
except ValueError as e:
|
|
raise BadRequestError(str(e))
|
|
|
|
|
|
@kb_router.delete("/{kb_id}", response_model=BaseResponse, summary="删除知识库")
|
|
async def delete_knowledge_base(
|
|
kb_id: int,
|
|
current_user: User = Depends(get_current_user),
|
|
conn: asyncpg.Connection = Depends(get_db)
|
|
):
|
|
"""
|
|
删除知识库(软删除)
|
|
同时删除知识库的所有文件、向量和物理文件
|
|
"""
|
|
# 检查知识库是否存在
|
|
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
|
|
if not kb:
|
|
raise NotFoundError("知识库")
|
|
|
|
# 1. 获取知识库的所有文件
|
|
all_files = await KnowledgeBaseFileService.get_all_files_by_kb(conn, kb_id)
|
|
logger.info(f"知识库 {kb_id} 共有 {len(all_files)} 个文件需要删除")
|
|
|
|
# 2. 删除所有物理文件
|
|
deleted_files_count = 0
|
|
oss_service = get_oss_service()
|
|
for file in all_files:
|
|
try:
|
|
if oss_service.enabled and file.file_path.startswith(('http://', 'https://')):
|
|
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, kb_id)
|
|
if oss_object_name and oss_service.delete_file(oss_object_name):
|
|
deleted_files_count += 1
|
|
logger.debug(f"删除 OSS 文件: {oss_object_name}")
|
|
elif os.path.exists(file.file_path):
|
|
os.remove(file.file_path)
|
|
deleted_files_count += 1
|
|
logger.debug(f"删除本地文件: {file.file_path}")
|
|
except Exception as e:
|
|
logger.warning(f"删除物理文件失败 {file.file_path}: {e}")
|
|
|
|
logger.info(f"已删除 {deleted_files_count} 个物理文件")
|
|
|
|
# 3. 获取所有向量 ID
|
|
vector_ids = await KnowledgeBaseFileService.get_kb_all_vector_ids(conn, kb_id)
|
|
|
|
# 4. 删除文档块
|
|
deleted_chunks = await KnowledgeBaseFileService.delete_kb_all_chunks(conn, kb_id)
|
|
logger.info(f"已删除知识库 {kb_id} 的 {deleted_chunks} 个文档块")
|
|
|
|
# 5. 删除向量
|
|
if vector_ids:
|
|
try:
|
|
vector_service = get_vector_service()
|
|
vector_service.delete_vectors_by_ids(kb_id, vector_ids)
|
|
logger.info(f"已删除知识库 {kb_id} 的 {len(vector_ids)} 个向量")
|
|
except Exception as e:
|
|
logger.warning(f"删除向量库中的向量失败: {e}")
|
|
|
|
# 6. 删除向量库集合
|
|
try:
|
|
vector_service = get_vector_service()
|
|
vector_service.delete_collection(kb_id)
|
|
logger.info(f"已删除知识库 {kb_id} 的向量库集合")
|
|
except Exception as e:
|
|
logger.warning(f"删除向量库集合失败: {e}")
|
|
|
|
# 7. 删除知识库目录
|
|
try:
|
|
kb_dir = Path(UPLOAD_DIR) / f"kb_{kb_id}"
|
|
if kb_dir.exists():
|
|
shutil.rmtree(kb_dir)
|
|
logger.info(f"已删除知识库目录: {kb_dir}")
|
|
except Exception as e:
|
|
logger.warning(f"删除知识库目录失败: {e}")
|
|
|
|
# 8. 软删除知识库
|
|
success = await KnowledgeBaseService.delete_knowledge_base(conn, kb_id, current_user)
|
|
if not success:
|
|
raise NotFoundError("知识库")
|
|
|
|
return BaseResponse(
|
|
code=200,
|
|
msg=f"删除知识库成功,已删除 {len(all_files)} 个文件、{deleted_chunks} 个文档块和 {len(vector_ids)} 个向量",
|
|
data={
|
|
"id": kb_id,
|
|
"files_deleted": len(all_files),
|
|
"chunks_deleted": deleted_chunks,
|
|
"vectors_deleted": len(vector_ids)
|
|
}
|
|
)
|