huoyan-enterprise/backend/api/kb_file_router.py

675 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识库文件 API 路由模块
处理知识库文件的上传、列表、详情、删除和搜索功能。
"""
import os
import time
from pathlib import Path
from urllib.parse import urlparse
import asyncpg
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, BackgroundTasks
from pydantic import BaseModel, Field
from core.config import settings
from core.dependencies import get_db, get_current_user
from core.database import get_db_pool
from core.exceptions import NotFoundError, BadRequestError
from models.user import User
from models.knowledge_base_file import FileUploadResponse, FileListResponse
from services.knowledge_base_service import KnowledgeBaseService
from services.knowledge_base_file_service import KnowledgeBaseFileService
from services.vector_service import get_vector_service
from services.oss_service import get_oss_service
from utils.helpers import BaseResponse
from logger.logging import get_logger
logger = get_logger(__name__)
# 创建知识库文件路由
kb_file_router = APIRouter(prefix="/api/knowledge-base", tags=["知识库文件"])
# 文件上传目录
UPLOAD_DIR = "./uploads"
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
# 支持的文件类型
SUPPORTED_EXTENSIONS = {'.pdf', '.docx', '.xlsx', '.xls', '.csv', '.txt', '.png', '.jpg', '.jpeg', '.bmp'}
FILE_TYPE_MAP = {
'.pdf': 'pdf',
'.docx': 'docx',
'.xlsx': 'xlsx',
'.xls': 'xls',
'.csv': 'csv',
'.txt': 'txt',
'.png': 'png',
'.jpg': 'jpg',
'.jpeg': 'jpeg',
'.bmp': 'bmp'
}
class UrlUploadRequest(BaseModel):
"""URL 上传请求模型"""
url: str = Field(..., description="网页 URL", min_length=1)
async def _check_kb_access(conn: asyncpg.Connection, kb_id: int, user: User):
"""检查知识库访问权限(企业版可见性)"""
kb = await KnowledgeBaseService.get_knowledge_base_by_id(conn, kb_id, user)
if not kb:
raise NotFoundError("知识库")
return kb
async def process_file_background(
file_id: int,
file_path: str,
knowledge_base_id: int,
file_type: str = "pdf"
):
"""
后台任务:处理文件向量化
Args:
file_id: 文件 ID
file_path: 文件路径(可能是 OSS URL 或本地路径)
knowledge_base_id: 知识库 ID
file_type: 文件类型
"""
pool = await get_db_pool()
async with pool.acquire() as conn:
local_file_path = None
try:
logger.info(f"开始后台处理文件 ID: {file_id}, 路径: {file_path}, 类型: {file_type}")
oss_service = get_oss_service()
if oss_service.enabled and file_path.startswith(('http://', 'https://')):
logger.info(f"检测到 OSS URL开始下载文件: {file_path}")
oss_object_name = oss_service.extract_object_name_from_url(file_path, knowledge_base_id)
if not oss_object_name:
logger.error(f"无法从 OSS URL 提取对象名称: {file_path}")
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
return
local_file_path = oss_service.download_file(oss_object_name)
if not local_file_path:
logger.error("从 OSS 下载文件失败")
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
return
logger.info(f"文件下载成功: {local_file_path}")
actual_file_path = local_file_path
else:
actual_file_path = file_path
# 处理文档(传入 file_id 和 OSS URL
vector_service = get_vector_service()
result = await vector_service.process_document(
actual_file_path,
knowledge_base_id,
file_type,
file_id=file_id,
source_url=file_path # 🔑 传递原始 OSS URL
)
# 检查处理结果
if not result.success:
logger.warning(f"文件处理失败 ID: {file_id}, 原因: {result.error_message}")
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
return
# 生成文件摘要
summary_text = None
try:
from services.summary_service import SummaryService
from langchain_core.documents import Document
# 判断是否为图片类型
image_types = {'png', 'jpg', 'jpeg', 'bmp'}
is_image = file_type.lower() in image_types
if is_image:
# 🎨 使用视觉模型处理图片
from services.vision_service import VisionService
logger.info(f"🎨 使用视觉模型为知识库图片 {file_id} 生成描述")
# 生成带签名的临时访问 URL用于私有 OSS
vision_image_url = file_path
if file_path.startswith(('http://', 'https://')):
# 是 OSS URL生成签名 URL 供视觉模型访问
try:
oss_object_name = oss_service.extract_object_name_from_url(file_path, knowledge_base_id)
if oss_object_name:
# 生成有效期 1 小时的签名 URL
signed_url = oss_service.get_signed_url(oss_object_name, expires=3600)
if signed_url:
vision_image_url = signed_url
logger.info(f"🔐 已生成签名 URL 供视觉模型访问有效期1小时")
else:
logger.warning(f"生成签名 URL 失败,尝试使用原始 URL")
else:
logger.warning(f"无法从 OSS URL 提取对象名称,使用原始 URL")
except Exception as e:
logger.warning(f"生成签名 URL 时出错,使用原始 URL: {e}")
# 调用视觉模型
vision_prompt = "详细描述图片的内容,包括主要元素、颜色、布局、文字信息等。回答需要详细且准确。"
vision_description = await VisionService.get_image_description(
image_url=vision_image_url,
prompt=vision_prompt
)
if vision_description:
logger.info(f"✅ 视觉模型返回结果:")
logger.info(f"{'='*60}")
logger.info(f"图片URL: {file_path}")
logger.info(f"描述内容: {vision_description}")
logger.info(f"描述长度: {len(vision_description)} 字符")
logger.info(f"{'='*60}")
# 组合 OCR 文字和视觉描述
ocr_text = "\n\n".join([content for _, content, _, _ in result.chunks])
combined_content = f"【图片内容描述】\n{vision_description}\n\n【图片文字识别OCR\n{ocr_text}" if ocr_text.strip() else f"【图片内容描述】\n{vision_description}"
logger.info(f"📝 组合内容长度: {len(combined_content)} 字符")
logger.info(f" - 视觉描述: {len(vision_description)} 字符")
logger.info(f" - OCR文字: {len(ocr_text)} 字符")
# 将组合内容转换为 Document 对象
docs = [Document(page_content=combined_content)]
# 生成摘要
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
else:
logger.warning(f"⚠️ 视觉模型未返回描述,降级使用 OCR 文字生成摘要")
# 降级使用 OCR 文字
file_content = "\n\n".join([content for _, content, _, _ in result.chunks])
docs = [Document(page_content=file_content)]
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
else:
# 非图片文件,使用原有逻辑
# 拼接所有 chunks 的内容用于生成摘要
file_content = "\n\n".join([content for _, content, _, _ in result.chunks])
# 限制内容长度,避免超出 LLM 限制
max_content_length = 10000 # 约 3000-4000 tokens
if len(file_content) > max_content_length:
file_content = file_content[:max_content_length] + "..."
logger.info(f"正在为文件 {file_id} 生成摘要,内容长度: {len(file_content)} 字符")
# 将文本内容转换为 Document 对象
docs = [Document(page_content=file_content)]
# 生成摘要
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
if summary_text:
logger.info(f"📝 文件 {file_id} 摘要生成成功:")
logger.info(f"{'='*60}")
logger.info(f"摘要内容: {summary_text}")
logger.info(f"{'='*60}")
else:
logger.warning(f"文件 {file_id} 摘要生成失败,返回为空")
except Exception as e:
logger.error(f"生成文件摘要失败: {e}")
import traceback
logger.error(f"错误堆栈: {traceback.format_exc()}")
# 摘要生成失败不影响主流程,继续处理
# 保存成功处理的结果(包含 summary
await KnowledgeBaseFileService.save_chunks(
conn, file_id, knowledge_base_id, result.chunks, summary=summary_text
)
# 🔑 关键:更新 ChromaDB 中的 summary metadata
if summary_text:
success = vector_service.update_kb_file_summary_in_vectors(
knowledge_base_id=knowledge_base_id,
file_id=file_id,
summary=summary_text
)
if success:
logger.info(f"✅ ChromaDB metadata 已同步 summary")
else:
logger.warning(f"⚠️ ChromaDB metadata 同步 summary 失败,但不影响主流程")
await KnowledgeBaseFileService.update_file_status(conn, file_id, "completed", result.chunk_count)
logger.info(f"文件处理完成 ID: {file_id}, 类型: {file_type}, 块数: {result.chunk_count}, 摘要: {'已生成' if summary_text else '未生成'}")
except Exception as e:
logger.error(f"后台处理文件异常 ID: {file_id}, 类型: {file_type}, 错误: {e}")
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
finally:
if local_file_path and os.path.exists(local_file_path):
try:
os.remove(local_file_path)
logger.debug(f"已删除临时文件: {local_file_path}")
except Exception as e:
logger.warning(f"删除临时文件失败: {e}")
async def process_url_background(file_id: int, url: str, knowledge_base_id: int):
"""后台任务:处理 URL 向量化"""
pool = await get_db_pool()
async with pool.acquire() as conn:
try:
logger.info(f"开始后台处理 URL ID: {file_id}, URL: {url}")
# 处理 URL
vector_service = get_vector_service()
result = await vector_service.process_url(url, knowledge_base_id)
# 检查处理结果
if not result.success:
logger.warning(f"URL 处理失败 ID: {file_id}, 原因: {result.error_message}")
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
return
# 生成文件摘要
summary_text = None
try:
from services.summary_service import SummaryService
from langchain_core.documents import Document
# 拼接所有 chunks 的内容用于生成摘要
file_content = "\n\n".join([content for _, content, _, _ in result.chunks])
# 限制内容长度
max_content_length = 10000
if len(file_content) > max_content_length:
file_content = file_content[:max_content_length] + "..."
logger.info(f"正在为 URL {file_id} 生成摘要,内容长度: {len(file_content)} 字符")
docs = [Document(page_content=file_content)]
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
if summary_text:
logger.info(f"📝 URL {file_id} 摘要生成成功")
else:
logger.warning(f"URL {file_id} 摘要生成失败,返回为空")
except Exception as e:
logger.error(f"生成 URL 摘要失败: {e}")
# 保存成功处理的结果(包含 summary
await KnowledgeBaseFileService.save_chunks(
conn, file_id, knowledge_base_id, result.chunks, summary=summary_text
)
# 更新 ChromaDB metadataURL 暂不支持 file_id跳过
# if summary_text:
# vector_service.update_kb_file_summary_in_vectors(...)
await KnowledgeBaseFileService.update_file_status(conn, file_id, "completed", result.chunk_count)
logger.info(f"URL 处理完成 ID: {file_id}, 块数: {result.chunk_count}, 摘要: {'已生成' if summary_text else '未生成'}")
except Exception as e:
logger.error(f"后台处理 URL 异常 ID: {file_id}, 错误: {e}")
await KnowledgeBaseFileService.update_file_status(conn, file_id, "failed", 0)
@kb_file_router.post("/{kb_id}/upload", response_model=BaseResponse, summary="上传文件到知识库")
async def upload_file(
kb_id: int,
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""上传文件到知识库并进行向量化处理"""
try:
logger.info(f"📤 开始上传文件到知识库 {kb_id}: {file.filename}, 用户: {current_user.username}")
await _check_kb_access(conn, kb_id, current_user)
# 检查文件类型
file_ext = Path(file.filename).suffix.lower()
if file_ext not in SUPPORTED_EXTENSIONS:
logger.warning(f"❌ 不支持的文件类型: {file_ext}, 文件: {file.filename}")
raise BadRequestError(f"不支持的文件类型: {file_ext},支持的类型: {', '.join(SUPPORTED_EXTENSIONS)}")
file_type = FILE_TYPE_MAP[file_ext]
logger.info(f"📋 文件类型识别: {file_ext} -> {file_type}")
content = await file.read()
file_size = len(content)
file_size_mb = file_size / (1024 * 1024)
# 检查文件大小(限制 15MB
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15MB
if file_size > MAX_FILE_SIZE:
logger.warning(f"❌ 文件大小超限: {file_size_mb:.2f}MB (最大 15MB), 文件: {file.filename}")
raise BadRequestError(f"文件大小超过限制,当前: {file_size_mb:.2f}MB最大允许: 15MB")
logger.info(f"✅ 文件大小验证通过: {file_size_mb:.2f}MB ({file_size} bytes)")
# 生成唯一文件名
timestamp = int(time.time() * 1000)
unique_filename = f"{timestamp}_{file.filename}"
oss_object_name = f"kb_{kb_id}/{unique_filename}"
# 上传文件
oss_service = get_oss_service()
file_path = None
file_url = None
logger.info(f"☁️ 开始上传文件OSS 状态: {'已启用' if oss_service.enabled else '未启用'}")
if oss_service.enabled:
logger.info(f"☁️ 上传文件到 OSS: {oss_object_name}")
file_url = oss_service.upload_file_from_bytes(content, oss_object_name, file.filename)
if file_url:
file_path = file_url
logger.info(f"✅ 文件已上传到 OSS: {file_url}")
# 🔑 图片审核:在创建文件记录前进行审核
if file_type in ['png', 'jpg', 'jpeg', 'bmp']:
from core.dependencies import get_moderation_service
from core.config import settings
from core.exceptions import ModerationError
from models.moderation import ModerationDecision
moderation_service = await get_moderation_service()
if moderation_service and settings.moderation_enabled:
try:
logger.info(f"🔍 开始图片审核: {file.filename}")
# 使用 OSS URL 进行审核
result = await moderation_service.moderate_image(
image_source=file_url,
source_type="url",
request_id=f"kb_file_{timestamp}"
)
# 检查审核结果
if result.decision == ModerationDecision.BLOCK:
# 删除已上传的 OSS 文件
oss_service.delete_file(oss_object_name)
logger.warning(
f"❌ 图片审核不通过: {file.filename}, "
f"原因: {result.message}, "
f"标签: {[label.label for label in result.labels]}"
)
raise BadRequestError(
result.message or "图片包含不当内容,无法上传"
)
logger.info(
f"✅ 图片审核通过: {file.filename}, "
f"决策: {result.decision.value}"
)
except ModerationError as e:
# 审核服务错误,删除 OSS 文件并返回错误
oss_service.delete_file(oss_object_name)
logger.error(f"❌ 图片审核服务错误: {e}")
raise BadRequestError("图片审核服务暂时不可用,请稍后重试")
else:
logger.warning("⚠️ OSS 上传失败,回退到本地存储")
if not file_path:
kb_dir = Path(UPLOAD_DIR) / f"kb_{kb_id}"
kb_dir.mkdir(parents=True, exist_ok=True)
local_path = kb_dir / unique_filename
with open(local_path, "wb") as f:
f.write(content)
file_path = str(local_path)
logger.info(f"💾 文件已保存到本地: {file_path}")
# 创建文件记录
logger.info(f"📝 创建文件记录: {file.filename}")
file_record = await KnowledgeBaseFileService.create_file_record(
conn, kb_id, current_user.id, file.filename, file_path, file_size, file_type
)
logger.info(f"✅ 文件记录已创建: ID={file_record.id}, 状态={file_record.status}")
# 添加后台任务
logger.info(f"🚀 添加后台向量化任务: file_id={file_record.id}, type={file_type}")
background_tasks.add_task(process_file_background, file_record.id, file_path, kb_id, file_type)
return BaseResponse(
code=200,
msg="文件上传成功,正在处理中",
data=FileUploadResponse(
id=file_record.id,
file_name=file_record.file_name,
file_size=file_record.file_size,
status=file_record.status,
chunk_count=file_record.chunk_count,
created_at=file_record.created_at,
file_url=file_url or file_path
).dict()
)
except BadRequestError:
raise
except ValueError as e:
# 文件名重复等业务错误
logger.warning(f"文件上传验证失败: {e}")
raise BadRequestError(str(e))
except Exception as e:
logger.error(f"上传文件失败: {e}")
raise BadRequestError(f"上传文件失败: {str(e)}")
@kb_file_router.post("/{kb_id}/upload-url", response_model=BaseResponse, summary="上传 URL 到知识库")
async def upload_url(
kb_id: int,
background_tasks: BackgroundTasks,
request: UrlUploadRequest,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""上传 URL 到知识库并进行向量化处理"""
await _check_kb_access(conn, kb_id, current_user)
url = request.url.strip()
if not url.startswith(('http://', 'https://')):
raise BadRequestError("URL 格式不正确,必须以 http:// 或 https:// 开头")
# 生成文件名
parsed_url = urlparse(url)
file_name = f"{parsed_url.netloc}{parsed_url.path}".replace('/', '_')[:200]
if not file_name:
file_name = "webpage"
file_name = f"{file_name}.url"
# 创建文件记录
file_record = await KnowledgeBaseFileService.create_file_record(
conn, kb_id, current_user.id, file_name, url, 0, "url"
)
logger.info(f"URL 已记录: {url}, 文件 ID: {file_record.id}")
background_tasks.add_task(process_url_background, file_record.id, url, kb_id)
return BaseResponse(
code=200,
msg="URL 上传成功,正在处理中",
data=FileUploadResponse(
id=file_record.id,
file_name=file_record.file_name,
file_size=file_record.file_size,
status=file_record.status,
chunk_count=file_record.chunk_count,
created_at=file_record.created_at
).dict()
)
@kb_file_router.get("/{kb_id}/files", response_model=BaseResponse, summary="获取知识库文件列表")
async def get_knowledge_base_files(
kb_id: int,
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""获取知识库的文件列表"""
await _check_kb_access(conn, kb_id, current_user)
files, total = await KnowledgeBaseFileService.get_files_by_kb(
conn, kb_id, current_user.id, page, page_size
)
items = [
FileUploadResponse(
id=f.id,
file_name=f.file_name,
file_size=f.file_size,
status=f.status,
chunk_count=f.chunk_count,
created_at=f.created_at,
file_url=f.file_path
).dict()
for f in files
]
return BaseResponse(
code=200,
msg="获取文件列表成功",
data=FileListResponse(total=total, items=items).dict()
)
@kb_file_router.get("/{kb_id}/files/{file_id}", response_model=BaseResponse, summary="获取文件详情")
async def get_file_detail(
kb_id: int,
file_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""获取文件详情"""
await _check_kb_access(conn, kb_id, current_user)
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, current_user.id)
if not file or file.knowledge_base_id != kb_id:
raise NotFoundError("文件")
return BaseResponse(
code=200,
msg="获取文件详情成功",
data=FileUploadResponse(
id=file.id,
file_name=file.file_name,
file_size=file.file_size,
status=file.status,
chunk_count=file.chunk_count,
created_at=file.created_at,
file_url=file.file_path
).dict()
)
@kb_file_router.get("/{kb_id}/files/{file_id}/status", response_model=BaseResponse, summary="查询文件处理状态")
async def get_file_processing_status(
kb_id: int,
file_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""
查询知识库文件的处理状态(用于前端轮询)
Returns:
- status: processing处理中/ completed已完成/ failed失败
- chunk_count: 已处理的文档块数量
- file_name: 文件名
- file_type: 文件类型
- created_at: 创建时间
- updated_at: 更新时间
"""
await _check_kb_access(conn, kb_id, current_user)
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, current_user.id)
if not file or file.knowledge_base_id != kb_id:
raise NotFoundError("文件")
return BaseResponse(
code=200,
msg="获取文件状态成功",
data={
"id": file.id,
"file_name": file.file_name,
"file_type": file.file_type,
"status": file.status,
"chunk_count": file.chunk_count,
"created_at": file.created_at.isoformat() if file.created_at else None,
"updated_at": file.updated_at.isoformat() if file.updated_at else None,
}
)
@kb_file_router.delete("/{kb_id}/files/{file_id}", response_model=BaseResponse, summary="删除文件")
async def delete_file(
kb_id: int,
file_id: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""删除知识库中的文件"""
await _check_kb_access(conn, kb_id, current_user)
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, current_user.id)
if not file or file.knowledge_base_id != kb_id:
raise NotFoundError("文件")
# 删除文件记录
success, vector_ids = await KnowledgeBaseFileService.delete_file(conn, file_id, current_user.id)
if not success:
raise NotFoundError("文件")
# 删除向量
if vector_ids:
try:
vector_service = get_vector_service()
vector_service.delete_vectors_by_ids(kb_id, vector_ids)
logger.info(f"已删除 {len(vector_ids)} 个向量")
except Exception as e:
logger.warning(f"删除向量库中的向量失败: {e}")
# 删除物理文件
try:
oss_service = get_oss_service()
if oss_service.enabled and file.file_path.startswith(('http://', 'https://')):
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, kb_id)
if oss_object_name:
oss_service.delete_file(oss_object_name)
logger.info(f"已删除 OSS 文件: {oss_object_name}")
elif os.path.exists(file.file_path):
os.remove(file.file_path)
logger.info(f"已删除本地文件: {file.file_path}")
except Exception as e:
logger.warning(f"删除物理文件失败: {e}")
return BaseResponse(code=200, msg="删除文件成功", data={"id": file_id})
@kb_file_router.post("/{kb_id}/search", response_model=BaseResponse, summary="在知识库中搜索")
async def search_in_knowledge_base(
kb_id: int,
query: str = Query(..., description="搜索查询"),
k: int = Query(5, ge=1, le=20, description="返回结果数量"),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db)
):
"""在知识库中进行语义搜索"""
await _check_kb_access(conn, kb_id, current_user)
vector_service = get_vector_service()
results = vector_service.search_similar(kb_id, query, k)
return BaseResponse(
code=200,
msg="搜索成功",
data={"query": query, "results": results, "count": len(results)}
)