huoyan-enterprise/backend/api/chat_title.py

248 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天标题相关 API 路由模块
处理会话标题的生成和重命名功能。
"""
from fastapi import APIRouter, Depends, HTTPException, status
from utils.helpers import BaseResponse
from logger.logging import get_logger
from core.dependencies import get_current_user
from models.user import User
from core.database import get_db_pool
from models.chat import (
GenerateTitleRequest,
GenerateTitleResponse,
RenameThreadRequest
)
from core.llm_catalog import build_chat_model
# 获取日志记录器
logger = get_logger(__name__)
# 创建路由实例
chat_title_router = APIRouter(prefix="/api", tags=["聊天标题接口"])
@chat_title_router.put("/chat/thread/{thread_id}/rename", summary="重命名会话", response_model=BaseResponse)
async def rename_thread(
thread_id: str,
request: RenameThreadRequest,
current_user: User = Depends(get_current_user)
):
"""
重命名聊天会话
Args:
thread_id: 会话线程 ID路径参数
request: 重命名请求数据(包含新标题)
current_user: 当前登录用户
Returns:
BaseResponse: 重命名结果
Raises:
HTTPException: 会话不存在、无权限或会话已删除
"""
logger.info(f"用户 {current_user.username} (ID: {current_user.id}) 请求重命名会话: {thread_id}, 新标题: {request.title}")
pool = await get_db_pool()
async with pool.acquire() as conn:
# 检查会话是否存在且属于该用户
thread_info = await conn.fetchrow(
"""
SELECT id, user_id, is_deleted
FROM chat_threads
WHERE thread_id = $1
""",
thread_id
)
if not thread_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话不存在"
)
if thread_info['user_id'] != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权限访问该会话"
)
if thread_info['is_deleted']:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话已被删除"
)
# 更新标题
await conn.execute(
"""
UPDATE chat_threads
SET title = $1,
updated_at = CURRENT_TIMESTAMP
WHERE thread_id = $2
""",
request.title,
thread_id
)
logger.info(f"成功重命名会话: thread_id={thread_id}, 新标题='{request.title}'")
return BaseResponse(
code=200,
msg="重命名成功",
data={"thread_id": thread_id, "title": request.title}
)
@chat_title_router.post("/chat/generate-title", summary="生成会话标题", response_model=GenerateTitleResponse)
async def generate_title(
request: GenerateTitleRequest,
current_user: User = Depends(get_current_user)
):
"""
根据用户的查询内容生成简洁的会话标题
Args:
request: 生成标题请求数据(包含 thread_id 和用户查询内容)
current_user: 当前登录用户
Returns:
GenerateTitleResponse: 生成的标题
Raises:
HTTPException: 会话不存在、无权限或会话已删除
"""
logger.info(f"用户 {current_user.username} (ID: {current_user.id}) 请求生成标题thread_id: {request.thread_id}, query: {request.query[:50]}...")
pool = await get_db_pool()
async with pool.acquire() as conn:
# 检查会话是否存在且属于该用户
thread_info = await conn.fetchrow(
"""
SELECT id, user_id, is_deleted
FROM chat_threads
WHERE thread_id = $1
""",
request.thread_id
)
if not thread_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话不存在"
)
if thread_info['user_id'] != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权限访问该会话"
)
if thread_info['is_deleted']:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话已被删除"
)
try:
# 标题生成默认走 DeepSeek短文本、低温度
from langchain_deepseek import ChatDeepSeek
import os
model = ChatDeepSeek(
model="deepseek-chat",
api_key=os.getenv("DEEPSEEK_API_KEY"),
base_url=os.getenv("DEEPSEEK_BASE_URL"),
streaming=False,
temperature=0.1,
)
# 创建专门用于生成标题的 system prompt
system_message = """你是一个专业的标题生成助手。你的任务是根据用户的问题生成一个简洁的标题。
严格要求:
1. 只返回标题文本,不要有任何其他内容
2. 标题长度2-10个汉字
3. 不要包含标点符号
4. 不要有引号、冒号等任何符号
5. 直接返回标题,不要解释
示例:
用户:"今天苏州天气怎么样啊" -> 苏州天气
用户:"请帮我写一个Python爬虫" -> Python爬虫
用户:"如何学习机器学习" -> 机器学习入门"""
# 构造消息
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": f"请为以下问题生成标题:{request.query}"}
]
# 调用模型生成标题(非流式)
response = await model.ainvoke(messages)
# 提取生成的标题
title = response.content.strip()
# 清理标题:移除可能的引号、标点符号等
title = title.strip('"\'""''「」『』【】《》::。,,、!!?')
# 如果生成失败或标题为空,使用默认逻辑
if not title or len(title) < 2:
title = request.query[:10] if len(request.query) <= 10 else request.query[:10]
logger.warning(f"AI 生成标题失败或过短,使用默认逻辑: {title}")
# 确保标题长度合理(最多 20 个字符)
if len(title) > 20:
title = title[:20]
# 更新数据库中的标题
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE chat_threads
SET title = $1,
updated_at = CURRENT_TIMESTAMP
WHERE thread_id = $2
""",
title,
request.thread_id
)
logger.info(f"成功生成并更新标题: '{title}', thread_id: {request.thread_id}")
return GenerateTitleResponse(
title=title,
original_query=request.query
)
except HTTPException:
# 重新抛出 HTTP 异常
raise
except Exception as e:
logger.error(f"生成标题失败: {e}")
# 降级处理:使用简单的截取逻辑
fallback_title = request.query[:10] if len(request.query) <= 10 else request.query[:10]
logger.info(f"使用降级标题: {fallback_title}")
# 更新数据库中的标题
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE chat_threads
SET title = $1,
updated_at = CURRENT_TIMESTAMP
WHERE thread_id = $2
""",
fallback_title,
request.thread_id
)
return GenerateTitleResponse(
title=fallback_title,
original_query=request.query
)