248 lines
7.9 KiB
Python
248 lines
7.9 KiB
Python
"""
|
||
聊天标题相关 API 路由模块
|
||
|
||
处理会话标题的生成和重命名功能。
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
|
||
from utils.helpers import BaseResponse
|
||
from logger.logging import get_logger
|
||
from core.dependencies import get_current_user
|
||
from models.user import User
|
||
from core.database import get_db_pool
|
||
from models.chat import (
|
||
GenerateTitleRequest,
|
||
GenerateTitleResponse,
|
||
RenameThreadRequest
|
||
)
|
||
from core.llm_catalog import build_chat_model
|
||
|
||
# 获取日志记录器
|
||
logger = get_logger(__name__)
|
||
|
||
# 创建路由实例
|
||
chat_title_router = APIRouter(prefix="/api", tags=["聊天标题接口"])
|
||
|
||
|
||
@chat_title_router.put("/chat/thread/{thread_id}/rename", summary="重命名会话", response_model=BaseResponse)
|
||
async def rename_thread(
|
||
thread_id: str,
|
||
request: RenameThreadRequest,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
重命名聊天会话
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID(路径参数)
|
||
request: 重命名请求数据(包含新标题)
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
BaseResponse: 重命名结果
|
||
|
||
Raises:
|
||
HTTPException: 会话不存在、无权限或会话已删除
|
||
"""
|
||
logger.info(f"用户 {current_user.username} (ID: {current_user.id}) 请求重命名会话: {thread_id}, 新标题: {request.title}")
|
||
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
# 检查会话是否存在且属于该用户
|
||
thread_info = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id, is_deleted
|
||
FROM chat_threads
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
if not thread_info:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="会话不存在"
|
||
)
|
||
|
||
if thread_info['user_id'] != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="无权限访问该会话"
|
||
)
|
||
|
||
if thread_info['is_deleted']:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="会话已被删除"
|
||
)
|
||
|
||
# 更新标题
|
||
await conn.execute(
|
||
"""
|
||
UPDATE chat_threads
|
||
SET title = $1,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE thread_id = $2
|
||
""",
|
||
request.title,
|
||
thread_id
|
||
)
|
||
|
||
logger.info(f"成功重命名会话: thread_id={thread_id}, 新标题='{request.title}'")
|
||
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="重命名成功",
|
||
data={"thread_id": thread_id, "title": request.title}
|
||
)
|
||
|
||
|
||
@chat_title_router.post("/chat/generate-title", summary="生成会话标题", response_model=GenerateTitleResponse)
|
||
async def generate_title(
|
||
request: GenerateTitleRequest,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
根据用户的查询内容生成简洁的会话标题
|
||
|
||
Args:
|
||
request: 生成标题请求数据(包含 thread_id 和用户查询内容)
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
GenerateTitleResponse: 生成的标题
|
||
|
||
Raises:
|
||
HTTPException: 会话不存在、无权限或会话已删除
|
||
"""
|
||
logger.info(f"用户 {current_user.username} (ID: {current_user.id}) 请求生成标题,thread_id: {request.thread_id}, query: {request.query[:50]}...")
|
||
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as conn:
|
||
# 检查会话是否存在且属于该用户
|
||
thread_info = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id, is_deleted
|
||
FROM chat_threads
|
||
WHERE thread_id = $1
|
||
""",
|
||
request.thread_id
|
||
)
|
||
|
||
if not thread_info:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="会话不存在"
|
||
)
|
||
|
||
if thread_info['user_id'] != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="无权限访问该会话"
|
||
)
|
||
|
||
if thread_info['is_deleted']:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="会话已被删除"
|
||
)
|
||
|
||
try:
|
||
# 标题生成默认走 DeepSeek(短文本、低温度)
|
||
from langchain_deepseek import ChatDeepSeek
|
||
import os
|
||
model = ChatDeepSeek(
|
||
model="deepseek-chat",
|
||
api_key=os.getenv("DEEPSEEK_API_KEY"),
|
||
base_url=os.getenv("DEEPSEEK_BASE_URL"),
|
||
streaming=False,
|
||
temperature=0.1,
|
||
)
|
||
|
||
# 创建专门用于生成标题的 system prompt
|
||
system_message = """你是一个专业的标题生成助手。你的任务是根据用户的问题生成一个简洁的标题。
|
||
|
||
严格要求:
|
||
1. 只返回标题文本,不要有任何其他内容
|
||
2. 标题长度:2-10个汉字
|
||
3. 不要包含标点符号
|
||
4. 不要有引号、冒号等任何符号
|
||
5. 直接返回标题,不要解释
|
||
|
||
示例:
|
||
用户:"今天苏州天气怎么样啊" -> 苏州天气
|
||
用户:"请帮我写一个Python爬虫" -> Python爬虫
|
||
用户:"如何学习机器学习" -> 机器学习入门"""
|
||
|
||
# 构造消息
|
||
messages = [
|
||
{"role": "system", "content": system_message},
|
||
{"role": "user", "content": f"请为以下问题生成标题:{request.query}"}
|
||
]
|
||
|
||
# 调用模型生成标题(非流式)
|
||
response = await model.ainvoke(messages)
|
||
|
||
# 提取生成的标题
|
||
title = response.content.strip()
|
||
|
||
# 清理标题:移除可能的引号、标点符号等
|
||
title = title.strip('"\'""''「」『』【】《》::。,,、!!??')
|
||
|
||
# 如果生成失败或标题为空,使用默认逻辑
|
||
if not title or len(title) < 2:
|
||
title = request.query[:10] if len(request.query) <= 10 else request.query[:10]
|
||
logger.warning(f"AI 生成标题失败或过短,使用默认逻辑: {title}")
|
||
|
||
# 确保标题长度合理(最多 20 个字符)
|
||
if len(title) > 20:
|
||
title = title[:20]
|
||
|
||
# 更新数据库中的标题
|
||
async with pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE chat_threads
|
||
SET title = $1,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE thread_id = $2
|
||
""",
|
||
title,
|
||
request.thread_id
|
||
)
|
||
|
||
logger.info(f"成功生成并更新标题: '{title}', thread_id: {request.thread_id}")
|
||
|
||
return GenerateTitleResponse(
|
||
title=title,
|
||
original_query=request.query
|
||
)
|
||
|
||
except HTTPException:
|
||
# 重新抛出 HTTP 异常
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"生成标题失败: {e}")
|
||
# 降级处理:使用简单的截取逻辑
|
||
fallback_title = request.query[:10] if len(request.query) <= 10 else request.query[:10]
|
||
logger.info(f"使用降级标题: {fallback_title}")
|
||
|
||
# 更新数据库中的标题
|
||
async with pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE chat_threads
|
||
SET title = $1,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE thread_id = $2
|
||
""",
|
||
fallback_title,
|
||
request.thread_id
|
||
)
|
||
|
||
return GenerateTitleResponse(
|
||
title=fallback_title,
|
||
original_query=request.query
|
||
)
|
||
|