318 lines
11 KiB
Python
318 lines
11 KiB
Python
"""
|
||
知识加工 API 路由模块
|
||
|
||
处理知识库文件的加工任务,包括合并、对比、总结等功能。
|
||
"""
|
||
from typing import Optional
|
||
import asyncpg
|
||
from fastapi import APIRouter, Depends, BackgroundTasks, Query
|
||
|
||
from core.dependencies import get_db, get_current_user
|
||
from core.database import get_db_pool
|
||
from core.exceptions import NotFoundError, BadRequestError
|
||
from models.user import User
|
||
from models.knowledge_processing import (
|
||
TaskCreateRequest,
|
||
TaskResponse,
|
||
TaskListResponse,
|
||
TaskStatusResponse,
|
||
TaskStatus
|
||
)
|
||
from services.knowledge_base_service import KnowledgeBaseService
|
||
from services.knowledge_processing_service import (
|
||
KnowledgeProcessingService,
|
||
KnowledgeProcessingExecutor
|
||
)
|
||
from utils.helpers import BaseResponse
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 创建知识加工路由
|
||
kb_processing_router = APIRouter(prefix="/api/knowledge-base", tags=["知识加工"])
|
||
|
||
|
||
async def process_task_background(task_id: int):
|
||
"""
|
||
后台任务:执行知识加工
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
try:
|
||
logger.info(f"开始后台处理知识加工任务 ID: {task_id}")
|
||
|
||
# 获取任务信息
|
||
task = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids,
|
||
task_type, status, result, result_file_url, error_message,
|
||
created_at, updated_at, started_at, completed_at
|
||
FROM knowledge_processing_task
|
||
WHERE id = $1
|
||
""",
|
||
task_id
|
||
)
|
||
|
||
if not task:
|
||
logger.error(f"任务 {task_id} 不存在")
|
||
return
|
||
|
||
from models.knowledge_processing import KnowledgeProcessingTask
|
||
task_obj = KnowledgeProcessingTask(**dict(task))
|
||
|
||
# 更新状态为处理中
|
||
await KnowledgeProcessingService.update_task_status(
|
||
conn, task_id, TaskStatus.PROCESSING
|
||
)
|
||
|
||
# 执行任务
|
||
success, result, error_message, result_file_url = await KnowledgeProcessingExecutor.process_task(
|
||
conn, task_obj
|
||
)
|
||
|
||
# 更新任务状态
|
||
if success:
|
||
await KnowledgeProcessingService.update_task_status(
|
||
conn, task_id, TaskStatus.COMPLETED,
|
||
result=result, result_file_url=result_file_url
|
||
)
|
||
logger.info(f"任务 {task_id} 处理成功,文件链接: {result_file_url}")
|
||
else:
|
||
await KnowledgeProcessingService.update_task_status(
|
||
conn, task_id, TaskStatus.FAILED, error_message=error_message
|
||
)
|
||
logger.error(f"任务 {task_id} 处理失败: {error_message}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"后台处理任务异常 ID: {task_id}, 错误: {e}")
|
||
import traceback
|
||
logger.error(f"错误堆栈: {traceback.format_exc()}")
|
||
|
||
# 更新任务状态为失败
|
||
try:
|
||
await KnowledgeProcessingService.update_task_status(
|
||
conn, task_id, TaskStatus.FAILED, error_message=str(e)
|
||
)
|
||
except Exception as update_error:
|
||
logger.error(f"更新任务状态失败: {update_error}")
|
||
|
||
|
||
@kb_processing_router.post("/{kb_id}/processing/tasks", response_model=BaseResponse, summary="创建知识加工任务")
|
||
async def create_processing_task(
|
||
kb_id: int,
|
||
task_data: TaskCreateRequest,
|
||
background_tasks: BackgroundTasks,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""
|
||
创建知识加工任务
|
||
|
||
用户可以选择知识库中的一个或多个文件,输入加工指令,系统将异步处理任务。
|
||
|
||
支持的任务类型:
|
||
- merge: 合并文件
|
||
- compare: 对比文件
|
||
- summary: 总结文件
|
||
- custom: 自定义指令
|
||
"""
|
||
try:
|
||
# 检查知识库是否存在
|
||
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
|
||
if not kb:
|
||
raise NotFoundError("知识库")
|
||
|
||
# 创建任务
|
||
task = await KnowledgeProcessingService.create_task(
|
||
conn, current_user.id, kb_id, task_data
|
||
)
|
||
|
||
# 添加后台处理任务
|
||
logger.info(f"添加后台加工任务: task_id={task.id}, type={task.task_type}")
|
||
background_tasks.add_task(process_task_background, task.id)
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="任务创建成功,正在处理中",
|
||
data=TaskResponse(
|
||
id=task.id,
|
||
task_name=task.task_name,
|
||
instruction=task.instruction,
|
||
file_ids=task.file_ids,
|
||
task_type=task.task_type.value,
|
||
status=task.status.value,
|
||
result=task.result,
|
||
result_file_url=task.result_file_url,
|
||
error_message=task.error_message,
|
||
created_at=task.created_at,
|
||
updated_at=task.updated_at,
|
||
started_at=task.started_at,
|
||
completed_at=task.completed_at
|
||
).dict()
|
||
)
|
||
|
||
except ValueError as e:
|
||
raise BadRequestError(str(e))
|
||
except Exception as e:
|
||
logger.error(f"创建知识加工任务失败: {e}")
|
||
raise BadRequestError(f"创建任务失败: {str(e)}")
|
||
|
||
|
||
@kb_processing_router.get("/{kb_id}/processing/tasks", response_model=BaseResponse, summary="获取知识加工任务列表")
|
||
async def get_processing_tasks(
|
||
kb_id: int,
|
||
status: Optional[str] = Query(None, description="任务状态筛选"),
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""获取知识库的加工任务列表"""
|
||
# 检查知识库是否存在
|
||
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
|
||
if not kb:
|
||
raise NotFoundError("知识库")
|
||
|
||
# 获取任务列表
|
||
tasks, total = await KnowledgeProcessingService.get_user_tasks(
|
||
conn, current_user.id, kb_id, status, page, page_size
|
||
)
|
||
|
||
items = [
|
||
TaskResponse(
|
||
id=task.id,
|
||
task_name=task.task_name,
|
||
instruction=task.instruction,
|
||
file_ids=task.file_ids,
|
||
task_type=task.task_type.value,
|
||
status=task.status.value,
|
||
result=task.result,
|
||
result_file_url=task.result_file_url,
|
||
error_message=task.error_message,
|
||
created_at=task.created_at,
|
||
updated_at=task.updated_at,
|
||
started_at=task.started_at,
|
||
completed_at=task.completed_at
|
||
).dict()
|
||
for task in tasks
|
||
]
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取任务列表成功",
|
||
data=TaskListResponse(total=total, items=items).dict()
|
||
)
|
||
|
||
|
||
@kb_processing_router.get("/{kb_id}/processing/tasks/{task_id}", response_model=BaseResponse, summary="获取任务详情")
|
||
async def get_task_detail(
|
||
kb_id: int,
|
||
task_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""获取知识加工任务详情"""
|
||
# 检查知识库是否存在
|
||
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
|
||
if not kb:
|
||
raise NotFoundError("知识库")
|
||
|
||
# 获取任务
|
||
task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id)
|
||
if not task or task.knowledge_base_id != kb_id:
|
||
raise NotFoundError("任务")
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取任务详情成功",
|
||
data=TaskResponse(
|
||
id=task.id,
|
||
task_name=task.task_name,
|
||
instruction=task.instruction,
|
||
file_ids=task.file_ids,
|
||
task_type=task.task_type.value,
|
||
status=task.status.value,
|
||
result=task.result,
|
||
result_file_url=task.result_file_url,
|
||
error_message=task.error_message,
|
||
created_at=task.created_at,
|
||
updated_at=task.updated_at,
|
||
started_at=task.started_at,
|
||
completed_at=task.completed_at
|
||
).dict()
|
||
)
|
||
|
||
|
||
@kb_processing_router.get("/{kb_id}/processing/tasks/{task_id}/status", response_model=BaseResponse, summary="查询任务处理状态")
|
||
async def get_task_status(
|
||
kb_id: int,
|
||
task_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""
|
||
查询知识加工任务的处理状态(用于前端轮询)
|
||
|
||
Returns:
|
||
- id: 任务ID
|
||
- status: pending(待处理)/ processing(处理中)/ completed(已完成)/ failed(失败)
|
||
- result: 处理结果(仅在completed时返回)
|
||
- error_message: 错误信息(仅在failed时返回)
|
||
- updated_at: 更新时间
|
||
- started_at: 开始时间
|
||
- completed_at: 完成时间
|
||
"""
|
||
# 检查知识库是否存在
|
||
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
|
||
if not kb:
|
||
raise NotFoundError("知识库")
|
||
|
||
# 获取任务
|
||
task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id)
|
||
if not task or task.knowledge_base_id != kb_id:
|
||
raise NotFoundError("任务")
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取任务状态成功",
|
||
data=TaskStatusResponse(
|
||
id=task.id,
|
||
status=task.status.value,
|
||
result=task.result,
|
||
result_file_url=task.result_file_url,
|
||
error_message=task.error_message,
|
||
updated_at=task.updated_at,
|
||
started_at=task.started_at,
|
||
completed_at=task.completed_at
|
||
).dict()
|
||
)
|
||
|
||
|
||
@kb_processing_router.delete("/{kb_id}/processing/tasks/{task_id}", response_model=BaseResponse, summary="删除任务")
|
||
async def delete_task(
|
||
kb_id: int,
|
||
task_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""删除知识加工任务"""
|
||
# 检查知识库是否存在
|
||
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
|
||
if not kb:
|
||
raise NotFoundError("知识库")
|
||
|
||
# 获取任务
|
||
task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id)
|
||
if not task or task.knowledge_base_id != kb_id:
|
||
raise NotFoundError("任务")
|
||
|
||
# 删除任务
|
||
success = await KnowledgeProcessingService.delete_task(conn, task_id, current_user.id)
|
||
if not success:
|
||
raise NotFoundError("任务")
|
||
|
||
return BaseResponse(code=200, msg="删除任务成功", data={"id": task_id})
|