huoyan-enterprise/backend/services/chat_message_file_service.py

355 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天消息文件关联服务
"""
from typing import Optional, List
import asyncpg
from datetime import datetime
from logger.logging import get_logger
logger = get_logger(__name__)
class ChatMessageFileService:
"""聊天消息文件关联服务类"""
@staticmethod
async def create_message_file_association(
conn: asyncpg.Connection,
thread_id: str,
checkpoint_id: str,
message_index: int,
file_id: int
) -> int:
"""
创建消息和文件的关联关系
Args:
conn: 数据库连接
thread_id: 会话线程 ID
checkpoint_id: checkpoint ID
message_index: 消息在 messages 列表中的索引
file_id: 文件 ID
Returns:
int: 关联记录 ID
"""
try:
row = await conn.fetchrow(
"""
INSERT INTO chat_message_file
(thread_id, checkpoint_id, message_index, file_id)
VALUES ($1, $2, $3, $4)
ON CONFLICT (checkpoint_id, message_index, file_id) DO NOTHING
RETURNING id
""",
thread_id, checkpoint_id, message_index, file_id
)
if row:
logger.info(f"创建消息文件关联: thread_id={thread_id}, checkpoint_id={checkpoint_id}, message_index={message_index}, file_id={file_id}")
return row['id']
return None
except Exception as e:
logger.error(f"创建消息文件关联失败: {e}")
raise Exception(f"创建消息文件关联失败: {str(e)}")
@staticmethod
async def get_files_by_message(
conn: asyncpg.Connection,
checkpoint_id: str,
message_index: int
) -> List[dict]:
"""
获取消息关联的文件列表
Args:
conn: 数据库连接
checkpoint_id: checkpoint ID
message_index: 消息索引
Returns:
List[dict]: 文件信息列表
"""
try:
rows = await conn.fetch(
"""
SELECT
cmf.id,
cmf.file_id,
ctf.file_name,
ctf.file_size,
ctf.file_type,
ctf.status,
ctf.created_at
FROM chat_message_file cmf
INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id
WHERE cmf.checkpoint_id = $1 AND cmf.message_index = $2
AND ctf.is_deleted = FALSE
ORDER BY cmf.created_at ASC
""",
checkpoint_id, message_index
)
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"获取消息文件列表失败: {e}")
return []
@staticmethod
async def get_files_by_checkpoint(
conn: asyncpg.Connection,
checkpoint_id: str
) -> dict:
"""
获取 checkpoint 中所有消息关联的文件
Args:
conn: 数据库连接
checkpoint_id: checkpoint ID
Returns:
dict: {message_index: [file_info, ...], ...}
"""
try:
rows = await conn.fetch(
"""
SELECT
cmf.message_index,
cmf.file_id,
ctf.file_name,
ctf.file_size,
ctf.file_type,
ctf.status,
ctf.created_at
FROM chat_message_file cmf
INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id
WHERE cmf.checkpoint_id = $1
AND ctf.is_deleted = FALSE
ORDER BY cmf.message_index ASC, cmf.created_at ASC
""",
checkpoint_id
)
# 按 message_index 分组
result = {}
for row in rows:
message_index = row['message_index']
if message_index not in result:
result[message_index] = []
result[message_index].append({
'file_id': row['file_id'],
'file_name': row['file_name'],
'file_size': row['file_size'],
'file_type': row['file_type'],
'status': row['status'],
'created_at': row['created_at'].isoformat() if row['created_at'] else None
})
return result
except Exception as e:
logger.error(f"获取 checkpoint 文件列表失败: {e}")
return {}
@staticmethod
async def get_all_files_by_thread(
conn: asyncpg.Connection,
thread_id: str,
latest_checkpoint_id: str
) -> dict:
"""
获取该 thread_id 下所有 checkpoint 的文件关联,并映射到最新 checkpoint 的消息索引
由于文件可能在不同的 checkpoint 中关联,但最新的 checkpoint 包含所有历史消息,
所以需要查询所有 checkpoint 的文件关联,然后根据 checkpoint_id 匹配
Args:
conn: 数据库连接
thread_id: 会话线程 ID
latest_checkpoint_id: 最新的 checkpoint ID
Returns:
dict: {message_index: [file_info, ...], ...} 其中 message_index 是相对于最新 checkpoint 的
"""
try:
# 查询该 thread_id 下的所有文件关联,包含 file_path (file_url)
rows = await conn.fetch(
"""
SELECT
cmf.checkpoint_id,
cmf.message_index,
cmf.file_id,
ctf.file_name,
ctf.file_size,
ctf.file_type,
ctf.file_path,
ctf.status,
ctf.created_at
FROM chat_message_file cmf
INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id
WHERE cmf.thread_id = $1
AND ctf.is_deleted = FALSE
ORDER BY cmf.checkpoint_id ASC, cmf.message_index ASC, cmf.created_at ASC
""",
thread_id
)
# 按 checkpoint_id 和 message_index 分组
# 由于 LangGraph 的 checkpoint 是累积的,所有 checkpoint 的 message_index 应该都是相对于同一个消息列表的
# 所以我们可以直接使用 message_index
result = {}
for row in rows:
checkpoint_id = row['checkpoint_id']
message_index = row['message_index']
file_id = row['file_id']
file_name = row['file_name']
logger.debug(f"文件关联: checkpoint_id={checkpoint_id}, message_index={message_index}, file_id={file_id}, file_name={file_name}")
if message_index not in result:
result[message_index] = []
result[message_index].append({
'file_id': row['file_id'],
'file_name': row['file_name'],
'file_size': row['file_size'],
'file_type': row['file_type'],
'file_url': row['file_path'], # file_path 存储的是 OSS URL作为 file_url 返回
'status': row['status'],
'created_at': row['created_at'].isoformat() if row['created_at'] else None
})
logger.info(f"查询到文件关联映射: {result}")
return result
except Exception as e:
logger.error(f"获取 thread 所有文件关联失败: {e}")
return {}
@staticmethod
async def delete_message_file_association(
conn: asyncpg.Connection,
checkpoint_id: str,
message_index: int,
file_id: int
) -> bool:
"""
删除消息和文件的关联关系
Args:
conn: 数据库连接
checkpoint_id: checkpoint ID
message_index: 消息索引
file_id: 文件 ID
Returns:
bool: 是否删除成功
"""
try:
result = await conn.execute(
"""
DELETE FROM chat_message_file
WHERE checkpoint_id = $1 AND message_index = $2 AND file_id = $3
""",
checkpoint_id, message_index, file_id
)
return result == "DELETE 1"
except Exception as e:
logger.error(f"删除消息文件关联失败: {e}")
return False
@staticmethod
async def delete_thread_associations(
conn: asyncpg.Connection,
thread_id: str
) -> int:
"""
删除会话的所有消息文件关联(用于删除会话时清理)
Args:
conn: 数据库连接
thread_id: 会话线程 ID
Returns:
int: 删除的记录数
"""
try:
result = await conn.execute(
"""
DELETE FROM chat_message_file
WHERE thread_id = $1
""",
thread_id
)
deleted_count = int(result.split()[-1]) if result else 0
logger.info(f"删除会话 {thread_id}{deleted_count} 条消息文件关联")
return deleted_count
except Exception as e:
logger.error(f"删除消息文件关联失败: {e}")
return 0
@staticmethod
async def get_unlinked_files(
conn: asyncpg.Connection,
thread_id: str
) -> List[dict]:
"""
获取会话中未关联到消息的文件列表(通过关联查询)
这些文件上传了但还没有关联到任何消息,需要在历史消息中显示
Args:
conn: 数据库连接
thread_id: 会话线程 ID
Returns:
List[dict]: 未关联的文件信息列表,按创建时间升序排列
"""
try:
rows = await conn.fetch(
"""
SELECT
ctf.id as file_id,
ctf.file_name,
ctf.file_size,
ctf.file_type,
ctf.file_path,
ctf.status,
ctf.created_at
FROM chat_thread_file ctf
WHERE ctf.thread_id = $1
AND ctf.is_deleted = FALSE
AND ctf.id NOT IN (
SELECT DISTINCT cmf.file_id
FROM chat_message_file cmf
WHERE cmf.thread_id = $1
)
ORDER BY ctf.created_at ASC
""",
thread_id
)
return [
{
'file_id': row['file_id'],
'file_name': row['file_name'],
'file_size': row['file_size'],
'file_type': row['file_type'],
'file_url': row['file_path'], # file_path 存储的是 OSS URL作为 file_url 返回
'status': row['status'],
'created_at': row['created_at'].isoformat() if row['created_at'] else None
}
for row in rows
]
except Exception as e:
logger.error(f"获取未关联文件列表失败: {e}")
return []