1168 lines
51 KiB
Python
1168 lines
51 KiB
Python
"""
|
||
聊天 API 路由模块
|
||
|
||
定义聊天相关的 API 路由,包括对话、会话管理等接口。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||
from fastapi.responses import JSONResponse
|
||
from langchain_core.messages import message_to_dict
|
||
from langchain.agents import create_agent
|
||
from sse_starlette import EventSourceResponse, ServerSentEvent
|
||
|
||
from core.config import settings
|
||
from core.llm_catalog import (
|
||
build_chat_model_for_completion,
|
||
coerce_model_id,
|
||
deepseek_api_model_by_reasoner_setting,
|
||
list_llm_options_payload,
|
||
normalize_provider,
|
||
resolve_to_api_model,
|
||
validate_request_can_use_provider,
|
||
)
|
||
from core.database import get_db_pool, get_checkpointer
|
||
from core.mcp_client import get_mcp_client
|
||
from core.dependencies import get_current_user, get_moderation_service
|
||
from core.exceptions import ModerationError
|
||
from models.user import User
|
||
from models.moderation import ModerationDecision
|
||
from models.chat import (
|
||
ChatRequest,
|
||
DeleteThreadRequest,
|
||
ChatThreadListResponse,
|
||
ChatThreadDetailResponse,
|
||
)
|
||
from prompt.prompt import (
|
||
get_translate_instructions,
|
||
get_text2video_instructions,
|
||
get_text2img_instructions,
|
||
get_text2poster_instructions,
|
||
get_research_instructions
|
||
)
|
||
from services.enterprise_service import EnterpriseService
|
||
from tools.tools import (
|
||
get_current_time,
|
||
internet_search,
|
||
text_to_image,
|
||
text_to_video,
|
||
text_to_poster,
|
||
create_rag_retrieve_tool,
|
||
create_kb_rag_retrieve_tool,
|
||
create_knowledge_graph_neo4j_search_tool,
|
||
create_knowledge_graph_rag_retrieve_tool,
|
||
)
|
||
from services.chat_thread_service import (
|
||
create_or_update_chat_thread,
|
||
delete_chat_thread,
|
||
get_user_chat_threads,
|
||
get_chat_thread_detail,
|
||
get_chat_thread_detail_v2, # V2版本:基于 chat_messages 表
|
||
check_thread_has_files,
|
||
check_knowledge_base_has_files,
|
||
get_knowledge_graph_tool_flags,
|
||
)
|
||
from services.chat_message_file_service import ChatMessageFileService
|
||
from services.chat_message_service import ChatMessageService # 新增:消息保存服务
|
||
from utils.helpers import BaseResponse
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 创建路由实例
|
||
chat_router = APIRouter(prefix="/api", tags=["聊天接口"])
|
||
|
||
|
||
@chat_router.get("/chat/llm-options", summary="聊天可选大模型列表")
|
||
async def chat_llm_options(current_user: User = Depends(get_current_user)):
|
||
"""返回前端下拉所需的提供方与模型逻辑 id(不含密钥)。"""
|
||
return list_llm_options_payload()
|
||
|
||
|
||
@chat_router.post("/chat/completion", summary="聊天接口主入口")
|
||
async def chat_completion(
|
||
request: ChatRequest,
|
||
http_request: Request,
|
||
current_user: User = Depends(get_current_user),
|
||
moderation_service = Depends(get_moderation_service)
|
||
):
|
||
"""
|
||
聊天接口主入口(需要认证)
|
||
|
||
Args:
|
||
request: 聊天请求数据(JSON 格式)
|
||
http_request: HTTP 请求对象(用于获取客户端 IP)
|
||
current_user: 当前登录用户
|
||
moderation_service: 内容审核服务(依赖注入)
|
||
|
||
Returns:
|
||
EventSourceResponse: 服务器发送事件流式响应
|
||
"""
|
||
# 获取客户端 IP 地址
|
||
client_ip = http_request.client.host if http_request.client else None
|
||
# 如果使用了代理,尝试从 X-Forwarded-For 或 X-Real-IP 头获取真实 IP
|
||
if "x-forwarded-for" in http_request.headers:
|
||
client_ip = http_request.headers["x-forwarded-for"].split(",")[0].strip()
|
||
elif "x-real-ip" in http_request.headers:
|
||
client_ip = http_request.headers["x-real-ip"]
|
||
|
||
logger.info(
|
||
f"用户 {current_user.username} (ID: {current_user.id}) 发起聊天请求,thread_id: {request.thread_id}, "
|
||
f"text2img: {request.text2img}, text2video: {request.text2video}, text2poster: {request.text2poster}, "
|
||
f"translate: {request.translate}, knowledge_base_id: {request.knowledge_base_id}, "
|
||
f"knowledge_graph_id: {request.knowledge_graph_id}, "
|
||
f"llm_provider: {request.llm_provider}, llm_model: {request.llm_model}, ip={client_ip}"
|
||
)
|
||
|
||
# ============ 内容审核前置处理 ============
|
||
# 在 AI 处理前对用户消息进行内容审核
|
||
try:
|
||
# 生成唯一请求 ID 用于追踪
|
||
import uuid
|
||
moderation_request_id = str(uuid.uuid4())
|
||
|
||
logger.info(
|
||
f"开始内容审核 - user_id: {current_user.id}, "
|
||
f"request_id: {moderation_request_id}, "
|
||
f"message_length: {len(request.query)}"
|
||
)
|
||
|
||
# 调用审核服务
|
||
moderation_result = await moderation_service.moderate_text(
|
||
text=request.query,
|
||
request_id=moderation_request_id
|
||
)
|
||
|
||
logger.info(
|
||
f"审核完成 - user_id: {current_user.id}, "
|
||
f"request_id: {moderation_request_id}, "
|
||
f"decision: {moderation_result.decision.value}, "
|
||
f"labels: {[label.label for label in moderation_result.labels]}"
|
||
)
|
||
|
||
# 处理 BLOCK 决策:阻止内容
|
||
if moderation_result.decision == ModerationDecision.BLOCK:
|
||
logger.warning(
|
||
f"内容被阻止 - user_id: {current_user.id}, "
|
||
f"username: {current_user.username}, "
|
||
f"request_id: {moderation_request_id}, "
|
||
f"labels: {[label.label for label in moderation_result.labels]}"
|
||
)
|
||
# 使用统一的错误响应格式(与图片审核一致)
|
||
from fastapi.responses import JSONResponse
|
||
return JSONResponse(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
content={
|
||
"code": 400,
|
||
"msg": "您的消息包含不当内容,无法处理。",
|
||
"data": None
|
||
}
|
||
)
|
||
|
||
# 处理 REVIEW 决策:根据配置决定是否允许
|
||
if moderation_result.decision == ModerationDecision.REVIEW:
|
||
logger.info(
|
||
f"内容需要复审 - user_id: {current_user.id}, "
|
||
f"username: {current_user.username}, "
|
||
f"request_id: {moderation_request_id}, "
|
||
f"labels: {[label.label for label in moderation_result.labels]}"
|
||
)
|
||
|
||
# 检查配置:是否阻止需要复审的内容
|
||
from core.config import get_settings
|
||
settings_obj = get_settings()
|
||
|
||
if settings_obj.moderation_review_action == "block":
|
||
logger.warning(
|
||
f"复审内容被阻止(配置策略)- user_id: {current_user.id}, "
|
||
f"request_id: {moderation_request_id}"
|
||
)
|
||
# 使用统一的错误响应格式(与图片审核一致)
|
||
from fastapi.responses import JSONResponse
|
||
return JSONResponse(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
content={
|
||
"code": 400,
|
||
"msg": "您的消息需要人工复审,暂时无法处理。",
|
||
"data": None
|
||
}
|
||
)
|
||
else:
|
||
logger.info(
|
||
f"复审内容允许通过(配置策略)- user_id: {current_user.id}, "
|
||
f"request_id: {moderation_request_id}"
|
||
)
|
||
|
||
# PASS 决策:继续正常流程
|
||
logger.info(
|
||
f"内容审核通过 - user_id: {current_user.id}, "
|
||
f"request_id: {moderation_request_id}"
|
||
)
|
||
|
||
except HTTPException:
|
||
# 重新抛出 HTTPException(BLOCK 或 REVIEW 阻止)
|
||
raise
|
||
|
||
except ModerationError as e:
|
||
# 审核服务错误:降级模式,记录错误但允许继续
|
||
logger.error(
|
||
f"审核服务错误(降级模式)- user_id: {current_user.id}, "
|
||
f"error: {e.message}, "
|
||
f"original_error: {str(e.original_error) if e.original_error else None}"
|
||
)
|
||
# 不抛出异常,允许消息继续处理
|
||
|
||
except Exception as e:
|
||
# 未预期的错误:降级模式,记录错误但允许继续
|
||
logger.error(
|
||
f"审核过程未知错误(降级模式)- user_id: {current_user.id}, "
|
||
f"error_type: {type(e).__name__}, "
|
||
f"error: {str(e)}"
|
||
)
|
||
# 不抛出异常,允许消息继续处理
|
||
# ============ 内容审核结束 ============
|
||
|
||
llm_provider = normalize_provider(request.llm_provider)
|
||
llm_model_key = coerce_model_id(llm_provider, request.llm_model)
|
||
cfg_err = validate_request_can_use_provider(llm_provider)
|
||
if cfg_err:
|
||
return JSONResponse(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
content={"code": 400, "msg": cfg_err, "data": None},
|
||
)
|
||
try:
|
||
api_model = resolve_to_api_model(llm_provider, llm_model_key)
|
||
except ValueError as e:
|
||
return JSONResponse(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
content={"code": 400, "msg": str(e), "data": None},
|
||
)
|
||
try:
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
reasoner_row = await conn.fetchrow(
|
||
"SELECT is_reasoner FROM user_list WHERE id = $1",
|
||
current_user.id,
|
||
)
|
||
user_is_reasoner = bool(reasoner_row["is_reasoner"]) if reasoner_row and reasoner_row["is_reasoner"] is not None else False
|
||
if llm_provider == "deepseek":
|
||
api_model = deepseek_api_model_by_reasoner_setting(user_is_reasoner=user_is_reasoner)
|
||
model = build_chat_model_for_completion(
|
||
llm_provider,
|
||
api_model,
|
||
enable_thinking=user_is_reasoner,
|
||
logical_llm_id=llm_model_key,
|
||
)
|
||
logger.debug(
|
||
"chat_completion 模型: provider={} req_llm_model={} api_model={} user_is_reasoner={}",
|
||
llm_provider,
|
||
llm_model_key,
|
||
api_model,
|
||
user_is_reasoner,
|
||
)
|
||
except ValueError as e:
|
||
return JSONResponse(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
content={"code": 400, "msg": str(e), "data": None},
|
||
)
|
||
|
||
# DeepSeek:落库与调用均以 is_reasoner 决定的实际 API 模型为准,忽略请求里的 llm_model
|
||
thread_llm_model = (
|
||
deepseek_api_model_by_reasoner_setting(user_is_reasoner=user_is_reasoner)
|
||
if llm_provider == "deepseek"
|
||
else llm_model_key
|
||
)
|
||
|
||
# 记录会话到数据库(如果未携带 knowledge_base_id 则更新为 null)
|
||
await create_or_update_chat_thread(
|
||
thread_id=request.thread_id,
|
||
user_id=current_user.id,
|
||
query=request.query,
|
||
knowledge_base_id=request.knowledge_base_id,
|
||
knowledge_graph_id=request.knowledge_graph_id,
|
||
ip=client_ip,
|
||
llm_provider=llm_provider,
|
||
llm_model=thread_llm_model,
|
||
)
|
||
|
||
config = {"configurable": {"thread_id": request.thread_id}, "recursion_limit": 30}
|
||
checkpointer = await get_checkpointer()
|
||
# 翻译 / 文生图 等模式的 agent 在此处创建;普通聊天需在 generate 内根据文件与知识库上下文再创建 agent
|
||
use_early_agent = bool(
|
||
request.translate or request.text2video or request.text2img or request.text2poster
|
||
)
|
||
agent_early = None
|
||
if use_early_agent:
|
||
agent_early = await _create_agent_for_request(
|
||
request=request,
|
||
current_user=current_user,
|
||
model=model,
|
||
checkpointer=checkpointer,
|
||
config=config,
|
||
llm_provider=llm_provider,
|
||
api_model=api_model,
|
||
user_is_reasoner=user_is_reasoner,
|
||
)
|
||
|
||
# 创建流式生成器
|
||
async def generate():
|
||
try:
|
||
pool = await get_db_pool()
|
||
unlinked_files = []
|
||
system_context_sections: list[str] = []
|
||
file_list_for_intent: list = []
|
||
|
||
async with pool.acquire() as conn:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id FROM chat_thread_file
|
||
WHERE thread_id = $1
|
||
AND is_deleted = FALSE
|
||
AND created_at > NOW() - INTERVAL '5 minutes'
|
||
AND id NOT IN (
|
||
SELECT DISTINCT file_id FROM chat_message_file WHERE thread_id = $1
|
||
)
|
||
ORDER BY created_at DESC
|
||
LIMIT 10
|
||
""",
|
||
request.thread_id
|
||
)
|
||
unlinked_files = [row['id'] for row in rows]
|
||
|
||
if not use_early_agent:
|
||
from services.chat_thread_file_service import ChatThreadFileService
|
||
from services.rag_intent_service import get_rag_intent_service
|
||
|
||
file_summaries_raw = await ChatThreadFileService.get_recent_files_with_summary(
|
||
conn, request.thread_id, limit=5
|
||
)
|
||
for file_info in file_summaries_raw:
|
||
row = await conn.fetchrow(
|
||
"SELECT id FROM chat_thread_file WHERE thread_id = $1 AND file_name = $2 AND is_deleted = FALSE",
|
||
request.thread_id, file_info['file_name']
|
||
)
|
||
if row:
|
||
file_list_for_intent.append({
|
||
"file_id": row['id'],
|
||
"file_name": file_info['file_name'],
|
||
"summary": file_info['summary']
|
||
})
|
||
|
||
if not use_early_agent and file_list_for_intent:
|
||
intent_service = await get_rag_intent_service()
|
||
intents = await intent_service.judge_intent(
|
||
query=request.query,
|
||
file_list=file_list_for_intent
|
||
)
|
||
|
||
logger.info(f"💡 意图判断结果: {len(intents)} 个文件相关")
|
||
|
||
if intents:
|
||
summary_files = [i for i in intents if i.question_type == "summary"]
|
||
search_files = [i for i in intents if i.question_type == "search"]
|
||
|
||
if summary_files:
|
||
logger.info(f"📄 检测到 {len(summary_files)} 个文件需要完整内容")
|
||
|
||
from services.chat_thread_file_service import ChatThreadFileService
|
||
|
||
summary_prefix = "\n\n" + "="*60 + "\n"
|
||
summary_prefix += "📎 **已为您准备的文件完整内容**\n"
|
||
summary_prefix += "="*60 + "\n"
|
||
summary_prefix += "⚠️ 重要:以下是文件的完整关键信息,请直接使用这些信息回答用户问题。\n\n"
|
||
|
||
async with pool.acquire() as conn:
|
||
for idx, intent in enumerate(summary_files, 1):
|
||
logger.info(f"🔍 正在从数据库获取文件完整内容: file_id={intent.file_id}, file_name={intent.file_name}")
|
||
|
||
all_chunks = await ChatThreadFileService.get_file_chunks_from_db(
|
||
conn, intent.file_id
|
||
)
|
||
|
||
logger.info(f"📊 获取结果: 返回了 {len(all_chunks)} 个chunks")
|
||
if all_chunks:
|
||
logger.info(f"📝 Chunks详情:")
|
||
for i, chunk in enumerate(all_chunks[:3]):
|
||
logger.info(f" - Chunk {i}: index={chunk.get('chunk_index')}, 内容长度={len(chunk.get('content', ''))}, 有摘要={bool(chunk.get('summary'))}")
|
||
if len(all_chunks) > 3:
|
||
logger.info(f" - ... 还有 {len(all_chunks) - 3} 个chunks")
|
||
else:
|
||
logger.warning(f"⚠️ 文件 {intent.file_name} (ID: {intent.file_id}) 未返回任何chunks!")
|
||
|
||
if all_chunks:
|
||
full_content = "\n\n".join([chunk['content'] for chunk in all_chunks])
|
||
file_summary = all_chunks[0].get('summary', '')
|
||
|
||
file_type_icon = "🖼️" if intent.file_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp')) else "📄"
|
||
summary_prefix += f"{file_type_icon} 文件 {idx}: {intent.file_name}\n"
|
||
summary_prefix += f"{'─'*60}\n"
|
||
|
||
if file_summary:
|
||
summary_prefix += f"【文件摘要】\n{file_summary}\n\n"
|
||
|
||
max_content_length = 8000
|
||
if len(full_content) > max_content_length:
|
||
full_content = full_content[:max_content_length] + "\n...(内容过长,已截断)"
|
||
|
||
summary_prefix += f"【完整内容】\n{full_content}\n"
|
||
summary_prefix += f"{'─'*60}\n\n"
|
||
|
||
logger.info(f"✅ 已注入文件 {intent.file_name} (摘要{len(file_summary)}字 + 内容{len(full_content)}字)")
|
||
|
||
summary_prefix += "="*60 + "\n"
|
||
system_context_sections.append(summary_prefix)
|
||
|
||
elif search_files:
|
||
logger.info(f"🔍 检测到 {len(search_files)} 个文件需要向量检索")
|
||
|
||
search_hint = "\n\n" + "="*60 + "\n"
|
||
search_hint += "📎 **重要提示**\n"
|
||
search_hint += "="*60 + "\n"
|
||
search_hint += "⚠️ 用户刚刚上传了以下文件,你的回答必须基于这些文件内容:\n\n"
|
||
|
||
for idx, intent in enumerate(search_files, 1):
|
||
search_hint += f"{idx}. 📄 {intent.file_name}\n"
|
||
|
||
search_hint += "\n**请务必使用检索工具查询文件内容,不要使用你的训练数据!**\n"
|
||
search_hint += "="*60 + "\n"
|
||
system_context_sections.append(search_hint)
|
||
|
||
else:
|
||
logger.warning("⚠️ 意图类型未识别,降级为注入摘要")
|
||
summary_prefix = "\n\n📎 **文件摘要**:\n"
|
||
for file_info in file_list_for_intent:
|
||
summary_prefix += f"【{file_info['file_name']}】\n{file_info['summary']}\n\n"
|
||
system_context_sections.append(summary_prefix)
|
||
else:
|
||
logger.info("ℹ️ 未识别到相关文件,不注入内容")
|
||
|
||
if not use_early_agent and request.knowledge_base_id and not request.knowledge_graph_id:
|
||
from services.knowledge_base_file_service import KnowledgeBaseFileService
|
||
|
||
kb_files_raw = []
|
||
async with pool.acquire() as conn:
|
||
kb_files_raw = await KnowledgeBaseFileService.get_recent_files_with_summary(
|
||
conn, request.knowledge_base_id, limit=5
|
||
)
|
||
|
||
if kb_files_raw:
|
||
logger.info(f"📚 知识库 {request.knowledge_base_id} 检测到 {len(kb_files_raw)} 个最近文件")
|
||
|
||
intent_service = await get_rag_intent_service()
|
||
kb_intents = await intent_service.judge_intent(
|
||
query=request.query,
|
||
file_list=kb_files_raw
|
||
)
|
||
|
||
logger.info(f"💡 知识库意图判断结果: {len(kb_intents)} 个文件相关")
|
||
|
||
if kb_intents:
|
||
summary_files = [i for i in kb_intents if i.question_type == "summary"]
|
||
search_files = [i for i in kb_intents if i.question_type == "search"]
|
||
|
||
if summary_files:
|
||
logger.info(f"📄 知识库检测到 {len(summary_files)} 个文件需要完整内容")
|
||
|
||
kb_prefix = "\n\n" + "="*60 + "\n"
|
||
kb_prefix += "📚 **知识库文件完整内容**\n"
|
||
kb_prefix += "="*60 + "\n"
|
||
kb_prefix += "⚠️ 重要:以下是知识库文件的完整关键信息,请直接使用这些信息回答用户问题。\n\n"
|
||
|
||
async with pool.acquire() as conn:
|
||
for idx, intent in enumerate(summary_files, 1):
|
||
logger.info(f"🔍 正在从数据库获取知识库文件: file_id={intent.file_id}, file_name={intent.file_name}")
|
||
|
||
all_chunks = await KnowledgeBaseFileService.get_file_chunks_from_db(
|
||
conn, intent.file_id
|
||
)
|
||
|
||
logger.info(f"📊 获取结果: 返回了 {len(all_chunks)} 个chunks")
|
||
if all_chunks:
|
||
full_content = "\n\n".join([chunk['content'] for chunk in all_chunks])
|
||
file_summary = all_chunks[0].get('summary', '')
|
||
|
||
file_type_icon = "🖼️" if intent.file_name.lower().endswith(('png', 'jpg', 'jpeg', 'bmp')) else "📄"
|
||
kb_prefix += f"{file_type_icon} 文件 {idx}: {intent.file_name}\n"
|
||
kb_prefix += f"{'─'*60}\n"
|
||
|
||
if file_summary:
|
||
kb_prefix += f"【文件摘要】\n{file_summary}\n\n"
|
||
|
||
max_content_length = 8000
|
||
if len(full_content) > max_content_length:
|
||
full_content = full_content[:max_content_length] + "\n...(内容过长,已截断)"
|
||
|
||
kb_prefix += f"【完整内容】\n{full_content}\n"
|
||
kb_prefix += f"{'─'*60}\n\n"
|
||
|
||
logger.info(f"✅ 已注入知识库文件 {intent.file_name} (摘要{len(file_summary)}字 + 内容{len(full_content)}字)")
|
||
|
||
kb_prefix += "="*60 + "\n"
|
||
system_context_sections.append(kb_prefix)
|
||
|
||
elif search_files:
|
||
logger.info(f"🔍 知识库检测到 {len(search_files)} 个文件需要向量检索")
|
||
|
||
kb_hint = "\n\n" + "="*60 + "\n"
|
||
kb_hint += "📚 **知识库检索提示**\n"
|
||
kb_hint += "="*60 + "\n"
|
||
kb_hint += "⚠️ 用户的知识库包含以下文件,你的回答必须基于这些文件内容:\n\n"
|
||
|
||
for idx, intent in enumerate(search_files, 1):
|
||
kb_hint += f"{idx}. 📄 {intent.file_name}\n"
|
||
|
||
kb_hint += "\n**请务必使用知识库检索工具查询文件内容,不要使用你的训练数据!**\n"
|
||
kb_hint += "="*60 + "\n"
|
||
system_context_sections.append(kb_hint)
|
||
|
||
extra_system_context = "\n\n".join(
|
||
s.strip() for s in system_context_sections if s and str(s).strip()
|
||
)
|
||
|
||
if use_early_agent:
|
||
active_agent = agent_early
|
||
else:
|
||
active_agent = await _create_agent_for_request(
|
||
request=request,
|
||
current_user=current_user,
|
||
model=model,
|
||
checkpointer=checkpointer,
|
||
config=config,
|
||
llm_provider=llm_provider,
|
||
api_model=api_model,
|
||
user_is_reasoner=user_is_reasoner,
|
||
extra_system_context=extra_system_context.strip() if extra_system_context.strip() else None,
|
||
)
|
||
|
||
try:
|
||
async for event_data in active_agent.astream(
|
||
{
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": request.query,
|
||
}
|
||
]
|
||
},
|
||
stream_mode="messages",
|
||
config=config
|
||
):
|
||
message, metadata = event_data
|
||
|
||
try:
|
||
message_dict = message_to_dict(message)
|
||
except Exception as e:
|
||
logger.warning(f"使用 message_to_dict 序列化失败: {e}")
|
||
message_dict = {
|
||
"content": message.content if hasattr(message, 'content') else str(message),
|
||
"type": message.__class__.__name__ if hasattr(message, '__class__') else "unknown"
|
||
}
|
||
if hasattr(message, 'additional_kwargs'):
|
||
message_dict['additional_kwargs'] = message.additional_kwargs or {}
|
||
if hasattr(message, 'response_metadata'):
|
||
message_dict['response_metadata'] = message.response_metadata or {}
|
||
if hasattr(message, 'id'):
|
||
message_dict['id'] = message.id
|
||
|
||
serializable_data = {
|
||
"message": message_dict,
|
||
"metadata": metadata
|
||
}
|
||
|
||
yield json.dumps(serializable_data, ensure_ascii=False, default=str)
|
||
except IndexError as e:
|
||
# 兼容部分 LLM + 工具流组合的 chunk 对齐问题(曾为 ChatTongyi 场景保留)
|
||
if "list index out of range" in str(e):
|
||
error_msg = "抱歉,处理您的请求时遇到了技术问题(流式工具调用错误)。请稍后重试,或尝试简化您的问题。"
|
||
logger.error(f"LangChain 流式工具调用 bug: {e}")
|
||
yield json.dumps({
|
||
"error": error_msg,
|
||
"message": {
|
||
"type": "AIMessageChunk",
|
||
"data": {
|
||
"content": error_msg,
|
||
"additional_kwargs": {},
|
||
"response_metadata": {}
|
||
}
|
||
}
|
||
}, ensure_ascii=False)
|
||
else:
|
||
raise
|
||
|
||
# 关联文件到消息
|
||
if unlinked_files:
|
||
await _associate_files_to_message(
|
||
request.thread_id,
|
||
unlinked_files,
|
||
checkpointer,
|
||
config,
|
||
pool
|
||
)
|
||
|
||
# 💾 [V2] 保存消息到 chat_messages 表(双写策略)
|
||
try:
|
||
await _save_messages_to_chat_messages_table(
|
||
thread_id=request.thread_id,
|
||
user_query=request.query,
|
||
user_message_content=request.query,
|
||
has_files=len(unlinked_files) > 0,
|
||
checkpointer=checkpointer,
|
||
config=config,
|
||
pool=pool,
|
||
system_attached_context=extra_system_context.strip() if extra_system_context.strip() else None,
|
||
)
|
||
except Exception as save_err:
|
||
# 不影响主流程,只记录日志
|
||
logger.error(f"[V2] 保存消息到 chat_messages 表失败: {save_err}")
|
||
|
||
except Exception as e:
|
||
logger.exception(f"聊天接口错误: {e}")
|
||
yield json.dumps({"error": str(e)}, ensure_ascii=False)
|
||
|
||
yield "[DONE]"
|
||
|
||
return EventSourceResponse(
|
||
generate(),
|
||
ping_message_factory=lambda: ServerSentEvent(data=f"ping - {datetime.now(timezone.utc)}")
|
||
)
|
||
|
||
|
||
def _wrap_mcp_tool_safe(tool):
|
||
"""
|
||
包装 MCP 工具,确保第三方服务异常(网络不可达、超时等)被捕获并以
|
||
错误字符串返回,而非向上抛出。
|
||
|
||
这样 LangChain Agent 在工具调用失败时仍能写入对应的 ToolMessage,
|
||
保持 checkpoint 中消息历史的完整性,避免下次对话因 tool_calls 缺少
|
||
配对回复而触发 LLM 400 BadRequest 错误。
|
||
"""
|
||
original_arun = tool._arun
|
||
|
||
async def safe_arun(*args, **kwargs):
|
||
try:
|
||
return await original_arun(*args, **kwargs)
|
||
except Exception as exc:
|
||
error_msg = f"工具 [{tool.name}] 调用失败: {exc}"
|
||
logger.warning(f"MCP 工具异常已捕获,返回错误字符串以维持对话完整性: {error_msg}")
|
||
return error_msg
|
||
|
||
tool._arun = safe_arun
|
||
return tool
|
||
|
||
|
||
async def _create_agent_for_request(
|
||
request: ChatRequest,
|
||
current_user: User,
|
||
model,
|
||
checkpointer,
|
||
config: dict,
|
||
llm_provider: str,
|
||
api_model: str,
|
||
user_is_reasoner: bool,
|
||
extra_system_context: Optional[str] = None,
|
||
):
|
||
"""
|
||
根据请求类型创建对应的 Agent
|
||
"""
|
||
ai_display_name = await EnterpriseService.resolve_ai_display_name(current_user.enterprise_id)
|
||
|
||
# 翻译模式
|
||
if request.translate:
|
||
logger.info(f"使用翻译模式")
|
||
language_names = {
|
||
'auto': '自动检测', 'zh': '中文(简体)', 'zh-TW': '中文(繁體)',
|
||
'en': '英文', 'ja': '日文', 'ko': '韩文', 'fr': '法文',
|
||
'de': '德文', 'es': '西班牙文', 'ru': '俄文'
|
||
}
|
||
from_lang = language_names.get(request.from_language, '自动检测')
|
||
target_lang = language_names.get(request.target_language, '英文')
|
||
|
||
return create_agent(
|
||
model=model,
|
||
tools=[get_current_time],
|
||
system_prompt=get_translate_instructions(from_lang, target_lang, ai_display_name),
|
||
checkpointer=checkpointer
|
||
)
|
||
|
||
# 文生视频模式
|
||
if request.text2video:
|
||
logger.info("使用文生视频模式")
|
||
return create_agent(
|
||
model=model,
|
||
tools=[get_current_time, text_to_video],
|
||
system_prompt=get_text2video_instructions(ai_display_name),
|
||
checkpointer=checkpointer
|
||
)
|
||
|
||
# 文生图模式
|
||
if request.text2img:
|
||
logger.info("使用文生图模式")
|
||
return create_agent(
|
||
model=model,
|
||
tools=[get_current_time, text_to_image],
|
||
system_prompt=get_text2img_instructions(ai_display_name),
|
||
checkpointer=checkpointer
|
||
)
|
||
|
||
# 创意海报生成模式
|
||
if request.text2poster:
|
||
logger.info("使用创意海报生成模式")
|
||
return create_agent(
|
||
model=model,
|
||
tools=[get_current_time, text_to_poster],
|
||
system_prompt=get_text2poster_instructions(ai_display_name),
|
||
checkpointer=checkpointer
|
||
)
|
||
|
||
# 普通聊天模式
|
||
mcp_client = await get_mcp_client()
|
||
mcp_tools = await mcp_client.get_tools()
|
||
# mcp_tools = [_wrap_mcp_tool_safe(t) for t in mcp_tools]
|
||
logger.info(f"成功加载 {len(mcp_tools)} 个 MCP 工具")
|
||
|
||
# 查询用户设置(深度思考是否与模型侧一致取决于 user_list.is_reasoner,
|
||
# 已在 chat_completion 中读取并传入 user_is_reasoner)
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
user_row = await conn.fetchrow(
|
||
"SELECT is_search FROM user_list WHERE id = $1",
|
||
current_user.id,
|
||
)
|
||
user_is_search = bool(user_row["is_search"]) if user_row and user_row["is_search"] is not None else False
|
||
|
||
use_reasoner_mode = user_is_reasoner
|
||
|
||
# 检查文件
|
||
has_files = await check_thread_has_files(request.thread_id)
|
||
has_kb_files = False
|
||
kb_id = request.knowledge_base_id
|
||
if kb_id and not request.knowledge_graph_id:
|
||
has_kb_files = await check_knowledge_base_has_files(kb_id, current_user.id)
|
||
|
||
kg_flags = {"has_rag": False, "neo4j_graph_id": None}
|
||
if request.knowledge_graph_id:
|
||
kg_flags = await get_knowledge_graph_tool_flags(current_user, request.knowledge_graph_id)
|
||
has_novel_rag = bool(kg_flags.get("has_rag"))
|
||
has_kg_neo4j = bool(kg_flags.get("neo4j_graph_id"))
|
||
|
||
# 获取系统提示词(绑定图谱时:有正文 RAG 和/或有 Neo4j 关系查询)
|
||
has_kg_any = has_novel_rag or has_kg_neo4j
|
||
research_instructions = get_research_instructions(
|
||
has_files=has_files,
|
||
has_kb_files=has_kb_files,
|
||
use_reasoner_mode=use_reasoner_mode,
|
||
has_knowledge_graph=has_kg_any,
|
||
has_knowledge_graph_neo4j=has_kg_neo4j,
|
||
ai_display_name=ai_display_name,
|
||
)
|
||
|
||
# 构建工具列表:Neo4j 关系查询优先于正文 RAG,便于「谁是谁」类问题先走图
|
||
all_tools = [get_current_time] + mcp_tools.copy()
|
||
if has_files:
|
||
all_tools = [create_rag_retrieve_tool(request.thread_id)] + all_tools
|
||
if request.knowledge_graph_id and has_kg_neo4j:
|
||
all_tools = [
|
||
create_knowledge_graph_neo4j_search_tool(kg_flags["neo4j_graph_id"])
|
||
] + all_tools
|
||
if request.knowledge_graph_id and has_novel_rag:
|
||
all_tools = [create_knowledge_graph_rag_retrieve_tool(request.knowledge_graph_id)] + all_tools
|
||
elif kb_id and has_kb_files:
|
||
all_tools = [create_kb_rag_retrieve_tool(kb_id)] + all_tools
|
||
if user_is_search:
|
||
all_tools = [internet_search] + all_tools
|
||
logger.info(f"用户 {current_user.username} 启用了联网搜索功能")
|
||
|
||
# 深度思考模式:递归上限调高;底层模型已在入口处按 is_reasoner 配置好 ChatDeepSeek / ChatOpenAI(通义)
|
||
if use_reasoner_mode:
|
||
config["recursion_limit"] = 60
|
||
logger.info(
|
||
f"用户 {current_user.username} 启用深度思考模式(提供方={llm_provider}, api_model={api_model})"
|
||
)
|
||
print(f"model: {model}")
|
||
|
||
merged_system_prompt = research_instructions.rstrip()
|
||
if extra_system_context and extra_system_context.strip():
|
||
merged_system_prompt = (
|
||
merged_system_prompt
|
||
+ "\n\n【本轮系统提供的参考上下文(用户原话仅在用户消息中;请优先依据此处与工具检索结果作答)】\n"
|
||
+ extra_system_context.strip()
|
||
)
|
||
|
||
return create_agent(
|
||
model=model,
|
||
tools=all_tools,
|
||
system_prompt=merged_system_prompt,
|
||
checkpointer=checkpointer,
|
||
)
|
||
|
||
|
||
async def _associate_files_to_message(
|
||
thread_id: str,
|
||
unlinked_files: list,
|
||
checkpointer,
|
||
config: dict,
|
||
pool
|
||
):
|
||
"""
|
||
将未关联的文件关联到用户消息
|
||
"""
|
||
import asyncio
|
||
|
||
latest_checkpoint = None
|
||
for attempt in range(5):
|
||
await asyncio.sleep(0.2)
|
||
try:
|
||
latest_checkpoint = await checkpointer.aget_tuple(config)
|
||
if latest_checkpoint:
|
||
break
|
||
except Exception as e:
|
||
logger.debug(f"获取 checkpoint 失败(尝试 {attempt + 1}/5): {e}")
|
||
|
||
if not latest_checkpoint:
|
||
logger.warning("无法获取最新的 checkpoint,文件关联失败")
|
||
return
|
||
|
||
try:
|
||
checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"]
|
||
checkpoint_data = latest_checkpoint.checkpoint
|
||
|
||
if "channel_values" not in checkpoint_data or "messages" not in checkpoint_data["channel_values"]:
|
||
return
|
||
|
||
messages = checkpoint_data["channel_values"]["messages"]
|
||
user_message_index = None
|
||
|
||
for idx in range(len(messages) - 1, -1, -1):
|
||
msg = messages[idx]
|
||
if hasattr(msg, 'type') and msg.type == "human":
|
||
user_message_index = idx
|
||
break
|
||
|
||
if user_message_index is not None:
|
||
async with pool.acquire() as conn:
|
||
for file_id in unlinked_files:
|
||
try:
|
||
await ChatMessageFileService.create_message_file_association(
|
||
conn,
|
||
thread_id,
|
||
checkpoint_id,
|
||
user_message_index,
|
||
file_id
|
||
)
|
||
logger.info(f"文件 {file_id} 已关联到消息")
|
||
except Exception as e:
|
||
logger.warning(f"关联文件到消息失败: {e}")
|
||
except Exception as e:
|
||
logger.warning(f"关联文件到消息失败: {e}")
|
||
|
||
|
||
async def _save_messages_to_chat_messages_table(
|
||
thread_id: str,
|
||
user_query: str,
|
||
user_message_content: str,
|
||
has_files: bool,
|
||
checkpointer,
|
||
config: dict,
|
||
pool,
|
||
system_attached_context: Optional[str] = None,
|
||
):
|
||
"""
|
||
将消息保存到 chat_messages 表(双写策略)
|
||
|
||
这是V2版本的核心:将用户原始问题和AI响应保存到独立的 chat_messages 表,
|
||
便于快速查询和分析,避免解析 checkpoint JSONB。
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
user_query: 用户原始问题
|
||
user_message_content: 用户消息内容(与 checkpoint 中 human 一致;可与 user_query 相同)
|
||
system_attached_context: 附在系统提示上的参考上下文(不入用户原话),写入 injected_content
|
||
has_files: 是否关联了文件
|
||
checkpointer: checkpoint 管理器
|
||
config: 配置
|
||
pool: 数据库连接池
|
||
"""
|
||
import asyncio
|
||
|
||
try:
|
||
# 等待 checkpoint 更新
|
||
latest_checkpoint = None
|
||
for attempt in range(5):
|
||
await asyncio.sleep(0.2)
|
||
try:
|
||
latest_checkpoint = await checkpointer.aget_tuple(config)
|
||
if latest_checkpoint:
|
||
break
|
||
except Exception as e:
|
||
logger.debug(f"获取 checkpoint 失败(尝试 {attempt + 1}/5): {e}")
|
||
|
||
if not latest_checkpoint:
|
||
logger.warning("无法获取最新的 checkpoint,消息保存失败")
|
||
return
|
||
|
||
checkpoint_id = latest_checkpoint.config["configurable"]["checkpoint_id"]
|
||
checkpoint_data = latest_checkpoint.checkpoint
|
||
|
||
if "channel_values" not in checkpoint_data or "messages" not in checkpoint_data["channel_values"]:
|
||
return
|
||
|
||
messages = checkpoint_data["channel_values"]["messages"]
|
||
|
||
if len(messages) < 2:
|
||
logger.warning(f"[V2] 消息数量不足,跳过保存: messages_count={len(messages)}")
|
||
return
|
||
|
||
# 🔥 核心修复:找到最后一个用户消息,保存从该消息开始的所有后续消息
|
||
# 这样可以包含完整的对话轮次:用户消息 → AI调用工具 → 工具结果 → AI最终回复
|
||
last_user_index = None
|
||
for idx in range(len(messages) - 1, -1, -1):
|
||
msg = messages[idx]
|
||
if hasattr(msg, 'type') and msg.type == "human":
|
||
last_user_index = idx
|
||
break
|
||
|
||
if last_user_index is None:
|
||
logger.warning(f"[V2] 未找到用户消息,跳过保存")
|
||
return
|
||
|
||
# 获取从最后一个用户消息开始的所有消息
|
||
messages_to_save = messages[last_user_index:]
|
||
|
||
async with pool.acquire() as conn:
|
||
# 检查是否已经保存过(通过检查最后一条 AI 消息)
|
||
last_ai_index = len(messages) - 1
|
||
existing = await conn.fetchval(
|
||
"""
|
||
SELECT id FROM chat_messages
|
||
WHERE thread_id = $1 AND checkpoint_id = $2 AND message_index = $3
|
||
""",
|
||
thread_id, checkpoint_id, last_ai_index
|
||
)
|
||
|
||
if existing:
|
||
logger.debug(f"[V2] 消息已存在,跳过保存: thread_id={thread_id}, checkpoint_id={checkpoint_id}")
|
||
return
|
||
|
||
# 保存当前轮次的所有消息
|
||
for relative_idx, msg in enumerate(messages_to_save):
|
||
actual_index = last_user_index + relative_idx
|
||
|
||
if not hasattr(msg, 'type'):
|
||
continue
|
||
|
||
msg_type = msg.type
|
||
msg_content = getattr(msg, 'content', '') or ''
|
||
|
||
# 检查单条消息是否已存在
|
||
existing = await conn.fetchval(
|
||
"SELECT id FROM chat_messages WHERE thread_id = $1 AND checkpoint_id = $2 AND message_index = $3",
|
||
thread_id, checkpoint_id, actual_index
|
||
)
|
||
|
||
if existing:
|
||
logger.debug(f"[V2] 消息已存在: index={actual_index}, type={msg_type}")
|
||
continue
|
||
|
||
if msg_type == "human":
|
||
# 保存用户消息(只在第一条消息时使用传入的参数)
|
||
if relative_idx == 0:
|
||
# 当前轮用户消息:正文为原始问题;系统附加材料单独落库
|
||
if system_attached_context and system_attached_context.strip():
|
||
injected_content = system_attached_context.strip()
|
||
else:
|
||
injected_content = (
|
||
user_message_content if user_message_content != user_query else None
|
||
)
|
||
await ChatMessageService.save_user_message(
|
||
conn,
|
||
thread_id=thread_id,
|
||
checkpoint_id=checkpoint_id,
|
||
message_index=actual_index,
|
||
content=user_query, # 用户原始问题
|
||
injected_content=injected_content, # 注入的完整内容(如果有)
|
||
has_files=has_files,
|
||
metadata={}
|
||
)
|
||
logger.info(f"✅ [V2] 保存用户消息: index={actual_index}, content='{user_query[:50]}...', has_files={has_files}")
|
||
else:
|
||
# 理论上不应该出现第二条是 human 的情况
|
||
await ChatMessageService.save_user_message(
|
||
conn,
|
||
thread_id=thread_id,
|
||
checkpoint_id=checkpoint_id,
|
||
message_index=actual_index,
|
||
content=msg_content,
|
||
injected_content=None,
|
||
has_files=False,
|
||
metadata={}
|
||
)
|
||
logger.warning(f"⚠️ [V2] 非预期的用户消息位置: index={actual_index}")
|
||
|
||
elif msg_type == "ai":
|
||
# 保存 AI 消息
|
||
reasoning_content = ""
|
||
if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs:
|
||
reasoning_content = msg.additional_kwargs.get("reasoning_content", "") or ""
|
||
|
||
token_usage = {}
|
||
finish_reason = ""
|
||
if hasattr(msg, 'response_metadata') and msg.response_metadata:
|
||
token_usage = msg.response_metadata.get('token_usage', {})
|
||
finish_reason = msg.response_metadata.get('finish_reason', '')
|
||
|
||
await ChatMessageService.save_assistant_message(
|
||
conn,
|
||
thread_id=thread_id,
|
||
checkpoint_id=checkpoint_id,
|
||
message_index=actual_index,
|
||
content=msg_content,
|
||
metadata={
|
||
'token_usage': token_usage,
|
||
'finish_reason': finish_reason,
|
||
'reasoning_content': reasoning_content
|
||
}
|
||
)
|
||
logger.info(f"✅ [V2] 保存AI消息: index={actual_index}, content_length={len(msg_content)}")
|
||
|
||
elif msg_type == "tool":
|
||
# 保存工具消息
|
||
tool_name = getattr(msg, 'name', '') or ''
|
||
await ChatMessageService.save_tool_message(
|
||
conn,
|
||
thread_id=thread_id,
|
||
checkpoint_id=checkpoint_id,
|
||
message_index=actual_index,
|
||
content=msg_content,
|
||
name=tool_name, # 工具名称(如 text_to_poster, internet_search 等)
|
||
metadata={'tool_name': tool_name} # 保留在 metadata 中以兼容
|
||
)
|
||
logger.info(f"✅ [V2] 保存工具消息: index={actual_index}, tool={tool_name}")
|
||
|
||
logger.info(f"✅ [V2] 消息保存完成: thread_id={thread_id}, checkpoint_id={checkpoint_id}, saved_count={len(messages_to_save)}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[V2] 保存消息到 chat_messages 表失败: {e}")
|
||
# 不抛出异常,避免影响主流程
|
||
|
||
|
||
@chat_router.get("/chat/threads", summary="获取会话列表", response_model=ChatThreadListResponse)
|
||
async def get_threads(
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取用户的会话列表(分页)"""
|
||
logger.info(f"用户 {current_user.username} 查询会话列表: page={page}")
|
||
|
||
if page < 1:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="页码必须大于 0")
|
||
if page_size < 1 or page_size > 100:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="每页数量必须在 1-100 之间")
|
||
|
||
try:
|
||
return await get_user_chat_threads(user_id=current_user.id, page=page, page_size=page_size)
|
||
except Exception as e:
|
||
logger.error(f"查询会话列表失败: {e}")
|
||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="查询会话列表失败")
|
||
|
||
|
||
@chat_router.get("/chat/thread/{thread_id}", summary="获取会话明细", response_model=ChatThreadDetailResponse)
|
||
async def get_thread_detail(
|
||
thread_id: str,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""获取会话的聊天明细"""
|
||
logger.info(f"用户 {current_user.username} 查询会话明细: thread_id={thread_id}")
|
||
return await get_chat_thread_detail(thread_id=thread_id, user_id=current_user.id)
|
||
|
||
|
||
@chat_router.delete("/chat/thread", summary="删除会话")
|
||
async def delete_thread(
|
||
request: DeleteThreadRequest,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""删除聊天会话(软删除)"""
|
||
logger.info(f"用户 {current_user.username} 请求删除会话: {request.thread_id}")
|
||
|
||
try:
|
||
await delete_chat_thread(thread_id=request.thread_id, user_id=current_user.id)
|
||
return BaseResponse(code=200, msg="会话删除成功", data={"thread_id": request.thread_id})
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"删除会话失败: {e}")
|
||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除会话失败")
|
||
|
||
|
||
# ============ V2 版本路由(基于 chat_messages 表) ============
|
||
|
||
@chat_router.get("/chat/thread/{thread_id}/v2", summary="获取会话明细 V2(推荐)", response_model=ChatThreadDetailResponse)
|
||
async def get_thread_detail_v2(
|
||
thread_id: str,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取会话的聊天明细(V2版本)
|
||
|
||
**V2 优势**:
|
||
- ✅ 查询速度更快(直接从 chat_messages 表查询,无需解析 checkpoint JSONB)
|
||
- ✅ 用户原始问题和注入内容分离(返回的 content 只包含用户原始问题)
|
||
- ✅ 支持全文搜索、统计分析
|
||
- ✅ 数据结构更清晰
|
||
|
||
**注意**:需要先启用消息双写功能,才能使用此接口
|
||
"""
|
||
logger.info(f"用户 {current_user.username} 查询会话明细 V2: thread_id={thread_id}")
|
||
return await get_chat_thread_detail_v2(thread_id=thread_id, user_id=current_user.id)
|
||
|
||
|
||
# ============ 兼容旧路由(保持向后兼容) ============
|
||
# 这些路由已迁移到 /api/user/ 前缀下,但为了兼容旧客户端保留
|
||
|
||
from services.user_setting_service import UserSettingService
|
||
from models.chat import (
|
||
SearchSettingResponse,
|
||
UpdateSearchSettingRequest,
|
||
ReasonerSettingResponse,
|
||
UpdateReasonerSettingRequest,
|
||
)
|
||
|
||
|
||
@chat_router.get("/chat/search-setting", summary="获取用户联网搜索设置(兼容)", response_model=SearchSettingResponse, deprecated=True)
|
||
async def get_search_setting_compat(current_user: User = Depends(get_current_user)):
|
||
"""获取当前用户的联网搜索设置(已迁移到 /api/user/search-setting)"""
|
||
is_search = await UserSettingService.get_search_setting(current_user.id)
|
||
return SearchSettingResponse(is_search=is_search)
|
||
|
||
|
||
@chat_router.put("/chat/search-setting", summary="更新用户联网搜索设置(兼容)", response_model=SearchSettingResponse, deprecated=True)
|
||
async def update_search_setting_compat(
|
||
request: UpdateSearchSettingRequest,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""更新当前用户的联网搜索设置(已迁移到 /api/user/search-setting)"""
|
||
is_search = await UserSettingService.update_search_setting(current_user.id, request.is_search)
|
||
return SearchSettingResponse(is_search=is_search)
|
||
|
||
|
||
@chat_router.get("/chat/reasoner-setting", summary="获取用户深度思考设置(兼容)", response_model=ReasonerSettingResponse, deprecated=True)
|
||
async def get_reasoner_setting_compat(current_user: User = Depends(get_current_user)):
|
||
"""获取当前用户的深度思考设置(已迁移到 /api/user/reasoner-setting)"""
|
||
is_reasoner = await UserSettingService.get_reasoner_setting(current_user.id)
|
||
return ReasonerSettingResponse(is_reasoner=is_reasoner)
|
||
|
||
|
||
@chat_router.put("/chat/reasoner-setting", summary="更新用户深度思考设置(兼容)", response_model=ReasonerSettingResponse, deprecated=True)
|
||
async def update_reasoner_setting_compat(
|
||
request: UpdateReasonerSettingRequest,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""更新当前用户的深度思考设置(已迁移到 /api/user/reasoner-setting)"""
|
||
is_reasoner = await UserSettingService.update_reasoner_setting(current_user.id, request.is_reasoner)
|
||
return ReasonerSettingResponse(is_reasoner=is_reasoner)
|