huoyan-enterprise/backend/api/chat_file.py

850 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天文件相关 API 路由模块
处理聊天对话中的文件上传、列表查询和删除功能。
"""
import os
import time
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, BackgroundTasks, Query
from utils.helpers import BaseResponse
from logger.logging import get_logger
from core.dependencies import get_current_user
from models.user import User
from core.database import get_db_pool
from services.chat_thread_file_service import ChatThreadFileService
from services.vector_service import get_vector_service
from services.oss_service import get_oss_service
from models.chat_thread_file import (
ChatThreadFileUploadResponse,
ChatThreadFileListResponse
)
# 获取日志记录器
logger = get_logger(__name__)
# 创建路由实例
chat_file_router = APIRouter(prefix="/api", tags=["聊天文件接口"])
async def process_chat_file_background(
file_id: int,
file_path: str,
thread_id: str,
file_type: str
):
"""
后台任务:处理聊天文件向量化
Args:
file_id: 文件 ID
file_path: 文件路径OSS URL
thread_id: 会话线程 ID
file_type: 文件类型pdf 或 url
"""
pool = await get_db_pool()
async with pool.acquire() as conn:
local_file_path = None
try:
logger.info(f"开始后台处理聊天文件 ID: {file_id}, thread_id: {thread_id}, 路径: {file_path}")
# file_path 是 OSS URL需要先下载到本地临时文件
oss_service = get_oss_service()
if not oss_service.enabled:
logger.error("OSS 服务未启用")
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
return
if not file_path.startswith(('http://', 'https://')):
logger.error(f"无效的文件路径格式(应为 OSS URL: {file_path}")
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
return
logger.info(f"检测到 OSS URL开始下载文件: {file_path}")
# 从 OSS URL 提取对象名称
oss_object_name = oss_service.extract_object_name_from_url(file_path, thread_id=thread_id)
if not oss_object_name:
logger.error(f"无法从 OSS URL 提取对象名称: {file_path}")
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
return
# 下载文件到临时目录
local_file_path = oss_service.download_file(oss_object_name)
if not local_file_path:
logger.error("从 OSS 下载文件失败")
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
return
logger.info(f"文件下载成功: {local_file_path}")
actual_file_path = local_file_path
# 获取向量服务
vector_service = get_vector_service()
# 处理文件:分割和向量化(传入 file_id 和 OSS URL
result = await vector_service.process_chat_thread_file(
actual_file_path,
thread_id,
file_type,
file_id=file_id,
source_url=file_path # 🔑 传递原始 OSS URL
)
# 检查处理结果
if not result.success:
logger.warning(f"聊天文件处理失败 ID: {file_id}, 原因: {result.error_message}")
await ChatThreadFileService.update_file_status(conn, file_id, "failed", 0)
return
# 生成文件摘要
summary_text = None
try:
# 判断是否为图片类型
image_types = {'png', 'jpg', 'jpeg', 'bmp'}
is_image = file_type.lower() in image_types
if is_image:
# 🎨 使用视觉模型处理图片
from services.vision_service import VisionService
logger.info(f"🎨 使用视觉模型为图片 {file_id} 生成摘要")
# 生成带签名的临时访问 URL用于私有 OSS
vision_image_url = file_path
if file_path.startswith(('http://', 'https://')):
# 是 OSS URL生成签名 URL 供视觉模型访问
try:
oss_object_name = oss_service.extract_object_name_from_url(file_path, thread_id=thread_id)
if oss_object_name:
# 生成有效期 1 小时的签名 URL
signed_url = oss_service.get_signed_url(oss_object_name, expires=3600)
if signed_url:
vision_image_url = signed_url
logger.info(f"🔐 已生成签名 URL 供视觉模型访问有效期1小时")
else:
logger.warning(f"生成签名 URL 失败,尝试使用原始 URL")
else:
logger.warning(f"无法从 OSS URL 提取对象名称,使用原始 URL")
except Exception as e:
logger.warning(f"生成签名 URL 时出错,使用原始 URL: {e}")
# 使用视觉模型获取图片描述(强调识别文字)
vision_prompt = "请详细描述这张图片的内容。特别注意1) 完整提取图片中的所有文字内容标题、正文、数据、数字等2) 描述图片的视觉场景人物、动作、背景等。用100-200字详细说明。"
logger.info(f"🤖 调用视觉模型prompt: {vision_prompt}")
vision_description = await VisionService.get_image_description(
image_url=vision_image_url, # 使用签名 URL
prompt=vision_prompt
)
if vision_description:
logger.info(f"✅ 视觉模型返回结果:")
logger.info(f"{'='*60}")
logger.info(f"图片URL: {file_path}")
logger.info(f"描述内容: {vision_description}")
logger.info(f"描述长度: {len(vision_description)} 字符")
logger.info(f"{'='*60}")
# 获取 OCR 文字内容(完整)
ocr_content = "\n\n".join([chunk[1] for chunk in result.chunks])
logger.info(f"📝 组合视觉描述和OCR文字:")
logger.info(f" - 视觉描述: {len(vision_description)} 字符")
logger.info(f" - OCR文字: {len(ocr_content)} 字符")
# 组合视觉描述和 OCR 文字
if ocr_content and len(ocr_content.strip()) > 10:
# 如果有足够的 OCR 文字,组合两者
# 限制摘要长度避免过长保留更多内容最多2000字符
max_ocr_length = 2000
ocr_summary = ocr_content if len(ocr_content) <= max_ocr_length else ocr_content[:max_ocr_length] + "...(文字内容较长,已截断)"
summary_text = f"【视觉内容】{vision_description}\n\n【图片文字内容】\n{ocr_summary}"
logger.info(f"✅ 使用视觉模型+OCR 生成图片摘要")
logger.info(f" - OCR原始: {len(ocr_content)} 字符")
logger.info(f" - OCR摘要: {len(ocr_summary)} 字符")
logger.info(f" - 最终摘要: {len(summary_text)} 字符")
else:
# OCR 文字较少或没有,仅使用视觉描述
summary_text = f"【视觉内容】{vision_description}"
logger.info(f"✅ 使用视觉模型生成图片摘要OCR文字不足仅使用视觉描述")
logger.info(f" - 最终摘要: {len(summary_text)} 字符")
else:
logger.warning(f"⚠️ 视觉模型返回为空降级使用OCR文字")
# 降级方案:使用 OCR 文字
ocr_content = "\n\n".join([chunk[1] for chunk in result.chunks])
if ocr_content:
# 限制长度
max_ocr_length = 2000
ocr_summary = ocr_content if len(ocr_content) <= max_ocr_length else ocr_content[:max_ocr_length] + "...(文字内容较长,已截断)"
summary_text = f"【图片文字内容】\n{ocr_summary}"
else:
# 📄 非图片文件使用文本摘要服务
from services.summary_service import SummaryService
from services.vision_service import VisionService
try:
from langchain_core.documents import Document
except ImportError:
from langchain_core.documents import Document
# 获取文件内容(从所有 chunks 中提取)
file_content = "\n\n".join([chunk[1] for chunk in result.chunks])
# 🖼️ 检查是否为 DOCX 且包含图片,如果是则使用视觉模型
docx_image_descriptions = []
if file_type.lower() == 'docx' and result.extracted_image_paths:
logger.info(f"📸 DOCX 包含 {len(result.extracted_image_paths)} 张图片,使用视觉模型分析")
# 为每张图片上传到 OSS 并使用视觉模型分析
for idx, img_path in enumerate(result.extracted_image_paths, 1):
try:
if not os.path.exists(img_path):
logger.warning(f"图片文件不存在: {img_path}")
continue
# 读取图片内容
with open(img_path, 'rb') as f:
img_content = f.read()
# 上传到 OSS
img_filename = f"docx_image_{idx}_{int(time.time())}.png"
img_oss_name = f"thread_{thread_id}/temp/{img_filename}"
img_url = oss_service.upload_file_from_bytes(img_content, img_oss_name, img_filename)
if img_url:
# 生成签名 URL
signed_url = oss_service.get_signed_url(img_oss_name, expires=3600)
vision_url = signed_url if signed_url else img_url
# 使用视觉模型分析(要求识别文字和场景)
vision_desc = await VisionService.get_image_description(
image_url=vision_url,
prompt="请详细描述这张图片的内容包括1) 图片中的所有文字内容如标题、正文、数据等2) 图片的视觉场景人物、动作、环境等。用100-200字详细描述。"
)
if vision_desc:
docx_image_descriptions.append(f"[图片{idx}] {vision_desc}")
logger.info(f"✅ DOCX 图片 {idx} 视觉分析完成")
# 删除临时 OSS 文件
try:
oss_service.delete_file(img_oss_name)
except:
pass
except Exception as e:
logger.warning(f"处理 DOCX 图片 {idx} 失败: {e}")
finally:
# 清理本地临时图片文件
try:
if os.path.exists(img_path):
os.remove(img_path)
except:
pass
# 限制内容长度,避免 token 超限
max_content_length = 10000 # 约 3000-4000 tokens
if len(file_content) > max_content_length:
file_content = file_content[:max_content_length] + "..."
logger.info(f"正在为文件 {file_id} 生成摘要,内容长度: {len(file_content)} 字符")
# 将文本内容转换为 Document 对象
docs = [Document(page_content=file_content)]
# 生成摘要
summary_text = await SummaryService.generate_file_summary(docs, max_docs=1)
# 如果有视觉模型分析的图片描述,追加到摘要中
if docx_image_descriptions:
image_summary = "\n\n文档图片内容:\n" + "\n".join(docx_image_descriptions)
summary_text = summary_text + image_summary if summary_text else image_summary
logger.info(f"✅ 已将 {len(docx_image_descriptions)} 张图片的视觉描述加入摘要")
if summary_text:
logger.info(f"📝 文件 {file_id} 摘要生成成功:")
logger.info(f"{'='*60}")
logger.info(f"摘要内容: {summary_text}")
logger.info(f"{'='*60}")
else:
logger.warning(f"文件 {file_id} 摘要生成失败,返回为空")
except Exception as e:
logger.error(f"生成文件摘要失败: {e}")
# 摘要生成失败不影响主流程,继续处理
# 将 summary 和 file_id 添加到每个 chunk 的 metadata 中(参考 server 实现)
enhanced_chunks = []
for chunk_index, content, metadata, vector_id in result.chunks:
# 复制 metadata 并添加关键信息
enhanced_metadata = metadata.copy()
enhanced_metadata['file_id'] = file_id # 🔑 关键:用于检索时过滤
enhanced_metadata['chunk_index'] = chunk_index # 🔑 关键:用于排序
if summary_text:
enhanced_metadata['file_summary'] = summary_text
enhanced_chunks.append((chunk_index, content, enhanced_metadata, vector_id))
# 保存文档块到数据库(包含 summary
await ChatThreadFileService.save_chunks(
conn, file_id, thread_id, enhanced_chunks, summary=summary_text
)
# 🔑 关键:更新 ChromaDB 中的 summary metadata
if summary_text:
success = vector_service.update_file_summary_in_vectors(
thread_id=thread_id,
file_id=file_id,
summary=summary_text
)
if success:
logger.info(f"✅ ChromaDB metadata 已同步 summary")
else:
logger.warning(f"⚠️ ChromaDB metadata 同步 summary 失败,但不影响主流程")
# 更新文件状态为完成
await ChatThreadFileService.update_file_status(
conn, file_id, "completed", result.chunk_count
)
logger.info(f"聊天文件处理完成 ID: {file_id}, 块数: {result.chunk_count}, 摘要: {'已生成' if summary_text else '未生成'}")
except Exception as e:
logger.error(f"后台处理聊天文件异常 ID: {file_id}, 错误: {e}")
# 更新状态为失败
await ChatThreadFileService.update_file_status(
conn, file_id, "failed", 0
)
finally:
# 清理临时下载的文件
if local_file_path and os.path.exists(local_file_path):
try:
os.remove(local_file_path)
logger.debug(f"已删除临时文件: {local_file_path}")
except Exception as e:
logger.warning(f"删除临时文件失败: {e}")
@chat_file_router.post("/chat/thread/{thread_id}/upload", response_model=BaseResponse, summary="上传文件到聊天对话")
async def upload_chat_file(
thread_id: str,
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
current_user: User = Depends(get_current_user)
):
"""
上传文件到聊天对话并进行向量化处理
Args:
thread_id: 会话线程 ID
background_tasks: 后台任务
file: 上传的文件
current_user: 当前登录用户
Returns:
BaseResponse: 包含文件信息
"""
try:
# 验证 thread_id 是否属于当前用户,如果不存在则自动创建
pool = await get_db_pool()
async with pool.acquire() as conn:
thread_info = await conn.fetchrow(
"""
SELECT id, user_id FROM chat_threads
WHERE thread_id = $1 AND is_deleted = FALSE
""",
thread_id
)
# 如果会话不存在,自动创建会话记录
if not thread_info:
logger.info(f"会话不存在,自动创建会话记录: thread_id={thread_id}, user_id={current_user.id}")
try:
await conn.execute(
"""
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count)
VALUES ($1, $2, $3, $4, 0)
""",
thread_id,
current_user.id,
"新对话",
""
)
logger.info(f"成功创建新会话: thread_id={thread_id}")
# 重新查询会话信息
thread_info = await conn.fetchrow(
"""
SELECT id, user_id FROM chat_threads
WHERE thread_id = $1 AND is_deleted = FALSE
""",
thread_id
)
except Exception as e:
logger.error(f"创建会话记录失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"创建会话失败: {str(e)}"
)
# 验证会话是否属于当前用户
if thread_info['user_id'] != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权限访问该会话"
)
logger.info(f"📤 开始上传文件到聊天 {thread_id}: {file.filename}, 用户: {current_user.username}")
# 检查文件类型
file_ext = Path(file.filename).suffix.lower()
supported_extensions = {'.pdf', '.docx', '.xlsx', '.xls', '.txt', '.png', '.jpg', '.jpeg', '.bmp'}
if file_ext not in supported_extensions:
logger.warning(f"❌ 不支持的文件类型: {file_ext}, 文件: {file.filename}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"不支持的文件类型: {file_ext},支持的类型: {', '.join(supported_extensions)}"
)
# 确定文件类型
file_type_map = {
'.pdf': 'pdf',
'.docx': 'docx',
'.xlsx': 'xlsx',
'.xls': 'xls',
'.txt': 'txt',
'.png': 'png',
'.jpg': 'jpg',
'.jpeg': 'jpeg',
'.bmp': 'bmp'
}
file_type = file_type_map[file_ext]
logger.info(f"📋 文件类型识别: {file_ext} -> {file_type}")
# 读取文件内容
content = await file.read()
file_size = len(content)
file_size_mb = file_size / (1024 * 1024)
# 检查文件大小(限制 15MB
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15MB
if file_size > MAX_FILE_SIZE:
logger.warning(f"❌ 文件大小超限: {file_size_mb:.2f}MB (最大 15MB), 文件: {file.filename}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件大小超过限制,当前: {file_size_mb:.2f}MB最大允许: 15MB"
)
logger.info(f"✅ 文件大小验证通过: {file_size_mb:.2f}MB ({file_size} bytes)")
# 生成唯一文件名(使用时间戳)
timestamp = int(time.time() * 1000)
unique_filename = f"{timestamp}_{file.filename}"
# OSS 对象名称(存储路径)
oss_object_name = f"thread_{thread_id}/{unique_filename}"
# 获取 OSS 服务
oss_service = get_oss_service()
# 检查 OSS 是否启用
if not oss_service.enabled:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="OSS 服务未启用,无法上传文件"
)
# 上传到 OSS
logger.info(f"☁️ 上传文件到 OSS: {oss_object_name}")
file_url = oss_service.upload_file_from_bytes(
content,
oss_object_name,
file.filename
)
if not file_url:
logger.error(f"❌ OSS 上传失败: thread_id={thread_id}, filename={file.filename}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="文件上传到 OSS 失败"
)
# OSS 上传成功,使用 OSS URL 作为文件路径
file_path = file_url
logger.info(f"✅ 文件已上传到 OSS: {file_url}")
# 🔑 图片审核:在创建文件记录前进行审核
if file_type in ['png', 'jpg', 'jpeg', 'bmp']:
from core.dependencies import get_moderation_service
from core.config import settings
from core.exceptions import ModerationError
from models.moderation import ModerationDecision
moderation_service = await get_moderation_service()
if moderation_service and settings.moderation_enabled:
try:
logger.info(f"🔍 开始图片审核: {file.filename}")
# 使用 OSS URL 进行审核
result = await moderation_service.moderate_image(
image_source=file_url,
source_type="url",
request_id=f"chat_file_{timestamp}"
)
# 检查审核结果
if result.decision == ModerationDecision.BLOCK:
# 删除已上传的 OSS 文件
oss_service.delete_file(oss_object_name)
logger.warning(
f"❌ 图片审核不通过: {file.filename}, "
f"原因: {result.message}, "
f"标签: {[label.label for label in result.labels]}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=result.message or "图片包含不当内容,无法上传"
)
logger.info(
f"✅ 图片审核通过: {file.filename}, "
f"决策: {result.decision.value}"
)
except ModerationError as e:
# 审核服务错误,删除 OSS 文件并返回错误
oss_service.delete_file(oss_object_name)
logger.error(f"❌ 图片审核服务错误: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="图片审核服务暂时不可用,请稍后重试"
)
# 创建文件记录file_path 存储 OSS URL
logger.info(f"📝 创建文件记录: {file.filename}")
async with pool.acquire() as conn:
file_record = await ChatThreadFileService.create_file_record(
conn,
thread_id,
current_user.id,
file.filename,
file_path, # 存储 OSS URL
file_size,
file_type # 使用检测到的文件类型
)
logger.info(f"✅ 文件记录已创建: ID={file_record.id}, 状态={file_record.status}")
# 添加后台任务处理向量化(传递 OSS URL 和文件类型)
logger.info(f"🚀 添加后台向量化任务: file_id={file_record.id}, type={file_type}")
background_tasks.add_task(
process_chat_file_background,
file_record.id,
file_path, # OSS URL
thread_id,
file_type # 使用检测到的文件类型
)
# 注意:文件上传后不会立即关联到消息
# 文件会在用户发送下一条消息时,自动关联到该消息
# 这样可以确保文件显示在用户消息旁边(如 DeepSeek 的展示方式)
return BaseResponse(
code=200,
msg="文件上传成功,正在处理中",
data=ChatThreadFileUploadResponse(
id=file_record.id,
file_name=file_record.file_name,
file_size=file_record.file_size,
status=file_record.status,
chunk_count=file_record.chunk_count,
created_at=file_record.created_at,
file_url=file_url # 返回 OSS URL
).dict()
)
except HTTPException:
raise
except ValueError as e:
# 文件名重复等业务错误
logger.warning(f"文件上传验证失败: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"上传文件失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"上传文件失败: {str(e)}"
)
@chat_file_router.get("/chat/thread/{thread_id}/files", response_model=BaseResponse, summary="获取聊天对话文件列表")
async def get_chat_thread_files(
thread_id: str,
page: int = Query(1, ge=1, description="页码,从 1 开始"),
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大 100"),
current_user: User = Depends(get_current_user)
):
"""
获取聊天对话的文件列表
Args:
thread_id: 会话线程 ID
page: 页码
page_size: 每页数量
current_user: 当前登录用户
Returns:
BaseResponse: 包含文件列表和总数
"""
try:
# 验证 thread_id 是否属于当前用户
pool = await get_db_pool()
async with pool.acquire() as conn:
thread_info = await conn.fetchrow(
"""
SELECT id, user_id FROM chat_threads
WHERE thread_id = $1 AND is_deleted = FALSE
""",
thread_id
)
if not thread_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话不存在"
)
if thread_info['user_id'] != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权限访问该会话"
)
# 获取文件列表
files, total = await ChatThreadFileService.get_files_by_thread(
conn, thread_id, current_user.id, page, page_size
)
items = [
ChatThreadFileUploadResponse(
id=f.id,
file_name=f.file_name,
file_size=f.file_size,
status=f.status,
chunk_count=f.chunk_count,
created_at=f.created_at,
file_url=f.file_path # file_path 存储的是 OSS URL
).dict()
for f in files
]
return BaseResponse(
code=200,
msg="获取文件列表成功",
data=ChatThreadFileListResponse(total=total, items=items).dict()
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取文件列表失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取文件列表失败"
)
@chat_file_router.get("/chat/thread/{thread_id}/files/{file_id}/status", response_model=BaseResponse, summary="查询文件处理状态")
async def get_file_processing_status(
thread_id: str,
file_id: int,
current_user: User = Depends(get_current_user)
):
"""
查询文件的处理状态(用于前端轮询)
Args:
thread_id: 会话线程 ID
file_id: 文件 ID
current_user: 当前登录用户
Returns:
BaseResponse: 文件处理状态信息
- status: processing处理中/ completed已完成/ failed失败
- chunk_count: 已处理的文档块数量
- file_name: 文件名
- file_type: 文件类型
- created_at: 创建时间
- updated_at: 更新时间
"""
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
# 验证 thread_id 是否属于当前用户
thread_info = await conn.fetchrow(
"""
SELECT id, user_id FROM chat_threads
WHERE thread_id = $1 AND is_deleted = FALSE
""",
thread_id
)
if not thread_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话不存在"
)
if thread_info['user_id'] != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权限访问该会话"
)
# 获取文件信息
file = await ChatThreadFileService.get_file_by_id(
conn, file_id, current_user.id
)
if not file or file.thread_id != thread_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件不存在"
)
# 返回文件状态信息
return BaseResponse(
code=200,
msg="获取文件状态成功",
data={
"id": file.id,
"file_name": file.file_name,
"file_type": file.file_type,
"status": file.status,
"chunk_count": file.chunk_count,
"created_at": file.created_at.isoformat() if file.created_at else None,
"updated_at": file.updated_at.isoformat() if file.updated_at else None,
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取文件状态失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取文件状态失败"
)
@chat_file_router.delete("/chat/thread/{thread_id}/files/{file_id}", response_model=BaseResponse, summary="删除聊天对话文件")
async def delete_chat_thread_file(
thread_id: str,
file_id: int,
current_user: User = Depends(get_current_user)
):
"""
删除聊天对话中的文件
Args:
thread_id: 会话线程 ID
file_id: 文件 ID
current_user: 当前登录用户
Returns:
BaseResponse: 删除结果
"""
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
# 验证 thread_id 是否属于当前用户
thread_info = await conn.fetchrow(
"""
SELECT id, user_id FROM chat_threads
WHERE thread_id = $1 AND is_deleted = FALSE
""",
thread_id
)
if not thread_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话不存在"
)
if thread_info['user_id'] != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权限访问该会话"
)
# 获取文件信息
file = await ChatThreadFileService.get_file_by_id(
conn, file_id, current_user.id
)
if not file or file.thread_id != thread_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件不存在"
)
# 删除文件记录(软删除),同时获取向量 ID 列表
success, vector_ids = await ChatThreadFileService.delete_file(
conn, file_id, current_user.id
)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件不存在"
)
# 删除向量库中的向量
if vector_ids:
try:
vector_service = get_vector_service()
vector_service.delete_thread_vectors(thread_id, vector_ids)
logger.info(f"已删除 {len(vector_ids)} 个向量")
except Exception as e:
logger.warning(f"删除向量库中的向量失败: {e}")
# 删除物理文件OSS
try:
oss_service = get_oss_service()
if not oss_service.enabled:
logger.warning("OSS 服务未启用,无法删除物理文件")
elif file.file_path.startswith(('http://', 'https://')):
# 是 OSS URL删除 OSS 上的文件
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, thread_id=thread_id)
if oss_object_name:
oss_service.delete_file(oss_object_name)
logger.info(f"已删除 OSS 文件: {oss_object_name}")
else:
logger.warning(f"无法从 OSS URL 提取对象名称: {file.file_path}")
else:
logger.warning(f"文件路径不是 OSS URL 格式: {file.file_path}")
except Exception as e:
logger.warning(f"删除物理文件失败: {e}")
return BaseResponse(
code=200,
msg="删除文件成功",
data={"id": file_id, "vector_count": len(vector_ids)}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"删除文件失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="删除文件失败"
)