850 lines
37 KiB
Python
850 lines
37 KiB
Python
"""
|
||
聊天文件相关 API 路由模块
|
||
|
||
处理聊天对话中的文件上传、列表查询和删除功能。
|
||
"""
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, BackgroundTasks, Query
|
||
|
||
from utils.helpers import BaseResponse
|
||
from logger.logging import get_logger
|
||
from core.dependencies import get_current_user
|
||
from models.user import User
|
||
from core.database import get_db_pool
|
||
from services.chat_thread_file_service import ChatThreadFileService
|
||
from services.vector_service import get_vector_service
|
||
from services.oss_service import get_oss_service
|
||
from models.chat_thread_file import (
|
||
ChatThreadFileUploadResponse,
|
||
ChatThreadFileListResponse
|
||
)
|
||
|
||
# 获取日志记录器
|
||
logger = get_logger(__name__)
|
||
|
||
# 创建路由实例
|
||
chat_file_router = APIRouter(prefix="/api", tags=["聊天文件接口"])
|
||
|
||
|
||
async def process_chat_file_background(
|
||
file_id: int,
|
||
file_path: str,
|
||
thread_id: str,
|
||
file_type: str
|
||
):
|
||
"""
|
||
后台任务:处理聊天文件向量化
|
||
|
||
Args:
|
||
file_id: 文件 ID
|
||
file_path: 文件路径(OSS URL)
|
||
thread_id: 会话线程 ID
|
||
file_type: 文件类型(pdf 或 url)
|
||
"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
local_file_path = None
|
||
try:
|
||
logger.info(f"开始后台处理聊天文件 ID: {file_id}, thread_id: {thread_id}, 路径: {file_path}")
|
||
|
||
# file_path 是 OSS URL,需要先下载到本地临时文件
|
||
oss_service = get_oss_service()
|
||
if not oss_service.enabled:
|
||
logger.error("OSS 服务未启用")
|
||
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
if not file_path.startswith(('http://', 'https://')):
|
||
logger.error(f"无效的文件路径格式(应为 OSS URL): {file_path}")
|
||
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
logger.info(f"检测到 OSS URL,开始下载文件: {file_path}")
|
||
|
||
# 从 OSS URL 提取对象名称
|
||
oss_object_name = oss_service.extract_object_name_from_url(file_path, thread_id=thread_id)
|
||
if not oss_object_name:
|
||
logger.error(f"无法从 OSS URL 提取对象名称: {file_path}")
|
||
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
# 下载文件到临时目录
|
||
local_file_path = oss_service.download_file(oss_object_name)
|
||
if not local_file_path:
|
||
logger.error("从 OSS 下载文件失败")
|
||
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
logger.info(f"文件下载成功: {local_file_path}")
|
||
actual_file_path = local_file_path
|
||
|
||
# 获取向量服务
|
||
vector_service = get_vector_service()
|
||
|
||
# 处理文件:分割和向量化(传入 file_id 和 OSS URL)
|
||
result = await vector_service.process_chat_thread_file(
|
||
actual_file_path,
|
||
thread_id,
|
||
file_type,
|
||
file_id=file_id,
|
||
source_url=file_path # 🔑 传递原始 OSS URL
|
||
)
|
||
|
||
# 检查处理结果
|
||
if not result.success:
|
||
logger.warning(f"聊天文件处理失败 ID: {file_id}, 原因: {result.error_message}")
|
||
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
|
||
return
|
||
|
||
# 生成文件摘要
|
||
summary_text = None
|
||
try:
|
||
# 判断是否为图片类型
|
||
image_types = {'png', 'jpg', 'jpeg', 'bmp'}
|
||
is_image = file_type.lower() in image_types
|
||
|
||
if is_image:
|
||
# 🎨 使用视觉模型处理图片
|
||
from services.vision_service import VisionService
|
||
|
||
logger.info(f"🎨 使用视觉模型为图片 {file_id} 生成摘要")
|
||
|
||
# 生成带签名的临时访问 URL(用于私有 OSS)
|
||
vision_image_url = file_path
|
||
if file_path.startswith(('http://', 'https://')):
|
||
# 是 OSS URL,生成签名 URL 供视觉模型访问
|
||
try:
|
||
oss_object_name = oss_service.extract_object_name_from_url(file_path, thread_id=thread_id)
|
||
if oss_object_name:
|
||
# 生成有效期 1 小时的签名 URL
|
||
signed_url = oss_service.get_signed_url(oss_object_name, expires=3600)
|
||
if signed_url:
|
||
vision_image_url = signed_url
|
||
logger.info(f"🔐 已生成签名 URL 供视觉模型访问(有效期1小时)")
|
||
else:
|
||
logger.warning(f"生成签名 URL 失败,尝试使用原始 URL")
|
||
else:
|
||
logger.warning(f"无法从 OSS URL 提取对象名称,使用原始 URL")
|
||
except Exception as e:
|
||
logger.warning(f"生成签名 URL 时出错,使用原始 URL: {e}")
|
||
|
||
# 使用视觉模型获取图片描述(强调识别文字)
|
||
vision_prompt = "请详细描述这张图片的内容。特别注意:1) 完整提取图片中的所有文字内容(标题、正文、数据、数字等);2) 描述图片的视觉场景(人物、动作、背景等)。用100-200字详细说明。"
|
||
logger.info(f"🤖 调用视觉模型,prompt: {vision_prompt}")
|
||
|
||
vision_description = await VisionService.get_image_description(
|
||
image_url=vision_image_url, # 使用签名 URL
|
||
prompt=vision_prompt
|
||
)
|
||
|
||
if vision_description:
|
||
logger.info(f"✅ 视觉模型返回结果:")
|
||
logger.info(f"{'='*60}")
|
||
logger.info(f"图片URL: {file_path}")
|
||
logger.info(f"描述内容: {vision_description}")
|
||
logger.info(f"描述长度: {len(vision_description)} 字符")
|
||
logger.info(f"{'='*60}")
|
||
# 获取 OCR 文字内容(完整)
|
||
ocr_content = "\n\n".join([chunk[1] for chunk in result.chunks])
|
||
|
||
logger.info(f"📝 组合视觉描述和OCR文字:")
|
||
logger.info(f" - 视觉描述: {len(vision_description)} 字符")
|
||
logger.info(f" - OCR文字: {len(ocr_content)} 字符")
|
||
|
||
# 组合视觉描述和 OCR 文字
|
||
if ocr_content and len(ocr_content.strip()) > 10:
|
||
# 如果有足够的 OCR 文字,组合两者
|
||
# 限制摘要长度,避免过长(保留更多内容,最多2000字符)
|
||
max_ocr_length = 2000
|
||
ocr_summary = ocr_content if len(ocr_content) <= max_ocr_length else ocr_content[:max_ocr_length] + "...(文字内容较长,已截断)"
|
||
|
||
summary_text = f"【视觉内容】{vision_description}\n\n【图片文字内容】\n{ocr_summary}"
|
||
logger.info(f"✅ 使用视觉模型+OCR 生成图片摘要")
|
||
logger.info(f" - OCR原始: {len(ocr_content)} 字符")
|
||
logger.info(f" - OCR摘要: {len(ocr_summary)} 字符")
|
||
logger.info(f" - 最终摘要: {len(summary_text)} 字符")
|
||
else:
|
||
# OCR 文字较少或没有,仅使用视觉描述
|
||
summary_text = f"【视觉内容】{vision_description}"
|
||
logger.info(f"✅ 使用视觉模型生成图片摘要(OCR文字不足,仅使用视觉描述)")
|
||
logger.info(f" - 最终摘要: {len(summary_text)} 字符")
|
||
else:
|
||
logger.warning(f"⚠️ 视觉模型返回为空,降级使用OCR文字")
|
||
# 降级方案:使用 OCR 文字
|
||
ocr_content = "\n\n".join([chunk[1] for chunk in result.chunks])
|
||
if ocr_content:
|
||
# 限制长度
|
||
max_ocr_length = 2000
|
||
ocr_summary = ocr_content if len(ocr_content) <= max_ocr_length else ocr_content[:max_ocr_length] + "...(文字内容较长,已截断)"
|
||
summary_text = f"【图片文字内容】\n{ocr_summary}"
|
||
|
||
else:
|
||
# 📄 非图片文件使用文本摘要服务
|
||
from services.summary_service import SummaryService
|
||
from services.vision_service import VisionService
|
||
try:
|
||
from langchain_core.documents import Document
|
||
except ImportError:
|
||
from langchain_core.documents import Document
|
||
|
||
# 获取文件内容(从所有 chunks 中提取)
|
||
file_content = "\n\n".join([chunk[1] for chunk in result.chunks])
|
||
|
||
# 🖼️ 检查是否为 DOCX 且包含图片,如果是则使用视觉模型
|
||
docx_image_descriptions = []
|
||
|
||
if file_type.lower() == 'docx' and result.extracted_image_paths:
|
||
logger.info(f"📸 DOCX 包含 {len(result.extracted_image_paths)} 张图片,使用视觉模型分析")
|
||
|
||
# 为每张图片上传到 OSS 并使用视觉模型分析
|
||
for idx, img_path in enumerate(result.extracted_image_paths, 1):
|
||
try:
|
||
if not os.path.exists(img_path):
|
||
logger.warning(f"图片文件不存在: {img_path}")
|
||
continue
|
||
|
||
# 读取图片内容
|
||
with open(img_path, 'rb') as f:
|
||
img_content = f.read()
|
||
|
||
# 上传到 OSS
|
||
img_filename = f"docx_image_{idx}_{int(time.time())}.png"
|
||
img_oss_name = f"thread_{thread_id}/temp/{img_filename}"
|
||
img_url = oss_service.upload_file_from_bytes(img_content, img_oss_name, img_filename)
|
||
|
||
if img_url:
|
||
# 生成签名 URL
|
||
signed_url = oss_service.get_signed_url(img_oss_name, expires=3600)
|
||
vision_url = signed_url if signed_url else img_url
|
||
|
||
# 使用视觉模型分析(要求识别文字和场景)
|
||
vision_desc = await VisionService.get_image_description(
|
||
image_url=vision_url,
|
||
prompt="请详细描述这张图片的内容,包括:1) 图片中的所有文字内容(如标题、正文、数据等);2) 图片的视觉场景(人物、动作、环境等)。用100-200字详细描述。"
|
||
)
|
||
|
||
if vision_desc:
|
||
docx_image_descriptions.append(f"[图片{idx}] {vision_desc}")
|
||
logger.info(f"✅ DOCX 图片 {idx} 视觉分析完成")
|
||
|
||
# 删除临时 OSS 文件
|
||
try:
|
||
oss_service.delete_file(img_oss_name)
|
||
except:
|
||
pass
|
||
|
||
except Exception as e:
|
||
logger.warning(f"处理 DOCX 图片 {idx} 失败: {e}")
|
||
finally:
|
||
# 清理本地临时图片文件
|
||
try:
|
||
if os.path.exists(img_path):
|
||
os.remove(img_path)
|
||
except:
|
||
pass
|
||
|
||
# 限制内容长度,避免 token 超限
|
||
max_content_length = 10000 # 约 3000-4000 tokens
|
||
if len(file_content) > max_content_length:
|
||
file_content = file_content[:max_content_length] + "..."
|
||
|
||
logger.info(f"正在为文件 {file_id} 生成摘要,内容长度: {len(file_content)} 字符")
|
||
|
||
# 将文本内容转换为 Document 对象
|
||
docs = [Document(page_content=file_content)]
|
||
|
||
# 生成摘要
|
||
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
|
||
|
||
# 如果有视觉模型分析的图片描述,追加到摘要中
|
||
if docx_image_descriptions:
|
||
image_summary = "\n\n文档图片内容:\n" + "\n".join(docx_image_descriptions)
|
||
summary_text = summary_text + image_summary if summary_text else image_summary
|
||
logger.info(f"✅ 已将 {len(docx_image_descriptions)} 张图片的视觉描述加入摘要")
|
||
|
||
if summary_text:
|
||
logger.info(f"📝 文件 {file_id} 摘要生成成功:")
|
||
logger.info(f"{'='*60}")
|
||
logger.info(f"摘要内容: {summary_text}")
|
||
logger.info(f"{'='*60}")
|
||
else:
|
||
logger.warning(f"文件 {file_id} 摘要生成失败,返回为空")
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成文件摘要失败: {e}")
|
||
# 摘要生成失败不影响主流程,继续处理
|
||
|
||
# 将 summary 和 file_id 添加到每个 chunk 的 metadata 中(参考 server 实现)
|
||
enhanced_chunks = []
|
||
for chunk_index, content, metadata, vector_id in result.chunks:
|
||
# 复制 metadata 并添加关键信息
|
||
enhanced_metadata = metadata.copy()
|
||
enhanced_metadata['file_id'] = file_id # 🔑 关键:用于检索时过滤
|
||
enhanced_metadata['chunk_index'] = chunk_index # 🔑 关键:用于排序
|
||
if summary_text:
|
||
enhanced_metadata['file_summary'] = summary_text
|
||
enhanced_chunks.append((chunk_index, content, enhanced_metadata, vector_id))
|
||
|
||
# 保存文档块到数据库(包含 summary)
|
||
await ChatThreadFileService.save_chunks(
|
||
conn, file_id, thread_id, enhanced_chunks, summary=summary_text
|
||
)
|
||
|
||
# 🔑 关键:更新 ChromaDB 中的 summary metadata
|
||
if summary_text:
|
||
success = vector_service.update_file_summary_in_vectors(
|
||
thread_id=thread_id,
|
||
file_id=file_id,
|
||
summary=summary_text
|
||
)
|
||
if success:
|
||
logger.info(f"✅ ChromaDB metadata 已同步 summary")
|
||
else:
|
||
logger.warning(f"⚠️ ChromaDB metadata 同步 summary 失败,但不影响主流程")
|
||
|
||
# 更新文件状态为完成
|
||
await ChatThreadFileService.update_file_status(
|
||
conn, file_id, "completed", result.chunk_count
|
||
)
|
||
|
||
logger.info(f"聊天文件处理完成 ID: {file_id}, 块数: {result.chunk_count}, 摘要: {'已生成' if summary_text else '未生成'}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"后台处理聊天文件异常 ID: {file_id}, 错误: {e}")
|
||
# 更新状态为失败
|
||
await ChatThreadFileService.update_file_status(
|
||
conn, file_id, "failed", 0
|
||
)
|
||
finally:
|
||
# 清理临时下载的文件
|
||
if local_file_path and os.path.exists(local_file_path):
|
||
try:
|
||
os.remove(local_file_path)
|
||
logger.debug(f"已删除临时文件: {local_file_path}")
|
||
except Exception as e:
|
||
logger.warning(f"删除临时文件失败: {e}")
|
||
|
||
|
||
@chat_file_router.post("/chat/thread/{thread_id}/upload", response_model=BaseResponse, summary="上传文件到聊天对话")
|
||
async def upload_chat_file(
|
||
thread_id: str,
|
||
background_tasks: BackgroundTasks,
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
上传文件到聊天对话并进行向量化处理
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
background_tasks: 后台任务
|
||
file: 上传的文件
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
BaseResponse: 包含文件信息
|
||
"""
|
||
try:
|
||
# 验证 thread_id 是否属于当前用户,如果不存在则自动创建
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
thread_info = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id FROM chat_threads
|
||
WHERE thread_id = $1 AND is_deleted = FALSE
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
# 如果会话不存在,自动创建会话记录
|
||
if not thread_info:
|
||
logger.info(f"会话不存在,自动创建会话记录: thread_id={thread_id}, user_id={current_user.id}")
|
||
try:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count)
|
||
VALUES ($1, $2, $3, $4, 0)
|
||
""",
|
||
thread_id,
|
||
current_user.id,
|
||
"新对话",
|
||
""
|
||
)
|
||
logger.info(f"成功创建新会话: thread_id={thread_id}")
|
||
# 重新查询会话信息
|
||
thread_info = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id FROM chat_threads
|
||
WHERE thread_id = $1 AND is_deleted = FALSE
|
||
""",
|
||
thread_id
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"创建会话记录失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"创建会话失败: {str(e)}"
|
||
)
|
||
|
||
# 验证会话是否属于当前用户
|
||
if thread_info['user_id'] != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="无权限访问该会话"
|
||
)
|
||
|
||
logger.info(f"📤 开始上传文件到聊天 {thread_id}: {file.filename}, 用户: {current_user.username}")
|
||
|
||
# 检查文件类型
|
||
file_ext = Path(file.filename).suffix.lower()
|
||
supported_extensions = {'.pdf', '.docx', '.xlsx', '.xls', '.txt', '.png', '.jpg', '.jpeg', '.bmp'}
|
||
if file_ext not in supported_extensions:
|
||
logger.warning(f"❌ 不支持的文件类型: {file_ext}, 文件: {file.filename}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"不支持的文件类型: {file_ext},支持的类型: {', '.join(supported_extensions)}"
|
||
)
|
||
|
||
# 确定文件类型
|
||
file_type_map = {
|
||
'.pdf': 'pdf',
|
||
'.docx': 'docx',
|
||
'.xlsx': 'xlsx',
|
||
'.xls': 'xls',
|
||
'.txt': 'txt',
|
||
'.png': 'png',
|
||
'.jpg': 'jpg',
|
||
'.jpeg': 'jpeg',
|
||
'.bmp': 'bmp'
|
||
}
|
||
file_type = file_type_map[file_ext]
|
||
logger.info(f"📋 文件类型识别: {file_ext} -> {file_type}")
|
||
|
||
# 读取文件内容
|
||
content = await file.read()
|
||
file_size = len(content)
|
||
file_size_mb = file_size / (1024 * 1024)
|
||
|
||
# 检查文件大小(限制 15MB)
|
||
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15MB
|
||
if file_size > MAX_FILE_SIZE:
|
||
logger.warning(f"❌ 文件大小超限: {file_size_mb:.2f}MB (最大 15MB), 文件: {file.filename}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"文件大小超过限制,当前: {file_size_mb:.2f}MB,最大允许: 15MB"
|
||
)
|
||
|
||
logger.info(f"✅ 文件大小验证通过: {file_size_mb:.2f}MB ({file_size} bytes)")
|
||
|
||
# 生成唯一文件名(使用时间戳)
|
||
timestamp = int(time.time() * 1000)
|
||
unique_filename = f"{timestamp}_{file.filename}"
|
||
|
||
# OSS 对象名称(存储路径)
|
||
oss_object_name = f"thread_{thread_id}/{unique_filename}"
|
||
|
||
# 获取 OSS 服务
|
||
oss_service = get_oss_service()
|
||
|
||
# 检查 OSS 是否启用
|
||
if not oss_service.enabled:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="OSS 服务未启用,无法上传文件"
|
||
)
|
||
|
||
# 上传到 OSS
|
||
logger.info(f"☁️ 上传文件到 OSS: {oss_object_name}")
|
||
file_url = oss_service.upload_file_from_bytes(
|
||
content,
|
||
oss_object_name,
|
||
file.filename
|
||
)
|
||
|
||
if not file_url:
|
||
logger.error(f"❌ OSS 上传失败: thread_id={thread_id}, filename={file.filename}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="文件上传到 OSS 失败"
|
||
)
|
||
|
||
# OSS 上传成功,使用 OSS URL 作为文件路径
|
||
file_path = file_url
|
||
logger.info(f"✅ 文件已上传到 OSS: {file_url}")
|
||
|
||
# 🔑 图片审核:在创建文件记录前进行审核
|
||
if file_type in ['png', 'jpg', 'jpeg', 'bmp']:
|
||
from core.dependencies import get_moderation_service
|
||
from core.config import settings
|
||
from core.exceptions import ModerationError
|
||
from models.moderation import ModerationDecision
|
||
|
||
moderation_service = await get_moderation_service()
|
||
|
||
if moderation_service and settings.moderation_enabled:
|
||
try:
|
||
logger.info(f"🔍 开始图片审核: {file.filename}")
|
||
|
||
# 使用 OSS URL 进行审核
|
||
result = await moderation_service.moderate_image(
|
||
image_source=file_url,
|
||
source_type="url",
|
||
request_id=f"chat_file_{timestamp}"
|
||
)
|
||
|
||
# 检查审核结果
|
||
if result.decision == ModerationDecision.BLOCK:
|
||
# 删除已上传的 OSS 文件
|
||
oss_service.delete_file(oss_object_name)
|
||
logger.warning(
|
||
f"❌ 图片审核不通过: {file.filename}, "
|
||
f"原因: {result.message}, "
|
||
f"标签: {[label.label for label in result.labels]}"
|
||
)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=result.message or "图片包含不当内容,无法上传"
|
||
)
|
||
|
||
logger.info(
|
||
f"✅ 图片审核通过: {file.filename}, "
|
||
f"决策: {result.decision.value}"
|
||
)
|
||
|
||
except ModerationError as e:
|
||
# 审核服务错误,删除 OSS 文件并返回错误
|
||
oss_service.delete_file(oss_object_name)
|
||
logger.error(f"❌ 图片审核服务错误: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="图片审核服务暂时不可用,请稍后重试"
|
||
)
|
||
|
||
# 创建文件记录(file_path 存储 OSS URL)
|
||
logger.info(f"📝 创建文件记录: {file.filename}")
|
||
async with pool.acquire() as conn:
|
||
file_record = await ChatThreadFileService.create_file_record(
|
||
conn,
|
||
thread_id,
|
||
current_user.id,
|
||
file.filename,
|
||
file_path, # 存储 OSS URL
|
||
file_size,
|
||
file_type # 使用检测到的文件类型
|
||
)
|
||
logger.info(f"✅ 文件记录已创建: ID={file_record.id}, 状态={file_record.status}")
|
||
|
||
# 添加后台任务处理向量化(传递 OSS URL 和文件类型)
|
||
logger.info(f"🚀 添加后台向量化任务: file_id={file_record.id}, type={file_type}")
|
||
background_tasks.add_task(
|
||
process_chat_file_background,
|
||
file_record.id,
|
||
file_path, # OSS URL
|
||
thread_id,
|
||
file_type # 使用检测到的文件类型
|
||
)
|
||
|
||
# 注意:文件上传后不会立即关联到消息
|
||
# 文件会在用户发送下一条消息时,自动关联到该消息
|
||
# 这样可以确保文件显示在用户消息旁边(如 DeepSeek 的展示方式)
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="文件上传成功,正在处理中",
|
||
data=ChatThreadFileUploadResponse(
|
||
id=file_record.id,
|
||
file_name=file_record.file_name,
|
||
file_size=file_record.file_size,
|
||
status=file_record.status,
|
||
chunk_count=file_record.chunk_count,
|
||
created_at=file_record.created_at,
|
||
file_url=file_url # 返回 OSS URL
|
||
).dict()
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except ValueError as e:
|
||
# 文件名重复等业务错误
|
||
logger.warning(f"文件上传验证失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=str(e)
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"上传文件失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"上传文件失败: {str(e)}"
|
||
)
|
||
|
||
|
||
@chat_file_router.get("/chat/thread/{thread_id}/files", response_model=BaseResponse, summary="获取聊天对话文件列表")
|
||
async def get_chat_thread_files(
|
||
thread_id: str,
|
||
page: int = Query(1, ge=1, description="页码,从 1 开始"),
|
||
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大 100"),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取聊天对话的文件列表
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
page: 页码
|
||
page_size: 每页数量
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
BaseResponse: 包含文件列表和总数
|
||
"""
|
||
try:
|
||
# 验证 thread_id 是否属于当前用户
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
thread_info = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id FROM chat_threads
|
||
WHERE thread_id = $1 AND is_deleted = FALSE
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
if not thread_info:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="会话不存在"
|
||
)
|
||
|
||
if thread_info['user_id'] != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="无权限访问该会话"
|
||
)
|
||
|
||
# 获取文件列表
|
||
files, total = await ChatThreadFileService.get_files_by_thread(
|
||
conn, thread_id, current_user.id, page, page_size
|
||
)
|
||
|
||
items = [
|
||
ChatThreadFileUploadResponse(
|
||
id=f.id,
|
||
file_name=f.file_name,
|
||
file_size=f.file_size,
|
||
status=f.status,
|
||
chunk_count=f.chunk_count,
|
||
created_at=f.created_at,
|
||
file_url=f.file_path # file_path 存储的是 OSS URL
|
||
).dict()
|
||
for f in files
|
||
]
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取文件列表成功",
|
||
data=ChatThreadFileListResponse(total=total, items=items).dict()
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取文件列表失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="获取文件列表失败"
|
||
)
|
||
|
||
|
||
@chat_file_router.get("/chat/thread/{thread_id}/files/{file_id}/status", response_model=BaseResponse, summary="查询文件处理状态")
|
||
async def get_file_processing_status(
|
||
thread_id: str,
|
||
file_id: int,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
查询文件的处理状态(用于前端轮询)
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
file_id: 文件 ID
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
BaseResponse: 文件处理状态信息
|
||
- status: processing(处理中)/ completed(已完成)/ failed(失败)
|
||
- chunk_count: 已处理的文档块数量
|
||
- file_name: 文件名
|
||
- file_type: 文件类型
|
||
- created_at: 创建时间
|
||
- updated_at: 更新时间
|
||
"""
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
# 验证 thread_id 是否属于当前用户
|
||
thread_info = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id FROM chat_threads
|
||
WHERE thread_id = $1 AND is_deleted = FALSE
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
if not thread_info:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="会话不存在"
|
||
)
|
||
|
||
if thread_info['user_id'] != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="无权限访问该会话"
|
||
)
|
||
|
||
# 获取文件信息
|
||
file = await ChatThreadFileService.get_file_by_id(
|
||
conn, file_id, current_user.id
|
||
)
|
||
|
||
if not file or file.thread_id != thread_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="文件不存在"
|
||
)
|
||
|
||
# 返回文件状态信息
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="获取文件状态成功",
|
||
data={
|
||
"id": file.id,
|
||
"file_name": file.file_name,
|
||
"file_type": file.file_type,
|
||
"status": file.status,
|
||
"chunk_count": file.chunk_count,
|
||
"created_at": file.created_at.isoformat() if file.created_at else None,
|
||
"updated_at": file.updated_at.isoformat() if file.updated_at else None,
|
||
}
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取文件状态失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="获取文件状态失败"
|
||
)
|
||
|
||
|
||
@chat_file_router.delete("/chat/thread/{thread_id}/files/{file_id}", response_model=BaseResponse, summary="删除聊天对话文件")
|
||
async def delete_chat_thread_file(
|
||
thread_id: str,
|
||
file_id: int,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
删除聊天对话中的文件
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
file_id: 文件 ID
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
BaseResponse: 删除结果
|
||
"""
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
# 验证 thread_id 是否属于当前用户
|
||
thread_info = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id FROM chat_threads
|
||
WHERE thread_id = $1 AND is_deleted = FALSE
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
if not thread_info:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="会话不存在"
|
||
)
|
||
|
||
if thread_info['user_id'] != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="无权限访问该会话"
|
||
)
|
||
|
||
# 获取文件信息
|
||
file = await ChatThreadFileService.get_file_by_id(
|
||
conn, file_id, current_user.id
|
||
)
|
||
|
||
if not file or file.thread_id != thread_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="文件不存在"
|
||
)
|
||
|
||
# 删除文件记录(软删除),同时获取向量 ID 列表
|
||
success, vector_ids = await ChatThreadFileService.delete_file(
|
||
conn, file_id, current_user.id
|
||
)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="文件不存在"
|
||
)
|
||
|
||
# 删除向量库中的向量
|
||
if vector_ids:
|
||
try:
|
||
vector_service = get_vector_service()
|
||
vector_service.delete_thread_vectors(thread_id, vector_ids)
|
||
logger.info(f"已删除 {len(vector_ids)} 个向量")
|
||
except Exception as e:
|
||
logger.warning(f"删除向量库中的向量失败: {e}")
|
||
|
||
# 删除物理文件(OSS)
|
||
try:
|
||
oss_service = get_oss_service()
|
||
if not oss_service.enabled:
|
||
logger.warning("OSS 服务未启用,无法删除物理文件")
|
||
elif file.file_path.startswith(('http://', 'https://')):
|
||
# 是 OSS URL,删除 OSS 上的文件
|
||
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, thread_id=thread_id)
|
||
if oss_object_name:
|
||
oss_service.delete_file(oss_object_name)
|
||
logger.info(f"已删除 OSS 文件: {oss_object_name}")
|
||
else:
|
||
logger.warning(f"无法从 OSS URL 提取对象名称: {file.file_path}")
|
||
else:
|
||
logger.warning(f"文件路径不是 OSS URL 格式: {file.file_path}")
|
||
except Exception as e:
|
||
logger.warning(f"删除物理文件失败: {e}")
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="删除文件成功",
|
||
data={"id": file_id, "vector_count": len(vector_ids)}
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"删除文件失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="删除文件失败"
|
||
)
|
||
|