""" 聊天 API 路由模块 定义聊天相关的 API 路由,包括对话、会话管理等接口。 """ from __future__ import annotations import json import os from datetime import datetime, timezone from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import JSONResponse from langchain_core.messages import message_to_dict from langchain.agents import create_agent from sse_starlette import EventSourceResponse, ServerSentEvent from core.config import settings from core.llm_catalog import ( build_chat_model_for_completion, coerce_model_id, deepseek_api_model_by_reasoner_setting, list_llm_options_payload, normalize_provider, resolve_to_api_model, validate_request_can_use_provider, ) from core.database import get_db_pool, get_checkpointer from core.mcp_client import get_mcp_client from core.dependencies import get_current_user, get_moderation_service from core.exceptions import ModerationError from models.user import User from models.moderation import ModerationDecision from models.chat import ( ChatRequest, DeleteThreadRequest, ChatThreadListResponse, ChatThreadDetailResponse, ) from prompt.prompt import ( get_translate_instructions, get_text2video_instructions, get_text2img_instructions, get_text2poster_instructions, get_research_instructions ) from services.enterprise_service import EnterpriseService from tools.tools import ( get_current_time, internet_search, text_to_image, text_to_video, text_to_poster, create_rag_retrieve_tool, create_kb_rag_retrieve_tool, create_knowledge_graph_neo4j_search_tool, create_knowledge_graph_rag_retrieve_tool, ) from services.chat_thread_service import ( create_or_update_chat_thread, delete_chat_thread, get_user_chat_threads, get_chat_thread_detail, get_chat_thread_detail_v2, # V2版本:基于 chat_messages 表 check_thread_has_files, check_knowledge_base_has_files, get_knowledge_graph_tool_flags, ) from services.chat_message_file_service import ChatMessageFileService from services.chat_message_service import ChatMessageService # 新增:消息保存服务 from utils.helpers import BaseResponse from logger.logging import get_logger logger = get_logger(__name__) # 创建路由实例 chat_router = APIRouter(prefix="/api", tags=["聊天接口"]) @chat_router.get("/chat/llm-options", summary="聊天可选大模型列表") async def chat_llm_options(current_user: User = Depends(get_current_user)): """返回前端下拉所需的提供方与模型逻辑 id(不含密钥)。""" return list_llm_options_payload() @chat_router.post("/chat/completion", summary="聊天接口主入口") async def chat_completion( request: ChatRequest, http_request: Request, current_user: User = Depends(get_current_user), moderation_service = Depends(get_moderation_service) ): """ 聊天接口主入口(需要认证) Args: request: 聊天请求数据(JSON 格式) http_request: HTTP 请求对象(用于获取客户端 IP) current_user: 当前登录用户 moderation_service: 内容审核服务(依赖注入) Returns: EventSourceResponse: 服务器发送事件流式响应 """ # 获取客户端 IP 地址 client_ip = http_request.client.host if http_request.client else None # 如果使用了代理,尝试从 X-Forwarded-For 或 X-Real-IP 头获取真实 IP if "x-forwarded-for" in http_request.headers: client_ip = http_request.headers["x-forwarded-for"].split(",")[0].strip() elif "x-real-ip" in http_request.headers: client_ip = http_request.headers["x-real-ip"] logger.info( f"用户 {current_user.username} (ID: {current_user.id}) 发起聊天请求,thread_id: {request.thread_id}, " f"text2img: {request.text2img}, text2video: {request.text2video}, text2poster: {request.text2poster}, " f"translate: {request.translate}, knowledge_base_id: {request.knowledge_base_id}, " f"knowledge_graph_id: {request.knowledge_graph_id}, " f"llm_provider: {request.llm_provider}, llm_model: {request.llm_model}, ip={client_ip}" ) # ============ 内容审核前置处理 ============ # 在 AI 处理前对用户消息进行内容审核 try: # 生成唯一请求 ID 用于追踪 import uuid moderation_request_id = str(uuid.uuid4()) logger.info( f"开始内容审核 - user_id: {current_user.id}, " f"request_id: {moderation_request_id}, " f"message_length: {len(request.query)}" ) # 调用审核服务 moderation_result = await moderation_service.moderate_text( text=request.query, request_id=moderation_request_id ) logger.info( f"审核完成 - user_id: {current_user.id}, " f"request_id: {moderation_request_id}, " f"decision: {moderation_result.decision.value}, " f"labels: {[label.label for label in moderation_result.labels]}" ) # 处理 BLOCK 决策:阻止内容 if moderation_result.decision == ModerationDecision.BLOCK: logger.warning( f"内容被阻止 - user_id: {current_user.id}, " f"username: {current_user.username}, " f"request_id: {moderation_request_id}, " f"labels: {[label.label for label in moderation_result.labels]}" ) # 使用统一的错误响应格式(与图片审核一致) from fastapi.responses import JSONResponse return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={ "code": 400, "msg": "您的消息包含不当内容,无法处理。", "data": None } ) # 处理 REVIEW 决策:根据配置决定是否允许 if moderation_result.decision == ModerationDecision.REVIEW: logger.info( f"内容需要复审 - user_id: {current_user.id}, " f"username: {current_user.username}, " f"request_id: {moderation_request_id}, " f"labels: {[label.label for label in moderation_result.labels]}" ) # 检查配置:是否阻止需要复审的内容 from core.config import get_settings settings_obj = get_settings() if settings_obj.moderation_review_action == "block": logger.warning( f"复审内容被阻止(配置策略)- user_id: {current_user.id}, " f"request_id: {moderation_request_id}" ) # 使用统一的错误响应格式(与图片审核一致) from fastapi.responses import JSONResponse return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={ "code": 400, "msg": "您的消息需要人工复审,暂时无法处理。", "data": None } ) else: logger.info( f"复审内容允许通过(配置策略)- user_id: {current_user.id}, " f"request_id: {moderation_request_id}" ) # PASS 决策:继续正常流程 logger.info( f"内容审核通过 - user_id: {current_user.id}, " f"request_id: {moderation_request_id}" ) except HTTPException: # 重新抛出 HTTPException(BLOCK 或 REVIEW 阻止) raise except ModerationError as e: # 审核服务错误:降级模式,记录错误但允许继续 logger.error( f"审核服务错误(降级模式)- user_id: {current_user.id}, " f"error: {e.message}, " f"original_error: {str(e.original_error) if e.original_error else None}" ) # 不抛出异常,允许消息继续处理 except Exception as e: # 未预期的错误:降级模式,记录错误但允许继续 logger.error( f"审核过程未知错误(降级模式)- user_id: {current_user.id}, " f"error_type: {type(e).__name__}, " f"error: {str(e)}" ) # 不抛出异常,允许消息继续处理 # ============ 内容审核结束 ============ llm_provider = normalize_provider(request.llm_provider) llm_model_key = coerce_model_id(llm_provider, request.llm_model) cfg_err = validate_request_can_use_provider(llm_provider) if cfg_err: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"code": 400, "msg": cfg_err, "data": None}, ) try: api_model = resolve_to_api_model(llm_provider, llm_model_key) except ValueError as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"code": 400, "msg": str(e), "data": None}, ) try: pool = await get_db_pool() async with pool.acquire() as conn: reasoner_row = await conn.fetchrow( "SELECT is_reasoner FROM user_list WHERE id = $1", current_user.id, ) user_is_reasoner = bool(reasoner_row["is_reasoner"]) if reasoner_row and reasoner_row["is_reasoner"] is not None else False if llm_provider == "deepseek": api_model = deepseek_api_model_by_reasoner_setting(user_is_reasoner=user_is_reasoner) model = build_chat_model_for_completion( llm_provider, api_model, enable_thinking=user_is_reasoner, logical_llm_id=llm_model_key, ) logger.debug( "chat_completion 模型: provider={} req_llm_model={} api_model={} user_is_reasoner={}", llm_provider, llm_model_key, api_model, user_is_reasoner, ) except ValueError as e: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={"code": 400, "msg": str(e), "data": None}, ) # DeepSeek:落库与调用均以 is_reasoner 决定的实际 API 模型为准,忽略请求里的 llm_model thread_llm_model = ( deepseek_api_model_by_reasoner_setting(user_is_reasoner=user_is_reasoner) if llm_provider == "deepseek" else llm_model_key ) # 记录会话到数据库(如果未携带 knowledge_base_id 则更新为 null) await create_or_update_chat_thread( thread_id=request.thread_id, user_id=current_user.id, query=request.query, knowledge_base_id=request.knowledge_base_id, knowledge_graph_id=request.knowledge_graph_id, ip=client_ip, llm_provider=llm_provider, llm_model=thread_llm_model, ) config = {"configurable": {"thread_id": request.thread_id}, "recursion_limit": 30} checkpointer = await get_checkpointer() # 翻译 / 文生图 等模式的 agent 在此处创建;普通聊天需在 generate 内根据文件与知识库上下文再创建 agent use_early_agent = bool( request.translate or request.text2video or request.text2img or request.text2poster ) agent_early = None if use_early_agent: agent_early = await _create_agent_for_request( request=request, current_user=current_user, model=model, checkpointer=checkpointer, config=config, llm_provider=llm_provider, api_model=api_model, user_is_reasoner=user_is_reasoner, ) # 创建流式生成器 async def generate(): try: pool = await get_db_pool() unlinked_files = [] system_context_sections: list[str] = [] file_list_for_intent: list = [] async with pool.acquire() as conn: rows = await conn.fetch( """ SELECT id FROM chat_thread_file WHERE thread_id = $1 AND is_deleted = FALSE AND created_at > NOW() - INTERVAL '5 minutes' AND id NOT IN ( SELECT DISTINCT file_id FROM chat_message_file WHERE thread_id = $1 ) ORDER BY created_at DESC LIMIT 10 """, request.thread_id ) unlinked_files = [row['id'] for row in rows] if not use_early_agent: from services.chat_thread_file_service import ChatThreadFileService from services.rag_intent_service import get_rag_intent_service file_summaries_raw = await ChatThreadFileService.get_recent_files_with_summary( conn, request.thread_id, limit=5 ) for file_info in file_summaries_raw: row = await conn.fetchrow( "SELECT id FROM chat_thread_file WHERE thread_id = $1 AND file_name = $2 AND is_deleted = FALSE", request.thread_id, file_info['file_name'] ) if row: file_list_for_intent.append({ "file_id": row['id'], "file_name": file_info['file_name'], "summary": file_info['summary'] }) if not use_early_agent and file_list_for_intent: intent_service = await get_rag_intent_service() intents = await intent_service.judge_intent( query=request.query, file_list=file_list_for_intent ) logger.info(f"💡 意图判断结果: {len(intents)} 个文件相关") if intents: summary_files = [i for i in intents if i.question_type == "summary"] search_files = [i for i in intents if i.question_type == "search"] if summary_files: logger.info(f"📄 检测到 {len(summary_files)} 个文件需要完整内容") from services.chat_thread_file_service import ChatThreadFileService summary_prefix = "\n\n" + "="*60 + "\n" summary_prefix += "📎 **已为您准备的文件完整内容**\n" summary_prefix += "="*60 + "\n" summary_prefix += "⚠️ 重要:以下是文件的完整关键信息,请直接使用这些信息回答用户问题。\n\n" async with pool.acquire() as conn: for idx, intent in enumerate(summary_files, 1): logger.info(f"🔍 正在从数据库获取文件完整内容: file_id={intent.file_id}, file_name={intent.file_name}") all_chunks = await ChatThreadFileService.get_file_chunks_from_db( conn, intent.file_id ) logger.info(f"📊 获取结果: 返回了 {len(all_chunks)} 个chunks") if all_chunks: logger.info(f"📝 Chunks详情:") for i, chunk in enumerate(all_chunks[:3]): logger.info(f" - Chunk {i}: index={chunk.get('chunk_index')}, 内容长度={len(chunk.get('content', ''))}, 有摘要={bool(chunk.get('summary'))}") if len(all_chunks) > 3: logger.info(f" - ... 还有 {len(all_chunks) - 3} 个chunks") else: logger.warning(f"⚠️ 文件 {intent.file_name} (ID: {intent.file_id}) 未返回任何chunks!") if all_chunks: full_content = "\n\n".join([chunk['content'] for chunk in all_chunks]) file_summary = all_chunks[0].get('summary', '') file_type_icon = "🖼️" if intent.file_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp')) else "📄" summary_prefix += f"{file_type_icon} 文件 {idx}: {intent.file_name}\n" summary_prefix += f"{'─'*60}\n" if file_summary: summary_prefix += f"【文件摘要】\n{file_summary}\n\n" max_content_length = 8000 if len(full_content) > max_content_length: full_content = full_content[:max_content_length] + "\n...(内容过长,已截断)" summary_prefix += f"【完整内容】\n{full_content}\n" summary_prefix += f"{'─'*60}\n\n" logger.info(f"✅ 已注入文件 {intent.file_name} (摘要{len(file_summary)}字 + 内容{len(full_content)}字)") summary_prefix += "="*60 + "\n" system_context_sections.append(summary_prefix) elif search_files: logger.info(f"🔍 检测到 {len(search_files)} 个文件需要向量检索") search_hint = "\n\n" + "="*60 + "\n" search_hint += "📎 **重要提示**\n" search_hint += "="*60 + "\n" search_hint += "⚠️ 用户刚刚上传了以下文件,你的回答必须基于这些文件内容:\n\n" for idx, intent in enumerate(search_files, 1): search_hint += f"{idx}. 📄 {intent.file_name}\n" search_hint += "\n**请务必使用检索工具查询文件内容,不要使用你的训练数据!**\n" search_hint += "="*60 + "\n" system_context_sections.append(search_hint) else: logger.warning("⚠️ 意图类型未识别,降级为注入摘要") summary_prefix = "\n\n📎 **文件摘要**:\n" for file_info in file_list_for_intent: summary_prefix += f"【{file_info['file_name']}】\n{file_info['summary']}\n\n" system_context_sections.append(summary_prefix) else: logger.info("ℹ️ 未识别到相关文件,不注入内容") if not use_early_agent and request.knowledge_base_id and not request.knowledge_graph_id: from services.knowledge_base_file_service import KnowledgeBaseFileService kb_files_raw = [] async with pool.acquire() as conn: kb_files_raw = await KnowledgeBaseFileService.get_recent_files_with_summary( conn, request.knowledge_base_id, limit=5 ) if kb_files_raw: logger.info(f"📚 知识库 {request.knowledge_base_id} 检测到 {len(kb_files_raw)} 个最近文件") intent_service = await get_rag_intent_service() kb_intents = await intent_service.judge_intent( query=request.query, file_list=kb_files_raw ) logger.info(f"💡 知识库意图判断结果: {len(kb_intents)} 个文件相关") if kb_intents: summary_files = [i for i in kb_intents if i.question_type == "summary"] search_files = [i for i in kb_intents if i.question_type == "search"] if summary_files: logger.info(f"📄 知识库检测到 {len(summary_files)} 个文件需要完整内容") kb_prefix = "\n\n" + "="*60 + "\n" kb_prefix += "📚 **知识库文件完整内容**\n" kb_prefix += "="*60 + "\n" kb_prefix += "⚠️ 重要:以下是知识库文件的完整关键信息,请直接使用这些信息回答用户问题。\n\n" async with pool.acquire() as conn: for idx, intent in enumerate(summary_files, 1): logger.info(f"🔍 正在从数据库获取知识库文件: file_id={intent.file_id}, file_name={intent.file_name}") all_chunks = await KnowledgeBaseFileService.get_file_chunks_from_db( conn, intent.file_id ) logger.info(f"📊 获取结果: 返回了 {len(all_chunks)} 个chunks") if all_chunks: full_content = "\n\n".join([chunk['content'] for chunk in all_chunks]) file_summary = all_chunks[0].get('summary', '') file_type_icon = "🖼️" if intent.file_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp')) else "📄" kb_prefix += f"{file_type_icon} 文件 {idx}: {intent.file_name}\n" kb_prefix += f"{'─'*60}\n" if file_summary: kb_prefix += f"【文件摘要】\n{file_summary}\n\n" max_content_length = 8000 if len(full_content) > max_content_length: full_content = full_content[:max_content_length] + "\n...(内容过长,已截断)" kb_prefix += f"【完整内容】\n{full_content}\n" kb_prefix += f"{'─'*60}\n\n" logger.info(f"✅ 已注入知识库文件 {intent.file_name} (摘要{len(file_summary)}字 + 内容{len(full_content)}字)") kb_prefix += "="*60 + "\n" system_context_sections.append(kb_prefix) elif search_files: logger.info(f"🔍 知识库检测到 {len(search_files)} 个文件需要向量检索") kb_hint = "\n\n" + "="*60 + "\n" kb_hint += "📚 **知识库检索提示**\n" kb_hint += "="*60 + "\n" kb_hint += "⚠️ 用户的知识库包含以下文件,你的回答必须基于这些文件内容:\n\n" for idx, intent in enumerate(search_files, 1): kb_hint += f"{idx}. 📄 {intent.file_name}\n" kb_hint += "\n**请务必使用知识库检索工具查询文件内容,不要使用你的训练数据!**\n" kb_hint += "="*60 + "\n" system_context_sections.append(kb_hint) extra_system_context = "\n\n".join( s.strip() for s in system_context_sections if s and str(s).strip() ) if use_early_agent: active_agent = agent_early else: active_agent = await _create_agent_for_request( request=request, current_user=current_user, model=model, checkpointer=checkpointer, config=config, llm_provider=llm_provider, api_model=api_model, user_is_reasoner=user_is_reasoner, extra_system_context=extra_system_context.strip() if extra_system_context.strip() else None, ) try: async for event_data in active_agent.astream( { "messages": [ { "role": "user", "content": request.query, } ] }, stream_mode="messages", config=config ): message, metadata = event_data try: message_dict = message_to_dict(message) except Exception as e: logger.warning(f"使用 message_to_dict 序列化失败: {e}") message_dict = { "content": message.content if hasattr(message, 'content') else str(message), "type": message.__class__.__name__ if hasattr(message, '__class__') else "unknown" } if hasattr(message, 'additional_kwargs'): message_dict['additional_kwargs'] = message.additional_kwargs or {} if hasattr(message, 'response_metadata'): message_dict['response_metadata'] = message.response_metadata or {} if hasattr(message, 'id'): message_dict['id'] = message.id serializable_data = { "message": message_dict, "metadata": metadata } yield json.dumps(serializable_data, ensure_ascii=False, default=str) except IndexError as e: # 兼容部分 LLM + 工具流组合的 chunk 对齐问题(曾为 ChatTongyi 场景保留) if "list index out of range" in str(e): error_msg = "抱歉,处理您的请求时遇到了技术问题(流式工具调用错误)。请稍后重试,或尝试简化您的问题。" logger.error(f"LangChain 流式工具调用 bug: {e}") yield json.dumps({ "error": error_msg, "message": { "type": "AIMessageChunk", "data": { "content": error_msg, "additional_kwargs": {}, "response_metadata": {} } } }, ensure_ascii=False) else: raise # 关联文件到消息 if unlinked_files: await _associate_files_to_message( request.thread_id, unlinked_files, checkpointer, config, pool ) # 💾 [V2] 保存消息到 chat_messages 表(双写策略) try: await _save_messages_to_chat_messages_table( thread_id=request.thread_id, user_query=request.query, user_message_content=request.query, has_files=len(unlinked_files) > 0, checkpointer=checkpointer, config=config, pool=pool, system_attached_context=extra_system_context.strip() if extra_system_context.strip() else None, ) except Exception as save_err: # 不影响主流程,只记录日志 logger.error(f"[V2] 保存消息到 chat_messages 表失败: {save_err}") except Exception as e: logger.exception(f"聊天接口错误: {e}") yield json.dumps({"error": str(e)}, ensure_ascii=False) yield "[DONE]" return EventSourceResponse( generate(), ping_message_factory=lambda: ServerSentEvent(data=f"ping - {datetime.now(timezone.utc)}") ) def _wrap_mcp_tool_safe(tool): """ 包装 MCP 工具,确保第三方服务异常(网络不可达、超时等)被捕获并以 错误字符串返回,而非向上抛出。 这样 LangChain Agent 在工具调用失败时仍能写入对应的 ToolMessage, 保持 checkpoint 中消息历史的完整性,避免下次对话因 tool_calls 缺少 配对回复而触发 LLM 400 BadRequest 错误。 """ original_arun = tool._arun async def safe_arun(*args, **kwargs): try: return await original_arun(*args, **kwargs) except Exception as exc: error_msg = f"工具 [{tool.name}] 调用失败: {exc}" logger.warning(f"MCP 工具异常已捕获,返回错误字符串以维持对话完整性: {error_msg}") return error_msg tool._arun = safe_arun return tool async def _create_agent_for_request( request: ChatRequest, current_user: User, model, checkpointer, config: dict, llm_provider: str, api_model: str, user_is_reasoner: bool, extra_system_context: Optional[str] = None, ): """ 根据请求类型创建对应的 Agent """ ai_display_name = await EnterpriseService.resolve_ai_display_name(current_user.enterprise_id) # 翻译模式 if request.translate: logger.info(f"使用翻译模式") language_names = { 'auto': '自动检测', 'zh': '中文(简体)', 'zh-TW': '中文(繁體)', 'en': '英文', 'ja': '日文', 'ko': '韩文', 'fr': '法文', 'de': '德文', 'es': '西班牙文', 'ru': '俄文' } from_lang = language_names.get(request.from_language, '自动检测') target_lang = language_names.get(request.target_language, '英文') return create_agent( model=model, tools=[get_current_time], system_prompt=get_translate_instructions(from_lang, target_lang, ai_display_name), checkpointer=checkpointer ) # 文生视频模式 if request.text2video: logger.info("使用文生视频模式") return create_agent( model=model, tools=[get_current_time, text_to_video], system_prompt=get_text2video_instructions(ai_display_name), checkpointer=checkpointer ) # 文生图模式 if request.text2img: logger.info("使用文生图模式") return create_agent( model=model, tools=[get_current_time, text_to_image], system_prompt=get_text2img_instructions(ai_display_name), checkpointer=checkpointer ) # 创意海报生成模式 if request.text2poster: logger.info("使用创意海报生成模式") return create_agent( model=model, tools=[get_current_time, text_to_poster], system_prompt=get_text2poster_instructions(ai_display_name), checkpointer=checkpointer ) # 普通聊天模式 mcp_client = await get_mcp_client() mcp_tools = await mcp_client.get_tools() # mcp_tools = [_wrap_mcp_tool_safe(t) for t in mcp_tools] logger.info(f"成功加载 {len(mcp_tools)} 个 MCP 工具") # 查询用户设置(深度思考是否与模型侧一致取决于 user_list.is_reasoner, # 已在 chat_completion 中读取并传入 user_is_reasoner) pool = await get_db_pool() async with pool.acquire() as conn: user_row = await conn.fetchrow( "SELECT is_search FROM user_list WHERE id = $1", current_user.id, ) user_is_search = bool(user_row["is_search"]) if user_row and user_row["is_search"] is not None else False use_reasoner_mode = user_is_reasoner # 检查文件 has_files = await check_thread_has_files(request.thread_id) has_kb_files = False kb_id = request.knowledge_base_id if kb_id and not request.knowledge_graph_id: has_kb_files = await check_knowledge_base_has_files(kb_id, current_user.id) kg_flags = {"has_rag": False, "neo4j_graph_id": None} if request.knowledge_graph_id: kg_flags = await get_knowledge_graph_tool_flags(current_user, request.knowledge_graph_id) has_novel_rag = bool(kg_flags.get("has_rag")) has_kg_neo4j = bool(kg_flags.get("neo4j_graph_id")) # 获取系统提示词(绑定图谱时:有正文 RAG 和/或有 Neo4j 关系查询) has_kg_any = has_novel_rag or has_kg_neo4j research_instructions = get_research_instructions( has_files=has_files, has_kb_files=has_kb_files, use_reasoner_mode=use_reasoner_mode, has_knowledge_graph=has_kg_any, has_knowledge_graph_neo4j=has_kg_neo4j, ai_display_name=ai_display_name, ) # 构建工具列表:Neo4j 关系查询优先于正文 RAG,便于「谁是谁」类问题先走图 all_tools = [get_current_time] + mcp_tools.copy() if has_files: all_tools = [create_rag_retrieve_tool(request.thread_id)] + all_tools if request.knowledge_graph_id and has_kg_neo4j: all_tools = [ create_knowledge_graph_neo4j_search_tool(kg_flags["neo4j_graph_id"]) ] + all_tools if request.knowledge_graph_id and has_novel_rag: all_tools = [create_knowledge_graph_rag_retrieve_tool(request.knowledge_graph_id)] + all_tools elif kb_id and has_kb_files: all_tools = [create_kb_rag_retrieve_tool(kb_id)] + all_tools if user_is_search: all_tools = [internet_search] + all_tools logger.info(f"用户 {current_user.username} 启用了联网搜索功能") # 深度思考模式:递归上限调高;底层模型已在入口处按 is_reasoner 配置好 ChatDeepSeek / ChatOpenAI(通义) if use_reasoner_mode: config["recursion_limit"] = 60 logger.info( f"用户 {current_user.username} 启用深度思考模式(提供方={llm_provider}, api_model={api_model})" ) print(f"model: {model}") merged_system_prompt = research_instructions.rstrip() if extra_system_context and extra_system_context.strip(): merged_system_prompt = ( merged_system_prompt + "\n\n【本轮系统提供的参考上下文(用户原话仅在用户消息中;请优先依据此处与工具检索结果作答)】\n" + extra_system_context.strip() ) return create_agent( model=model, tools=all_tools, system_prompt=merged_system_prompt, checkpointer=checkpointer, ) async def _associate_files_to_message( thread_id: str, unlinked_files: list, checkpointer, config: dict, pool ): """ 将未关联的文件关联到用户消息 """ import asyncio latest_checkpoint = None for attempt in range(5): await asyncio.sleep(0.2) try: latest_checkpoint = await checkpointer.aget_tuple(config) if latest_checkpoint: break except Exception as e: logger.debug(f"获取 checkpoint 失败(尝试 {attempt + 1}/5): {e}") if not latest_checkpoint: logger.warning("无法获取最新的 checkpoint,文件关联失败") return try: checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"] checkpoint_data = latest_checkpoint.checkpoint if "channel_values" not in checkpoint_data or "messages" not in checkpoint_data["channel_values"]: return messages = checkpoint_data["channel_values"]["messages"] user_message_index = None for idx in range(len(messages) - 1, -1, -1): msg = messages[idx] if hasattr(msg, 'type') and msg.type == "human": user_message_index = idx break if user_message_index is not None: async with pool.acquire() as conn: for file_id in unlinked_files: try: await ChatMessageFileService.create_message_file_association( conn, thread_id, checkpoint_id, user_message_index, file_id ) logger.info(f"文件 {file_id} 已关联到消息") except Exception as e: logger.warning(f"关联文件到消息失败: {e}") except Exception as e: logger.warning(f"关联文件到消息失败: {e}") async def _save_messages_to_chat_messages_table( thread_id: str, user_query: str, user_message_content: str, has_files: bool, checkpointer, config: dict, pool, system_attached_context: Optional[str] = None, ): """ 将消息保存到 chat_messages 表(双写策略) 这是V2版本的核心:将用户原始问题和AI响应保存到独立的 chat_messages 表, 便于快速查询和分析,避免解析 checkpoint JSONB。 Args: thread_id: 会话线程 ID user_query: 用户原始问题 user_message_content: 用户消息内容(与 checkpoint 中 human 一致;可与 user_query 相同) system_attached_context: 附在系统提示上的参考上下文(不入用户原话),写入 injected_content has_files: 是否关联了文件 checkpointer: checkpoint 管理器 config: 配置 pool: 数据库连接池 """ import asyncio try: # 等待 checkpoint 更新 latest_checkpoint = None for attempt in range(5): await asyncio.sleep(0.2) try: latest_checkpoint = await checkpointer.aget_tuple(config) if latest_checkpoint: break except Exception as e: logger.debug(f"获取 checkpoint 失败(尝试 {attempt + 1}/5): {e}") if not latest_checkpoint: logger.warning("无法获取最新的 checkpoint,消息保存失败") return checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"] checkpoint_data = latest_checkpoint.checkpoint if "channel_values" not in checkpoint_data or "messages" not in checkpoint_data["channel_values"]: return messages = checkpoint_data["channel_values"]["messages"] if len(messages) < 2: logger.warning(f"[V2] 消息数量不足,跳过保存: messages_count={len(messages)}") return # 🔥 核心修复:找到最后一个用户消息,保存从该消息开始的所有后续消息 # 这样可以包含完整的对话轮次:用户消息 → AI调用工具 → 工具结果 → AI最终回复 last_user_index = None for idx in range(len(messages) - 1, -1, -1): msg = messages[idx] if hasattr(msg, 'type') and msg.type == "human": last_user_index = idx break if last_user_index is None: logger.warning(f"[V2] 未找到用户消息,跳过保存") return # 获取从最后一个用户消息开始的所有消息 messages_to_save = messages[last_user_index:] async with pool.acquire() as conn: # 检查是否已经保存过(通过检查最后一条 AI 消息) last_ai_index = len(messages) - 1 existing = await conn.fetchval( """ SELECT id FROM chat_messages WHERE thread_id = $1 AND checkpoint_id = $2 AND message_index = $3 """, thread_id, checkpoint_id, last_ai_index ) if existing: logger.debug(f"[V2] 消息已存在,跳过保存: thread_id={thread_id}, checkpoint_id={checkpoint_id}") return # 保存当前轮次的所有消息 for relative_idx, msg in enumerate(messages_to_save): actual_index = last_user_index + relative_idx if not hasattr(msg, 'type'): continue msg_type = msg.type msg_content = getattr(msg, 'content', '') or '' # 检查单条消息是否已存在 existing = await conn.fetchval( "SELECT id FROM chat_messages WHERE thread_id = $1 AND checkpoint_id = $2 AND message_index = $3", thread_id, checkpoint_id, actual_index ) if existing: logger.debug(f"[V2] 消息已存在: index={actual_index}, type={msg_type}") continue if msg_type == "human": # 保存用户消息(只在第一条消息时使用传入的参数) if relative_idx == 0: # 当前轮用户消息:正文为原始问题;系统附加材料单独落库 if system_attached_context and system_attached_context.strip(): injected_content = system_attached_context.strip() else: injected_content = ( user_message_content if user_message_content != user_query else None ) await ChatMessageService.save_user_message( conn, thread_id=thread_id, checkpoint_id=checkpoint_id, message_index=actual_index, content=user_query, # 用户原始问题 injected_content=injected_content, # 注入的完整内容(如果有) has_files=has_files, metadata={} ) logger.info(f"✅ [V2] 保存用户消息: index={actual_index}, content='{user_query[:50]}...', has_files={has_files}") else: # 理论上不应该出现第二条是 human 的情况 await ChatMessageService.save_user_message( conn, thread_id=thread_id, checkpoint_id=checkpoint_id, message_index=actual_index, content=msg_content, injected_content=None, has_files=False, metadata={} ) logger.warning(f"⚠️ [V2] 非预期的用户消息位置: index={actual_index}") elif msg_type == "ai": # 保存 AI 消息 reasoning_content = "" if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs: reasoning_content = msg.additional_kwargs.get("reasoning_content", "") or "" token_usage = {} finish_reason = "" if hasattr(msg, 'response_metadata') and msg.response_metadata: token_usage = msg.response_metadata.get('token_usage', {}) finish_reason = msg.response_metadata.get('finish_reason', '') await ChatMessageService.save_assistant_message( conn, thread_id=thread_id, checkpoint_id=checkpoint_id, message_index=actual_index, content=msg_content, metadata={ 'token_usage': token_usage, 'finish_reason': finish_reason, 'reasoning_content': reasoning_content } ) logger.info(f"✅ [V2] 保存AI消息: index={actual_index}, content_length={len(msg_content)}") elif msg_type == "tool": # 保存工具消息 tool_name = getattr(msg, 'name', '') or '' await ChatMessageService.save_tool_message( conn, thread_id=thread_id, checkpoint_id=checkpoint_id, message_index=actual_index, content=msg_content, name=tool_name, # 工具名称(如 text_to_poster, internet_search 等) metadata={'tool_name': tool_name} # 保留在 metadata 中以兼容 ) logger.info(f"✅ [V2] 保存工具消息: index={actual_index}, tool={tool_name}") logger.info(f"✅ [V2] 消息保存完成: thread_id={thread_id}, checkpoint_id={checkpoint_id}, saved_count={len(messages_to_save)}") except Exception as e: logger.error(f"[V2] 保存消息到 chat_messages 表失败: {e}") # 不抛出异常,避免影响主流程 @chat_router.get("/chat/threads", summary="获取会话列表", response_model=ChatThreadListResponse) async def get_threads( page: int = 1, page_size: int = 20, current_user: User = Depends(get_current_user) ): """获取用户的会话列表(分页)""" logger.info(f"用户 {current_user.username} 查询会话列表: page={page}") if page < 1: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="页码必须大于 0") if page_size < 1 or page_size > 100: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="每页数量必须在 1-100 之间") try: return await get_user_chat_threads(user_id=current_user.id, page=page, page_size=page_size) except Exception as e: logger.error(f"查询会话列表失败: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="查询会话列表失败") @chat_router.get("/chat/thread/{thread_id}", summary="获取会话明细", response_model=ChatThreadDetailResponse) async def get_thread_detail( thread_id: str, current_user: User = Depends(get_current_user) ): """获取会话的聊天明细""" logger.info(f"用户 {current_user.username} 查询会话明细: thread_id={thread_id}") return await get_chat_thread_detail(thread_id=thread_id, user_id=current_user.id) @chat_router.delete("/chat/thread", summary="删除会话") async def delete_thread( request: DeleteThreadRequest, current_user: User = Depends(get_current_user) ): """删除聊天会话(软删除)""" logger.info(f"用户 {current_user.username} 请求删除会话: {request.thread_id}") try: await delete_chat_thread(thread_id=request.thread_id, user_id=current_user.id) return BaseResponse(code=200, msg="会话删除成功", data={"thread_id": request.thread_id}) except HTTPException: raise except Exception as e: logger.error(f"删除会话失败: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除会话失败") # ============ V2 版本路由(基于 chat_messages 表) ============ @chat_router.get("/chat/thread/{thread_id}/v2", summary="获取会话明细 V2(推荐)", response_model=ChatThreadDetailResponse) async def get_thread_detail_v2( thread_id: str, current_user: User = Depends(get_current_user) ): """ 获取会话的聊天明细(V2版本) **V2 优势**: - ✅ 查询速度更快(直接从 chat_messages 表查询,无需解析 checkpoint JSONB) - ✅ 用户原始问题和注入内容分离(返回的 content 只包含用户原始问题) - ✅ 支持全文搜索、统计分析 - ✅ 数据结构更清晰 **注意**:需要先启用消息双写功能,才能使用此接口 """ logger.info(f"用户 {current_user.username} 查询会话明细 V2: thread_id={thread_id}") return await get_chat_thread_detail_v2(thread_id=thread_id, user_id=current_user.id) # ============ 兼容旧路由(保持向后兼容) ============ # 这些路由已迁移到 /api/user/ 前缀下,但为了兼容旧客户端保留 from services.user_setting_service import UserSettingService from models.chat import ( SearchSettingResponse, UpdateSearchSettingRequest, ReasonerSettingResponse, UpdateReasonerSettingRequest, ) @chat_router.get("/chat/search-setting", summary="获取用户联网搜索设置(兼容)", response_model=SearchSettingResponse, deprecated=True) async def get_search_setting_compat(current_user: User = Depends(get_current_user)): """获取当前用户的联网搜索设置(已迁移到 /api/user/search-setting)""" is_search = await UserSettingService.get_search_setting(current_user.id) return SearchSettingResponse(is_search=is_search) @chat_router.put("/chat/search-setting", summary="更新用户联网搜索设置(兼容)", response_model=SearchSettingResponse, deprecated=True) async def update_search_setting_compat( request: UpdateSearchSettingRequest, current_user: User = Depends(get_current_user) ): """更新当前用户的联网搜索设置(已迁移到 /api/user/search-setting)""" is_search = await UserSettingService.update_search_setting(current_user.id, request.is_search) return SearchSettingResponse(is_search=is_search) @chat_router.get("/chat/reasoner-setting", summary="获取用户深度思考设置(兼容)", response_model=ReasonerSettingResponse, deprecated=True) async def get_reasoner_setting_compat(current_user: User = Depends(get_current_user)): """获取当前用户的深度思考设置(已迁移到 /api/user/reasoner-setting)""" is_reasoner = await UserSettingService.get_reasoner_setting(current_user.id) return ReasonerSettingResponse(is_reasoner=is_reasoner) @chat_router.put("/chat/reasoner-setting", summary="更新用户深度思考设置(兼容)", response_model=ReasonerSettingResponse, deprecated=True) async def update_reasoner_setting_compat( request: UpdateReasonerSettingRequest, current_user: User = Depends(get_current_user) ): """更新当前用户的深度思考设置(已迁移到 /api/user/reasoner-setting)""" is_reasoner = await UserSettingService.update_reasoner_setting(current_user.id, request.is_reasoner) return ReasonerSettingResponse(is_reasoner=is_reasoner)