722 lines
31 KiB
Python
722 lines
31 KiB
Python
"""
|
||
知识库文件 API 路由模块
|
||
|
||
处理知识库文件的上传、列表、详情、删除和搜索功能。
|
||
"""
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
from urllib.parse import urlparse
|
||
|
||
import asyncpg
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, BackgroundTasks
|
||
from pydantic import BaseModel, Field
|
||
|
||
from core.config import settings
|
||
from core.dependencies import get_db, get_current_user
|
||
from core.database import get_db_pool
|
||
from core.exceptions import NotFoundError, BadRequestError
|
||
from core.permissions import can_delete_file, can_upload_to_kb
|
||
from models.user import User
|
||
from models.knowledge_base_file import FileUploadResponse, FileListResponse, KnowledgeBaseFile
|
||
from services.knowledge_base_service import KnowledgeBaseService
|
||
from services.knowledge_base_file_service import KnowledgeBaseFileService
|
||
from services.audit_service import AuditService
|
||
from services.vector_service import get_vector_service
|
||
from services.oss_service import get_oss_service
|
||
from utils.helpers import BaseResponse
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 创建知识库文件路由
|
||
kb_file_router = APIRouter(prefix="/api/knowledge-base", tags=["知识库文件"])
|
||
|
||
# 文件上传目录
|
||
UPLOAD_DIR = "./uploads"
|
||
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||
|
||
# 支持的文件类型
|
||
SUPPORTED_EXTENSIONS = {'.pdf', '.docx', '.xlsx', '.xls', '.csv', '.txt', '.png', '.jpg', '.jpeg', '.bmp'}
|
||
FILE_TYPE_MAP = {
|
||
'.pdf': 'pdf',
|
||
'.docx': 'docx',
|
||
'.xlsx': 'xlsx',
|
||
'.xls': 'xls',
|
||
'.csv': 'csv',
|
||
'.txt': 'txt',
|
||
'.png': 'png',
|
||
'.jpg': 'jpg',
|
||
'.jpeg': 'jpeg',
|
||
'.bmp': 'bmp'
|
||
}
|
||
|
||
|
||
class UrlUploadRequest(BaseModel):
|
||
"""URL 上传请求模型"""
|
||
url: str = Field(..., description="网页 URL", min_length=1)
|
||
|
||
|
||
async def _check_kb_access(conn: asyncpg.Connection, kb_id: int, user: User):
|
||
"""检查知识库访问权限(企业版可见性)"""
|
||
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, user)
|
||
if not kb:
|
||
raise NotFoundError("知识库")
|
||
return kb
|
||
|
||
|
||
async def process_file_background(
|
||
file_id: int,
|
||
file_path: str,
|
||
knowledge_base_id: int,
|
||
file_type: str = "pdf"
|
||
):
|
||
"""
|
||
后台任务:处理文件向量化
|
||
|
||
Args:
|
||
file_id: 文件 ID
|
||
file_path: 文件路径(可能是 OSS URL 或本地路径)
|
||
knowledge_base_id: 知识库 ID
|
||
file_type: 文件类型
|
||
"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
local_file_path = None
|
||
try:
|
||
logger.info(f"开始后台处理文件 ID: {file_id}, 路径: {file_path}, 类型: {file_type}")
|
||
|
||
oss_service = get_oss_service()
|
||
if oss_service.enabled and file_path.startswith(('http://', 'https://')):
|
||
logger.info(f"检测到 OSS URL,开始下载文件: {file_path}")
|
||
oss_object_name = oss_service.extract_object_name_from_url(file_path, knowledge_base_id)
|
||
if not oss_object_name:
|
||
logger.error(f"无法从 OSS URL 提取对象名称: {file_path}")
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
local_file_path = oss_service.download_file(oss_object_name)
|
||
if not local_file_path:
|
||
logger.error("从 OSS 下载文件失败")
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
logger.info(f"文件下载成功: {local_file_path}")
|
||
actual_file_path = local_file_path
|
||
else:
|
||
actual_file_path = file_path
|
||
|
||
# 处理文档(传入 file_id 和 OSS URL)
|
||
vector_service = get_vector_service()
|
||
result = await vector_service.process_document(
|
||
actual_file_path,
|
||
knowledge_base_id,
|
||
file_type,
|
||
file_id=file_id,
|
||
source_url=file_path # 🔑 传递原始 OSS URL
|
||
)
|
||
|
||
# 检查处理结果
|
||
if not result.success:
|
||
logger.warning(f"文件处理失败 ID: {file_id}, 原因: {result.error_message}")
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
# 生成文件摘要
|
||
summary_text = None
|
||
try:
|
||
from services.summary_service import SummaryService
|
||
from langchain_core.documents import Document
|
||
|
||
# 判断是否为图片类型
|
||
image_types = {'png', 'jpg', 'jpeg', 'bmp'}
|
||
is_image = file_type.lower() in image_types
|
||
|
||
if is_image:
|
||
# 🎨 使用视觉模型处理图片
|
||
from services.vision_service import VisionService
|
||
|
||
logger.info(f"🎨 使用视觉模型为知识库图片 {file_id} 生成描述")
|
||
|
||
# 生成带签名的临时访问 URL(用于私有 OSS)
|
||
vision_image_url = file_path
|
||
if file_path.startswith(('http://', 'https://')):
|
||
# 是 OSS URL,生成签名 URL 供视觉模型访问
|
||
try:
|
||
oss_object_name = oss_service.extract_object_name_from_url(file_path, knowledge_base_id)
|
||
if oss_object_name:
|
||
# 生成有效期 1 小时的签名 URL
|
||
signed_url = oss_service.get_signed_url(oss_object_name, expires=3600)
|
||
if signed_url:
|
||
vision_image_url = signed_url
|
||
logger.info(f"🔐 已生成签名 URL 供视觉模型访问(有效期1小时)")
|
||
else:
|
||
logger.warning(f"生成签名 URL 失败,尝试使用原始 URL")
|
||
else:
|
||
logger.warning(f"无法从 OSS URL 提取对象名称,使用原始 URL")
|
||
except Exception as e:
|
||
logger.warning(f"生成签名 URL 时出错,使用原始 URL: {e}")
|
||
|
||
# 调用视觉模型
|
||
vision_prompt = "详细描述图片的内容,包括主要元素、颜色、布局、文字信息等。回答需要详细且准确。"
|
||
vision_description = await VisionService.get_image_description(
|
||
image_url=vision_image_url,
|
||
prompt=vision_prompt
|
||
)
|
||
|
||
if vision_description:
|
||
logger.info(f"✅ 视觉模型返回结果:")
|
||
logger.info(f"{'='*60}")
|
||
logger.info(f"图片URL: {file_path}")
|
||
logger.info(f"描述内容: {vision_description}")
|
||
logger.info(f"描述长度: {len(vision_description)} 字符")
|
||
logger.info(f"{'='*60}")
|
||
|
||
# 组合 OCR 文字和视觉描述
|
||
ocr_text = "\n\n".join([content for _, content, _, _ in result.chunks])
|
||
combined_content = f"【图片内容描述】\n{vision_description}\n\n【图片文字识别(OCR)】\n{ocr_text}" if ocr_text.strip() else f"【图片内容描述】\n{vision_description}"
|
||
|
||
logger.info(f"📝 组合内容长度: {len(combined_content)} 字符")
|
||
logger.info(f" - 视觉描述: {len(vision_description)} 字符")
|
||
logger.info(f" - OCR文字: {len(ocr_text)} 字符")
|
||
|
||
# 将组合内容转换为 Document 对象
|
||
docs = [Document(page_content=combined_content)]
|
||
|
||
# 生成摘要
|
||
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
|
||
else:
|
||
logger.warning(f"⚠️ 视觉模型未返回描述,降级使用 OCR 文字生成摘要")
|
||
# 降级使用 OCR 文字
|
||
file_content = "\n\n".join([content for _, content, _, _ in result.chunks])
|
||
docs = [Document(page_content=file_content)]
|
||
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
|
||
else:
|
||
# 非图片文件,使用原有逻辑
|
||
# 拼接所有 chunks 的内容用于生成摘要
|
||
file_content = "\n\n".join([content for _, content, _, _ in result.chunks])
|
||
|
||
# 限制内容长度,避免超出 LLM 限制
|
||
max_content_length = 10000 # 约 3000-4000 tokens
|
||
if len(file_content) > max_content_length:
|
||
file_content = file_content[:max_content_length] + "..."
|
||
|
||
logger.info(f"正在为文件 {file_id} 生成摘要,内容长度: {len(file_content)} 字符")
|
||
|
||
# 将文本内容转换为 Document 对象
|
||
docs = [Document(page_content=file_content)]
|
||
|
||
# 生成摘要
|
||
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
|
||
|
||
if summary_text:
|
||
logger.info(f"📝 文件 {file_id} 摘要生成成功:")
|
||
logger.info(f"{'='*60}")
|
||
logger.info(f"摘要内容: {summary_text}")
|
||
logger.info(f"{'='*60}")
|
||
else:
|
||
logger.warning(f"文件 {file_id} 摘要生成失败,返回为空")
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成文件摘要失败: {e}")
|
||
import traceback
|
||
logger.error(f"错误堆栈: {traceback.format_exc()}")
|
||
# 摘要生成失败不影响主流程,继续处理
|
||
|
||
# 保存成功处理的结果(包含 summary)
|
||
await KnowledgeBaseFileService.save_chunks(
|
||
conn, file_id, knowledge_base_id, result.chunks, summary=summary_text
|
||
)
|
||
|
||
# 🔑 关键:更新 ChromaDB 中的 summary metadata
|
||
if summary_text:
|
||
success = vector_service.update_kb_file_summary_in_vectors(
|
||
knowledge_base_id=knowledge_base_id,
|
||
file_id=file_id,
|
||
summary=summary_text
|
||
)
|
||
if success:
|
||
logger.info(f"✅ ChromaDB metadata 已同步 summary")
|
||
else:
|
||
logger.warning(f"⚠️ ChromaDB metadata 同步 summary 失败,但不影响主流程")
|
||
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "completed", result.chunk_count)
|
||
|
||
logger.info(f"文件处理完成 ID: {file_id}, 类型: {file_type}, 块数: {result.chunk_count}, 摘要: {'已生成' if summary_text else '未生成'}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"后台处理文件异常 ID: {file_id}, 类型: {file_type}, 错误: {e}")
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
|
||
finally:
|
||
if local_file_path and os.path.exists(local_file_path):
|
||
try:
|
||
os.remove(local_file_path)
|
||
logger.debug(f"已删除临时文件: {local_file_path}")
|
||
except Exception as e:
|
||
logger.warning(f"删除临时文件失败: {e}")
|
||
|
||
|
||
async def process_url_background(file_id: int, url: str, knowledge_base_id: int):
|
||
"""后台任务:处理 URL 向量化"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
try:
|
||
logger.info(f"开始后台处理 URL ID: {file_id}, URL: {url}")
|
||
|
||
# 处理 URL
|
||
vector_service = get_vector_service()
|
||
result = await vector_service.process_url(url, knowledge_base_id)
|
||
|
||
# 检查处理结果
|
||
if not result.success:
|
||
logger.warning(f"URL 处理失败 ID: {file_id}, 原因: {result.error_message}")
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
# 生成文件摘要
|
||
summary_text = None
|
||
try:
|
||
from services.summary_service import SummaryService
|
||
from langchain_core.documents import Document
|
||
|
||
# 拼接所有 chunks 的内容用于生成摘要
|
||
file_content = "\n\n".join([content for _, content, _, _ in result.chunks])
|
||
|
||
# 限制内容长度
|
||
max_content_length = 10000
|
||
if len(file_content) > max_content_length:
|
||
file_content = file_content[:max_content_length] + "..."
|
||
|
||
logger.info(f"正在为 URL {file_id} 生成摘要,内容长度: {len(file_content)} 字符")
|
||
|
||
docs = [Document(page_content=file_content)]
|
||
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
|
||
|
||
if summary_text:
|
||
logger.info(f"📝 URL {file_id} 摘要生成成功")
|
||
else:
|
||
logger.warning(f"URL {file_id} 摘要生成失败,返回为空")
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成 URL 摘要失败: {e}")
|
||
|
||
# 保存成功处理的结果(包含 summary)
|
||
await KnowledgeBaseFileService.save_chunks(
|
||
conn, file_id, knowledge_base_id, result.chunks, summary=summary_text
|
||
)
|
||
|
||
# 更新 ChromaDB metadata(URL 暂不支持 file_id,跳过)
|
||
# if summary_text:
|
||
# vector_service.update_kb_file_summary_in_vectors(...)
|
||
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "completed", result.chunk_count)
|
||
|
||
logger.info(f"URL 处理完成 ID: {file_id}, 块数: {result.chunk_count}, 摘要: {'已生成' if summary_text else '未生成'}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"后台处理 URL 异常 ID: {file_id}, 错误: {e}")
|
||
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
|
||
|
||
|
||
@kb_file_router.post("/{kb_id}/upload", response_model=BaseResponse, summary="上传文件到知识库")
|
||
async def upload_file(
|
||
kb_id: int,
|
||
background_tasks: BackgroundTasks,
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""上传文件到知识库并进行向量化处理"""
|
||
try:
|
||
logger.info(f"📤 开始上传文件到知识库 {kb_id}: {file.filename}, 用户: {current_user.username}")
|
||
|
||
kb = await _check_kb_access(conn, kb_id, current_user)
|
||
|
||
# 检查上传权限(allow_kb_upload 开关)
|
||
if not can_upload_to_kb(current_user, kb):
|
||
raise BadRequestError("您的上传权限已被关闭,请联系部门领导或管理员")
|
||
|
||
# 检查文件类型
|
||
file_ext = Path(file.filename).suffix.lower()
|
||
if file_ext not in SUPPORTED_EXTENSIONS:
|
||
logger.warning(f"❌ 不支持的文件类型: {file_ext}, 文件: {file.filename}")
|
||
raise BadRequestError(f"不支持的文件类型: {file_ext},支持的类型: {', '.join(SUPPORTED_EXTENSIONS)}")
|
||
|
||
file_type = FILE_TYPE_MAP[file_ext]
|
||
logger.info(f"📋 文件类型识别: {file_ext} -> {file_type}")
|
||
|
||
content = await file.read()
|
||
file_size = len(content)
|
||
file_size_mb = file_size / (1024 * 1024)
|
||
|
||
# 检查文件大小(限制 15MB)
|
||
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15MB
|
||
if file_size > MAX_FILE_SIZE:
|
||
logger.warning(f"❌ 文件大小超限: {file_size_mb:.2f}MB (最大 15MB), 文件: {file.filename}")
|
||
raise BadRequestError(f"文件大小超过限制,当前: {file_size_mb:.2f}MB,最大允许: 15MB")
|
||
|
||
logger.info(f"✅ 文件大小验证通过: {file_size_mb:.2f}MB ({file_size} bytes)")
|
||
|
||
# 生成唯一文件名
|
||
timestamp = int(time.time() * 1000)
|
||
unique_filename = f"{timestamp}_{file.filename}"
|
||
oss_object_name = f"kb_{kb_id}/{unique_filename}"
|
||
|
||
# 上传文件
|
||
oss_service = get_oss_service()
|
||
file_path = None
|
||
file_url = None
|
||
|
||
logger.info(f"☁️ 开始上传文件,OSS 状态: {'已启用' if oss_service.enabled else '未启用'}")
|
||
|
||
if oss_service.enabled:
|
||
logger.info(f"☁️ 上传文件到 OSS: {oss_object_name}")
|
||
file_url = oss_service.upload_file_from_bytes(content, oss_object_name, file.filename)
|
||
if file_url:
|
||
file_path = file_url
|
||
logger.info(f"✅ 文件已上传到 OSS: {file_url}")
|
||
|
||
# 🔑 图片审核:在创建文件记录前进行审核
|
||
if file_type in ['png', 'jpg', 'jpeg', 'bmp']:
|
||
from core.dependencies import get_moderation_service
|
||
from core.config import settings
|
||
from core.exceptions import ModerationError
|
||
from models.moderation import ModerationDecision
|
||
|
||
moderation_service = await get_moderation_service()
|
||
|
||
if moderation_service and settings.moderation_enabled:
|
||
try:
|
||
logger.info(f"🔍 开始图片审核: {file.filename}")
|
||
|
||
# 使用 OSS URL 进行审核
|
||
result = await moderation_service.moderate_image(
|
||
image_source=file_url,
|
||
source_type="url",
|
||
request_id=f"kb_file_{timestamp}"
|
||
)
|
||
|
||
# 检查审核结果
|
||
if result.decision == ModerationDecision.BLOCK:
|
||
# 删除已上传的 OSS 文件
|
||
oss_service.delete_file(oss_object_name)
|
||
logger.warning(
|
||
f"❌ 图片审核不通过: {file.filename}, "
|
||
f"原因: {result.message}, "
|
||
f"标签: {[label.label for label in result.labels]}"
|
||
)
|
||
raise BadRequestError(
|
||
result.message or "图片包含不当内容,无法上传"
|
||
)
|
||
|
||
logger.info(
|
||
f"✅ 图片审核通过: {file.filename}, "
|
||
f"决策: {result.decision.value}"
|
||
)
|
||
|
||
except ModerationError as e:
|
||
# 审核服务错误,删除 OSS 文件并返回错误
|
||
oss_service.delete_file(oss_object_name)
|
||
logger.error(f"❌ 图片审核服务错误: {e}")
|
||
raise BadRequestError("图片审核服务暂时不可用,请稍后重试")
|
||
else:
|
||
logger.warning("⚠️ OSS 上传失败,回退到本地存储")
|
||
|
||
if not file_path:
|
||
kb_dir = Path(UPLOAD_DIR) / f"kb_{kb_id}"
|
||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||
local_path = kb_dir / unique_filename
|
||
with open(local_path, "wb") as f:
|
||
f.write(content)
|
||
file_path = str(local_path)
|
||
logger.info(f"💾 文件已保存到本地: {file_path}")
|
||
|
||
# 创建文件记录
|
||
logger.info(f"📝 创建文件记录: {file.filename}")
|
||
file_record = await KnowledgeBaseFileService.create_file_record(
|
||
conn, kb_id, current_user.id, file.filename, file_path, file_size, file_type
|
||
)
|
||
logger.info(f"✅ 文件记录已创建: ID={file_record.id}, 状态={file_record.status}")
|
||
|
||
# 审计日志:上传
|
||
await AuditService.write(
|
||
conn,
|
||
enterprise_id=current_user.enterprise_id or 0,
|
||
actor_id=current_user.id,
|
||
action="upload",
|
||
department_id=current_user.department_id,
|
||
kb_id=kb_id,
|
||
file_id=file_record.id,
|
||
metadata={"file_name": file.filename, "file_size": file_size, "file_type": file_type},
|
||
)
|
||
|
||
# 添加后台任务
|
||
logger.info(f"🚀 添加后台向量化任务: file_id={file_record.id}, type={file_type}")
|
||
background_tasks.add_task(process_file_background, file_record.id, file_path, kb_id, file_type)
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="文件上传成功,正在处理中",
|
||
data=FileUploadResponse(
|
||
id=file_record.id,
|
||
file_name=file_record.file_name,
|
||
file_size=file_record.file_size,
|
||
status=file_record.status,
|
||
chunk_count=file_record.chunk_count,
|
||
created_at=file_record.created_at,
|
||
file_url=file_url or file_path
|
||
).dict()
|
||
)
|
||
|
||
except BadRequestError:
|
||
raise
|
||
except ValueError as e:
|
||
# 文件名重复等业务错误
|
||
logger.warning(f"文件上传验证失败: {e}")
|
||
raise BadRequestError(str(e))
|
||
except Exception as e:
|
||
logger.error(f"上传文件失败: {e}")
|
||
raise BadRequestError(f"上传文件失败: {str(e)}")
|
||
|
||
|
||
@kb_file_router.post("/{kb_id}/upload-url", response_model=BaseResponse, summary="上传 URL 到知识库")
|
||
async def upload_url(
|
||
kb_id: int,
|
||
background_tasks: BackgroundTasks,
|
||
request: UrlUploadRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""上传 URL 到知识库并进行向量化处理"""
|
||
kb = await _check_kb_access(conn, kb_id, current_user)
|
||
if not can_upload_to_kb(current_user, kb):
|
||
raise BadRequestError("您的上传权限已被关闭,请联系部门领导或管理员")
|
||
|
||
url = request.url.strip()
|
||
if not url.startswith(('http://', 'https://')):
|
||
raise BadRequestError("URL 格式不正确,必须以 http:// 或 https:// 开头")
|
||
|
||
# 生成文件名
|
||
parsed_url = urlparse(url)
|
||
file_name = f"{parsed_url.netloc}{parsed_url.path}".replace('/', '_')[:200]
|
||
if not file_name:
|
||
file_name = "webpage"
|
||
file_name = f"{file_name}.url"
|
||
|
||
# 创建文件记录
|
||
file_record = await KnowledgeBaseFileService.create_file_record(
|
||
conn, kb_id, current_user.id, file_name, url, 0, "url"
|
||
)
|
||
|
||
logger.info(f"URL 已记录: {url}, 文件 ID: {file_record.id}")
|
||
background_tasks.add_task(process_url_background, file_record.id, url, kb_id)
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="URL 上传成功,正在处理中",
|
||
data=FileUploadResponse(
|
||
id=file_record.id,
|
||
file_name=file_record.file_name,
|
||
file_size=file_record.file_size,
|
||
status=file_record.status,
|
||
chunk_count=file_record.chunk_count,
|
||
created_at=file_record.created_at
|
||
).dict()
|
||
)
|
||
|
||
|
||
@kb_file_router.get("/{kb_id}/files", response_model=BaseResponse, summary="获取知识库文件列表")
|
||
async def get_knowledge_base_files(
|
||
kb_id: int,
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""获取知识库的文件列表"""
|
||
await _check_kb_access(conn, kb_id, current_user)
|
||
|
||
file_rows, total = await KnowledgeBaseFileService.get_files_by_kb(
|
||
conn, kb_id, current_user.id, page, page_size
|
||
)
|
||
|
||
items = [
|
||
{
|
||
"id": r["id"],
|
||
"file_name": r["file_name"],
|
||
"file_size": r["file_size"],
|
||
"file_type": r["file_type"],
|
||
"status": r["status"],
|
||
"chunk_count": r["chunk_count"],
|
||
"created_at": r["created_at"].isoformat() if r.get("created_at") else None,
|
||
"file_url": r["file_path"],
|
||
"uploader_name": r.get("uploader_name"),
|
||
"is_mine": r["user_id"] == current_user.id,
|
||
}
|
||
for r in file_rows
|
||
]
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取文件列表成功",
|
||
data={"total": total, "items": items},
|
||
)
|
||
|
||
|
||
@kb_file_router.get("/{kb_id}/files/{file_id}", response_model=BaseResponse, summary="获取文件详情")
|
||
async def get_file_detail(
|
||
kb_id: int,
|
||
file_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""获取文件详情"""
|
||
await _check_kb_access(conn, kb_id, current_user)
|
||
|
||
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, current_user.id)
|
||
if not file or file.knowledge_base_id != kb_id:
|
||
raise NotFoundError("文件")
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取文件详情成功",
|
||
data=FileUploadResponse(
|
||
id=file.id,
|
||
file_name=file.file_name,
|
||
file_size=file.file_size,
|
||
status=file.status,
|
||
chunk_count=file.chunk_count,
|
||
created_at=file.created_at,
|
||
file_url=file.file_path
|
||
).dict()
|
||
)
|
||
|
||
|
||
@kb_file_router.get("/{kb_id}/files/{file_id}/status", response_model=BaseResponse, summary="查询文件处理状态")
|
||
async def get_file_processing_status(
|
||
kb_id: int,
|
||
file_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""
|
||
查询知识库文件的处理状态(用于前端轮询)
|
||
|
||
Returns:
|
||
- status: processing(处理中)/ completed(已完成)/ failed(失败)
|
||
- chunk_count: 已处理的文档块数量
|
||
- file_name: 文件名
|
||
- file_type: 文件类型
|
||
- created_at: 创建时间
|
||
- updated_at: 更新时间
|
||
"""
|
||
await _check_kb_access(conn, kb_id, current_user)
|
||
|
||
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, current_user.id)
|
||
if not file or file.knowledge_base_id != kb_id:
|
||
raise NotFoundError("文件")
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取文件状态成功",
|
||
data={
|
||
"id": file.id,
|
||
"file_name": file.file_name,
|
||
"file_type": file.file_type,
|
||
"status": file.status,
|
||
"chunk_count": file.chunk_count,
|
||
"created_at": file.created_at.isoformat() if file.created_at else None,
|
||
"updated_at": file.updated_at.isoformat() if file.updated_at else None,
|
||
}
|
||
)
|
||
|
||
|
||
@kb_file_router.delete("/{kb_id}/files/{file_id}", response_model=BaseResponse, summary="删除文件")
|
||
async def delete_file(
|
||
kb_id: int,
|
||
file_id: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""删除知识库中的文件(本人、领导管辖范围内、admin 均可)"""
|
||
kb = await _check_kb_access(conn, kb_id, current_user)
|
||
|
||
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, current_user.id)
|
||
if not file or file.knowledge_base_id != kb_id:
|
||
raise NotFoundError("文件")
|
||
|
||
# 权限校验:can_delete_file 内部判断本人/领导/admin
|
||
allowed = await can_delete_file(conn, current_user, file, kb)
|
||
if not allowed:
|
||
from fastapi import HTTPException, status as http_status
|
||
raise HTTPException(
|
||
status_code=http_status.HTTP_403_FORBIDDEN,
|
||
detail="无权限删除该文件,仅文件上传者、部门领导或管理员可删除",
|
||
)
|
||
|
||
bypass = (current_user.role in ("admin", "leader") and file.user_id != current_user.id)
|
||
success, vector_ids = await KnowledgeBaseFileService.delete_file(
|
||
conn, file_id, current_user.id, bypass_owner_check=bypass
|
||
)
|
||
if not success:
|
||
raise NotFoundError("文件")
|
||
|
||
# 审计日志
|
||
await AuditService.write(
|
||
conn,
|
||
enterprise_id=current_user.enterprise_id or 0,
|
||
actor_id=current_user.id,
|
||
action="delete",
|
||
target_user_id=file.user_id if file.user_id != current_user.id else None,
|
||
department_id=current_user.department_id,
|
||
kb_id=kb_id,
|
||
file_id=file_id,
|
||
metadata={"file_name": file.file_name, "by_role": current_user.role},
|
||
)
|
||
|
||
# 删除向量
|
||
if vector_ids:
|
||
try:
|
||
vector_service = get_vector_service()
|
||
vector_service.delete_vectors_by_ids(kb_id, vector_ids)
|
||
logger.info(f"已删除 {len(vector_ids)} 个向量")
|
||
except Exception as e:
|
||
logger.warning(f"删除向量库中的向量失败: {e}")
|
||
|
||
# 删除物理文件
|
||
try:
|
||
oss_service = get_oss_service()
|
||
if oss_service.enabled and file.file_path.startswith(('http://', 'https://')):
|
||
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, kb_id)
|
||
if oss_object_name:
|
||
oss_service.delete_file(oss_object_name)
|
||
logger.info(f"已删除 OSS 文件: {oss_object_name}")
|
||
elif os.path.exists(file.file_path):
|
||
os.remove(file.file_path)
|
||
logger.info(f"已删除本地文件: {file.file_path}")
|
||
except Exception as e:
|
||
logger.warning(f"删除物理文件失败: {e}")
|
||
|
||
return BaseResponse(code=200, msg="删除文件成功", data={"id": file_id})
|
||
|
||
|
||
@kb_file_router.post("/{kb_id}/search", response_model=BaseResponse, summary="在知识库中搜索")
|
||
async def search_in_knowledge_base(
|
||
kb_id: int,
|
||
query: str = Query(..., description="搜索查询"),
|
||
k: int = Query(5, ge=1, le=20, description="返回结果数量"),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""在知识库中进行语义搜索"""
|
||
await _check_kb_access(conn, kb_id, current_user)
|
||
|
||
vector_service = get_vector_service()
|
||
results = vector_service.search_similar(kb_id, query, k)
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="搜索成功",
|
||
data={"query": query, "results": results, "count": len(results)}
|
||
)
|