huoyan-enterprise/backend/api/kb_processing_router.py

318 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识加工 API 路由模块
处理知识库文件的加工任务,包括合并、对比、总结等功能。
"""
from typing import Optional
import asyncpg
from fastapi import APIRouter, Depends, BackgroundTasks, Query
from core.dependencies import get_db, get_current_user
from core.database import get_db_pool
from core.exceptions import NotFoundError, BadRequestError
from models.user import User
from models.knowledge_processing import (
TaskCreateRequest,
TaskResponse,
TaskListResponse,
TaskStatusResponse,
TaskStatus
)
from services.knowledge_base_service import KnowledgeBaseService
from services.knowledge_processing_service import (
KnowledgeProcessingService,
KnowledgeProcessingExecutor
)
from utils.helpers import BaseResponse
from logger.logging import get_logger
logger = get_logger(__name__)
# 创建知识加工路由
kb_processing_router = APIRouter(prefix="/api/knowledge-base", tags=["知识加工"])
async def process_task_background(task_id: int):
"""
后台任务:执行知识加工
Args:
task_id: 任务 ID
"""
pool = await get_db_pool()
async with pool.acquire() as conn:
try:
logger.info(f"开始后台处理知识加工任务 ID: {task_id}")
# 获取任务信息
task = await conn.fetchrow(
"""
SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids,
task_type, status, result, result_file_url, error_message,
created_at, updated_at, started_at, completed_at
FROM knowledge_processing_task
WHERE id = $1
""",
task_id
)
if not task:
logger.error(f"任务 {task_id} 不存在")
return
from models.knowledge_processing import KnowledgeProcessingTask
task_obj = KnowledgeProcessingTask(**dict(task))
# 更新状态为处理中
await KnowledgeProcessingService.update_task_status(
conn, task_id, TaskStatus.PROCESSING
)
# 执行任务
success, result, error_message, result_file_url = await KnowledgeProcessingExecutor.process_task(
conn, task_obj
)
# 更新任务状态
if success:
await KnowledgeProcessingService.update_task_status(
conn, task_id, TaskStatus.COMPLETED,
result=result, result_file_url=result_file_url
)
logger.info(f"任务 {task_id} 处理成功,文件链接: {result_file_url}")
else:
await KnowledgeProcessingService.update_task_status(
conn, task_id, TaskStatus.FAILED, error_message=error_message
)
logger.error(f"任务 {task_id} 处理失败: {error_message}")
except Exception as e:
logger.error(f"后台处理任务异常 ID: {task_id}, 错误: {e}")
import traceback
logger.error(f"错误堆栈: {traceback.format_exc()}")
# 更新任务状态为失败
try:
await KnowledgeProcessingService.update_task_status(
conn, task_id, TaskStatus.FAILED, error_message=str(e)
)
except Exception as update_error:
logger.error(f"更新任务状态失败: {update_error}")
@kb_processing_router.post("/{kb_id}/processing/tasks", response_model=BaseResponse, summary="创建知识加工任务")
async def create_processing_task(
kb_id: int,
task_data: TaskCreateRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""
创建知识加工任务
用户可以选择知识库中的一个或多个文件,输入加工指令,系统将异步处理任务。
支持的任务类型:
- merge: 合并文件
- compare: 对比文件
- summary: 总结文件
- custom: 自定义指令
"""
try:
# 检查知识库是否存在
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
if not kb:
raise NotFoundError("知识库")
# 创建任务
task = await KnowledgeProcessingService.create_task(
conn, current_user.id, kb_id, task_data
)
# 添加后台处理任务
logger.info(f"添加后台加工任务: task_id={task.id}, type={task.task_type}")
background_tasks.add_task(process_task_background, task.id)
return BaseResponse(
code=200,
msg="任务创建成功,正在处理中",
data=TaskResponse(
id=task.id,
task_name=task.task_name,
instruction=task.instruction,
file_ids=task.file_ids,
task_type=task.task_type.value,
status=task.status.value,
result=task.result,
result_file_url=task.result_file_url,
error_message=task.error_message,
created_at=task.created_at,
updated_at=task.updated_at,
started_at=task.started_at,
completed_at=task.completed_at
).dict()
)
except ValueError as e:
raise BadRequestError(str(e))
except Exception as e:
logger.error(f"创建知识加工任务失败: {e}")
raise BadRequestError(f"创建任务失败: {str(e)}")
@kb_processing_router.get("/{kb_id}/processing/tasks", response_model=BaseResponse, summary="获取知识加工任务列表")
async def get_processing_tasks(
kb_id: int,
status: Optional[str] = Query(None, description="任务状态筛选"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""获取知识库的加工任务列表"""
# 检查知识库是否存在
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
if not kb:
raise NotFoundError("知识库")
# 获取任务列表
tasks, total = await KnowledgeProcessingService.get_user_tasks(
conn, current_user.id, kb_id, status, page, page_size
)
items = [
TaskResponse(
id=task.id,
task_name=task.task_name,
instruction=task.instruction,
file_ids=task.file_ids,
task_type=task.task_type.value,
status=task.status.value,
result=task.result,
result_file_url=task.result_file_url,
error_message=task.error_message,
created_at=task.created_at,
updated_at=task.updated_at,
started_at=task.started_at,
completed_at=task.completed_at
).dict()
for task in tasks
]
return BaseResponse(
code=200,
msg="获取任务列表成功",
data=TaskListResponse(total=total, items=items).dict()
)
@kb_processing_router.get("/{kb_id}/processing/tasks/{task_id}", response_model=BaseResponse, summary="获取任务详情")
async def get_task_detail(
kb_id: int,
task_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""获取知识加工任务详情"""
# 检查知识库是否存在
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
if not kb:
raise NotFoundError("知识库")
# 获取任务
task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id)
if not task or task.knowledge_base_id != kb_id:
raise NotFoundError("任务")
return BaseResponse(
code=200,
msg="获取任务详情成功",
data=TaskResponse(
id=task.id,
task_name=task.task_name,
instruction=task.instruction,
file_ids=task.file_ids,
task_type=task.task_type.value,
status=task.status.value,
result=task.result,
result_file_url=task.result_file_url,
error_message=task.error_message,
created_at=task.created_at,
updated_at=task.updated_at,
started_at=task.started_at,
completed_at=task.completed_at
).dict()
)
@kb_processing_router.get("/{kb_id}/processing/tasks/{task_id}/status", response_model=BaseResponse, summary="查询任务处理状态")
async def get_task_status(
kb_id: int,
task_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""
查询知识加工任务的处理状态(用于前端轮询)
Returns:
- id: 任务ID
- status: pending待处理/ processing处理中/ completed已完成/ failed失败
- result: 处理结果仅在completed时返回
- error_message: 错误信息仅在failed时返回
- updated_at: 更新时间
- started_at: 开始时间
- completed_at: 完成时间
"""
# 检查知识库是否存在
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
if not kb:
raise NotFoundError("知识库")
# 获取任务
task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id)
if not task or task.knowledge_base_id != kb_id:
raise NotFoundError("任务")
return BaseResponse(
code=200,
msg="获取任务状态成功",
data=TaskStatusResponse(
id=task.id,
status=task.status.value,
result=task.result,
result_file_url=task.result_file_url,
error_message=task.error_message,
updated_at=task.updated_at,
started_at=task.started_at,
completed_at=task.completed_at
).dict()
)
@kb_processing_router.delete("/{kb_id}/processing/tasks/{task_id}", response_model=BaseResponse, summary="删除任务")
async def delete_task(
kb_id: int,
task_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""删除知识加工任务"""
# 检查知识库是否存在
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, current_user)
if not kb:
raise NotFoundError("知识库")
# 获取任务
task = await KnowledgeProcessingService.get_task_by_id(conn, task_id, current_user.id)
if not task or task.knowledge_base_id != kb_id:
raise NotFoundError("任务")
# 删除任务
success = await KnowledgeProcessingService.delete_task(conn, task_id, current_user.id)
if not success:
raise NotFoundError("任务")
return BaseResponse(code=200, msg="删除任务成功", data={"id": task_id})