""" 聊天会话服务模块 提供聊天会话的 CRUD 操作和业务逻辑。 """ import copy from typing import Any, Dict, List, Optional from langchain_core.messages import AIMessage, messages_to_dict from core.database import get_db_pool, get_checkpointer from core.graph_metadata import ( chat_thread_kg_column_sql, chat_thread_kg_select_fragment_sql, chat_thread_llm_select_fragment_sql, chat_threads_has_ip_column, chat_threads_has_kg_column, chat_threads_has_llm_columns, graph_table_sql, ) from core.permissions import can_view_graph from models.graph_metadata import GraphRecord from models.user import User from core.exceptions import NotFoundError, ForbiddenError, BadRequestError, InternalError from models.chat import ( ChatThreadItem, ChatThreadListResponse, ChatThreadDetailResponse, ) from services.chat_thread_file_service import ChatThreadFileService from services.chat_message_file_service import ChatMessageFileService from services.knowledge_graph_service import KnowledgeGraphService from services.oss_service import get_oss_service from utils.checkpoint_helper import rebuild_full_message_history from logger.logging import get_logger logger = get_logger(__name__) async def create_or_update_chat_thread( thread_id: str, user_id: int, query: str, knowledge_base_id: Optional[int] = None, knowledge_graph_id: Optional[int] = None, ip: Optional[str] = None, llm_provider: Optional[str] = None, llm_model: Optional[str] = None, ) -> None: """ 创建或更新聊天会话记录 Args: thread_id: 会话线程 ID user_id: 用户 ID query: 用户查询内容 knowledge_base_id: 知识库 ID(可选) knowledge_graph_id: 知识图谱 ID(可选,对应 graphs.id) ip: 用户 IP 地址(可选) llm_provider: 本次消息选用的提供方 tongyi/deepseek(可选,需库存在 llm 列) llm_model: 本次消息选用的模型逻辑 id(可选) """ pool = await get_db_pool() async with pool.acquire() as conn: try: # 检查该 thread_id 是否已存在 existing = await conn.fetchrow( "SELECT id, message_count FROM chat_threads WHERE thread_id = $1", thread_id ) if existing: # 已存在,更新消息计数、知识库ID和更新时间 if chat_threads_has_kg_column(): kg_col = chat_thread_kg_column_sql() await conn.execute( f""" UPDATE chat_threads SET message_count = message_count + 1, knowledge_base_id = $2, {kg_col} = $3, updated_at = CURRENT_TIMESTAMP WHERE thread_id = $1 """, thread_id, knowledge_base_id, knowledge_graph_id, ) else: await conn.execute( """ UPDATE chat_threads SET message_count = message_count + 1, knowledge_base_id = $2, updated_at = CURRENT_TIMESTAMP WHERE thread_id = $1 """, thread_id, knowledge_base_id, ) logger.info( f"更新会话记录: thread_id={thread_id}, 消息数={existing['message_count'] + 1}, " f"knowledge_base_id={knowledge_base_id}, knowledge_graph_id={knowledge_graph_id}" ) else: # 不存在,创建新记录 # 取查询内容的前 10 个字作为标题 title = query[:10] if len(query) <= 10 else query[:10] has_kg = chat_threads_has_kg_column() has_ip = chat_threads_has_ip_column() if has_kg: kg_col = chat_thread_kg_column_sql() if has_kg and has_ip: await conn.execute( f""" INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, {kg_col}, ip) VALUES ($1, $2, $3, $4, 1, $5, $6, $7) """, thread_id, user_id, title, query, knowledge_base_id, knowledge_graph_id, ip, ) elif has_kg and not has_ip: await conn.execute( f""" INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, {kg_col}) VALUES ($1, $2, $3, $4, 1, $5, $6) """, thread_id, user_id, title, query, knowledge_base_id, knowledge_graph_id, ) elif not has_kg and has_ip: await conn.execute( """ INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id, ip) VALUES ($1, $2, $3, $4, 1, $5, $6) """, thread_id, user_id, title, query, knowledge_base_id, ip, ) else: await conn.execute( """ INSERT INTO chat_threads (thread_id, user_id, title, first_query, message_count, knowledge_base_id) VALUES ($1, $2, $3, $4, 1, $5) """, thread_id, user_id, title, query, knowledge_base_id, ) logger.info( f"创建新会话记录: thread_id={thread_id}, user_id={user_id}, title={title}, " f"knowledge_base_id={knowledge_base_id}, knowledge_graph_id={knowledge_graph_id}, ip={ip}" ) if chat_threads_has_llm_columns() and llm_provider and llm_model: await conn.execute( """ UPDATE chat_threads SET llm_provider = $2, llm_model = $3, updated_at = CURRENT_TIMESTAMP WHERE thread_id = $1 """, thread_id, llm_provider, llm_model, ) except Exception as e: logger.exception("记录会话到 chat_threads 失败(会导致会话列表为空): {}", e) # 不抛出异常,避免影响主流程 pass async def delete_chat_thread(thread_id: str, user_id: int) -> bool: """ 删除聊天会话(软删除) Args: thread_id: 会话线程 ID user_id: 用户 ID(用于权限验证) Returns: bool: 是否删除成功 Raises: HTTPException: 会话不存在或无权限删除 """ pool = await get_db_pool() async with pool.acquire() as conn: # 先检查会话是否存在且属于该用户 existing = await conn.fetchrow( """ SELECT id, user_id, is_deleted FROM chat_threads WHERE thread_id = $1 """, thread_id ) if not existing: raise NotFoundError("会话") if existing['user_id'] != user_id: raise ForbiddenError("无权限删除该会话") if existing['is_deleted']: raise BadRequestError("会话已被删除") # 删除消息文件关联(物理删除) await ChatMessageFileService.delete_thread_associations(conn, thread_id) # 获取会话的所有文件,删除 OSS 文件 all_files = await ChatThreadFileService.get_all_files_by_thread(conn, thread_id) logger.info(f"会话 {thread_id} 共有 {len(all_files)} 个文件需要删除") # 删除所有物理文件(OSS) deleted_files_count = 0 oss_service = get_oss_service() for file in all_files: try: if not oss_service.enabled: logger.warning("OSS 服务未启用,无法删除物理文件") elif file.file_path.startswith(('http://', 'https://')): # 是 OSS URL,删除 OSS 上的文件 oss_object_name = oss_service.extract_object_name_from_url(file.file_path, thread_id=thread_id) if oss_object_name and oss_service.delete_file(oss_object_name): deleted_files_count += 1 logger.debug(f"删除 OSS 文件: {oss_object_name}") else: logger.warning(f"无法删除 OSS 文件: {file.file_path}") else: logger.warning(f"文件路径不是 OSS URL 格式: {file.file_path}") except Exception as e: logger.warning(f"删除物理文件失败 {file.file_path}: {e}") logger.info(f"已删除 {deleted_files_count} 个物理文件") # 执行软删除 await conn.execute( """ UPDATE chat_threads SET is_deleted = TRUE, updated_at = CURRENT_TIMESTAMP WHERE thread_id = $1 """, thread_id ) logger.info(f"删除会话成功: thread_id={thread_id}, user_id={user_id}") return True async def get_user_chat_threads( user_id: int, page: int = 1, page_size: int = 20 ) -> ChatThreadListResponse: """ 获取用户的会话列表(分页) Args: user_id: 用户 ID page: 页码(从 1 开始) page_size: 每页数量 Returns: ChatThreadListResponse: 会话列表响应 """ pool = await get_db_pool() async with pool.acquire() as conn: # 计算偏移量 offset = (page - 1) * page_size # 查询总数(只统计未删除的且有消息的) total_row = await conn.fetchrow( """ SELECT COUNT(*) as total FROM chat_threads WHERE user_id = $1 AND is_deleted = FALSE AND message_count > 0 """, user_id ) total = total_row['total'] # 计算总页数 total_pages = (total + page_size - 1) // page_size if total > 0 else 0 # 查询会话列表(按更新时间倒序,只查询有消息的会话) kg_sel = chat_thread_kg_select_fragment_sql() rows = await conn.fetch( f""" SELECT id, thread_id, title, first_query, message_count, knowledge_base_id, {kg_sel}, created_at, updated_at FROM chat_threads WHERE user_id = $1 AND is_deleted = FALSE AND message_count > 0 ORDER BY updated_at DESC LIMIT $2 OFFSET $3 """, user_id, page_size, offset ) # 转换为模型列表 items = [ ChatThreadItem( id=row['id'], thread_id=row['thread_id'], title=row['title'], first_query=row['first_query'], message_count=row['message_count'], knowledge_base_id=row['knowledge_base_id'], knowledge_graph_id=row['knowledge_graph_id'], created_at=row['created_at'], updated_at=row['updated_at'] ) for row in rows ] logger.info(f"查询用户会话列表: user_id={user_id}, page={page}, total={total}") return ChatThreadListResponse( total=total, page=page, page_size=page_size, total_pages=total_pages, items=items ) async def get_chat_thread_detail(thread_id: str, user_id: int) -> ChatThreadDetailResponse: """ 获取会话的聊天明细 Args: thread_id: 会话线程 ID user_id: 用户 ID(用于权限验证) Returns: ChatThreadDetailResponse: 会话明细响应 Raises: HTTPException: 会话不存在或无权限访问 """ # 先验证会话是否存在且属于该用户 pool = await get_db_pool() async with pool.acquire() as conn: kg_sel = chat_thread_kg_select_fragment_sql() llm_sel = chat_thread_llm_select_fragment_sql() thread_info = await conn.fetchrow( f""" SELECT id, thread_id, user_id, title, message_count, knowledge_base_id, {kg_sel}, is_deleted, {llm_sel} FROM chat_threads WHERE thread_id = $1 """, thread_id ) if not thread_info: raise NotFoundError("会话") if thread_info['user_id'] != user_id: raise ForbiddenError("无权限访问该会话") if thread_info['is_deleted']: raise NotFoundError("会话已被删除") # 使用 checkpointer 查询会话消息 checkpointer = await get_checkpointer() try: # 获取该 thread_id 的所有 checkpoint checkpoints = [ checkpoint async for checkpoint in checkpointer.alist( {"configurable": {"thread_id": thread_id}} ) ] messages_list = [] if checkpoints: # 获取最新的 checkpoint(第一个) latest_checkpoint = checkpoints[0] checkpoint_data = latest_checkpoint.checkpoint checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"] # 通过关联查询获取该 thread_id 下所有 checkpoint 的文件关联 async with pool.acquire() as conn: message_files_map = await ChatMessageFileService.get_all_files_by_thread( conn, thread_id, checkpoint_id ) # 通过关联查询获取未关联到消息的文件 unlinked_files = await ChatMessageFileService.get_unlinked_files( conn, thread_id ) logger.info(f"查询到 {len(unlinked_files)} 个未关联的文件: {[f['file_name'] for f in unlinked_files]}") logger.info(f"查询到文件关联映射: {message_files_map}") # 确保所有文件都有 file_url 字段 file_ids_to_query = set() for files_list in message_files_map.values(): for file_info in files_list: if 'file_url' not in file_info or not file_info['file_url']: file_ids_to_query.add(file_info['file_id']) # 批量查询 file_url if file_ids_to_query: file_url_map = {} rows = await conn.fetch( """ SELECT id, file_path FROM chat_thread_file WHERE id = ANY($1::int[]) AND is_deleted = FALSE """, list(file_ids_to_query) ) for row in rows: file_url_map[row['id']] = row['file_path'] # 更新文件信息中的 file_url for files_list in message_files_map.values(): for file_info in files_list: if file_info['file_id'] in file_url_map: file_info['file_url'] = file_url_map[file_info['file_id']] # 提取消息列表 if "channel_values" in checkpoint_data and "messages" in checkpoint_data["channel_values"]: raw_messages = checkpoint_data["channel_values"]["messages"] # 处理同时包含 content 和 reasoning_content 的 AI 消息 processed_messages = [] original_to_processed_index = {} processed_idx = 0 for original_idx, msg in enumerate(raw_messages): if isinstance(msg, AIMessage): content = getattr(msg, 'content', "") or "" reasoning_content = "" if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs: reasoning_content = msg.additional_kwargs.get("reasoning_content", "") or "" if content.strip() and reasoning_content.strip(): # 创建第一个消息:只有 reasoning_content reasoning_msg = copy.deepcopy(msg) reasoning_msg.content = "" if not reasoning_msg.additional_kwargs: reasoning_msg.additional_kwargs = {} reasoning_msg.additional_kwargs["reasoning_content"] = reasoning_content processed_messages.append(reasoning_msg) processed_idx += 1 # 创建第二个消息:只有 content content_msg = copy.deepcopy(msg) content_msg.content = content if not content_msg.additional_kwargs: content_msg.additional_kwargs = {} content_msg.additional_kwargs["reasoning_content"] = "" processed_messages.append(content_msg) processed_idx += 1 else: processed_messages.append(msg) processed_idx += 1 else: processed_messages.append(msg) original_to_processed_index[original_idx] = processed_idx processed_idx += 1 raw_messages = processed_messages messages_list = messages_to_dict(raw_messages) # 将文件关联信息添加到 human 消息中 for msg_dict in messages_list: msg_dict['files'] = [] for original_idx, processed_idx in original_to_processed_index.items(): files = message_files_map.get(original_idx, []) if files and processed_idx < len(messages_list): messages_list[processed_idx]['files'] = files logger.info(f"消息索引 {processed_idx} 关联了 {len(files)} 个文件: {[f.get('file_name') for f in files]}") # checkpoint 无消息时:优先相信 DB —— 常为 checkpoint 缺失/过期而 chat_messages 仍有双写记录 if not messages_list: db_count = thread_info['message_count'] or 0 use_v2 = False if db_count > 0: use_v2 = True logger.info( f"V1 checkpoint 无可用消息但 chat_threads.message_count={db_count},回退 V2(chat_messages): thread_id={thread_id}" ) else: async with pool.acquire() as conn: v2cnt = await conn.fetchval( "SELECT COUNT(*)::int FROM chat_messages WHERE thread_id = $1", thread_id, ) if v2cnt and v2cnt > 0: use_v2 = True logger.info( f"V1 checkpoint 无消息(message_count=0)但 chat_messages 有 {v2cnt} 条,回退 V2: thread_id={thread_id}" ) if use_v2: return await get_chat_thread_detail_v2(thread_id, user_id) return ChatThreadDetailResponse( thread_id=thread_id, title=thread_info['title'], knowledge_base_id=thread_info['knowledge_base_id'], knowledge_graph_id=thread_info['knowledge_graph_id'], llm_provider=thread_info['llm_provider'], llm_model=thread_info['llm_model'], messages=messages_list ) except Exception as e: logger.error(f"查询会话明细失败: {e}") raise InternalError(f"查询会话明细失败: {str(e)}") async def check_thread_has_files(thread_id: str) -> bool: """ 检查会话是否有已完成的文件 Args: thread_id: 会话线程 ID Returns: bool: 是否有已完成的文件 """ try: pool = await get_db_pool() async with pool.acquire() as conn: count = await conn.fetchval( """ SELECT COUNT(*) FROM chat_thread_file WHERE thread_id = $1 AND is_deleted = FALSE AND status = 'completed' """, thread_id ) return count > 0 except Exception as e: logger.error(f"检查会话文件失败: {e}") return False async def check_knowledge_base_has_files(knowledge_base_id: int, user_id: int) -> bool: """ 检查知识库是否有已完成的文件 Args: knowledge_base_id: 知识库 ID user_id: 用户 ID Returns: bool: 是否有已完成的文件 """ try: pool = await get_db_pool() async with pool.acquire() as conn: count = await conn.fetchval( """ SELECT COUNT(*) FROM knowledge_base_file WHERE knowledge_base_id = $1 AND is_deleted = FALSE AND status = 'completed' """, knowledge_base_id, ) return count > 0 except Exception as e: logger.error(f"检查知识库文件失败: {e}") return False async def check_knowledge_graph_has_rag(knowledge_graph_id: int, user: User) -> bool: """检查知识图谱是否存在且当前用户可见、已构建完成且已向量化。""" try: pool = await get_db_pool() async with pool.acquire() as conn: raw = await KnowledgeGraphService.fetch_graph_by_id(conn, knowledge_graph_id) if not raw: return False gr = GraphRecord( id=int(raw["id"]), user_id=int(raw["user_id"]), enterprise_id=raw.get("enterprise_id"), department_id=raw.get("department_id"), creator_id=raw.get("creator_id"), visibility=raw.get("visibility") or "private", ) if not await can_view_graph(conn, user, gr): return False return ( raw.get("build_status") == "completed" and (raw.get("rag_chunk_count") or 0) > 0 ) except Exception as e: logger.error(f"检查知识图谱 RAG 失败: {e}") return False async def get_knowledge_graph_tool_flags(user: User, graph_id: int) -> Dict[str, Any]: """ 一次查询当前知识图谱可对聊天挂载哪些能力: - has_rag: 正文已向量化,可用资料片段检索; - neo4j_graph_id: 构建完成且存在 Neo4j 子图 ID 时,可用实体关系查询。 """ out: Dict[str, Any] = {"has_rag": False, "neo4j_graph_id": None} try: pool = await get_db_pool() async with pool.acquire() as conn: raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_id) if not raw: return out gr = GraphRecord( id=int(raw["id"]), user_id=int(raw["user_id"]), enterprise_id=raw.get("enterprise_id"), department_id=raw.get("department_id"), creator_id=raw.get("creator_id"), visibility=raw.get("visibility") or "private", ) if not await can_view_graph(conn, user, gr): return out if raw.get("build_status") != "completed": return out neo = raw.get("neo4j_graph_id") out["neo4j_graph_id"] = neo if neo else None out["has_rag"] = (raw.get("rag_chunk_count") or 0) > 0 return out except Exception as e: logger.error(f"查询知识图谱工具标志失败: {e}") return out # ==================================== # V2 版本:基于 chat_messages 表查询 # ==================================== async def get_chat_thread_detail_v2(thread_id: str, user_id: int) -> ChatThreadDetailResponse: """ 获取会话的聊天明细(V2版本:基于 chat_messages 表) **优势**: - 查询速度更快(直接SQL查询,无需解析JSONB) - 用户原始问题和注入内容分离 - 支持全文搜索、统计分析 Args: thread_id: 会话线程 ID user_id: 用户 ID(用于权限验证) Returns: ChatThreadDetailResponse: 会话明细响应 Raises: HTTPException: 会话不存在或无权限访问 """ from services.chat_message_service import ChatMessageService # 验证会话是否存在且属于该用户 pool = await get_db_pool() async with pool.acquire() as conn: kg_sel = chat_thread_kg_select_fragment_sql() llm_sel = chat_thread_llm_select_fragment_sql() thread_info = await conn.fetchrow( f""" SELECT id, thread_id, user_id, title, message_count, knowledge_base_id, {kg_sel}, is_deleted, {llm_sel} FROM chat_threads WHERE thread_id = $1 """, thread_id ) if not thread_info: raise NotFoundError("会话") if thread_info['user_id'] != user_id: raise ForbiddenError("无权限访问该会话") if thread_info['is_deleted']: raise NotFoundError("会话已被删除") # 从 chat_messages 表查询消息列表 try: messages = await ChatMessageService.get_messages_by_thread(conn, thread_id) # 获取文件关联信息(复用原有逻辑) if messages: # 获取最新的 checkpoint_id latest_checkpoint_id = messages[-1]['checkpoint_id'] if messages else None if latest_checkpoint_id: message_files_map = await ChatMessageFileService.get_all_files_by_thread( conn, thread_id, latest_checkpoint_id ) else: message_files_map = {} else: message_files_map = {} # 组装消息列表(转换为 LangChain 格式) # 类型映射:数据库存储 → 前端显示 role_type_mapping = { 'user': 'human', 'assistant': 'ai', 'tool': 'tool' } messages_list = [] for msg in messages: # 提取 metadata metadata = msg.get('metadata', {}) # 映射类型(保持向后兼容) db_role = msg['role'] display_type = role_type_mapping.get(db_role, db_role) # 构建消息数据结构 msg_dict = { 'type': display_type, 'data': { 'content': msg['content'], 'type': display_type, 'additional_kwargs': {}, 'response_metadata': {}, 'id': msg['checkpoint_id'] }, 'files': message_files_map.get(msg['message_index'], []) } # 添加 name 字段(用于工具消息) if msg['role'] == 'tool' and msg.get('name'): msg_dict['data']['name'] = msg['name'] # 添加额外信息到 data 中 if msg['role'] == 'assistant' and metadata: # AI 消息:添加 token 使用量、模型名称等 if 'token_usage' in metadata: msg_dict['data']['response_metadata']['token_usage'] = metadata['token_usage'] if 'model' in metadata: msg_dict['data']['response_metadata']['model'] = metadata['model'] if 'finish_reason' in metadata: msg_dict['data']['response_metadata']['finish_reason'] = metadata['finish_reason'] if 'reasoning_content' in metadata: msg_dict['data']['additional_kwargs']['reasoning_content'] = metadata['reasoning_content'] messages_list.append(msg_dict) logger.info(f"✅ V2查询会话明细: thread_id={thread_id}, 消息数量={len(messages_list)}") return ChatThreadDetailResponse( thread_id=thread_id, title=thread_info['title'], knowledge_base_id=thread_info['knowledge_base_id'], knowledge_graph_id=thread_info['knowledge_graph_id'], llm_provider=thread_info['llm_provider'], llm_model=thread_info['llm_model'], messages=messages_list ) except Exception as e: logger.error(f"V2查询会话明细失败: {e}") raise InternalError(f"查询会话明细失败: {str(e)}")