""" 聊天消息服务(基于 chat_messages 表) 用于保存和查询用户原始消息和AI响应,替代从 checkpoint 中解析 """ import json from typing import List, Dict, Any, Optional import asyncpg from logger.logging import get_logger logger = get_logger(__name__) class ChatMessageService: """聊天消息服务类""" @staticmethod async def save_user_message( conn: asyncpg.Connection, thread_id: str, checkpoint_id: str, message_index: int, content: str, injected_content: Optional[str] = None, has_files: bool = False, metadata: Optional[Dict[str, Any]] = None ) -> int: """ 保存用户消息到 chat_messages 表 Args: conn: 数据库连接 thread_id: 会话线程 ID checkpoint_id: checkpoint ID message_index: 消息索引 content: 用户原始问题 injected_content: 注入给 AI 的完整内容(包含文件内容) has_files: 是否关联了文件 metadata: 额外信息 Returns: int: 消息 ID """ try: row = await conn.fetchrow( """ INSERT INTO chat_messages (thread_id, checkpoint_id, message_index, role, content, injected_content, has_files, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (checkpoint_id, message_index) DO UPDATE SET content = EXCLUDED.content, injected_content = EXCLUDED.injected_content, has_files = EXCLUDED.has_files, metadata = EXCLUDED.metadata RETURNING id """, thread_id, checkpoint_id, message_index, 'user', content, injected_content, has_files, json.dumps(metadata) if metadata else None ) message_id = row['id'] logger.info(f"✅ 保存用户消息: message_id={message_id}, thread_id={thread_id}, index={message_index}") return message_id except Exception as e: logger.error(f"保存用户消息失败: {e}") raise Exception(f"保存用户消息失败: {str(e)}") @staticmethod async def save_assistant_message( conn: asyncpg.Connection, thread_id: str, checkpoint_id: str, message_index: int, content: str, metadata: Optional[Dict[str, Any]] = None ) -> int: """ 保存 AI 响应消息到 chat_messages 表 Args: conn: 数据库连接 thread_id: 会话线程 ID checkpoint_id: checkpoint ID message_index: 消息索引 content: AI 响应内容 metadata: 额外信息(token使用量、模型名称、推理内容等) Returns: int: 消息 ID """ try: row = await conn.fetchrow( """ INSERT INTO chat_messages (thread_id, checkpoint_id, message_index, role, content, metadata) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (checkpoint_id, message_index) DO UPDATE SET content = EXCLUDED.content, metadata = EXCLUDED.metadata RETURNING id """, thread_id, checkpoint_id, message_index, 'assistant', content, json.dumps(metadata) if metadata else None ) message_id = row['id'] logger.info(f"✅ 保存AI消息: message_id={message_id}, thread_id={thread_id}, index={message_index}") return message_id except Exception as e: logger.error(f"保存AI消息失败: {e}") raise Exception(f"保存AI消息失败: {str(e)}") @staticmethod async def save_tool_message( conn: asyncpg.Connection, thread_id: str, checkpoint_id: str, message_index: int, content: str, name: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None ) -> int: """ 保存工具消息到 chat_messages 表 Args: conn: 数据库连接 thread_id: 会话线程 ID checkpoint_id: checkpoint ID message_index: 消息索引 content: 工具消息内容 name: 工具名称(如 text_to_poster, internet_search 等) metadata: 额外信息(工具参数等) Returns: int: 消息 ID """ try: row = await conn.fetchrow( """ INSERT INTO chat_messages (thread_id, checkpoint_id, message_index, role, content, name, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (checkpoint_id, message_index) DO UPDATE SET content = EXCLUDED.content, name = EXCLUDED.name, metadata = EXCLUDED.metadata RETURNING id """, thread_id, checkpoint_id, message_index, 'tool', content, name, json.dumps(metadata) if metadata else None ) message_id = row['id'] logger.info(f"✅ 保存工具消息: message_id={message_id}, thread_id={thread_id}, index={message_index}, tool_name={name}") return message_id except Exception as e: logger.error(f"保存工具消息失败: {e}") raise Exception(f"保存工具消息失败: {str(e)}") @staticmethod async def get_messages_by_thread( conn: asyncpg.Connection, thread_id: str, limit: Optional[int] = None, offset: int = 0 ) -> List[Dict[str, Any]]: """ 查询会话的所有消息 Args: conn: 数据库连接 thread_id: 会话线程 ID limit: 限制数量 offset: 偏移量(用于分页) Returns: List[Dict]: 消息列表 """ try: # 🔥 使用 DISTINCT ON 去重:每个 message_index 只保留最新的记录 # 这样可以处理历史数据中的重复消息问题 query = """ SELECT DISTINCT ON (message_index) id, thread_id, checkpoint_id, message_index, role, content, injected_content, has_files, name, metadata, created_at FROM chat_messages WHERE thread_id = $1 ORDER BY message_index ASC, created_at DESC """ params = [thread_id] if limit: query = f""" SELECT * FROM ( SELECT DISTINCT ON (message_index) id, thread_id, checkpoint_id, message_index, role, content, injected_content, has_files, name, metadata, created_at FROM chat_messages WHERE thread_id = $1 ORDER BY message_index ASC, created_at DESC ) AS deduplicated ORDER BY message_index ASC LIMIT $2 OFFSET $3 """ params.extend([limit, offset]) rows = await conn.fetch(query, *params) messages = [] for row in rows: msg = { 'id': row['id'], 'thread_id': row['thread_id'], 'checkpoint_id': row['checkpoint_id'], 'message_index': row['message_index'], 'role': row['role'], 'content': row['content'], 'injected_content': row['injected_content'], 'has_files': row['has_files'], 'name': row['name'], # 工具名称(对于 tool 类型的消息) 'metadata': json.loads(row['metadata']) if row['metadata'] else {}, 'created_at': row['created_at'].isoformat() if row['created_at'] else None } messages.append(msg) logger.info(f"查询会话消息: thread_id={thread_id}, 消息数量={len(messages)}") return messages except Exception as e: logger.error(f"查询会话消息失败: {e}") raise Exception(f"查询会话消息失败: {str(e)}") @staticmethod async def get_message_count( conn: asyncpg.Connection, thread_id: str ) -> int: """ 获取会话的消息总数 Args: conn: 数据库连接 thread_id: 会话线程 ID Returns: int: 消息总数 """ try: count = await conn.fetchval( "SELECT COUNT(*) FROM chat_messages WHERE thread_id = $1", thread_id ) return count or 0 except Exception as e: logger.error(f"获取消息总数失败: {e}") return 0 @staticmethod async def search_messages( conn: asyncpg.Connection, thread_id: str, keyword: str, limit: int = 50 ) -> List[Dict[str, Any]]: """ 搜索会话中的消息(全文搜索) Args: conn: 数据库连接 thread_id: 会话线程 ID keyword: 搜索关键词 limit: 限制数量 Returns: List[Dict]: 匹配的消息列表 """ try: rows = await conn.fetch( """ SELECT id, thread_id, checkpoint_id, message_index, role, content, has_files, metadata, created_at, ts_rank(to_tsvector('simple', content), to_tsquery('simple', $2)) as rank FROM chat_messages WHERE thread_id = $1 AND to_tsvector('simple', content) @@ to_tsquery('simple', $2) ORDER BY rank DESC, message_index DESC LIMIT $3 """, thread_id, keyword, limit ) messages = [] for row in rows: msg = { 'id': row['id'], 'thread_id': row['thread_id'], 'checkpoint_id': row['checkpoint_id'], 'message_index': row['message_index'], 'role': row['role'], 'content': row['content'], 'has_files': row['has_files'], 'metadata': json.loads(row['metadata']) if row['metadata'] else {}, 'created_at': row['created_at'].isoformat() if row['created_at'] else None, 'rank': float(row['rank']) } messages.append(msg) logger.info(f"搜索会话消息: thread_id={thread_id}, 关键词={keyword}, 匹配数量={len(messages)}") return messages except Exception as e: logger.error(f"搜索会话消息失败: {e}") raise Exception(f"搜索会话消息失败: {str(e)}") @staticmethod async def delete_messages_by_thread( conn: asyncpg.Connection, thread_id: str ) -> int: """ 删除会话的所有消息 Args: conn: 数据库连接 thread_id: 会话线程 ID Returns: int: 删除的消息数量 """ try: result = await conn.execute( "DELETE FROM chat_messages WHERE thread_id = $1", thread_id ) deleted_count = int(result.split()[-1]) if result else 0 logger.info(f"删除会话消息: thread_id={thread_id}, 数量={deleted_count}") return deleted_count except Exception as e: logger.error(f"删除会话消息失败: {e}") raise Exception(f"删除会话消息失败: {str(e)}")