370 lines
13 KiB
Python
370 lines
13 KiB
Python
"""
|
||
聊天消息服务(基于 chat_messages 表)
|
||
|
||
用于保存和查询用户原始消息和AI响应,替代从 checkpoint 中解析
|
||
"""
|
||
import json
|
||
from typing import List, Dict, Any, Optional
|
||
import asyncpg
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class ChatMessageService:
|
||
"""聊天消息服务类"""
|
||
|
||
@staticmethod
|
||
async def save_user_message(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
checkpoint_id: str,
|
||
message_index: int,
|
||
content: str,
|
||
injected_content: Optional[str] = None,
|
||
has_files: bool = False,
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
) -> int:
|
||
"""
|
||
保存用户消息到 chat_messages 表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
checkpoint_id: checkpoint ID
|
||
message_index: 消息索引
|
||
content: 用户原始问题
|
||
injected_content: 注入给 AI 的完整内容(包含文件内容)
|
||
has_files: 是否关联了文件
|
||
metadata: 额外信息
|
||
|
||
Returns:
|
||
int: 消息 ID
|
||
"""
|
||
try:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO chat_messages
|
||
(thread_id, checkpoint_id, message_index, role, content, injected_content, has_files, metadata)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||
ON CONFLICT (checkpoint_id, message_index)
|
||
DO UPDATE SET
|
||
content = EXCLUDED.content,
|
||
injected_content = EXCLUDED.injected_content,
|
||
has_files = EXCLUDED.has_files,
|
||
metadata = EXCLUDED.metadata
|
||
RETURNING id
|
||
""",
|
||
thread_id,
|
||
checkpoint_id,
|
||
message_index,
|
||
'user',
|
||
content,
|
||
injected_content,
|
||
has_files,
|
||
json.dumps(metadata) if metadata else None
|
||
)
|
||
|
||
message_id = row['id']
|
||
logger.info(f"✅ 保存用户消息: message_id={message_id}, thread_id={thread_id}, index={message_index}")
|
||
return message_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存用户消息失败: {e}")
|
||
raise Exception(f"保存用户消息失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def save_assistant_message(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
checkpoint_id: str,
|
||
message_index: int,
|
||
content: str,
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
) -> int:
|
||
"""
|
||
保存 AI 响应消息到 chat_messages 表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
checkpoint_id: checkpoint ID
|
||
message_index: 消息索引
|
||
content: AI 响应内容
|
||
metadata: 额外信息(token使用量、模型名称、推理内容等)
|
||
|
||
Returns:
|
||
int: 消息 ID
|
||
"""
|
||
try:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO chat_messages
|
||
(thread_id, checkpoint_id, message_index, role, content, metadata)
|
||
VALUES ($1, $2, $3, $4, $5, $6)
|
||
ON CONFLICT (checkpoint_id, message_index)
|
||
DO UPDATE SET
|
||
content = EXCLUDED.content,
|
||
metadata = EXCLUDED.metadata
|
||
RETURNING id
|
||
""",
|
||
thread_id,
|
||
checkpoint_id,
|
||
message_index,
|
||
'assistant',
|
||
content,
|
||
json.dumps(metadata) if metadata else None
|
||
)
|
||
|
||
message_id = row['id']
|
||
logger.info(f"✅ 保存AI消息: message_id={message_id}, thread_id={thread_id}, index={message_index}")
|
||
return message_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存AI消息失败: {e}")
|
||
raise Exception(f"保存AI消息失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def save_tool_message(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
checkpoint_id: str,
|
||
message_index: int,
|
||
content: str,
|
||
name: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
) -> int:
|
||
"""
|
||
保存工具消息到 chat_messages 表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
checkpoint_id: checkpoint ID
|
||
message_index: 消息索引
|
||
content: 工具消息内容
|
||
name: 工具名称(如 text_to_poster, internet_search 等)
|
||
metadata: 额外信息(工具参数等)
|
||
|
||
Returns:
|
||
int: 消息 ID
|
||
"""
|
||
try:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO chat_messages
|
||
(thread_id, checkpoint_id, message_index, role, content, name, metadata)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||
ON CONFLICT (checkpoint_id, message_index)
|
||
DO UPDATE SET
|
||
content = EXCLUDED.content,
|
||
name = EXCLUDED.name,
|
||
metadata = EXCLUDED.metadata
|
||
RETURNING id
|
||
""",
|
||
thread_id,
|
||
checkpoint_id,
|
||
message_index,
|
||
'tool',
|
||
content,
|
||
name,
|
||
json.dumps(metadata) if metadata else None
|
||
)
|
||
|
||
message_id = row['id']
|
||
logger.info(f"✅ 保存工具消息: message_id={message_id}, thread_id={thread_id}, index={message_index}, tool_name={name}")
|
||
return message_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存工具消息失败: {e}")
|
||
raise Exception(f"保存工具消息失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_messages_by_thread(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
limit: Optional[int] = None,
|
||
offset: int = 0
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
查询会话的所有消息
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
limit: 限制数量
|
||
offset: 偏移量(用于分页)
|
||
|
||
Returns:
|
||
List[Dict]: 消息列表
|
||
"""
|
||
try:
|
||
# 🔥 使用 DISTINCT ON 去重:每个 message_index 只保留最新的记录
|
||
# 这样可以处理历史数据中的重复消息问题
|
||
query = """
|
||
SELECT DISTINCT ON (message_index)
|
||
id, thread_id, checkpoint_id, message_index, role,
|
||
content, injected_content, has_files, name, metadata, created_at
|
||
FROM chat_messages
|
||
WHERE thread_id = $1
|
||
ORDER BY message_index ASC, created_at DESC
|
||
"""
|
||
|
||
params = [thread_id]
|
||
|
||
if limit:
|
||
query = f"""
|
||
SELECT * FROM (
|
||
SELECT DISTINCT ON (message_index)
|
||
id, thread_id, checkpoint_id, message_index, role,
|
||
content, injected_content, has_files, name, metadata, created_at
|
||
FROM chat_messages
|
||
WHERE thread_id = $1
|
||
ORDER BY message_index ASC, created_at DESC
|
||
) AS deduplicated
|
||
ORDER BY message_index ASC
|
||
LIMIT $2 OFFSET $3
|
||
"""
|
||
params.extend([limit, offset])
|
||
|
||
rows = await conn.fetch(query, *params)
|
||
|
||
messages = []
|
||
for row in rows:
|
||
msg = {
|
||
'id': row['id'],
|
||
'thread_id': row['thread_id'],
|
||
'checkpoint_id': row['checkpoint_id'],
|
||
'message_index': row['message_index'],
|
||
'role': row['role'],
|
||
'content': row['content'],
|
||
'injected_content': row['injected_content'],
|
||
'has_files': row['has_files'],
|
||
'name': row['name'], # 工具名称(对于 tool 类型的消息)
|
||
'metadata': json.loads(row['metadata']) if row['metadata'] else {},
|
||
'created_at': row['created_at'].isoformat() if row['created_at'] else None
|
||
}
|
||
messages.append(msg)
|
||
|
||
logger.info(f"查询会话消息: thread_id={thread_id}, 消息数量={len(messages)}")
|
||
return messages
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询会话消息失败: {e}")
|
||
raise Exception(f"查询会话消息失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_message_count(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str
|
||
) -> int:
|
||
"""
|
||
获取会话的消息总数
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
int: 消息总数
|
||
"""
|
||
try:
|
||
count = await conn.fetchval(
|
||
"SELECT COUNT(*) FROM chat_messages WHERE thread_id = $1",
|
||
thread_id
|
||
)
|
||
return count or 0
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取消息总数失败: {e}")
|
||
return 0
|
||
|
||
@staticmethod
|
||
async def search_messages(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
keyword: str,
|
||
limit: int = 50
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
搜索会话中的消息(全文搜索)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
keyword: 搜索关键词
|
||
limit: 限制数量
|
||
|
||
Returns:
|
||
List[Dict]: 匹配的消息列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT
|
||
id, thread_id, checkpoint_id, message_index, role,
|
||
content, has_files, metadata, created_at,
|
||
ts_rank(to_tsvector('simple', content), to_tsquery('simple', $2)) as rank
|
||
FROM chat_messages
|
||
WHERE thread_id = $1
|
||
AND to_tsvector('simple', content) @@ to_tsquery('simple', $2)
|
||
ORDER BY rank DESC, message_index DESC
|
||
LIMIT $3
|
||
""",
|
||
thread_id,
|
||
keyword,
|
||
limit
|
||
)
|
||
|
||
messages = []
|
||
for row in rows:
|
||
msg = {
|
||
'id': row['id'],
|
||
'thread_id': row['thread_id'],
|
||
'checkpoint_id': row['checkpoint_id'],
|
||
'message_index': row['message_index'],
|
||
'role': row['role'],
|
||
'content': row['content'],
|
||
'has_files': row['has_files'],
|
||
'metadata': json.loads(row['metadata']) if row['metadata'] else {},
|
||
'created_at': row['created_at'].isoformat() if row['created_at'] else None,
|
||
'rank': float(row['rank'])
|
||
}
|
||
messages.append(msg)
|
||
|
||
logger.info(f"搜索会话消息: thread_id={thread_id}, 关键词={keyword}, 匹配数量={len(messages)}")
|
||
return messages
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索会话消息失败: {e}")
|
||
raise Exception(f"搜索会话消息失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def delete_messages_by_thread(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str
|
||
) -> int:
|
||
"""
|
||
删除会话的所有消息
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
int: 删除的消息数量
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"DELETE FROM chat_messages WHERE thread_id = $1",
|
||
thread_id
|
||
)
|
||
|
||
deleted_count = int(result.split()[-1]) if result else 0
|
||
logger.info(f"删除会话消息: thread_id={thread_id}, 数量={deleted_count}")
|
||
return deleted_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除会话消息失败: {e}")
|
||
raise Exception(f"删除会话消息失败: {str(e)}")
|