""" 聊天对话文件服务 """ import os import json from typing import Optional, List, Tuple from pathlib import Path import asyncpg from datetime import datetime from models.chat_thread_file import ChatThreadFile, ChatThreadChunk from logger.logging import get_logger logger = get_logger(__name__) class ChatThreadFileService: """聊天对话文件服务类""" @staticmethod async def create_file_record( conn: asyncpg.Connection, thread_id: str, user_id: int, file_name: str, file_path: str, file_size: int, file_type: str = "pdf" ) -> ChatThreadFile: """ 创建文件记录 Args: conn: 数据库连接 thread_id: 会话线程 ID user_id: 用户 ID file_name: 文件名 file_path: 文件路径 file_size: 文件大小 file_type: 文件类型 Returns: ChatThreadFile: 创建的文件记录 """ try: # 检查文件名是否已存在(同一 thread_id 下) existing = await conn.fetchrow( """ SELECT id FROM chat_thread_file WHERE thread_id = $1 AND file_name = $2 AND is_deleted = FALSE """, thread_id, file_name ) if existing: raise ValueError(f"文件 '{file_name}' 已存在于该对话中") # 插入文件记录 row = await conn.fetchrow( """ INSERT INTO chat_thread_file (thread_id, user_id, file_name, file_path, file_size, file_type, status) VALUES ($1, $2, $3, $4, $5, $6, 'processing') RETURNING id, thread_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at """, thread_id, user_id, file_name, file_path, file_size, file_type ) logger.info(f"创建文件记录: {file_name}, thread_id: {thread_id}") return ChatThreadFile(**dict(row)) except ValueError: raise except Exception as e: logger.error(f"创建文件记录失败: {e}") raise Exception(f"创建文件记录失败: {str(e)}") @staticmethod async def update_file_status( conn: asyncpg.Connection, file_id: int, status: str, chunk_count: int = 0 ) -> bool: """ 更新文件状态 Args: conn: 数据库连接 file_id: 文件 ID status: 状态(processing/completed/failed) chunk_count: 分块数量 Returns: bool: 是否更新成功 """ try: result = await conn.execute( """ UPDATE chat_thread_file SET status = $1, chunk_count = $2 WHERE id = $3 """, status, chunk_count, file_id ) return result == "UPDATE 1" except Exception as e: logger.error(f"更新文件状态失败: {e}") return False @staticmethod async def save_chunks( conn: asyncpg.Connection, file_id: int, thread_id: str, chunks: List[Tuple[int, str, dict, str]], summary: Optional[str] = None ) -> int: """ 批量保存文档块 Args: conn: 数据库连接 file_id: 文件 ID thread_id: 会话线程 ID chunks: 文档块列表 [(chunk_index, content, metadata, vector_id), ...] summary: 文件摘要(可选) Returns: int: 保存的块数量 """ try: # 批量插入(每个chunk都保存summary,便于独立检索) records = [ (file_id, thread_id, chunk_index, content, json.dumps(metadata), vector_id, summary) for chunk_index, content, metadata, vector_id in chunks ] await conn.executemany( """ INSERT INTO chat_thread_chunk (file_id, thread_id, chunk_index, content, metadata, vector_id, summary) VALUES ($1, $2, $3, $4, $5, $6, $7) """, records ) logger.info(f"保存 {len(chunks)} 个文档块,文件 ID: {file_id}, 摘要: {'已保存' if summary else '无'}") return len(chunks) except Exception as e: logger.error(f"保存文档块失败: {e}") raise Exception(f"保存文档块失败: {str(e)}") @staticmethod async def get_file_by_id( conn: asyncpg.Connection, file_id: int, user_id: int ) -> Optional[ChatThreadFile]: """ 根据 ID 获取文件 Args: conn: 数据库连接 file_id: 文件 ID user_id: 用户 ID(用于权限验证) Returns: Optional[ChatThreadFile]: 文件对象,如果不存在则返回 None """ try: row = await conn.fetchrow( """ SELECT id, thread_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at FROM chat_thread_file WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE """, file_id, user_id ) if row: return ChatThreadFile(**dict(row)) return None except Exception as e: logger.error(f"获取文件失败: {e}") raise Exception(f"获取文件失败: {str(e)}") @staticmethod async def get_recent_files_with_summary( conn: asyncpg.Connection, thread_id: str, limit: int = 10 ) -> List[dict]: """ 获取会话中最近上传的文件及其摘要(无时间限制) Args: conn: 数据库连接 thread_id: 会话线程 ID limit: 限制返回数量 Returns: List[dict]: 文件列表,包含摘要信息 [{"file_name": "xxx", "file_type": "png", "summary": "xxx"}, ...] """ try: rows = await conn.fetch( """ SELECT f.file_name, f.file_type, c.summary FROM chat_thread_file f LEFT JOIN chat_thread_chunk c ON f.id = c.file_id AND c.chunk_index = 0 WHERE f.thread_id = $1 AND f.is_deleted = FALSE AND f.status = 'completed' AND c.summary IS NOT NULL AND c.summary != '' ORDER BY f.created_at DESC LIMIT $2 """, thread_id, limit ) result = [] for row in rows: result.append({ "file_name": row['file_name'], "file_type": row['file_type'], "summary": row['summary'] }) return result except Exception as e: logger.error(f"获取文件摘要失败: {e}") return [] @staticmethod async def get_files_by_thread( conn: asyncpg.Connection, thread_id: str, user_id: int, page: int = 1, page_size: int = 20 ) -> Tuple[List[ChatThreadFile], int]: """ 获取会话的文件列表 Args: conn: 数据库连接 thread_id: 会话线程 ID user_id: 用户 ID(用于权限验证) page: 页码(从 1 开始) page_size: 每页数量 Returns: Tuple[List[ChatThreadFile], int]: (文件列表, 总数量) """ try: # 计算偏移量 offset = (page - 1) * page_size # 获取总数 total = await conn.fetchval( """ SELECT COUNT(*) FROM chat_thread_file WHERE thread_id = $1 AND user_id = $2 AND is_deleted = FALSE """, thread_id, user_id ) # 获取列表 rows = await conn.fetch( """ SELECT id, thread_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at FROM chat_thread_file WHERE thread_id = $1 AND user_id = $2 AND is_deleted = FALSE ORDER BY created_at DESC LIMIT $3 OFFSET $4 """, thread_id, user_id, page_size, offset ) files = [ChatThreadFile(**dict(row)) for row in rows] return files, total except Exception as e: logger.error(f"获取文件列表失败: {e}") raise Exception(f"获取文件列表失败: {str(e)}") @staticmethod async def get_all_files_by_thread( conn: asyncpg.Connection, thread_id: str ) -> List[ChatThreadFile]: """ 获取会话的所有文件(用于删除会话时清理) Args: conn: 数据库连接 thread_id: 会话线程 ID Returns: List[ChatThreadFile]: 文件列表 """ try: rows = await conn.fetch( """ SELECT id, thread_id, user_id, file_name, file_path, file_size, file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at FROM chat_thread_file WHERE thread_id = $1 AND is_deleted = FALSE """, thread_id ) return [ChatThreadFile(**dict(row)) for row in rows] except Exception as e: logger.error(f"获取文件列表失败: {e}") raise Exception(f"获取文件列表失败: {str(e)}") @staticmethod async def get_thread_all_vector_ids( conn: asyncpg.Connection, thread_id: str ) -> List[str]: """ 获取会话的所有向量 ID(用于删除向量) Args: conn: 数据库连接 thread_id: 会话线程 ID Returns: List[str]: 向量 ID 列表 """ try: rows = await conn.fetch( """ SELECT vector_id FROM chat_thread_chunk WHERE thread_id = $1 AND vector_id IS NOT NULL """, thread_id ) return [row['vector_id'] for row in rows if row['vector_id']] except Exception as e: logger.error(f"获取向量 ID 列表失败: {e}") return [] @staticmethod async def get_file_vector_ids( conn: asyncpg.Connection, file_id: int ) -> List[str]: """ 获取文件的所有向量 ID Args: conn: 数据库连接 file_id: 文件 ID Returns: List[str]: 向量 ID 列表 """ try: rows = await conn.fetch( """ SELECT vector_id FROM chat_thread_chunk WHERE file_id = $1 AND vector_id IS NOT NULL """, file_id ) return [row['vector_id'] for row in rows if row['vector_id']] except Exception as e: logger.error(f"获取向量 ID 列表失败: {e}") return [] @staticmethod async def get_file_chunks_from_db( conn: asyncpg.Connection, file_id: int ) -> List[dict]: """ 从 PostgreSQL 获取文件的所有 chunks(包括 summary) 用于注入完整内容到 AI 上下文 Args: conn: 数据库连接 file_id: 文件 ID Returns: List[dict]: [{"content": str, "summary": str, "chunk_index": int}] """ try: rows = await conn.fetch( """ SELECT chunk_index, content, summary FROM chat_thread_chunk WHERE file_id = $1 ORDER BY chunk_index """, file_id ) chunks = [ { "chunk_index": row['chunk_index'], "content": row['content'], "summary": row['summary'] or '' } for row in rows ] logger.info(f"从数据库获取文件chunks: file_id={file_id}, chunks数量={len(chunks)}") return chunks except Exception as e: logger.error(f"从数据库获取文件chunks失败: {e}") return [] @staticmethod async def delete_file( conn: asyncpg.Connection, file_id: int, user_id: int ) -> Tuple[bool, List[str]]: """ 删除文件(软删除),同时返回向量 ID 列表 Args: conn: 数据库连接 file_id: 文件 ID user_id: 用户 ID(用于权限验证) Returns: Tuple[bool, List[str]]: (是否删除成功, 向量 ID 列表) """ try: # 先获取向量 ID 列表 vector_ids = await ChatThreadFileService.get_file_vector_ids(conn, file_id) # 检查文件是否存在且属于该用户 existing = await conn.fetchrow( """ SELECT id FROM chat_thread_file WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE """, file_id, user_id ) if not existing: return False, [] # 软删除文件 await conn.execute( """ UPDATE chat_thread_file SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = $1 """, file_id ) # 删除文档块(物理删除,因为文件已删除) await conn.execute( """ DELETE FROM chat_thread_chunk WHERE file_id = $1 """, file_id ) logger.info(f"删除文件成功: file_id={file_id}, 向量数={len(vector_ids)}") return True, vector_ids except Exception as e: logger.error(f"删除文件失败: {e}") raise Exception(f"删除文件失败: {str(e)}") @staticmethod async def delete_thread_all_chunks( conn: asyncpg.Connection, thread_id: str ) -> int: """ 删除会话的所有文档块(用于删除会话时清理) Args: conn: 数据库连接 thread_id: 会话线程 ID Returns: int: 删除的块数量 """ try: result = await conn.execute( """ DELETE FROM chat_thread_chunk WHERE thread_id = $1 """, thread_id ) # 解析删除的行数 deleted_count = int(result.split()[-1]) if result else 0 logger.info(f"删除会话 {thread_id} 的 {deleted_count} 个文档块") return deleted_count except Exception as e: logger.error(f"删除文档块失败: {e}") return 0