""" 知识加工 API 路由模块 处理知识库文件的加工任务,包括合并、对比、总结等功能。 """ from typing import Optional import asyncpg from fastapi import APIRouter, Depends, BackgroundTasks, Query from core.dependencies import get_db, get_current_user from core.database import get_db_pool from core.exceptions import NotFoundError, BadRequestError from models.user import User from models.knowledge_processing import ( TaskCreateRequest, TaskResponse, TaskListResponse, TaskStatusResponse, TaskStatus ) from services.knowledge_base_service import KnowledgeBaseService from services.knowledge_processing_service import ( KnowledgeProcessingService, KnowledgeProcessingExecutor ) from utils.helpers import BaseResponse from logger.logging import get_logger logger = get_logger(__name__) # 创建知识加工路由 kb_processing_router = APIRouter(prefix="/api/knowledge-base", tags=["知识加工"]) async def process_task_background(task_id: int): """ 后台任务:执行知识加工 Args: task_id: 任务 ID """ pool = await get_db_pool() async with pool.acquire() as conn: try: logger.info(f"开始后台处理知识加工任务 ID: {task_id}") # 获取任务信息 task = await conn.fetchrow( """ SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids, task_type, status, result, result_file_url, error_message, created_at, updated_at, started_at, completed_at FROM knowledge_processing_task WHERE id = $1 """, task_id ) if not task: logger.error(f"任务 {task_id} 不存在") return from models.knowledge_processing import KnowledgeProcessingTask task_obj = KnowledgeProcessingTask(**dict(task)) # 更新状态为处理中 await KnowledgeProcessingService.update_task_status( conn, task_id, TaskStatus.PROCESSING ) # 执行任务 success, result, error_message, result_file_url = await KnowledgeProcessingExecutor.process_task( conn, task_obj ) # 更新任务状态 if success: await KnowledgeProcessingService.update_task_status( conn, task_id, TaskStatus.COMPLETED, result=result, result_file_url=result_file_url ) logger.info(f"任务 {task_id} 处理成功,文件链接: {result_file_url}") else: await KnowledgeProcessingService.update_task_status( conn, task_id, TaskStatus.FAILED, error_message=error_message ) logger.error(f"任务 {task_id} 处理失败: {error_message}") except Exception as e: logger.error(f"后台处理任务异常 ID: {task_id}, 错误: {e}") import traceback logger.error(f"错误堆栈: {traceback.format_exc()}") # 更新任务状态为失败 try: await KnowledgeProcessingService.update_task_status( conn, task_id, TaskStatus.FAILED, error_message=str(e) ) except Exception as update_error: logger.error(f"更新任务状态失败: {update_error}") @kb_processing_router.post("/{kb_id}/processing/tasks", response_model=BaseResponse, summary="创建知识加工任务") async def create_processing_task( kb_id: int, task_data: TaskCreateRequest, background_tasks: BackgroundTasks, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """ 创建知识加工任务 用户可以选择知识库中的一个或多个文件,输入加工指令,系统将异步处理任务。 支持的任务类型: - merge: 合并文件 - compare: 对比文件 - summary: 总结文件 - custom: 自定义指令 """ try: # 检查知识库是否存在 kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user) if not kb: raise NotFoundError("知识库") # 创建任务 task = await KnowledgeProcessingService.create_task( conn, current_user.id, kb_id, task_data ) # 添加后台处理任务 logger.info(f"添加后台加工任务: task_id={task.id}, type={task.task_type}") background_tasks.add_task(process_task_background, task.id) return BaseResponse( code=200, msg="任务创建成功,正在处理中", data=TaskResponse( id=task.id, task_name=task.task_name, instruction=task.instruction, file_ids=task.file_ids, task_type=task.task_type.value, status=task.status.value, result=task.result, result_file_url=task.result_file_url, error_message=task.error_message, created_at=task.created_at, updated_at=task.updated_at, started_at=task.started_at, completed_at=task.completed_at ).dict() ) except ValueError as e: raise BadRequestError(str(e)) except Exception as e: logger.error(f"创建知识加工任务失败: {e}") raise BadRequestError(f"创建任务失败: {str(e)}") @kb_processing_router.get("/{kb_id}/processing/tasks", response_model=BaseResponse, summary="获取知识加工任务列表") async def get_processing_tasks( kb_id: int, status: Optional[str] = Query(None, description="任务状态筛选"), page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """获取知识库的加工任务列表""" # 检查知识库是否存在 kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user) if not kb: raise NotFoundError("知识库") # 获取任务列表 tasks, total = await KnowledgeProcessingService.get_user_tasks( conn, current_user.id, kb_id, status, page, page_size ) items = [ TaskResponse( id=task.id, task_name=task.task_name, instruction=task.instruction, file_ids=task.file_ids, task_type=task.task_type.value, status=task.status.value, result=task.result, result_file_url=task.result_file_url, error_message=task.error_message, created_at=task.created_at, updated_at=task.updated_at, started_at=task.started_at, completed_at=task.completed_at ).dict() for task in tasks ] return BaseResponse( code=200, msg="获取任务列表成功", data=TaskListResponse(total=total, items=items).dict() ) @kb_processing_router.get("/{kb_id}/processing/tasks/{task_id}", response_model=BaseResponse, summary="获取任务详情") async def get_task_detail( kb_id: int, task_id: int, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """获取知识加工任务详情""" # 检查知识库是否存在 kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user) if not kb: raise NotFoundError("知识库") # 获取任务 task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id) if not task or task.knowledge_base_id != kb_id: raise NotFoundError("任务") return BaseResponse( code=200, msg="获取任务详情成功", data=TaskResponse( id=task.id, task_name=task.task_name, instruction=task.instruction, file_ids=task.file_ids, task_type=task.task_type.value, status=task.status.value, result=task.result, result_file_url=task.result_file_url, error_message=task.error_message, created_at=task.created_at, updated_at=task.updated_at, started_at=task.started_at, completed_at=task.completed_at ).dict() ) @kb_processing_router.get("/{kb_id}/processing/tasks/{task_id}/status", response_model=BaseResponse, summary="查询任务处理状态") async def get_task_status( kb_id: int, task_id: int, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """ 查询知识加工任务的处理状态(用于前端轮询) Returns: - id: 任务ID - status: pending(待处理)/ processing(处理中)/ completed(已完成)/ failed(失败) - result: 处理结果(仅在completed时返回) - error_message: 错误信息(仅在failed时返回) - updated_at: 更新时间 - started_at: 开始时间 - completed_at: 完成时间 """ # 检查知识库是否存在 kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user) if not kb: raise NotFoundError("知识库") # 获取任务 task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id) if not task or task.knowledge_base_id != kb_id: raise NotFoundError("任务") return BaseResponse( code=200, msg="获取任务状态成功", data=TaskStatusResponse( id=task.id, status=task.status.value, result=task.result, result_file_url=task.result_file_url, error_message=task.error_message, updated_at=task.updated_at, started_at=task.started_at, completed_at=task.completed_at ).dict() ) @kb_processing_router.delete("/{kb_id}/processing/tasks/{task_id}", response_model=BaseResponse, summary="删除任务") async def delete_task( kb_id: int, task_id: int, current_user: User = Depends(get_current_user), conn: asyncpg.Connection = Depends(get_db) ): """删除知识加工任务""" # 检查知识库是否存在 kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user) if not kb: raise NotFoundError("知识库") # 获取任务 task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id) if not task or task.knowledge_base_id != kb_id: raise NotFoundError("任务") # 删除任务 success = await KnowledgeProcessingService.delete_task(conn, task_id, current_user.id) if not success: raise NotFoundError("任务") return BaseResponse(code=200, msg="删除任务成功", data={"id": task_id})