huoyan-enterprise/backend/services/chat_message_service.py

370 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天消息服务(基于 chat_messages 表)
用于保存和查询用户原始消息和AI响应替代从 checkpoint 中解析
"""
import json
from typing import List, Dict, Any, Optional
import asyncpg
from logger.logging import get_logger
logger = get_logger(__name__)
class ChatMessageService:
"""聊天消息服务类"""
@staticmethod
async def save_user_message(
conn: asyncpg.Connection,
thread_id: str,
checkpoint_id: str,
message_index: int,
content: str,
injected_content: Optional[str] = None,
has_files: bool = False,
metadata: Optional[Dict[str, Any]] = None
) -> int:
"""
保存用户消息到 chat_messages 表
Args:
conn: 数据库连接
thread_id: 会话线程 ID
checkpoint_id: checkpoint ID
message_index: 消息索引
content: 用户原始问题
injected_content: 注入给 AI 的完整内容(包含文件内容)
has_files: 是否关联了文件
metadata: 额外信息
Returns:
int: 消息 ID
"""
try:
row = await conn.fetchrow(
"""
INSERT INTO chat_messages
(thread_id, checkpoint_id, message_index, role, content, injected_content, has_files, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (checkpoint_id, message_index)
DO UPDATE SET
content = EXCLUDED.content,
injected_content = EXCLUDED.injected_content,
has_files = EXCLUDED.has_files,
metadata = EXCLUDED.metadata
RETURNING id
""",
thread_id,
checkpoint_id,
message_index,
'user',
content,
injected_content,
has_files,
json.dumps(metadata) if metadata else None
)
message_id = row['id']
logger.info(f"✅ 保存用户消息: message_id={message_id}, thread_id={thread_id}, index={message_index}")
return message_id
except Exception as e:
logger.error(f"保存用户消息失败: {e}")
raise Exception(f"保存用户消息失败: {str(e)}")
@staticmethod
async def save_assistant_message(
conn: asyncpg.Connection,
thread_id: str,
checkpoint_id: str,
message_index: int,
content: str,
metadata: Optional[Dict[str, Any]] = None
) -> int:
"""
保存 AI 响应消息到 chat_messages 表
Args:
conn: 数据库连接
thread_id: 会话线程 ID
checkpoint_id: checkpoint ID
message_index: 消息索引
content: AI 响应内容
metadata: 额外信息token使用量、模型名称、推理内容等
Returns:
int: 消息 ID
"""
try:
row = await conn.fetchrow(
"""
INSERT INTO chat_messages
(thread_id, checkpoint_id, message_index, role, content, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (checkpoint_id, message_index)
DO UPDATE SET
content = EXCLUDED.content,
metadata = EXCLUDED.metadata
RETURNING id
""",
thread_id,
checkpoint_id,
message_index,
'assistant',
content,
json.dumps(metadata) if metadata else None
)
message_id = row['id']
logger.info(f"✅ 保存AI消息: message_id={message_id}, thread_id={thread_id}, index={message_index}")
return message_id
except Exception as e:
logger.error(f"保存AI消息失败: {e}")
raise Exception(f"保存AI消息失败: {str(e)}")
@staticmethod
async def save_tool_message(
conn: asyncpg.Connection,
thread_id: str,
checkpoint_id: str,
message_index: int,
content: str,
name: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> int:
"""
保存工具消息到 chat_messages 表
Args:
conn: 数据库连接
thread_id: 会话线程 ID
checkpoint_id: checkpoint ID
message_index: 消息索引
content: 工具消息内容
name: 工具名称(如 text_to_poster, internet_search 等)
metadata: 额外信息(工具参数等)
Returns:
int: 消息 ID
"""
try:
row = await conn.fetchrow(
"""
INSERT INTO chat_messages
(thread_id, checkpoint_id, message_index, role, content, name, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (checkpoint_id, message_index)
DO UPDATE SET
content = EXCLUDED.content,
name = EXCLUDED.name,
metadata = EXCLUDED.metadata
RETURNING id
""",
thread_id,
checkpoint_id,
message_index,
'tool',
content,
name,
json.dumps(metadata) if metadata else None
)
message_id = row['id']
logger.info(f"✅ 保存工具消息: message_id={message_id}, thread_id={thread_id}, index={message_index}, tool_name={name}")
return message_id
except Exception as e:
logger.error(f"保存工具消息失败: {e}")
raise Exception(f"保存工具消息失败: {str(e)}")
@staticmethod
async def get_messages_by_thread(
conn: asyncpg.Connection,
thread_id: str,
limit: Optional[int] = None,
offset: int = 0
) -> List[Dict[str, Any]]:
"""
查询会话的所有消息
Args:
conn: 数据库连接
thread_id: 会话线程 ID
limit: 限制数量
offset: 偏移量(用于分页)
Returns:
List[Dict]: 消息列表
"""
try:
# 🔥 使用 DISTINCT ON 去重:每个 message_index 只保留最新的记录
# 这样可以处理历史数据中的重复消息问题
query = """
SELECT DISTINCT ON (message_index)
id, thread_id, checkpoint_id, message_index, role,
content, injected_content, has_files, name, metadata, created_at
FROM chat_messages
WHERE thread_id = $1
ORDER BY message_index ASC, created_at DESC
"""
params = [thread_id]
if limit:
query = f"""
SELECT * FROM (
SELECT DISTINCT ON (message_index)
id, thread_id, checkpoint_id, message_index, role,
content, injected_content, has_files, name, metadata, created_at
FROM chat_messages
WHERE thread_id = $1
ORDER BY message_index ASC, created_at DESC
) AS deduplicated
ORDER BY message_index ASC
LIMIT $2 OFFSET $3
"""
params.extend([limit, offset])
rows = await conn.fetch(query, *params)
messages = []
for row in rows:
msg = {
'id': row['id'],
'thread_id': row['thread_id'],
'checkpoint_id': row['checkpoint_id'],
'message_index': row['message_index'],
'role': row['role'],
'content': row['content'],
'injected_content': row['injected_content'],
'has_files': row['has_files'],
'name': row['name'], # 工具名称(对于 tool 类型的消息)
'metadata': json.loads(row['metadata']) if row['metadata'] else {},
'created_at': row['created_at'].isoformat() if row['created_at'] else None
}
messages.append(msg)
logger.info(f"查询会话消息: thread_id={thread_id}, 消息数量={len(messages)}")
return messages
except Exception as e:
logger.error(f"查询会话消息失败: {e}")
raise Exception(f"查询会话消息失败: {str(e)}")
@staticmethod
async def get_message_count(
conn: asyncpg.Connection,
thread_id: str
) -> int:
"""
获取会话的消息总数
Args:
conn: 数据库连接
thread_id: 会话线程 ID
Returns:
int: 消息总数
"""
try:
count = await conn.fetchval(
"SELECT COUNT(*) FROM chat_messages WHERE thread_id = $1",
thread_id
)
return count or 0
except Exception as e:
logger.error(f"获取消息总数失败: {e}")
return 0
@staticmethod
async def search_messages(
conn: asyncpg.Connection,
thread_id: str,
keyword: str,
limit: int = 50
) -> List[Dict[str, Any]]:
"""
搜索会话中的消息(全文搜索)
Args:
conn: 数据库连接
thread_id: 会话线程 ID
keyword: 搜索关键词
limit: 限制数量
Returns:
List[Dict]: 匹配的消息列表
"""
try:
rows = await conn.fetch(
"""
SELECT
id, thread_id, checkpoint_id, message_index, role,
content, has_files, metadata, created_at,
ts_rank(to_tsvector('simple', content), to_tsquery('simple', $2)) as rank
FROM chat_messages
WHERE thread_id = $1
AND to_tsvector('simple', content) @@ to_tsquery('simple', $2)
ORDER BY rank DESC, message_index DESC
LIMIT $3
""",
thread_id,
keyword,
limit
)
messages = []
for row in rows:
msg = {
'id': row['id'],
'thread_id': row['thread_id'],
'checkpoint_id': row['checkpoint_id'],
'message_index': row['message_index'],
'role': row['role'],
'content': row['content'],
'has_files': row['has_files'],
'metadata': json.loads(row['metadata']) if row['metadata'] else {},
'created_at': row['created_at'].isoformat() if row['created_at'] else None,
'rank': float(row['rank'])
}
messages.append(msg)
logger.info(f"搜索会话消息: thread_id={thread_id}, 关键词={keyword}, 匹配数量={len(messages)}")
return messages
except Exception as e:
logger.error(f"搜索会话消息失败: {e}")
raise Exception(f"搜索会话消息失败: {str(e)}")
@staticmethod
async def delete_messages_by_thread(
conn: asyncpg.Connection,
thread_id: str
) -> int:
"""
删除会话的所有消息
Args:
conn: 数据库连接
thread_id: 会话线程 ID
Returns:
int: 删除的消息数量
"""
try:
result = await conn.execute(
"DELETE FROM chat_messages WHERE thread_id = $1",
thread_id
)
deleted_count = int(result.split()[-1]) if result else 0
logger.info(f"删除会话消息: thread_id={thread_id}, 数量={deleted_count}")
return deleted_count
except Exception as e:
logger.error(f"删除会话消息失败: {e}")
raise Exception(f"删除会话消息失败: {str(e)}")