""" 聊天消息文件关联服务 """ from typing import Optional, List import asyncpg from datetime import datetime from logger.logging import get_logger logger = get_logger(__name__) class ChatMessageFileService: """聊天消息文件关联服务类""" @staticmethod async def create_message_file_association( conn: asyncpg.Connection, thread_id: str, checkpoint_id: str, message_index: int, file_id: int ) -> int: """ 创建消息和文件的关联关系 Args: conn: 数据库连接 thread_id: 会话线程 ID checkpoint_id: checkpoint ID message_index: 消息在 messages 列表中的索引 file_id: 文件 ID Returns: int: 关联记录 ID """ try: row = await conn.fetchrow( """ INSERT INTO chat_message_file (thread_id, checkpoint_id, message_index, file_id) VALUES ($1, $2, $3, $4) ON CONFLICT (checkpoint_id, message_index, file_id) DO NOTHING RETURNING id """, thread_id, checkpoint_id, message_index, file_id ) if row: logger.info(f"创建消息文件关联: thread_id={thread_id}, checkpoint_id={checkpoint_id}, message_index={message_index}, file_id={file_id}") return row['id'] return None except Exception as e: logger.error(f"创建消息文件关联失败: {e}") raise Exception(f"创建消息文件关联失败: {str(e)}") @staticmethod async def get_files_by_message( conn: asyncpg.Connection, checkpoint_id: str, message_index: int ) -> List[dict]: """ 获取消息关联的文件列表 Args: conn: 数据库连接 checkpoint_id: checkpoint ID message_index: 消息索引 Returns: List[dict]: 文件信息列表 """ try: rows = await conn.fetch( """ SELECT cmf.id, cmf.file_id, ctf.file_name, ctf.file_size, ctf.file_type, ctf.status, ctf.created_at FROM chat_message_file cmf INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id WHERE cmf.checkpoint_id = $1 AND cmf.message_index = $2 AND ctf.is_deleted = FALSE ORDER BY cmf.created_at ASC """, checkpoint_id, message_index ) return [dict(row) for row in rows] except Exception as e: logger.error(f"获取消息文件列表失败: {e}") return [] @staticmethod async def get_files_by_checkpoint( conn: asyncpg.Connection, checkpoint_id: str ) -> dict: """ 获取 checkpoint 中所有消息关联的文件 Args: conn: 数据库连接 checkpoint_id: checkpoint ID Returns: dict: {message_index: [file_info, ...], ...} """ try: rows = await conn.fetch( """ SELECT cmf.message_index, cmf.file_id, ctf.file_name, ctf.file_size, ctf.file_type, ctf.status, ctf.created_at FROM chat_message_file cmf INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id WHERE cmf.checkpoint_id = $1 AND ctf.is_deleted = FALSE ORDER BY cmf.message_index ASC, cmf.created_at ASC """, checkpoint_id ) # 按 message_index 分组 result = {} for row in rows: message_index = row['message_index'] if message_index not in result: result[message_index] = [] result[message_index].append({ 'file_id': row['file_id'], 'file_name': row['file_name'], 'file_size': row['file_size'], 'file_type': row['file_type'], 'status': row['status'], 'created_at': row['created_at'].isoformat() if row['created_at'] else None }) return result except Exception as e: logger.error(f"获取 checkpoint 文件列表失败: {e}") return {} @staticmethod async def get_all_files_by_thread( conn: asyncpg.Connection, thread_id: str, latest_checkpoint_id: str ) -> dict: """ 获取该 thread_id 下所有 checkpoint 的文件关联,并映射到最新 checkpoint 的消息索引 由于文件可能在不同的 checkpoint 中关联,但最新的 checkpoint 包含所有历史消息, 所以需要查询所有 checkpoint 的文件关联,然后根据 checkpoint_id 匹配 Args: conn: 数据库连接 thread_id: 会话线程 ID latest_checkpoint_id: 最新的 checkpoint ID Returns: dict: {message_index: [file_info, ...], ...} 其中 message_index 是相对于最新 checkpoint 的 """ try: # 查询该 thread_id 下的所有文件关联,包含 file_path (file_url) rows = await conn.fetch( """ SELECT cmf.checkpoint_id, cmf.message_index, cmf.file_id, ctf.file_name, ctf.file_size, ctf.file_type, ctf.file_path, ctf.status, ctf.created_at FROM chat_message_file cmf INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id WHERE cmf.thread_id = $1 AND ctf.is_deleted = FALSE ORDER BY cmf.checkpoint_id ASC, cmf.message_index ASC, cmf.created_at ASC """, thread_id ) # 按 checkpoint_id 和 message_index 分组 # 由于 LangGraph 的 checkpoint 是累积的,所有 checkpoint 的 message_index 应该都是相对于同一个消息列表的 # 所以我们可以直接使用 message_index result = {} for row in rows: checkpoint_id = row['checkpoint_id'] message_index = row['message_index'] file_id = row['file_id'] file_name = row['file_name'] logger.debug(f"文件关联: checkpoint_id={checkpoint_id}, message_index={message_index}, file_id={file_id}, file_name={file_name}") if message_index not in result: result[message_index] = [] result[message_index].append({ 'file_id': row['file_id'], 'file_name': row['file_name'], 'file_size': row['file_size'], 'file_type': row['file_type'], 'file_url': row['file_path'], # file_path 存储的是 OSS URL,作为 file_url 返回 'status': row['status'], 'created_at': row['created_at'].isoformat() if row['created_at'] else None }) logger.info(f"查询到文件关联映射: {result}") return result except Exception as e: logger.error(f"获取 thread 所有文件关联失败: {e}") return {} @staticmethod async def delete_message_file_association( conn: asyncpg.Connection, checkpoint_id: str, message_index: int, file_id: int ) -> bool: """ 删除消息和文件的关联关系 Args: conn: 数据库连接 checkpoint_id: checkpoint ID message_index: 消息索引 file_id: 文件 ID Returns: bool: 是否删除成功 """ try: result = await conn.execute( """ DELETE FROM chat_message_file WHERE checkpoint_id = $1 AND message_index = $2 AND file_id = $3 """, checkpoint_id, message_index, file_id ) return result == "DELETE 1" except Exception as e: logger.error(f"删除消息文件关联失败: {e}") return False @staticmethod async def delete_thread_associations( conn: asyncpg.Connection, thread_id: str ) -> int: """ 删除会话的所有消息文件关联(用于删除会话时清理) Args: conn: 数据库连接 thread_id: 会话线程 ID Returns: int: 删除的记录数 """ try: result = await conn.execute( """ DELETE FROM chat_message_file WHERE thread_id = $1 """, thread_id ) deleted_count = int(result.split()[-1]) if result else 0 logger.info(f"删除会话 {thread_id} 的 {deleted_count} 条消息文件关联") return deleted_count except Exception as e: logger.error(f"删除消息文件关联失败: {e}") return 0 @staticmethod async def get_unlinked_files( conn: asyncpg.Connection, thread_id: str ) -> List[dict]: """ 获取会话中未关联到消息的文件列表(通过关联查询) 这些文件上传了但还没有关联到任何消息,需要在历史消息中显示 Args: conn: 数据库连接 thread_id: 会话线程 ID Returns: List[dict]: 未关联的文件信息列表,按创建时间升序排列 """ try: rows = await conn.fetch( """ SELECT ctf.id as file_id, ctf.file_name, ctf.file_size, ctf.file_type, ctf.file_path, ctf.status, ctf.created_at FROM chat_thread_file ctf WHERE ctf.thread_id = $1 AND ctf.is_deleted = FALSE AND ctf.id NOT IN ( SELECT DISTINCT cmf.file_id FROM chat_message_file cmf WHERE cmf.thread_id = $1 ) ORDER BY ctf.created_at ASC """, thread_id ) return [ { 'file_id': row['file_id'], 'file_name': row['file_name'], 'file_size': row['file_size'], 'file_type': row['file_type'], 'file_url': row['file_path'], # file_path 存储的是 OSS URL,作为 file_url 返回 'status': row['status'], 'created_at': row['created_at'].isoformat() if row['created_at'] else None } for row in rows ] except Exception as e: logger.error(f"获取未关联文件列表失败: {e}") return []