huoyan-enterprise/backend/api/chat_router.py

1168 lines
51 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天 API 路由模块
定义聊天相关的 API 路由,包括对话、会话管理等接口。
"""
from __future__ import annotations
import json
import os
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse
from langchain_core.messages import message_to_dict
from langchain.agents import create_agent
from sse_starlette import EventSourceResponse, ServerSentEvent
from core.config import settings
from core.llm_catalog import (
build_chat_model_for_completion,
coerce_model_id,
deepseek_api_model_by_reasoner_setting,
list_llm_options_payload,
normalize_provider,
resolve_to_api_model,
validate_request_can_use_provider,
)
from core.database import get_db_pool, get_checkpointer
from core.mcp_client import get_mcp_client
from core.dependencies import get_current_user, get_moderation_service
from core.exceptions import ModerationError
from models.user import User
from models.moderation import ModerationDecision
from models.chat import (
ChatRequest,
DeleteThreadRequest,
ChatThreadListResponse,
ChatThreadDetailResponse,
)
from prompt.prompt import (
get_translate_instructions,
get_text2video_instructions,
get_text2img_instructions,
get_text2poster_instructions,
get_research_instructions
)
from services.enterprise_service import EnterpriseService
from tools.tools import (
get_current_time,
internet_search,
text_to_image,
text_to_video,
text_to_poster,
create_rag_retrieve_tool,
create_kb_rag_retrieve_tool,
create_knowledge_graph_neo4j_search_tool,
create_knowledge_graph_rag_retrieve_tool,
)
from services.chat_thread_service import (
create_or_update_chat_thread,
delete_chat_thread,
get_user_chat_threads,
get_chat_thread_detail,
get_chat_thread_detail_v2, # V2版本基于 chat_messages 表
check_thread_has_files,
check_knowledge_base_has_files,
get_knowledge_graph_tool_flags,
)
from services.chat_message_file_service import ChatMessageFileService
from services.chat_message_service import ChatMessageService # 新增:消息保存服务
from utils.helpers import BaseResponse
from logger.logging import get_logger
logger = get_logger(__name__)
# 创建路由实例
chat_router = APIRouter(prefix="/api", tags=["聊天接口"])
@chat_router.get("/chat/llm-options", summary="聊天可选大模型列表")
async def chat_llm_options(current_user: User = Depends(get_current_user)):
"""返回前端下拉所需的提供方与模型逻辑 id不含密钥"""
return list_llm_options_payload()
@chat_router.post("/chat/completion", summary="聊天接口主入口")
async def chat_completion(
request: ChatRequest,
http_request: Request,
current_user: User = Depends(get_current_user),
moderation_service = Depends(get_moderation_service)
):
"""
聊天接口主入口(需要认证)
Args:
request: 聊天请求数据JSON 格式)
http_request: HTTP 请求对象(用于获取客户端 IP
current_user: 当前登录用户
moderation_service: 内容审核服务(依赖注入)
Returns:
EventSourceResponse: 服务器发送事件流式响应
"""
# 获取客户端 IP 地址
client_ip = http_request.client.host if http_request.client else None
# 如果使用了代理,尝试从 X-Forwarded-For 或 X-Real-IP 头获取真实 IP
if "x-forwarded-for" in http_request.headers:
client_ip = http_request.headers["x-forwarded-for"].split(",")[0].strip()
elif "x-real-ip" in http_request.headers:
client_ip = http_request.headers["x-real-ip"]
logger.info(
f"用户 {current_user.username} (ID: {current_user.id}) 发起聊天请求thread_id: {request.thread_id}, "
f"text2img: {request.text2img}, text2video: {request.text2video}, text2poster: {request.text2poster}, "
f"translate: {request.translate}, knowledge_base_id: {request.knowledge_base_id}, "
f"knowledge_graph_id: {request.knowledge_graph_id}, "
f"llm_provider: {request.llm_provider}, llm_model: {request.llm_model}, ip={client_ip}"
)
# ============ 内容审核前置处理 ============
# 在 AI 处理前对用户消息进行内容审核
try:
# 生成唯一请求 ID 用于追踪
import uuid
moderation_request_id = str(uuid.uuid4())
logger.info(
f"开始内容审核 - user_id: {current_user.id}, "
f"request_id: {moderation_request_id}, "
f"message_length: {len(request.query)}"
)
# 调用审核服务
moderation_result = await moderation_service.moderate_text(
text=request.query,
request_id=moderation_request_id
)
logger.info(
f"审核完成 - user_id: {current_user.id}, "
f"request_id: {moderation_request_id}, "
f"decision: {moderation_result.decision.value}, "
f"labels: {[label.label for label in moderation_result.labels]}"
)
# 处理 BLOCK 决策:阻止内容
if moderation_result.decision == ModerationDecision.BLOCK:
logger.warning(
f"内容被阻止 - user_id: {current_user.id}, "
f"username: {current_user.username}, "
f"request_id: {moderation_request_id}, "
f"labels: {[label.label for label in moderation_result.labels]}"
)
# 使用统一的错误响应格式(与图片审核一致)
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"code": 400,
"msg": "您的消息包含不当内容,无法处理。",
"data": None
}
)
# 处理 REVIEW 决策:根据配置决定是否允许
if moderation_result.decision == ModerationDecision.REVIEW:
logger.info(
f"内容需要复审 - user_id: {current_user.id}, "
f"username: {current_user.username}, "
f"request_id: {moderation_request_id}, "
f"labels: {[label.label for label in moderation_result.labels]}"
)
# 检查配置:是否阻止需要复审的内容
from core.config import get_settings
settings_obj = get_settings()
if settings_obj.moderation_review_action == "block":
logger.warning(
f"复审内容被阻止(配置策略)- user_id: {current_user.id}, "
f"request_id: {moderation_request_id}"
)
# 使用统一的错误响应格式(与图片审核一致)
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"code": 400,
"msg": "您的消息需要人工复审,暂时无法处理。",
"data": None
}
)
else:
logger.info(
f"复审内容允许通过(配置策略)- user_id: {current_user.id}, "
f"request_id: {moderation_request_id}"
)
# PASS 决策:继续正常流程
logger.info(
f"内容审核通过 - user_id: {current_user.id}, "
f"request_id: {moderation_request_id}"
)
except HTTPException:
# 重新抛出 HTTPExceptionBLOCK 或 REVIEW 阻止)
raise
except ModerationError as e:
# 审核服务错误:降级模式,记录错误但允许继续
logger.error(
f"审核服务错误(降级模式)- user_id: {current_user.id}, "
f"error: {e.message}, "
f"original_error: {str(e.original_error) if e.original_error else None}"
)
# 不抛出异常,允许消息继续处理
except Exception as e:
# 未预期的错误:降级模式,记录错误但允许继续
logger.error(
f"审核过程未知错误(降级模式)- user_id: {current_user.id}, "
f"error_type: {type(e).__name__}, "
f"error: {str(e)}"
)
# 不抛出异常,允许消息继续处理
# ============ 内容审核结束 ============
llm_provider = normalize_provider(request.llm_provider)
llm_model_key = coerce_model_id(llm_provider, request.llm_model)
cfg_err = validate_request_can_use_provider(llm_provider)
if cfg_err:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"code": 400, "msg": cfg_err, "data": None},
)
try:
api_model = resolve_to_api_model(llm_provider, llm_model_key)
except ValueError as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"code": 400, "msg": str(e), "data": None},
)
try:
pool = await get_db_pool()
async with pool.acquire() as conn:
reasoner_row = await conn.fetchrow(
"SELECT is_reasoner FROM user_list WHERE id = $1",
current_user.id,
)
user_is_reasoner = bool(reasoner_row["is_reasoner"]) if reasoner_row and reasoner_row["is_reasoner"] is not None else False
if llm_provider == "deepseek":
api_model = deepseek_api_model_by_reasoner_setting(user_is_reasoner=user_is_reasoner)
model = build_chat_model_for_completion(
llm_provider,
api_model,
enable_thinking=user_is_reasoner,
logical_llm_id=llm_model_key,
)
logger.debug(
"chat_completion 模型: provider={} req_llm_model={} api_model={} user_is_reasoner={}",
llm_provider,
llm_model_key,
api_model,
user_is_reasoner,
)
except ValueError as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"code": 400, "msg": str(e), "data": None},
)
# DeepSeek落库与调用均以 is_reasoner 决定的实际 API 模型为准,忽略请求里的 llm_model
thread_llm_model = (
deepseek_api_model_by_reasoner_setting(user_is_reasoner=user_is_reasoner)
if llm_provider == "deepseek"
else llm_model_key
)
# 记录会话到数据库(如果未携带 knowledge_base_id 则更新为 null
await create_or_update_chat_thread(
thread_id=request.thread_id,
user_id=current_user.id,
query=request.query,
knowledge_base_id=request.knowledge_base_id,
knowledge_graph_id=request.knowledge_graph_id,
ip=client_ip,
llm_provider=llm_provider,
llm_model=thread_llm_model,
)
config = {"configurable": {"thread_id": request.thread_id}, "recursion_limit": 30}
checkpointer = await get_checkpointer()
# 翻译 / 文生图 等模式的 agent 在此处创建;普通聊天需在 generate 内根据文件与知识库上下文再创建 agent
use_early_agent = bool(
request.translate or request.text2video or request.text2img or request.text2poster
)
agent_early = None
if use_early_agent:
agent_early = await _create_agent_for_request(
request=request,
current_user=current_user,
model=model,
checkpointer=checkpointer,
config=config,
llm_provider=llm_provider,
api_model=api_model,
user_is_reasoner=user_is_reasoner,
)
# 创建流式生成器
async def generate():
try:
pool = await get_db_pool()
unlinked_files = []
system_context_sections: list[str] = []
file_list_for_intent: list = []
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id FROM chat_thread_file
WHERE thread_id = $1
AND is_deleted = FALSE
AND created_at > NOW() - INTERVAL '5 minutes'
AND id NOT IN (
SELECT DISTINCT file_id FROM chat_message_file WHERE thread_id = $1
)
ORDER BY created_at DESC
LIMIT 10
""",
request.thread_id
)
unlinked_files = [row['id'] for row in rows]
if not use_early_agent:
from services.chat_thread_file_service import ChatThreadFileService
from services.rag_intent_service import get_rag_intent_service
file_summaries_raw = await ChatThreadFileService.get_recent_files_with_summary(
conn, request.thread_id, limit=5
)
for file_info in file_summaries_raw:
row = await conn.fetchrow(
"SELECT id FROM chat_thread_file WHERE thread_id = $1 AND file_name = $2 AND is_deleted = FALSE",
request.thread_id, file_info['file_name']
)
if row:
file_list_for_intent.append({
"file_id": row['id'],
"file_name": file_info['file_name'],
"summary": file_info['summary']
})
if not use_early_agent and file_list_for_intent:
intent_service = await get_rag_intent_service()
intents = await intent_service.judge_intent(
query=request.query,
file_list=file_list_for_intent
)
logger.info(f"💡 意图判断结果: {len(intents)} 个文件相关")
if intents:
summary_files = [i for i in intents if i.question_type == "summary"]
search_files = [i for i in intents if i.question_type == "search"]
if summary_files:
logger.info(f"📄 检测到 {len(summary_files)} 个文件需要完整内容")
from services.chat_thread_file_service import ChatThreadFileService
summary_prefix = "\n\n" + "="*60 + "\n"
summary_prefix += "📎 **已为您准备的文件完整内容**\n"
summary_prefix += "="*60 + "\n"
summary_prefix += "⚠️ 重要:以下是文件的完整关键信息,请直接使用这些信息回答用户问题。\n\n"
async with pool.acquire() as conn:
for idx, intent in enumerate(summary_files, 1):
logger.info(f"🔍 正在从数据库获取文件完整内容: file_id={intent.file_id}, file_name={intent.file_name}")
all_chunks = await ChatThreadFileService.get_file_chunks_from_db(
conn, intent.file_id
)
logger.info(f"📊 获取结果: 返回了 {len(all_chunks)} 个chunks")
if all_chunks:
logger.info(f"📝 Chunks详情:")
for i, chunk in enumerate(all_chunks[:3]):
logger.info(f" - Chunk {i}: index={chunk.get('chunk_index')}, 内容长度={len(chunk.get('content', ''))}, 有摘要={bool(chunk.get('summary'))}")
if len(all_chunks) > 3:
logger.info(f" - ... 还有 {len(all_chunks) - 3} 个chunks")
else:
logger.warning(f"⚠️ 文件 {intent.file_name} (ID: {intent.file_id}) 未返回任何chunks")
if all_chunks:
full_content = "\n\n".join([chunk['content'] for chunk in all_chunks])
file_summary = all_chunks[0].get('summary', '')
file_type_icon = "🖼️" if intent.file_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp')) else "📄"
summary_prefix += f"{file_type_icon} 文件 {idx}: {intent.file_name}\n"
summary_prefix += f"{''*60}\n"
if file_summary:
summary_prefix += f"【文件摘要】\n{file_summary}\n\n"
max_content_length = 8000
if len(full_content) > max_content_length:
full_content = full_content[:max_content_length] + "\n...(内容过长,已截断)"
summary_prefix += f"【完整内容】\n{full_content}\n"
summary_prefix += f"{''*60}\n\n"
logger.info(f"✅ 已注入文件 {intent.file_name} (摘要{len(file_summary)}字 + 内容{len(full_content)}字)")
summary_prefix += "="*60 + "\n"
system_context_sections.append(summary_prefix)
elif search_files:
logger.info(f"🔍 检测到 {len(search_files)} 个文件需要向量检索")
search_hint = "\n\n" + "="*60 + "\n"
search_hint += "📎 **重要提示**\n"
search_hint += "="*60 + "\n"
search_hint += "⚠️ 用户刚刚上传了以下文件,你的回答必须基于这些文件内容:\n\n"
for idx, intent in enumerate(search_files, 1):
search_hint += f"{idx}. 📄 {intent.file_name}\n"
search_hint += "\n**请务必使用检索工具查询文件内容,不要使用你的训练数据!**\n"
search_hint += "="*60 + "\n"
system_context_sections.append(search_hint)
else:
logger.warning("⚠️ 意图类型未识别,降级为注入摘要")
summary_prefix = "\n\n📎 **文件摘要**\n"
for file_info in file_list_for_intent:
summary_prefix += f"{file_info['file_name']}\n{file_info['summary']}\n\n"
system_context_sections.append(summary_prefix)
else:
logger.info(" 未识别到相关文件,不注入内容")
if not use_early_agent and request.knowledge_base_id and not request.knowledge_graph_id:
from services.knowledge_base_file_service import KnowledgeBaseFileService
kb_files_raw = []
async with pool.acquire() as conn:
kb_files_raw = await KnowledgeBaseFileService.get_recent_files_with_summary(
conn, request.knowledge_base_id, limit=5
)
if kb_files_raw:
logger.info(f"📚 知识库 {request.knowledge_base_id} 检测到 {len(kb_files_raw)} 个最近文件")
intent_service = await get_rag_intent_service()
kb_intents = await intent_service.judge_intent(
query=request.query,
file_list=kb_files_raw
)
logger.info(f"💡 知识库意图判断结果: {len(kb_intents)} 个文件相关")
if kb_intents:
summary_files = [i for i in kb_intents if i.question_type == "summary"]
search_files = [i for i in kb_intents if i.question_type == "search"]
if summary_files:
logger.info(f"📄 知识库检测到 {len(summary_files)} 个文件需要完整内容")
kb_prefix = "\n\n" + "="*60 + "\n"
kb_prefix += "📚 **知识库文件完整内容**\n"
kb_prefix += "="*60 + "\n"
kb_prefix += "⚠️ 重要:以下是知识库文件的完整关键信息,请直接使用这些信息回答用户问题。\n\n"
async with pool.acquire() as conn:
for idx, intent in enumerate(summary_files, 1):
logger.info(f"🔍 正在从数据库获取知识库文件: file_id={intent.file_id}, file_name={intent.file_name}")
all_chunks = await KnowledgeBaseFileService.get_file_chunks_from_db(
conn, intent.file_id
)
logger.info(f"📊 获取结果: 返回了 {len(all_chunks)} 个chunks")
if all_chunks:
full_content = "\n\n".join([chunk['content'] for chunk in all_chunks])
file_summary = all_chunks[0].get('summary', '')
file_type_icon = "🖼️" if intent.file_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp')) else "📄"
kb_prefix += f"{file_type_icon} 文件 {idx}: {intent.file_name}\n"
kb_prefix += f"{''*60}\n"
if file_summary:
kb_prefix += f"【文件摘要】\n{file_summary}\n\n"
max_content_length = 8000
if len(full_content) > max_content_length:
full_content = full_content[:max_content_length] + "\n...(内容过长,已截断)"
kb_prefix += f"【完整内容】\n{full_content}\n"
kb_prefix += f"{''*60}\n\n"
logger.info(f"✅ 已注入知识库文件 {intent.file_name} (摘要{len(file_summary)}字 + 内容{len(full_content)}字)")
kb_prefix += "="*60 + "\n"
system_context_sections.append(kb_prefix)
elif search_files:
logger.info(f"🔍 知识库检测到 {len(search_files)} 个文件需要向量检索")
kb_hint = "\n\n" + "="*60 + "\n"
kb_hint += "📚 **知识库检索提示**\n"
kb_hint += "="*60 + "\n"
kb_hint += "⚠️ 用户的知识库包含以下文件,你的回答必须基于这些文件内容:\n\n"
for idx, intent in enumerate(search_files, 1):
kb_hint += f"{idx}. 📄 {intent.file_name}\n"
kb_hint += "\n**请务必使用知识库检索工具查询文件内容,不要使用你的训练数据!**\n"
kb_hint += "="*60 + "\n"
system_context_sections.append(kb_hint)
extra_system_context = "\n\n".join(
s.strip() for s in system_context_sections if s and str(s).strip()
)
if use_early_agent:
active_agent = agent_early
else:
active_agent = await _create_agent_for_request(
request=request,
current_user=current_user,
model=model,
checkpointer=checkpointer,
config=config,
llm_provider=llm_provider,
api_model=api_model,
user_is_reasoner=user_is_reasoner,
extra_system_context=extra_system_context.strip() if extra_system_context.strip() else None,
)
try:
async for event_data in active_agent.astream(
{
"messages": [
{
"role": "user",
"content": request.query,
}
]
},
stream_mode="messages",
config=config
):
message, metadata = event_data
try:
message_dict = message_to_dict(message)
except Exception as e:
logger.warning(f"使用 message_to_dict 序列化失败: {e}")
message_dict = {
"content": message.content if hasattr(message, 'content') else str(message),
"type": message.__class__.__name__ if hasattr(message, '__class__') else "unknown"
}
if hasattr(message, 'additional_kwargs'):
message_dict['additional_kwargs'] = message.additional_kwargs or {}
if hasattr(message, 'response_metadata'):
message_dict['response_metadata'] = message.response_metadata or {}
if hasattr(message, 'id'):
message_dict['id'] = message.id
serializable_data = {
"message": message_dict,
"metadata": metadata
}
yield json.dumps(serializable_data, ensure_ascii=False, default=str)
except IndexError as e:
# 兼容部分 LLM + 工具流组合的 chunk 对齐问题(曾为 ChatTongyi 场景保留)
if "list index out of range" in str(e):
error_msg = "抱歉,处理您的请求时遇到了技术问题(流式工具调用错误)。请稍后重试,或尝试简化您的问题。"
logger.error(f"LangChain 流式工具调用 bug: {e}")
yield json.dumps({
"error": error_msg,
"message": {
"type": "AIMessageChunk",
"data": {
"content": error_msg,
"additional_kwargs": {},
"response_metadata": {}
}
}
}, ensure_ascii=False)
else:
raise
# 关联文件到消息
if unlinked_files:
await _associate_files_to_message(
request.thread_id,
unlinked_files,
checkpointer,
config,
pool
)
# 💾 [V2] 保存消息到 chat_messages 表(双写策略)
try:
await _save_messages_to_chat_messages_table(
thread_id=request.thread_id,
user_query=request.query,
user_message_content=request.query,
has_files=len(unlinked_files) > 0,
checkpointer=checkpointer,
config=config,
pool=pool,
system_attached_context=extra_system_context.strip() if extra_system_context.strip() else None,
)
except Exception as save_err:
# 不影响主流程,只记录日志
logger.error(f"[V2] 保存消息到 chat_messages 表失败: {save_err}")
except Exception as e:
logger.exception(f"聊天接口错误: {e}")
yield json.dumps({"error": str(e)}, ensure_ascii=False)
yield "[DONE]"
return EventSourceResponse(
generate(),
ping_message_factory=lambda: ServerSentEvent(data=f"ping - {datetime.now(timezone.utc)}")
)
def _wrap_mcp_tool_safe(tool):
"""
包装 MCP 工具,确保第三方服务异常(网络不可达、超时等)被捕获并以
错误字符串返回,而非向上抛出。
这样 LangChain Agent 在工具调用失败时仍能写入对应的 ToolMessage
保持 checkpoint 中消息历史的完整性,避免下次对话因 tool_calls 缺少
配对回复而触发 LLM 400 BadRequest 错误。
"""
original_arun = tool._arun
async def safe_arun(*args, **kwargs):
try:
return await original_arun(*args, **kwargs)
except Exception as exc:
error_msg = f"工具 [{tool.name}] 调用失败: {exc}"
logger.warning(f"MCP 工具异常已捕获,返回错误字符串以维持对话完整性: {error_msg}")
return error_msg
tool._arun = safe_arun
return tool
async def _create_agent_for_request(
request: ChatRequest,
current_user: User,
model,
checkpointer,
config: dict,
llm_provider: str,
api_model: str,
user_is_reasoner: bool,
extra_system_context: Optional[str] = None,
):
"""
根据请求类型创建对应的 Agent
"""
ai_display_name = await EnterpriseService.resolve_ai_display_name(current_user.enterprise_id)
# 翻译模式
if request.translate:
logger.info(f"使用翻译模式")
language_names = {
'auto': '自动检测', 'zh': '中文(简体)', 'zh-TW': '中文(繁體)',
'en': '英文', 'ja': '日文', 'ko': '韩文', 'fr': '法文',
'de': '德文', 'es': '西班牙文', 'ru': '俄文'
}
from_lang = language_names.get(request.from_language, '自动检测')
target_lang = language_names.get(request.target_language, '英文')
return create_agent(
model=model,
tools=[get_current_time],
system_prompt=get_translate_instructions(from_lang, target_lang, ai_display_name),
checkpointer=checkpointer
)
# 文生视频模式
if request.text2video:
logger.info("使用文生视频模式")
return create_agent(
model=model,
tools=[get_current_time, text_to_video],
system_prompt=get_text2video_instructions(ai_display_name),
checkpointer=checkpointer
)
# 文生图模式
if request.text2img:
logger.info("使用文生图模式")
return create_agent(
model=model,
tools=[get_current_time, text_to_image],
system_prompt=get_text2img_instructions(ai_display_name),
checkpointer=checkpointer
)
# 创意海报生成模式
if request.text2poster:
logger.info("使用创意海报生成模式")
return create_agent(
model=model,
tools=[get_current_time, text_to_poster],
system_prompt=get_text2poster_instructions(ai_display_name),
checkpointer=checkpointer
)
# 普通聊天模式
mcp_client = await get_mcp_client()
mcp_tools = await mcp_client.get_tools()
# mcp_tools = [_wrap_mcp_tool_safe(t) for t in mcp_tools]
logger.info(f"成功加载 {len(mcp_tools)} 个 MCP 工具")
# 查询用户设置(深度思考是否与模型侧一致取决于 user_list.is_reasoner
# 已在 chat_completion 中读取并传入 user_is_reasoner
pool = await get_db_pool()
async with pool.acquire() as conn:
user_row = await conn.fetchrow(
"SELECT is_search FROM user_list WHERE id = $1",
current_user.id,
)
user_is_search = bool(user_row["is_search"]) if user_row and user_row["is_search"] is not None else False
use_reasoner_mode = user_is_reasoner
# 检查文件
has_files = await check_thread_has_files(request.thread_id)
has_kb_files = False
kb_id = request.knowledge_base_id
if kb_id and not request.knowledge_graph_id:
has_kb_files = await check_knowledge_base_has_files(kb_id, current_user.id)
kg_flags = {"has_rag": False, "neo4j_graph_id": None}
if request.knowledge_graph_id:
kg_flags = await get_knowledge_graph_tool_flags(current_user, request.knowledge_graph_id)
has_novel_rag = bool(kg_flags.get("has_rag"))
has_kg_neo4j = bool(kg_flags.get("neo4j_graph_id"))
# 获取系统提示词(绑定图谱时:有正文 RAG 和/或有 Neo4j 关系查询)
has_kg_any = has_novel_rag or has_kg_neo4j
research_instructions = get_research_instructions(
has_files=has_files,
has_kb_files=has_kb_files,
use_reasoner_mode=use_reasoner_mode,
has_knowledge_graph=has_kg_any,
has_knowledge_graph_neo4j=has_kg_neo4j,
ai_display_name=ai_display_name,
)
# 构建工具列表Neo4j 关系查询优先于正文 RAG便于「谁是谁」类问题先走图
all_tools = [get_current_time] + mcp_tools.copy()
if has_files:
all_tools = [create_rag_retrieve_tool(request.thread_id)] + all_tools
if request.knowledge_graph_id and has_kg_neo4j:
all_tools = [
create_knowledge_graph_neo4j_search_tool(kg_flags["neo4j_graph_id"])
] + all_tools
if request.knowledge_graph_id and has_novel_rag:
all_tools = [create_knowledge_graph_rag_retrieve_tool(request.knowledge_graph_id)] + all_tools
elif kb_id and has_kb_files:
all_tools = [create_kb_rag_retrieve_tool(kb_id)] + all_tools
if user_is_search:
all_tools = [internet_search] + all_tools
logger.info(f"用户 {current_user.username} 启用了联网搜索功能")
# 深度思考模式:递归上限调高;底层模型已在入口处按 is_reasoner 配置好 ChatDeepSeek / ChatOpenAI(通义)
if use_reasoner_mode:
config["recursion_limit"] = 60
logger.info(
f"用户 {current_user.username} 启用深度思考模式(提供方={llm_provider}, api_model={api_model}"
)
print(f"model: {model}")
merged_system_prompt = research_instructions.rstrip()
if extra_system_context and extra_system_context.strip():
merged_system_prompt = (
merged_system_prompt
+ "\n\n【本轮系统提供的参考上下文(用户原话仅在用户消息中;请优先依据此处与工具检索结果作答)】\n"
+ extra_system_context.strip()
)
return create_agent(
model=model,
tools=all_tools,
system_prompt=merged_system_prompt,
checkpointer=checkpointer,
)
async def _associate_files_to_message(
thread_id: str,
unlinked_files: list,
checkpointer,
config: dict,
pool
):
"""
将未关联的文件关联到用户消息
"""
import asyncio
latest_checkpoint = None
for attempt in range(5):
await asyncio.sleep(0.2)
try:
latest_checkpoint = await checkpointer.aget_tuple(config)
if latest_checkpoint:
break
except Exception as e:
logger.debug(f"获取 checkpoint 失败(尝试 {attempt + 1}/5: {e}")
if not latest_checkpoint:
logger.warning("无法获取最新的 checkpoint文件关联失败")
return
try:
checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"]
checkpoint_data = latest_checkpoint.checkpoint
if "channel_values" not in checkpoint_data or "messages" not in checkpoint_data["channel_values"]:
return
messages = checkpoint_data["channel_values"]["messages"]
user_message_index = None
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if hasattr(msg, 'type') and msg.type == "human":
user_message_index = idx
break
if user_message_index is not None:
async with pool.acquire() as conn:
for file_id in unlinked_files:
try:
await ChatMessageFileService.create_message_file_association(
conn,
thread_id,
checkpoint_id,
user_message_index,
file_id
)
logger.info(f"文件 {file_id} 已关联到消息")
except Exception as e:
logger.warning(f"关联文件到消息失败: {e}")
except Exception as e:
logger.warning(f"关联文件到消息失败: {e}")
async def _save_messages_to_chat_messages_table(
thread_id: str,
user_query: str,
user_message_content: str,
has_files: bool,
checkpointer,
config: dict,
pool,
system_attached_context: Optional[str] = None,
):
"""
将消息保存到 chat_messages 表(双写策略)
这是V2版本的核心将用户原始问题和AI响应保存到独立的 chat_messages 表,
便于快速查询和分析,避免解析 checkpoint JSONB。
Args:
thread_id: 会话线程 ID
user_query: 用户原始问题
user_message_content: 用户消息内容(与 checkpoint 中 human 一致;可与 user_query 相同)
system_attached_context: 附在系统提示上的参考上下文(不入用户原话),写入 injected_content
has_files: 是否关联了文件
checkpointer: checkpoint 管理器
config: 配置
pool: 数据库连接池
"""
import asyncio
try:
# 等待 checkpoint 更新
latest_checkpoint = None
for attempt in range(5):
await asyncio.sleep(0.2)
try:
latest_checkpoint = await checkpointer.aget_tuple(config)
if latest_checkpoint:
break
except Exception as e:
logger.debug(f"获取 checkpoint 失败(尝试 {attempt + 1}/5: {e}")
if not latest_checkpoint:
logger.warning("无法获取最新的 checkpoint消息保存失败")
return
checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"]
checkpoint_data = latest_checkpoint.checkpoint
if "channel_values" not in checkpoint_data or "messages" not in checkpoint_data["channel_values"]:
return
messages = checkpoint_data["channel_values"]["messages"]
if len(messages) < 2:
logger.warning(f"[V2] 消息数量不足,跳过保存: messages_count={len(messages)}")
return
# 🔥 核心修复:找到最后一个用户消息,保存从该消息开始的所有后续消息
# 这样可以包含完整的对话轮次:用户消息 → AI调用工具 → 工具结果 → AI最终回复
last_user_index = None
for idx in range(len(messages) - 1, -1, -1):
msg = messages[idx]
if hasattr(msg, 'type') and msg.type == "human":
last_user_index = idx
break
if last_user_index is None:
logger.warning(f"[V2] 未找到用户消息,跳过保存")
return
# 获取从最后一个用户消息开始的所有消息
messages_to_save = messages[last_user_index:]
async with pool.acquire() as conn:
# 检查是否已经保存过(通过检查最后一条 AI 消息)
last_ai_index = len(messages) - 1
existing = await conn.fetchval(
"""
SELECT id FROM chat_messages
WHERE thread_id = $1 AND checkpoint_id = $2 AND message_index = $3
""",
thread_id, checkpoint_id, last_ai_index
)
if existing:
logger.debug(f"[V2] 消息已存在,跳过保存: thread_id={thread_id}, checkpoint_id={checkpoint_id}")
return
# 保存当前轮次的所有消息
for relative_idx, msg in enumerate(messages_to_save):
actual_index = last_user_index + relative_idx
if not hasattr(msg, 'type'):
continue
msg_type = msg.type
msg_content = getattr(msg, 'content', '') or ''
# 检查单条消息是否已存在
existing = await conn.fetchval(
"SELECT id FROM chat_messages WHERE thread_id = $1 AND checkpoint_id = $2 AND message_index = $3",
thread_id, checkpoint_id, actual_index
)
if existing:
logger.debug(f"[V2] 消息已存在: index={actual_index}, type={msg_type}")
continue
if msg_type == "human":
# 保存用户消息(只在第一条消息时使用传入的参数)
if relative_idx == 0:
# 当前轮用户消息:正文为原始问题;系统附加材料单独落库
if system_attached_context and system_attached_context.strip():
injected_content = system_attached_context.strip()
else:
injected_content = (
user_message_content if user_message_content != user_query else None
)
await ChatMessageService.save_user_message(
conn,
thread_id=thread_id,
checkpoint_id=checkpoint_id,
message_index=actual_index,
content=user_query, # 用户原始问题
injected_content=injected_content, # 注入的完整内容(如果有)
has_files=has_files,
metadata={}
)
logger.info(f"✅ [V2] 保存用户消息: index={actual_index}, content='{user_query[:50]}...', has_files={has_files}")
else:
# 理论上不应该出现第二条是 human 的情况
await ChatMessageService.save_user_message(
conn,
thread_id=thread_id,
checkpoint_id=checkpoint_id,
message_index=actual_index,
content=msg_content,
injected_content=None,
has_files=False,
metadata={}
)
logger.warning(f"⚠️ [V2] 非预期的用户消息位置: index={actual_index}")
elif msg_type == "ai":
# 保存 AI 消息
reasoning_content = ""
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs:
reasoning_content = msg.additional_kwargs.get("reasoning_content", "") or ""
token_usage = {}
finish_reason = ""
if hasattr(msg, 'response_metadata') and msg.response_metadata:
token_usage = msg.response_metadata.get('token_usage', {})
finish_reason = msg.response_metadata.get('finish_reason', '')
await ChatMessageService.save_assistant_message(
conn,
thread_id=thread_id,
checkpoint_id=checkpoint_id,
message_index=actual_index,
content=msg_content,
metadata={
'token_usage': token_usage,
'finish_reason': finish_reason,
'reasoning_content': reasoning_content
}
)
logger.info(f"✅ [V2] 保存AI消息: index={actual_index}, content_length={len(msg_content)}")
elif msg_type == "tool":
# 保存工具消息
tool_name = getattr(msg, 'name', '') or ''
await ChatMessageService.save_tool_message(
conn,
thread_id=thread_id,
checkpoint_id=checkpoint_id,
message_index=actual_index,
content=msg_content,
name=tool_name, # 工具名称(如 text_to_poster, internet_search 等)
metadata={'tool_name': tool_name} # 保留在 metadata 中以兼容
)
logger.info(f"✅ [V2] 保存工具消息: index={actual_index}, tool={tool_name}")
logger.info(f"✅ [V2] 消息保存完成: thread_id={thread_id}, checkpoint_id={checkpoint_id}, saved_count={len(messages_to_save)}")
except Exception as e:
logger.error(f"[V2] 保存消息到 chat_messages 表失败: {e}")
# 不抛出异常,避免影响主流程
@chat_router.get("/chat/threads", summary="获取会话列表", response_model=ChatThreadListResponse)
async def get_threads(
page: int = 1,
page_size: int = 20,
current_user: User = Depends(get_current_user)
):
"""获取用户的会话列表(分页)"""
logger.info(f"用户 {current_user.username} 查询会话列表: page={page}")
if page < 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="页码必须大于 0")
if page_size < 1 or page_size > 100:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="每页数量必须在 1-100 之间")
try:
return await get_user_chat_threads(user_id=current_user.id, page=page, page_size=page_size)
except Exception as e:
logger.error(f"查询会话列表失败: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="查询会话列表失败")
@chat_router.get("/chat/thread/{thread_id}", summary="获取会话明细", response_model=ChatThreadDetailResponse)
async def get_thread_detail(
thread_id: str,
current_user: User = Depends(get_current_user)
):
"""获取会话的聊天明细"""
logger.info(f"用户 {current_user.username} 查询会话明细: thread_id={thread_id}")
return await get_chat_thread_detail(thread_id=thread_id, user_id=current_user.id)
@chat_router.delete("/chat/thread", summary="删除会话")
async def delete_thread(
request: DeleteThreadRequest,
current_user: User = Depends(get_current_user)
):
"""删除聊天会话(软删除)"""
logger.info(f"用户 {current_user.username} 请求删除会话: {request.thread_id}")
try:
await delete_chat_thread(thread_id=request.thread_id, user_id=current_user.id)
return BaseResponse(code=200, msg="会话删除成功", data={"thread_id": request.thread_id})
except HTTPException:
raise
except Exception as e:
logger.error(f"删除会话失败: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除会话失败")
# ============ V2 版本路由(基于 chat_messages 表) ============
@chat_router.get("/chat/thread/{thread_id}/v2", summary="获取会话明细 V2推荐", response_model=ChatThreadDetailResponse)
async def get_thread_detail_v2(
thread_id: str,
current_user: User = Depends(get_current_user)
):
"""
获取会话的聊天明细V2版本
**V2 优势**
- ✅ 查询速度更快(直接从 chat_messages 表查询,无需解析 checkpoint JSONB
- ✅ 用户原始问题和注入内容分离(返回的 content 只包含用户原始问题)
- ✅ 支持全文搜索、统计分析
- ✅ 数据结构更清晰
**注意**:需要先启用消息双写功能,才能使用此接口
"""
logger.info(f"用户 {current_user.username} 查询会话明细 V2: thread_id={thread_id}")
return await get_chat_thread_detail_v2(thread_id=thread_id, user_id=current_user.id)
# ============ 兼容旧路由(保持向后兼容) ============
# 这些路由已迁移到 /api/user/ 前缀下,但为了兼容旧客户端保留
from services.user_setting_service import UserSettingService
from models.chat import (
SearchSettingResponse,
UpdateSearchSettingRequest,
ReasonerSettingResponse,
UpdateReasonerSettingRequest,
)
@chat_router.get("/chat/search-setting", summary="获取用户联网搜索设置(兼容)", response_model=SearchSettingResponse, deprecated=True)
async def get_search_setting_compat(current_user: User = Depends(get_current_user)):
"""获取当前用户的联网搜索设置(已迁移到 /api/user/search-setting"""
is_search = await UserSettingService.get_search_setting(current_user.id)
return SearchSettingResponse(is_search=is_search)
@chat_router.put("/chat/search-setting", summary="更新用户联网搜索设置(兼容)", response_model=SearchSettingResponse, deprecated=True)
async def update_search_setting_compat(
request: UpdateSearchSettingRequest,
current_user: User = Depends(get_current_user)
):
"""更新当前用户的联网搜索设置(已迁移到 /api/user/search-setting"""
is_search = await UserSettingService.update_search_setting(current_user.id, request.is_search)
return SearchSettingResponse(is_search=is_search)
@chat_router.get("/chat/reasoner-setting", summary="获取用户深度思考设置(兼容)", response_model=ReasonerSettingResponse, deprecated=True)
async def get_reasoner_setting_compat(current_user: User = Depends(get_current_user)):
"""获取当前用户的深度思考设置(已迁移到 /api/user/reasoner-setting"""
is_reasoner = await UserSettingService.get_reasoner_setting(current_user.id)
return ReasonerSettingResponse(is_reasoner=is_reasoner)
@chat_router.put("/chat/reasoner-setting", summary="更新用户深度思考设置(兼容)", response_model=ReasonerSettingResponse, deprecated=True)
async def update_reasoner_setting_compat(
request: UpdateReasonerSettingRequest,
current_user: User = Depends(get_current_user)
):
"""更新当前用户的深度思考设置(已迁移到 /api/user/reasoner-setting"""
is_reasoner = await UserSettingService.update_reasoner_setting(current_user.id, request.is_reasoner)
return ReasonerSettingResponse(is_reasoner=is_reasoner)