huoyan-enterprise/backend/services/chat_thread_service.py

778 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天会话服务模块
提供聊天会话的 CRUD 操作和业务逻辑。
"""
import copy
from typing import Any, Dict, List, Optional
from langchain_core.messages import AIMessage, messages_to_dict
from core.database import get_db_pool, get_checkpointer
from core.graph_metadata import (
chat_thread_kg_column_sql,
chat_thread_kg_select_fragment_sql,
chat_thread_llm_select_fragment_sql,
chat_threads_has_ip_column,
chat_threads_has_kg_column,
chat_threads_has_llm_columns,
graph_table_sql,
)
from core.permissions import can_view_graph
from models.graph_metadata import GraphRecord
from models.user import User
from core.exceptions import NotFoundError, ForbiddenError, BadRequestError, InternalError
from models.chat import (
ChatThreadItem,
ChatThreadListResponse,
ChatThreadDetailResponse,
)
from services.chat_thread_file_service import ChatThreadFileService
from services.chat_message_file_service import ChatMessageFileService
from services.knowledge_graph_service import KnowledgeGraphService
from services.oss_service import get_oss_service
from utils.checkpoint_helper import rebuild_full_message_history
from logger.logging import get_logger
logger = get_logger(__name__)
async def create_or_update_chat_thread(
thread_id: str,
user_id: int,
query: str,
knowledge_base_id: Optional[int] = None,
knowledge_graph_id: Optional[int] = None,
ip: Optional[str] = None,
llm_provider: Optional[str] = None,
llm_model: Optional[str] = None,
) -> None:
"""
创建或更新聊天会话记录
Args:
thread_id: 会话线程 ID
user_id: 用户 ID
query: 用户查询内容
knowledge_base_id: 知识库 ID可选
knowledge_graph_id: 知识图谱 ID可选对应 graphs.id
ip: 用户 IP 地址(可选)
llm_provider: 本次消息选用的提供方 tongyi/deepseek可选需库存在 llm 列)
llm_model: 本次消息选用的模型逻辑 id可选
"""
pool = await get_db_pool()
async with pool.acquire() as conn:
try:
# 检查该 thread_id 是否已存在
existing = await conn.fetchrow(
"SELECT id, message_count FROM chat_threads WHERE thread_id = $1",
thread_id
)
if existing:
# 已存在更新消息计数、知识库ID和更新时间
if chat_threads_has_kg_column():
kg_col = chat_thread_kg_column_sql()
await conn.execute(
f"""
UPDATE chat_threads
SET message_count = message_count + 1,
knowledge_base_id = $2,
{kg_col} = $3,
updated_at = CURRENT_TIMESTAMP
WHERE thread_id = $1
""",
thread_id,
knowledge_base_id,
knowledge_graph_id,
)
else:
await conn.execute(
"""
UPDATE chat_threads
SET message_count = message_count + 1,
knowledge_base_id = $2,
updated_at = CURRENT_TIMESTAMP
WHERE thread_id = $1
""",
thread_id,
knowledge_base_id,
)
logger.info(
f"更新会话记录: thread_id={thread_id}, 消息数={existing['message_count'] + 1}, "
f"knowledge_base_id={knowledge_base_id}, knowledge_graph_id={knowledge_graph_id}"
)
else:
# 不存在,创建新记录
# 取查询内容的前 10 个字作为标题
title = query[:10] if len(query) <= 10 else query[:10]
has_kg = chat_threads_has_kg_column()
has_ip = chat_threads_has_ip_column()
if has_kg:
kg_col = chat_thread_kg_column_sql()
if has_kg and has_ip:
await conn.execute(
f"""
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, {kg_col}, ip)
VALUES ($1, $2, $3, $4, 1, $5, $6, $7)
""",
thread_id,
user_id,
title,
query,
knowledge_base_id,
knowledge_graph_id,
ip,
)
elif has_kg and not has_ip:
await conn.execute(
f"""
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, {kg_col})
VALUES ($1, $2, $3, $4, 1, $5, $6)
""",
thread_id,
user_id,
title,
query,
knowledge_base_id,
knowledge_graph_id,
)
elif not has_kg and has_ip:
await conn.execute(
"""
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, ip)
VALUES ($1, $2, $3, $4, 1, $5, $6)
""",
thread_id,
user_id,
title,
query,
knowledge_base_id,
ip,
)
else:
await conn.execute(
"""
INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id)
VALUES ($1, $2, $3, $4, 1, $5)
""",
thread_id,
user_id,
title,
query,
knowledge_base_id,
)
logger.info(
f"创建新会话记录: thread_id={thread_id}, user_id={user_id}, title={title}, "
f"knowledge_base_id={knowledge_base_id}, knowledge_graph_id={knowledge_graph_id}, ip={ip}"
)
if chat_threads_has_llm_columns() and llm_provider and llm_model:
await conn.execute(
"""
UPDATE chat_threads
SET llm_provider = $2, llm_model = $3, updated_at = CURRENT_TIMESTAMP
WHERE thread_id = $1
""",
thread_id,
llm_provider,
llm_model,
)
except Exception as e:
logger.exception("记录会话到 chat_threads 失败(会导致会话列表为空): {}", e)
# 不抛出异常,避免影响主流程
pass
async def delete_chat_thread(thread_id: str, user_id: int) -> bool:
"""
删除聊天会话(软删除)
Args:
thread_id: 会话线程 ID
user_id: 用户 ID用于权限验证
Returns:
bool: 是否删除成功
Raises:
HTTPException: 会话不存在或无权限删除
"""
pool = await get_db_pool()
async with pool.acquire() as conn:
# 先检查会话是否存在且属于该用户
existing = await conn.fetchrow(
"""
SELECT id, user_id, is_deleted
FROM chat_threads
WHERE thread_id = $1
""",
thread_id
)
if not existing:
raise NotFoundError("会话")
if existing['user_id'] != user_id:
raise ForbiddenError("无权限删除该会话")
if existing['is_deleted']:
raise BadRequestError("会话已被删除")
# 删除消息文件关联(物理删除)
await ChatMessageFileService.delete_thread_associations(conn, thread_id)
# 获取会话的所有文件,删除 OSS 文件
all_files = await ChatThreadFileService.get_all_files_by_thread(conn, thread_id)
logger.info(f"会话 {thread_id} 共有 {len(all_files)} 个文件需要删除")
# 删除所有物理文件OSS
deleted_files_count = 0
oss_service = get_oss_service()
for file in all_files:
try:
if not oss_service.enabled:
logger.warning("OSS 服务未启用,无法删除物理文件")
elif file.file_path.startswith(('http://', 'https://')):
# 是 OSS URL删除 OSS 上的文件
oss_object_name = oss_service.extract_object_name_from_url(file.file_path, thread_id=thread_id)
if oss_object_name and oss_service.delete_file(oss_object_name):
deleted_files_count += 1
logger.debug(f"删除 OSS 文件: {oss_object_name}")
else:
logger.warning(f"无法删除 OSS 文件: {file.file_path}")
else:
logger.warning(f"文件路径不是 OSS URL 格式: {file.file_path}")
except Exception as e:
logger.warning(f"删除物理文件失败 {file.file_path}: {e}")
logger.info(f"已删除 {deleted_files_count} 个物理文件")
# 执行软删除
await conn.execute(
"""
UPDATE chat_threads
SET is_deleted = TRUE,
updated_at = CURRENT_TIMESTAMP
WHERE thread_id = $1
""",
thread_id
)
logger.info(f"删除会话成功: thread_id={thread_id}, user_id={user_id}")
return True
async def get_user_chat_threads(
user_id: int,
page: int = 1,
page_size: int = 20
) -> ChatThreadListResponse:
"""
获取用户的会话列表(分页)
Args:
user_id: 用户 ID
page: 页码(从 1 开始)
page_size: 每页数量
Returns:
ChatThreadListResponse: 会话列表响应
"""
pool = await get_db_pool()
async with pool.acquire() as conn:
# 计算偏移量
offset = (page - 1) * page_size
# 查询总数(只统计未删除的且有消息的)
total_row = await conn.fetchrow(
"""
SELECT COUNT(*) as total
FROM chat_threads
WHERE user_id = $1 AND is_deleted = FALSE AND message_count > 0
""",
user_id
)
total = total_row['total']
# 计算总页数
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
# 查询会话列表(按更新时间倒序,只查询有消息的会话)
kg_sel = chat_thread_kg_select_fragment_sql()
rows = await conn.fetch(
f"""
SELECT id, thread_id, title, first_query, message_count, knowledge_base_id, {kg_sel}, created_at, updated_at
FROM chat_threads
WHERE user_id = $1 AND is_deleted = FALSE AND message_count > 0
ORDER BY updated_at DESC
LIMIT $2 OFFSET $3
""",
user_id,
page_size,
offset
)
# 转换为模型列表
items = [
ChatThreadItem(
id=row['id'],
thread_id=row['thread_id'],
title=row['title'],
first_query=row['first_query'],
message_count=row['message_count'],
knowledge_base_id=row['knowledge_base_id'],
knowledge_graph_id=row['knowledge_graph_id'],
created_at=row['created_at'],
updated_at=row['updated_at']
)
for row in rows
]
logger.info(f"查询用户会话列表: user_id={user_id}, page={page}, total={total}")
return ChatThreadListResponse(
total=total,
page=page,
page_size=page_size,
total_pages=total_pages,
items=items
)
async def get_chat_thread_detail(thread_id: str, user_id: int) -> ChatThreadDetailResponse:
"""
获取会话的聊天明细
Args:
thread_id: 会话线程 ID
user_id: 用户 ID用于权限验证
Returns:
ChatThreadDetailResponse: 会话明细响应
Raises:
HTTPException: 会话不存在或无权限访问
"""
# 先验证会话是否存在且属于该用户
pool = await get_db_pool()
async with pool.acquire() as conn:
kg_sel = chat_thread_kg_select_fragment_sql()
llm_sel = chat_thread_llm_select_fragment_sql()
thread_info = await conn.fetchrow(
f"""
SELECT id, thread_id, user_id, title, message_count, knowledge_base_id, {kg_sel}, is_deleted, {llm_sel}
FROM chat_threads
WHERE thread_id = $1
""",
thread_id
)
if not thread_info:
raise NotFoundError("会话")
if thread_info['user_id'] != user_id:
raise ForbiddenError("无权限访问该会话")
if thread_info['is_deleted']:
raise NotFoundError("会话已被删除")
# 使用 checkpointer 查询会话消息
checkpointer = await get_checkpointer()
try:
# 获取该 thread_id 的所有 checkpoint
checkpoints = [
checkpoint async for checkpoint in checkpointer.alist(
{"configurable": {"thread_id": thread_id}}
)
]
messages_list = []
if checkpoints:
# 获取最新的 checkpoint第一个
latest_checkpoint = checkpoints[0]
checkpoint_data = latest_checkpoint.checkpoint
checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"]
# 通过关联查询获取该 thread_id 下所有 checkpoint 的文件关联
async with pool.acquire() as conn:
message_files_map = await ChatMessageFileService.get_all_files_by_thread(
conn, thread_id, checkpoint_id
)
# 通过关联查询获取未关联到消息的文件
unlinked_files = await ChatMessageFileService.get_unlinked_files(
conn, thread_id
)
logger.info(f"查询到 {len(unlinked_files)} 个未关联的文件: {[f['file_name'] for f in unlinked_files]}")
logger.info(f"查询到文件关联映射: {message_files_map}")
# 确保所有文件都有 file_url 字段
file_ids_to_query = set()
for files_list in message_files_map.values():
for file_info in files_list:
if 'file_url' not in file_info or not file_info['file_url']:
file_ids_to_query.add(file_info['file_id'])
# 批量查询 file_url
if file_ids_to_query:
file_url_map = {}
rows = await conn.fetch(
"""
SELECT id, file_path FROM chat_thread_file
WHERE id = ANY($1::int[]) AND is_deleted = FALSE
""",
list(file_ids_to_query)
)
for row in rows:
file_url_map[row['id']] = row['file_path']
# 更新文件信息中的 file_url
for files_list in message_files_map.values():
for file_info in files_list:
if file_info['file_id'] in file_url_map:
file_info['file_url'] = file_url_map[file_info['file_id']]
# 提取消息列表
if "channel_values" in checkpoint_data and "messages" in checkpoint_data["channel_values"]:
raw_messages = checkpoint_data["channel_values"]["messages"]
# 处理同时包含 content 和 reasoning_content 的 AI 消息
processed_messages = []
original_to_processed_index = {}
processed_idx = 0
for original_idx, msg in enumerate(raw_messages):
if isinstance(msg, AIMessage):
content = getattr(msg, 'content', "") or ""
reasoning_content = ""
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs:
reasoning_content = msg.additional_kwargs.get("reasoning_content", "") or ""
if content.strip() and reasoning_content.strip():
# 创建第一个消息:只有 reasoning_content
reasoning_msg = copy.deepcopy(msg)
reasoning_msg.content = ""
if not reasoning_msg.additional_kwargs:
reasoning_msg.additional_kwargs = {}
reasoning_msg.additional_kwargs["reasoning_content"] = reasoning_content
processed_messages.append(reasoning_msg)
processed_idx += 1
# 创建第二个消息:只有 content
content_msg = copy.deepcopy(msg)
content_msg.content = content
if not content_msg.additional_kwargs:
content_msg.additional_kwargs = {}
content_msg.additional_kwargs["reasoning_content"] = ""
processed_messages.append(content_msg)
processed_idx += 1
else:
processed_messages.append(msg)
processed_idx += 1
else:
processed_messages.append(msg)
original_to_processed_index[original_idx] = processed_idx
processed_idx += 1
raw_messages = processed_messages
messages_list = messages_to_dict(raw_messages)
# 将文件关联信息添加到 human 消息中
for msg_dict in messages_list:
msg_dict['files'] = []
for original_idx, processed_idx in original_to_processed_index.items():
files = message_files_map.get(original_idx, [])
if files and processed_idx < len(messages_list):
messages_list[processed_idx]['files'] = files
logger.info(f"消息索引 {processed_idx} 关联了 {len(files)} 个文件: {[f.get('file_name') for f in files]}")
# checkpoint 无消息时:优先相信 DB —— 常为 checkpoint 缺失/过期而 chat_messages 仍有双写记录
if not messages_list:
db_count = thread_info['message_count'] or 0
use_v2 = False
if db_count > 0:
use_v2 = True
logger.info(
f"V1 checkpoint 无可用消息但 chat_threads.message_count={db_count},回退 V2(chat_messages): thread_id={thread_id}"
)
else:
async with pool.acquire() as conn:
v2cnt = await conn.fetchval(
"SELECT COUNT(*)::int FROM chat_messages WHERE thread_id = $1",
thread_id,
)
if v2cnt and v2cnt > 0:
use_v2 = True
logger.info(
f"V1 checkpoint 无消息message_count=0但 chat_messages 有 {v2cnt} 条,回退 V2: thread_id={thread_id}"
)
if use_v2:
return await get_chat_thread_detail_v2(thread_id, user_id)
return ChatThreadDetailResponse(
thread_id=thread_id,
title=thread_info['title'],
knowledge_base_id=thread_info['knowledge_base_id'],
knowledge_graph_id=thread_info['knowledge_graph_id'],
llm_provider=thread_info['llm_provider'],
llm_model=thread_info['llm_model'],
messages=messages_list
)
except Exception as e:
logger.error(f"查询会话明细失败: {e}")
raise InternalError(f"查询会话明细失败: {str(e)}")
async def check_thread_has_files(thread_id: str) -> bool:
"""
检查会话是否有已完成的文件
Args:
thread_id: 会话线程 ID
Returns:
bool: 是否有已完成的文件
"""
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
count = await conn.fetchval(
"""
SELECT COUNT(*) FROM chat_thread_file
WHERE thread_id = $1 AND is_deleted = FALSE AND status = 'completed'
""",
thread_id
)
return count > 0
except Exception as e:
logger.error(f"检查会话文件失败: {e}")
return False
async def check_knowledge_base_has_files(knowledge_base_id: int, user_id: int) -> bool:
"""
检查知识库是否有已完成的文件
Args:
knowledge_base_id: 知识库 ID
user_id: 用户 ID
Returns:
bool: 是否有已完成的文件
"""
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
count = await conn.fetchval(
"""
SELECT COUNT(*) FROM knowledge_base_file
WHERE knowledge_base_id = $1
AND is_deleted = FALSE
AND status = 'completed'
""",
knowledge_base_id,
)
return count > 0
except Exception as e:
logger.error(f"检查知识库文件失败: {e}")
return False
async def check_knowledge_graph_has_rag(knowledge_graph_id: int, user: User) -> bool:
"""检查知识图谱是否存在且当前用户可见、已构建完成且已向量化。"""
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, knowledge_graph_id)
if not raw:
return False
gr = GraphRecord(
id=int(raw["id"]),
user_id=int(raw["user_id"]),
enterprise_id=raw.get("enterprise_id"),
department_id=raw.get("department_id"),
creator_id=raw.get("creator_id"),
visibility=raw.get("visibility") or "private",
)
if not await can_view_graph(conn, user, gr):
return False
return (
raw.get("build_status") == "completed"
and (raw.get("rag_chunk_count") or 0) > 0
)
except Exception as e:
logger.error(f"检查知识图谱 RAG 失败: {e}")
return False
async def get_knowledge_graph_tool_flags(user: User, graph_id: int) -> Dict[str, Any]:
"""
一次查询当前知识图谱可对聊天挂载哪些能力:
- has_rag: 正文已向量化,可用资料片段检索;
- neo4j_graph_id: 构建完成且存在 Neo4j 子图 ID 时,可用实体关系查询。
"""
out: Dict[str, Any] = {"has_rag": False, "neo4j_graph_id": None}
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_id)
if not raw:
return out
gr = GraphRecord(
id=int(raw["id"]),
user_id=int(raw["user_id"]),
enterprise_id=raw.get("enterprise_id"),
department_id=raw.get("department_id"),
creator_id=raw.get("creator_id"),
visibility=raw.get("visibility") or "private",
)
if not await can_view_graph(conn, user, gr):
return out
if raw.get("build_status") != "completed":
return out
neo = raw.get("neo4j_graph_id")
out["neo4j_graph_id"] = neo if neo else None
out["has_rag"] = (raw.get("rag_chunk_count") or 0) > 0
return out
except Exception as e:
logger.error(f"查询知识图谱工具标志失败: {e}")
return out
# ====================================
# V2 版本:基于 chat_messages 表查询
# ====================================
async def get_chat_thread_detail_v2(thread_id: str, user_id: int) -> ChatThreadDetailResponse:
"""
获取会话的聊天明细V2版本基于 chat_messages 表)
**优势**
- 查询速度更快直接SQL查询无需解析JSONB
- 用户原始问题和注入内容分离
- 支持全文搜索、统计分析
Args:
thread_id: 会话线程 ID
user_id: 用户 ID用于权限验证
Returns:
ChatThreadDetailResponse: 会话明细响应
Raises:
HTTPException: 会话不存在或无权限访问
"""
from services.chat_message_service import ChatMessageService
# 验证会话是否存在且属于该用户
pool = await get_db_pool()
async with pool.acquire() as conn:
kg_sel = chat_thread_kg_select_fragment_sql()
llm_sel = chat_thread_llm_select_fragment_sql()
thread_info = await conn.fetchrow(
f"""
SELECT id, thread_id, user_id, title, message_count, knowledge_base_id, {kg_sel}, is_deleted, {llm_sel}
FROM chat_threads
WHERE thread_id = $1
""",
thread_id
)
if not thread_info:
raise NotFoundError("会话")
if thread_info['user_id'] != user_id:
raise ForbiddenError("无权限访问该会话")
if thread_info['is_deleted']:
raise NotFoundError("会话已被删除")
# 从 chat_messages 表查询消息列表
try:
messages = await ChatMessageService.get_messages_by_thread(conn, thread_id)
# 获取文件关联信息(复用原有逻辑)
if messages:
# 获取最新的 checkpoint_id
latest_checkpoint_id = messages[-1]['checkpoint_id'] if messages else None
if latest_checkpoint_id:
message_files_map = await ChatMessageFileService.get_all_files_by_thread(
conn, thread_id, latest_checkpoint_id
)
else:
message_files_map = {}
else:
message_files_map = {}
# 组装消息列表(转换为 LangChain 格式)
# 类型映射:数据库存储 → 前端显示
role_type_mapping = {
'user': 'human',
'assistant': 'ai',
'tool': 'tool'
}
messages_list = []
for msg in messages:
# 提取 metadata
metadata = msg.get('metadata', {})
# 映射类型(保持向后兼容)
db_role = msg['role']
display_type = role_type_mapping.get(db_role, db_role)
# 构建消息数据结构
msg_dict = {
'type': display_type,
'data': {
'content': msg['content'],
'type': display_type,
'additional_kwargs': {},
'response_metadata': {},
'id': msg['checkpoint_id']
},
'files': message_files_map.get(msg['message_index'], [])
}
# 添加 name 字段(用于工具消息)
if msg['role'] == 'tool' and msg.get('name'):
msg_dict['data']['name'] = msg['name']
# 添加额外信息到 data 中
if msg['role'] == 'assistant' and metadata:
# AI 消息:添加 token 使用量、模型名称等
if 'token_usage' in metadata:
msg_dict['data']['response_metadata']['token_usage'] = metadata['token_usage']
if 'model' in metadata:
msg_dict['data']['response_metadata']['model'] = metadata['model']
if 'finish_reason' in metadata:
msg_dict['data']['response_metadata']['finish_reason'] = metadata['finish_reason']
if 'reasoning_content' in metadata:
msg_dict['data']['additional_kwargs']['reasoning_content'] = metadata['reasoning_content']
messages_list.append(msg_dict)
logger.info(f"✅ V2查询会话明细: thread_id={thread_id}, 消息数量={len(messages_list)}")
return ChatThreadDetailResponse(
thread_id=thread_id,
title=thread_info['title'],
knowledge_base_id=thread_info['knowledge_base_id'],
knowledge_graph_id=thread_info['knowledge_graph_id'],
llm_provider=thread_info['llm_provider'],
llm_model=thread_info['llm_model'],
messages=messages_list
)
except Exception as e:
logger.error(f"V2查询会话明细失败: {e}")
raise InternalError(f"查询会话明细失败: {str(e)}")