778 lines
30 KiB
Python
778 lines
30 KiB
Python
"""
|
||
聊天会话服务模块
|
||
|
||
提供聊天会话的 CRUD 操作和业务逻辑。
|
||
"""
|
||
import copy
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from langchain_core.messages import AIMessage, messages_to_dict
|
||
|
||
from core.database import get_db_pool, get_checkpointer
|
||
from core.graph_metadata import (
|
||
chat_thread_kg_column_sql,
|
||
chat_thread_kg_select_fragment_sql,
|
||
chat_thread_llm_select_fragment_sql,
|
||
chat_threads_has_ip_column,
|
||
chat_threads_has_kg_column,
|
||
chat_threads_has_llm_columns,
|
||
graph_table_sql,
|
||
)
|
||
from core.permissions import can_view_graph
|
||
from models.graph_metadata import GraphRecord
|
||
from models.user import User
|
||
from core.exceptions import NotFoundError, ForbiddenError, BadRequestError, InternalError
|
||
from models.chat import (
|
||
ChatThreadItem,
|
||
ChatThreadListResponse,
|
||
ChatThreadDetailResponse,
|
||
)
|
||
from services.chat_thread_file_service import ChatThreadFileService
|
||
from services.chat_message_file_service import ChatMessageFileService
|
||
from services.knowledge_graph_service import KnowledgeGraphService
|
||
from services.oss_service import get_oss_service
|
||
from utils.checkpoint_helper import rebuild_full_message_history
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
async def create_or_update_chat_thread(
|
||
thread_id: str,
|
||
user_id: int,
|
||
query: str,
|
||
knowledge_base_id: Optional[int] = None,
|
||
knowledge_graph_id: Optional[int] = None,
|
||
ip: Optional[str] = None,
|
||
llm_provider: Optional[str] = None,
|
||
llm_model: Optional[str] = None,
|
||
) -> None:
|
||
"""
|
||
创建或更新聊天会话记录
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
user_id: 用户 ID
|
||
query: 用户查询内容
|
||
knowledge_base_id: 知识库 ID(可选)
|
||
knowledge_graph_id: 知识图谱 ID(可选,对应 graphs.id)
|
||
ip: 用户 IP 地址(可选)
|
||
llm_provider: 本次消息选用的提供方 tongyi/deepseek(可选,需库存在 llm 列)
|
||
llm_model: 本次消息选用的模型逻辑 id(可选)
|
||
"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
try:
|
||
# 检查该 thread_id 是否已存在
|
||
existing = await conn.fetchrow(
|
||
"SELECT id, message_count FROM chat_threads WHERE thread_id = $1",
|
||
thread_id
|
||
)
|
||
|
||
if existing:
|
||
# 已存在,更新消息计数、知识库ID和更新时间
|
||
if chat_threads_has_kg_column():
|
||
kg_col = chat_thread_kg_column_sql()
|
||
await conn.execute(
|
||
f"""
|
||
UPDATE chat_threads
|
||
SET message_count = message_count + 1,
|
||
knowledge_base_id = $2,
|
||
{kg_col} = $3,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id,
|
||
knowledge_base_id,
|
||
knowledge_graph_id,
|
||
)
|
||
else:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE chat_threads
|
||
SET message_count = message_count + 1,
|
||
knowledge_base_id = $2,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id,
|
||
knowledge_base_id,
|
||
)
|
||
logger.info(
|
||
f"更新会话记录: thread_id={thread_id}, 消息数={existing['message_count'] + 1}, "
|
||
f"knowledge_base_id={knowledge_base_id}, knowledge_graph_id={knowledge_graph_id}"
|
||
)
|
||
else:
|
||
# 不存在,创建新记录
|
||
# 取查询内容的前 10 个字作为标题
|
||
title = query[:10] if len(query) <= 10 else query[:10]
|
||
|
||
has_kg = chat_threads_has_kg_column()
|
||
has_ip = chat_threads_has_ip_column()
|
||
if has_kg:
|
||
kg_col = chat_thread_kg_column_sql()
|
||
if has_kg and has_ip:
|
||
await conn.execute(
|
||
f"""
|
||
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, {kg_col}, ip)
|
||
VALUES ($1, $2, $3, $4, 1, $5, $6, $7)
|
||
""",
|
||
thread_id,
|
||
user_id,
|
||
title,
|
||
query,
|
||
knowledge_base_id,
|
||
knowledge_graph_id,
|
||
ip,
|
||
)
|
||
elif has_kg and not has_ip:
|
||
await conn.execute(
|
||
f"""
|
||
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, {kg_col})
|
||
VALUES ($1, $2, $3, $4, 1, $5, $6)
|
||
""",
|
||
thread_id,
|
||
user_id,
|
||
title,
|
||
query,
|
||
knowledge_base_id,
|
||
knowledge_graph_id,
|
||
)
|
||
elif not has_kg and has_ip:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, ip)
|
||
VALUES ($1, $2, $3, $4, 1, $5, $6)
|
||
""",
|
||
thread_id,
|
||
user_id,
|
||
title,
|
||
query,
|
||
knowledge_base_id,
|
||
ip,
|
||
)
|
||
else:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id)
|
||
VALUES ($1, $2, $3, $4, 1, $5)
|
||
""",
|
||
thread_id,
|
||
user_id,
|
||
title,
|
||
query,
|
||
knowledge_base_id,
|
||
)
|
||
logger.info(
|
||
f"创建新会话记录: thread_id={thread_id}, user_id={user_id}, title={title}, "
|
||
f"knowledge_base_id={knowledge_base_id}, knowledge_graph_id={knowledge_graph_id}, ip={ip}"
|
||
)
|
||
|
||
if chat_threads_has_llm_columns() and llm_provider and llm_model:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE chat_threads
|
||
SET llm_provider = $2, llm_model = $3, updated_at = CURRENT_TIMESTAMP
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id,
|
||
llm_provider,
|
||
llm_model,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.exception("记录会话到 chat_threads 失败(会导致会话列表为空): {}", e)
|
||
# 不抛出异常,避免影响主流程
|
||
pass
|
||
|
||
|
||
async def delete_chat_thread(thread_id: str, user_id: int) -> bool:
|
||
"""
|
||
删除聊天会话(软删除)
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
user_id: 用户 ID(用于权限验证)
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
|
||
Raises:
|
||
HTTPException: 会话不存在或无权限删除
|
||
"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
# 先检查会话是否存在且属于该用户
|
||
existing = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id, is_deleted
|
||
FROM chat_threads
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
if not existing:
|
||
raise NotFoundError("会话")
|
||
|
||
if existing['user_id'] != user_id:
|
||
raise ForbiddenError("无权限删除该会话")
|
||
|
||
if existing['is_deleted']:
|
||
raise BadRequestError("会话已被删除")
|
||
|
||
# 删除消息文件关联(物理删除)
|
||
await ChatMessageFileService.delete_thread_associations(conn, thread_id)
|
||
|
||
# 获取会话的所有文件,删除 OSS 文件
|
||
all_files = await ChatThreadFileService.get_all_files_by_thread(conn, thread_id)
|
||
logger.info(f"会话 {thread_id} 共有 {len(all_files)} 个文件需要删除")
|
||
|
||
# 删除所有物理文件(OSS)
|
||
deleted_files_count = 0
|
||
oss_service = get_oss_service()
|
||
for file in all_files:
|
||
try:
|
||
if not oss_service.enabled:
|
||
logger.warning("OSS 服务未启用,无法删除物理文件")
|
||
elif file.file_path.startswith(('http://', 'https://')):
|
||
# 是 OSS URL,删除 OSS 上的文件
|
||
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, thread_id=thread_id)
|
||
if oss_object_name and oss_service.delete_file(oss_object_name):
|
||
deleted_files_count += 1
|
||
logger.debug(f"删除 OSS 文件: {oss_object_name}")
|
||
else:
|
||
logger.warning(f"无法删除 OSS 文件: {file.file_path}")
|
||
else:
|
||
logger.warning(f"文件路径不是 OSS URL 格式: {file.file_path}")
|
||
except Exception as e:
|
||
logger.warning(f"删除物理文件失败 {file.file_path}: {e}")
|
||
|
||
logger.info(f"已删除 {deleted_files_count} 个物理文件")
|
||
|
||
# 执行软删除
|
||
await conn.execute(
|
||
"""
|
||
UPDATE chat_threads
|
||
SET is_deleted = TRUE,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
logger.info(f"删除会话成功: thread_id={thread_id}, user_id={user_id}")
|
||
return True
|
||
|
||
|
||
async def get_user_chat_threads(
|
||
user_id: int,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> ChatThreadListResponse:
|
||
"""
|
||
获取用户的会话列表(分页)
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
page: 页码(从 1 开始)
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
ChatThreadListResponse: 会话列表响应
|
||
"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
# 计算偏移量
|
||
offset = (page - 1) * page_size
|
||
|
||
# 查询总数(只统计未删除的且有消息的)
|
||
total_row = await conn.fetchrow(
|
||
"""
|
||
SELECT COUNT(*) as total
|
||
FROM chat_threads
|
||
WHERE user_id = $1 AND is_deleted = FALSE AND message_count > 0
|
||
""",
|
||
user_id
|
||
)
|
||
total = total_row['total']
|
||
|
||
# 计算总页数
|
||
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
|
||
|
||
# 查询会话列表(按更新时间倒序,只查询有消息的会话)
|
||
kg_sel = chat_thread_kg_select_fragment_sql()
|
||
rows = await conn.fetch(
|
||
f"""
|
||
SELECT id, thread_id, title, first_query, message_count, knowledge_base_id, {kg_sel}, created_at, updated_at
|
||
FROM chat_threads
|
||
WHERE user_id = $1 AND is_deleted = FALSE AND message_count > 0
|
||
ORDER BY updated_at DESC
|
||
LIMIT $2 OFFSET $3
|
||
""",
|
||
user_id,
|
||
page_size,
|
||
offset
|
||
)
|
||
|
||
# 转换为模型列表
|
||
items = [
|
||
ChatThreadItem(
|
||
id=row['id'],
|
||
thread_id=row['thread_id'],
|
||
title=row['title'],
|
||
first_query=row['first_query'],
|
||
message_count=row['message_count'],
|
||
knowledge_base_id=row['knowledge_base_id'],
|
||
knowledge_graph_id=row['knowledge_graph_id'],
|
||
created_at=row['created_at'],
|
||
updated_at=row['updated_at']
|
||
)
|
||
for row in rows
|
||
]
|
||
|
||
logger.info(f"查询用户会话列表: user_id={user_id}, page={page}, total={total}")
|
||
|
||
return ChatThreadListResponse(
|
||
total=total,
|
||
page=page,
|
||
page_size=page_size,
|
||
total_pages=total_pages,
|
||
items=items
|
||
)
|
||
|
||
|
||
async def get_chat_thread_detail(thread_id: str, user_id: int) -> ChatThreadDetailResponse:
|
||
"""
|
||
获取会话的聊天明细
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
user_id: 用户 ID(用于权限验证)
|
||
|
||
Returns:
|
||
ChatThreadDetailResponse: 会话明细响应
|
||
|
||
Raises:
|
||
HTTPException: 会话不存在或无权限访问
|
||
"""
|
||
# 先验证会话是否存在且属于该用户
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
kg_sel = chat_thread_kg_select_fragment_sql()
|
||
llm_sel = chat_thread_llm_select_fragment_sql()
|
||
thread_info = await conn.fetchrow(
|
||
f"""
|
||
SELECT id, thread_id, user_id, title, message_count, knowledge_base_id, {kg_sel}, is_deleted, {llm_sel}
|
||
FROM chat_threads
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
if not thread_info:
|
||
raise NotFoundError("会话")
|
||
|
||
if thread_info['user_id'] != user_id:
|
||
raise ForbiddenError("无权限访问该会话")
|
||
|
||
if thread_info['is_deleted']:
|
||
raise NotFoundError("会话已被删除")
|
||
|
||
# 使用 checkpointer 查询会话消息
|
||
checkpointer = await get_checkpointer()
|
||
|
||
try:
|
||
# 获取该 thread_id 的所有 checkpoint
|
||
checkpoints = [
|
||
checkpoint async for checkpoint in checkpointer.alist(
|
||
{"configurable": {"thread_id": thread_id}}
|
||
)
|
||
]
|
||
|
||
messages_list = []
|
||
|
||
if checkpoints:
|
||
# 获取最新的 checkpoint(第一个)
|
||
latest_checkpoint = checkpoints[0]
|
||
checkpoint_data = latest_checkpoint.checkpoint
|
||
checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"]
|
||
|
||
# 通过关联查询获取该 thread_id 下所有 checkpoint 的文件关联
|
||
async with pool.acquire() as conn:
|
||
message_files_map = await ChatMessageFileService.get_all_files_by_thread(
|
||
conn, thread_id, checkpoint_id
|
||
)
|
||
|
||
# 通过关联查询获取未关联到消息的文件
|
||
unlinked_files = await ChatMessageFileService.get_unlinked_files(
|
||
conn, thread_id
|
||
)
|
||
logger.info(f"查询到 {len(unlinked_files)} 个未关联的文件: {[f['file_name'] for f in unlinked_files]}")
|
||
logger.info(f"查询到文件关联映射: {message_files_map}")
|
||
|
||
# 确保所有文件都有 file_url 字段
|
||
file_ids_to_query = set()
|
||
for files_list in message_files_map.values():
|
||
for file_info in files_list:
|
||
if 'file_url' not in file_info or not file_info['file_url']:
|
||
file_ids_to_query.add(file_info['file_id'])
|
||
|
||
# 批量查询 file_url
|
||
if file_ids_to_query:
|
||
file_url_map = {}
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id, file_path FROM chat_thread_file
|
||
WHERE id = ANY($1::int[]) AND is_deleted = FALSE
|
||
""",
|
||
list(file_ids_to_query)
|
||
)
|
||
for row in rows:
|
||
file_url_map[row['id']] = row['file_path']
|
||
|
||
# 更新文件信息中的 file_url
|
||
for files_list in message_files_map.values():
|
||
for file_info in files_list:
|
||
if file_info['file_id'] in file_url_map:
|
||
file_info['file_url'] = file_url_map[file_info['file_id']]
|
||
|
||
# 提取消息列表
|
||
if "channel_values" in checkpoint_data and "messages" in checkpoint_data["channel_values"]:
|
||
raw_messages = checkpoint_data["channel_values"]["messages"]
|
||
|
||
# 处理同时包含 content 和 reasoning_content 的 AI 消息
|
||
processed_messages = []
|
||
original_to_processed_index = {}
|
||
processed_idx = 0
|
||
|
||
for original_idx, msg in enumerate(raw_messages):
|
||
if isinstance(msg, AIMessage):
|
||
content = getattr(msg, 'content', "") or ""
|
||
reasoning_content = ""
|
||
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs:
|
||
reasoning_content = msg.additional_kwargs.get("reasoning_content", "") or ""
|
||
|
||
if content.strip() and reasoning_content.strip():
|
||
# 创建第一个消息:只有 reasoning_content
|
||
reasoning_msg = copy.deepcopy(msg)
|
||
reasoning_msg.content = ""
|
||
if not reasoning_msg.additional_kwargs:
|
||
reasoning_msg.additional_kwargs = {}
|
||
reasoning_msg.additional_kwargs["reasoning_content"] = reasoning_content
|
||
processed_messages.append(reasoning_msg)
|
||
processed_idx += 1
|
||
|
||
# 创建第二个消息:只有 content
|
||
content_msg = copy.deepcopy(msg)
|
||
content_msg.content = content
|
||
if not content_msg.additional_kwargs:
|
||
content_msg.additional_kwargs = {}
|
||
content_msg.additional_kwargs["reasoning_content"] = ""
|
||
processed_messages.append(content_msg)
|
||
processed_idx += 1
|
||
else:
|
||
processed_messages.append(msg)
|
||
processed_idx += 1
|
||
else:
|
||
processed_messages.append(msg)
|
||
original_to_processed_index[original_idx] = processed_idx
|
||
processed_idx += 1
|
||
|
||
raw_messages = processed_messages
|
||
messages_list = messages_to_dict(raw_messages)
|
||
|
||
# 将文件关联信息添加到 human 消息中
|
||
for msg_dict in messages_list:
|
||
msg_dict['files'] = []
|
||
|
||
for original_idx, processed_idx in original_to_processed_index.items():
|
||
files = message_files_map.get(original_idx, [])
|
||
if files and processed_idx < len(messages_list):
|
||
messages_list[processed_idx]['files'] = files
|
||
logger.info(f"消息索引 {processed_idx} 关联了 {len(files)} 个文件: {[f.get('file_name') for f in files]}")
|
||
|
||
# checkpoint 无消息时:优先相信 DB —— 常为 checkpoint 缺失/过期而 chat_messages 仍有双写记录
|
||
if not messages_list:
|
||
db_count = thread_info['message_count'] or 0
|
||
use_v2 = False
|
||
if db_count > 0:
|
||
use_v2 = True
|
||
logger.info(
|
||
f"V1 checkpoint 无可用消息但 chat_threads.message_count={db_count},回退 V2(chat_messages): thread_id={thread_id}"
|
||
)
|
||
else:
|
||
async with pool.acquire() as conn:
|
||
v2cnt = await conn.fetchval(
|
||
"SELECT COUNT(*)::int FROM chat_messages WHERE thread_id = $1",
|
||
thread_id,
|
||
)
|
||
if v2cnt and v2cnt > 0:
|
||
use_v2 = True
|
||
logger.info(
|
||
f"V1 checkpoint 无消息(message_count=0)但 chat_messages 有 {v2cnt} 条,回退 V2: thread_id={thread_id}"
|
||
)
|
||
if use_v2:
|
||
return await get_chat_thread_detail_v2(thread_id, user_id)
|
||
|
||
return ChatThreadDetailResponse(
|
||
thread_id=thread_id,
|
||
title=thread_info['title'],
|
||
knowledge_base_id=thread_info['knowledge_base_id'],
|
||
knowledge_graph_id=thread_info['knowledge_graph_id'],
|
||
llm_provider=thread_info['llm_provider'],
|
||
llm_model=thread_info['llm_model'],
|
||
messages=messages_list
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询会话明细失败: {e}")
|
||
raise InternalError(f"查询会话明细失败: {str(e)}")
|
||
|
||
|
||
async def check_thread_has_files(thread_id: str) -> bool:
|
||
"""
|
||
检查会话是否有已完成的文件
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
bool: 是否有已完成的文件
|
||
"""
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
count = await conn.fetchval(
|
||
"""
|
||
SELECT COUNT(*) FROM chat_thread_file
|
||
WHERE thread_id = $1 AND is_deleted = FALSE AND status = 'completed'
|
||
""",
|
||
thread_id
|
||
)
|
||
return count > 0
|
||
except Exception as e:
|
||
logger.error(f"检查会话文件失败: {e}")
|
||
return False
|
||
|
||
|
||
async def check_knowledge_base_has_files(knowledge_base_id: int, user_id: int) -> bool:
|
||
"""
|
||
检查知识库是否有已完成的文件
|
||
|
||
Args:
|
||
knowledge_base_id: 知识库 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
bool: 是否有已完成的文件
|
||
"""
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
count = await conn.fetchval(
|
||
"""
|
||
SELECT COUNT(*) FROM knowledge_base_file
|
||
WHERE knowledge_base_id = $1
|
||
AND is_deleted = FALSE
|
||
AND status = 'completed'
|
||
""",
|
||
knowledge_base_id,
|
||
)
|
||
return count > 0
|
||
except Exception as e:
|
||
logger.error(f"检查知识库文件失败: {e}")
|
||
return False
|
||
|
||
|
||
async def check_knowledge_graph_has_rag(knowledge_graph_id: int, user: User) -> bool:
|
||
"""检查知识图谱是否存在且当前用户可见、已构建完成且已向量化。"""
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, knowledge_graph_id)
|
||
if not raw:
|
||
return False
|
||
gr = GraphRecord(
|
||
id=int(raw["id"]),
|
||
user_id=int(raw["user_id"]),
|
||
enterprise_id=raw.get("enterprise_id"),
|
||
department_id=raw.get("department_id"),
|
||
creator_id=raw.get("creator_id"),
|
||
visibility=raw.get("visibility") or "private",
|
||
)
|
||
if not can_view_graph(user, gr):
|
||
return False
|
||
return (
|
||
raw.get("build_status") == "completed"
|
||
and (raw.get("rag_chunk_count") or 0) > 0
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"检查知识图谱 RAG 失败: {e}")
|
||
return False
|
||
|
||
|
||
async def get_knowledge_graph_tool_flags(user: User, graph_id: int) -> Dict[str, Any]:
|
||
"""
|
||
一次查询当前知识图谱可对聊天挂载哪些能力:
|
||
- has_rag: 正文已向量化,可用资料片段检索;
|
||
- neo4j_graph_id: 构建完成且存在 Neo4j 子图 ID 时,可用实体关系查询。
|
||
"""
|
||
out: Dict[str, Any] = {"has_rag": False, "neo4j_graph_id": None}
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_id)
|
||
if not raw:
|
||
return out
|
||
gr = GraphRecord(
|
||
id=int(raw["id"]),
|
||
user_id=int(raw["user_id"]),
|
||
enterprise_id=raw.get("enterprise_id"),
|
||
department_id=raw.get("department_id"),
|
||
creator_id=raw.get("creator_id"),
|
||
visibility=raw.get("visibility") or "private",
|
||
)
|
||
if not can_view_graph(user, gr):
|
||
return out
|
||
if raw.get("build_status") != "completed":
|
||
return out
|
||
neo = raw.get("neo4j_graph_id")
|
||
out["neo4j_graph_id"] = neo if neo else None
|
||
out["has_rag"] = (raw.get("rag_chunk_count") or 0) > 0
|
||
return out
|
||
except Exception as e:
|
||
logger.error(f"查询知识图谱工具标志失败: {e}")
|
||
return out
|
||
|
||
|
||
# ====================================
|
||
# V2 版本:基于 chat_messages 表查询
|
||
# ====================================
|
||
|
||
async def get_chat_thread_detail_v2(thread_id: str, user_id: int) -> ChatThreadDetailResponse:
|
||
"""
|
||
获取会话的聊天明细(V2版本:基于 chat_messages 表)
|
||
|
||
**优势**:
|
||
- 查询速度更快(直接SQL查询,无需解析JSONB)
|
||
- 用户原始问题和注入内容分离
|
||
- 支持全文搜索、统计分析
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
user_id: 用户 ID(用于权限验证)
|
||
|
||
Returns:
|
||
ChatThreadDetailResponse: 会话明细响应
|
||
|
||
Raises:
|
||
HTTPException: 会话不存在或无权限访问
|
||
"""
|
||
from services.chat_message_service import ChatMessageService
|
||
|
||
# 验证会话是否存在且属于该用户
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
kg_sel = chat_thread_kg_select_fragment_sql()
|
||
llm_sel = chat_thread_llm_select_fragment_sql()
|
||
thread_info = await conn.fetchrow(
|
||
f"""
|
||
SELECT id, thread_id, user_id, title, message_count, knowledge_base_id, {kg_sel}, is_deleted, {llm_sel}
|
||
FROM chat_threads
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
if not thread_info:
|
||
raise NotFoundError("会话")
|
||
|
||
if thread_info['user_id'] != user_id:
|
||
raise ForbiddenError("无权限访问该会话")
|
||
|
||
if thread_info['is_deleted']:
|
||
raise NotFoundError("会话已被删除")
|
||
|
||
# 从 chat_messages 表查询消息列表
|
||
try:
|
||
messages = await ChatMessageService.get_messages_by_thread(conn, thread_id)
|
||
|
||
# 获取文件关联信息(复用原有逻辑)
|
||
if messages:
|
||
# 获取最新的 checkpoint_id
|
||
latest_checkpoint_id = messages[-1]['checkpoint_id'] if messages else None
|
||
|
||
if latest_checkpoint_id:
|
||
message_files_map = await ChatMessageFileService.get_all_files_by_thread(
|
||
conn, thread_id, latest_checkpoint_id
|
||
)
|
||
else:
|
||
message_files_map = {}
|
||
else:
|
||
message_files_map = {}
|
||
|
||
# 组装消息列表(转换为 LangChain 格式)
|
||
# 类型映射:数据库存储 → 前端显示
|
||
role_type_mapping = {
|
||
'user': 'human',
|
||
'assistant': 'ai',
|
||
'tool': 'tool'
|
||
}
|
||
|
||
messages_list = []
|
||
for msg in messages:
|
||
# 提取 metadata
|
||
metadata = msg.get('metadata', {})
|
||
|
||
# 映射类型(保持向后兼容)
|
||
db_role = msg['role']
|
||
display_type = role_type_mapping.get(db_role, db_role)
|
||
|
||
# 构建消息数据结构
|
||
msg_dict = {
|
||
'type': display_type,
|
||
'data': {
|
||
'content': msg['content'],
|
||
'type': display_type,
|
||
'additional_kwargs': {},
|
||
'response_metadata': {},
|
||
'id': msg['checkpoint_id']
|
||
},
|
||
'files': message_files_map.get(msg['message_index'], [])
|
||
}
|
||
|
||
# 添加 name 字段(用于工具消息)
|
||
if msg['role'] == 'tool' and msg.get('name'):
|
||
msg_dict['data']['name'] = msg['name']
|
||
|
||
# 添加额外信息到 data 中
|
||
if msg['role'] == 'assistant' and metadata:
|
||
# AI 消息:添加 token 使用量、模型名称等
|
||
if 'token_usage' in metadata:
|
||
msg_dict['data']['response_metadata']['token_usage'] = metadata['token_usage']
|
||
if 'model' in metadata:
|
||
msg_dict['data']['response_metadata']['model'] = metadata['model']
|
||
if 'finish_reason' in metadata:
|
||
msg_dict['data']['response_metadata']['finish_reason'] = metadata['finish_reason']
|
||
if 'reasoning_content' in metadata:
|
||
msg_dict['data']['additional_kwargs']['reasoning_content'] = metadata['reasoning_content']
|
||
|
||
messages_list.append(msg_dict)
|
||
|
||
logger.info(f"✅ V2查询会话明细: thread_id={thread_id}, 消息数量={len(messages_list)}")
|
||
|
||
return ChatThreadDetailResponse(
|
||
thread_id=thread_id,
|
||
title=thread_info['title'],
|
||
knowledge_base_id=thread_info['knowledge_base_id'],
|
||
knowledge_graph_id=thread_info['knowledge_graph_id'],
|
||
llm_provider=thread_info['llm_provider'],
|
||
llm_model=thread_info['llm_model'],
|
||
messages=messages_list
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"V2查询会话明细失败: {e}")
|
||
raise InternalError(f"查询会话明细失败: {str(e)}")
|