355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""
|
||
聊天消息文件关联服务
|
||
"""
|
||
from typing import Optional, List
|
||
import asyncpg
|
||
from datetime import datetime
|
||
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class ChatMessageFileService:
|
||
"""聊天消息文件关联服务类"""
|
||
|
||
@staticmethod
|
||
async def create_message_file_association(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
checkpoint_id: str,
|
||
message_index: int,
|
||
file_id: int
|
||
) -> int:
|
||
"""
|
||
创建消息和文件的关联关系
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
checkpoint_id: checkpoint ID
|
||
message_index: 消息在 messages 列表中的索引
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
int: 关联记录 ID
|
||
"""
|
||
try:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO chat_message_file
|
||
(thread_id, checkpoint_id, message_index, file_id)
|
||
VALUES ($1, $2, $3, $4)
|
||
ON CONFLICT (checkpoint_id, message_index, file_id) DO NOTHING
|
||
RETURNING id
|
||
""",
|
||
thread_id, checkpoint_id, message_index, file_id
|
||
)
|
||
|
||
if row:
|
||
logger.info(f"创建消息文件关联: thread_id={thread_id}, checkpoint_id={checkpoint_id}, message_index={message_index}, file_id={file_id}")
|
||
return row['id']
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建消息文件关联失败: {e}")
|
||
raise Exception(f"创建消息文件关联失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_files_by_message(
|
||
conn: asyncpg.Connection,
|
||
checkpoint_id: str,
|
||
message_index: int
|
||
) -> List[dict]:
|
||
"""
|
||
获取消息关联的文件列表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
checkpoint_id: checkpoint ID
|
||
message_index: 消息索引
|
||
|
||
Returns:
|
||
List[dict]: 文件信息列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT
|
||
cmf.id,
|
||
cmf.file_id,
|
||
ctf.file_name,
|
||
ctf.file_size,
|
||
ctf.file_type,
|
||
ctf.status,
|
||
ctf.created_at
|
||
FROM chat_message_file cmf
|
||
INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id
|
||
WHERE cmf.checkpoint_id = $1 AND cmf.message_index = $2
|
||
AND ctf.is_deleted = FALSE
|
||
ORDER BY cmf.created_at ASC
|
||
""",
|
||
checkpoint_id, message_index
|
||
)
|
||
|
||
return [dict(row) for row in rows]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取消息文件列表失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def get_files_by_checkpoint(
|
||
conn: asyncpg.Connection,
|
||
checkpoint_id: str
|
||
) -> dict:
|
||
"""
|
||
获取 checkpoint 中所有消息关联的文件
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
checkpoint_id: checkpoint ID
|
||
|
||
Returns:
|
||
dict: {message_index: [file_info, ...], ...}
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT
|
||
cmf.message_index,
|
||
cmf.file_id,
|
||
ctf.file_name,
|
||
ctf.file_size,
|
||
ctf.file_type,
|
||
ctf.status,
|
||
ctf.created_at
|
||
FROM chat_message_file cmf
|
||
INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id
|
||
WHERE cmf.checkpoint_id = $1
|
||
AND ctf.is_deleted = FALSE
|
||
ORDER BY cmf.message_index ASC, cmf.created_at ASC
|
||
""",
|
||
checkpoint_id
|
||
)
|
||
|
||
# 按 message_index 分组
|
||
result = {}
|
||
for row in rows:
|
||
message_index = row['message_index']
|
||
if message_index not in result:
|
||
result[message_index] = []
|
||
result[message_index].append({
|
||
'file_id': row['file_id'],
|
||
'file_name': row['file_name'],
|
||
'file_size': row['file_size'],
|
||
'file_type': row['file_type'],
|
||
'status': row['status'],
|
||
'created_at': row['created_at'].isoformat() if row['created_at'] else None
|
||
})
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取 checkpoint 文件列表失败: {e}")
|
||
return {}
|
||
|
||
@staticmethod
|
||
async def get_all_files_by_thread(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
latest_checkpoint_id: str
|
||
) -> dict:
|
||
"""
|
||
获取该 thread_id 下所有 checkpoint 的文件关联,并映射到最新 checkpoint 的消息索引
|
||
|
||
由于文件可能在不同的 checkpoint 中关联,但最新的 checkpoint 包含所有历史消息,
|
||
所以需要查询所有 checkpoint 的文件关联,然后根据 checkpoint_id 匹配
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
latest_checkpoint_id: 最新的 checkpoint ID
|
||
|
||
Returns:
|
||
dict: {message_index: [file_info, ...], ...} 其中 message_index 是相对于最新 checkpoint 的
|
||
"""
|
||
try:
|
||
# 查询该 thread_id 下的所有文件关联,包含 file_path (file_url)
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT
|
||
cmf.checkpoint_id,
|
||
cmf.message_index,
|
||
cmf.file_id,
|
||
ctf.file_name,
|
||
ctf.file_size,
|
||
ctf.file_type,
|
||
ctf.file_path,
|
||
ctf.status,
|
||
ctf.created_at
|
||
FROM chat_message_file cmf
|
||
INNER JOIN chat_thread_file ctf ON cmf.file_id = ctf.id
|
||
WHERE cmf.thread_id = $1
|
||
AND ctf.is_deleted = FALSE
|
||
ORDER BY cmf.checkpoint_id ASC, cmf.message_index ASC, cmf.created_at ASC
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
# 按 checkpoint_id 和 message_index 分组
|
||
# 由于 LangGraph 的 checkpoint 是累积的,所有 checkpoint 的 message_index 应该都是相对于同一个消息列表的
|
||
# 所以我们可以直接使用 message_index
|
||
result = {}
|
||
for row in rows:
|
||
checkpoint_id = row['checkpoint_id']
|
||
message_index = row['message_index']
|
||
file_id = row['file_id']
|
||
file_name = row['file_name']
|
||
|
||
logger.debug(f"文件关联: checkpoint_id={checkpoint_id}, message_index={message_index}, file_id={file_id}, file_name={file_name}")
|
||
|
||
if message_index not in result:
|
||
result[message_index] = []
|
||
result[message_index].append({
|
||
'file_id': row['file_id'],
|
||
'file_name': row['file_name'],
|
||
'file_size': row['file_size'],
|
||
'file_type': row['file_type'],
|
||
'file_url': row['file_path'], # file_path 存储的是 OSS URL,作为 file_url 返回
|
||
'status': row['status'],
|
||
'created_at': row['created_at'].isoformat() if row['created_at'] else None
|
||
})
|
||
|
||
logger.info(f"查询到文件关联映射: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取 thread 所有文件关联失败: {e}")
|
||
return {}
|
||
|
||
@staticmethod
|
||
async def delete_message_file_association(
|
||
conn: asyncpg.Connection,
|
||
checkpoint_id: str,
|
||
message_index: int,
|
||
file_id: int
|
||
) -> bool:
|
||
"""
|
||
删除消息和文件的关联关系
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
checkpoint_id: checkpoint ID
|
||
message_index: 消息索引
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
DELETE FROM chat_message_file
|
||
WHERE checkpoint_id = $1 AND message_index = $2 AND file_id = $3
|
||
""",
|
||
checkpoint_id, message_index, file_id
|
||
)
|
||
|
||
return result == "DELETE 1"
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除消息文件关联失败: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
async def delete_thread_associations(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str
|
||
) -> int:
|
||
"""
|
||
删除会话的所有消息文件关联(用于删除会话时清理)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
int: 删除的记录数
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
DELETE FROM chat_message_file
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
deleted_count = int(result.split()[-1]) if result else 0
|
||
logger.info(f"删除会话 {thread_id} 的 {deleted_count} 条消息文件关联")
|
||
return deleted_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除消息文件关联失败: {e}")
|
||
return 0
|
||
|
||
@staticmethod
|
||
async def get_unlinked_files(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str
|
||
) -> List[dict]:
|
||
"""
|
||
获取会话中未关联到消息的文件列表(通过关联查询)
|
||
|
||
这些文件上传了但还没有关联到任何消息,需要在历史消息中显示
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
List[dict]: 未关联的文件信息列表,按创建时间升序排列
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT
|
||
ctf.id as file_id,
|
||
ctf.file_name,
|
||
ctf.file_size,
|
||
ctf.file_type,
|
||
ctf.file_path,
|
||
ctf.status,
|
||
ctf.created_at
|
||
FROM chat_thread_file ctf
|
||
WHERE ctf.thread_id = $1
|
||
AND ctf.is_deleted = FALSE
|
||
AND ctf.id NOT IN (
|
||
SELECT DISTINCT cmf.file_id
|
||
FROM chat_message_file cmf
|
||
WHERE cmf.thread_id = $1
|
||
)
|
||
ORDER BY ctf.created_at ASC
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
return [
|
||
{
|
||
'file_id': row['file_id'],
|
||
'file_name': row['file_name'],
|
||
'file_size': row['file_size'],
|
||
'file_type': row['file_type'],
|
||
'file_url': row['file_path'], # file_path 存储的是 OSS URL,作为 file_url 返回
|
||
'status': row['status'],
|
||
'created_at': row['created_at'].isoformat() if row['created_at'] else None
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取未关联文件列表失败: {e}")
|
||
return []
|
||
|