huoyan-enterprise/backend/services/chat_thread_file_service.py

526 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天对话文件服务
"""
import os
import json
from typing import Optional, List, Tuple
from pathlib import Path
import asyncpg
from datetime import datetime
from models.chat_thread_file import ChatThreadFile, ChatThreadChunk
from logger.logging import get_logger
logger = get_logger(__name__)
class ChatThreadFileService:
"""聊天对话文件服务类"""
@staticmethod
async def create_file_record(
conn: asyncpg.Connection,
thread_id: str,
user_id: int,
file_name: str,
file_path: str,
file_size: int,
file_type: str = "pdf"
) -> ChatThreadFile:
"""
创建文件记录
Args:
conn: 数据库连接
thread_id: 会话线程 ID
user_id: 用户 ID
file_name: 文件名
file_path: 文件路径
file_size: 文件大小
file_type: 文件类型
Returns:
ChatThreadFile: 创建的文件记录
"""
try:
# 检查文件名是否已存在(同一 thread_id 下)
existing = await conn.fetchrow(
"""
SELECT id FROM chat_thread_file
WHERE thread_id = $1 AND file_name = $2 AND is_deleted = FALSE
""",
thread_id, file_name
)
if existing:
raise ValueError(f"文件 '{file_name}' 已存在于该对话中")
# 插入文件记录
row = await conn.fetchrow(
"""
INSERT INTO chat_thread_file
(thread_id, user_id, file_name, file_path, file_size, file_type, status)
VALUES ($1, $2, $3, $4, $5, $6, 'processing')
RETURNING id, thread_id, user_id, file_name, file_path, file_size,
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
""",
thread_id, user_id, file_name, file_path, file_size, file_type
)
logger.info(f"创建文件记录: {file_name}, thread_id: {thread_id}")
return ChatThreadFile(**dict(row))
except ValueError:
raise
except Exception as e:
logger.error(f"创建文件记录失败: {e}")
raise Exception(f"创建文件记录失败: {str(e)}")
@staticmethod
async def update_file_status(
conn: asyncpg.Connection,
file_id: int,
status: str,
chunk_count: int = 0
) -> bool:
"""
更新文件状态
Args:
conn: 数据库连接
file_id: 文件 ID
status: 状态processing/completed/failed
chunk_count: 分块数量
Returns:
bool: 是否更新成功
"""
try:
result = await conn.execute(
"""
UPDATE chat_thread_file
SET status = $1, chunk_count = $2
WHERE id = $3
""",
status, chunk_count, file_id
)
return result == "UPDATE 1"
except Exception as e:
logger.error(f"更新文件状态失败: {e}")
return False
@staticmethod
async def save_chunks(
conn: asyncpg.Connection,
file_id: int,
thread_id: str,
chunks: List[Tuple[int, str, dict, str]],
summary: Optional[str] = None
) -> int:
"""
批量保存文档块
Args:
conn: 数据库连接
file_id: 文件 ID
thread_id: 会话线程 ID
chunks: 文档块列表 [(chunk_index, content, metadata, vector_id), ...]
summary: 文件摘要(可选)
Returns:
int: 保存的块数量
"""
try:
# 批量插入每个chunk都保存summary便于独立检索
records = [
(file_id, thread_id, chunk_index, content, json.dumps(metadata), vector_id, summary)
for chunk_index, content, metadata, vector_id in chunks
]
await conn.executemany(
"""
INSERT INTO chat_thread_chunk
(file_id, thread_id, chunk_index, content, metadata, vector_id, summary)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
records
)
logger.info(f"保存 {len(chunks)} 个文档块,文件 ID: {file_id}, 摘要: {'已保存' if summary else ''}")
return len(chunks)
except Exception as e:
logger.error(f"保存文档块失败: {e}")
raise Exception(f"保存文档块失败: {str(e)}")
@staticmethod
async def get_file_by_id(
conn: asyncpg.Connection,
file_id: int,
user_id: int
) -> Optional[ChatThreadFile]:
"""
根据 ID 获取文件
Args:
conn: 数据库连接
file_id: 文件 ID
user_id: 用户 ID用于权限验证
Returns:
Optional[ChatThreadFile]: 文件对象,如果不存在则返回 None
"""
try:
row = await conn.fetchrow(
"""
SELECT id, thread_id, user_id, file_name, file_path, file_size,
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
FROM chat_thread_file
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
""",
file_id, user_id
)
if row:
return ChatThreadFile(**dict(row))
return None
except Exception as e:
logger.error(f"获取文件失败: {e}")
raise Exception(f"获取文件失败: {str(e)}")
@staticmethod
async def get_recent_files_with_summary(
conn: asyncpg.Connection,
thread_id: str,
limit: int = 10
) -> List[dict]:
"""
获取会话中最近上传的文件及其摘要(无时间限制)
Args:
conn: 数据库连接
thread_id: 会话线程 ID
limit: 限制返回数量
Returns:
List[dict]: 文件列表,包含摘要信息 [{"file_name": "xxx", "file_type": "png", "summary": "xxx"}, ...]
"""
try:
rows = await conn.fetch(
"""
SELECT
f.file_name,
f.file_type,
c.summary
FROM chat_thread_file f
LEFT JOIN chat_thread_chunk c ON f.id = c.file_id AND c.chunk_index = 0
WHERE f.thread_id = $1
AND f.is_deleted = FALSE
AND f.status = 'completed'
AND c.summary IS NOT NULL
AND c.summary != ''
ORDER BY f.created_at DESC
LIMIT $2
""",
thread_id, limit
)
result = []
for row in rows:
result.append({
"file_name": row['file_name'],
"file_type": row['file_type'],
"summary": row['summary']
})
return result
except Exception as e:
logger.error(f"获取文件摘要失败: {e}")
return []
@staticmethod
async def get_files_by_thread(
conn: asyncpg.Connection,
thread_id: str,
user_id: int,
page: int = 1,
page_size: int = 20
) -> Tuple[List[ChatThreadFile], int]:
"""
获取会话的文件列表
Args:
conn: 数据库连接
thread_id: 会话线程 ID
user_id: 用户 ID用于权限验证
page: 页码(从 1 开始)
page_size: 每页数量
Returns:
Tuple[List[ChatThreadFile], int]: (文件列表, 总数量)
"""
try:
# 计算偏移量
offset = (page - 1) * page_size
# 获取总数
total = await conn.fetchval(
"""
SELECT COUNT(*) FROM chat_thread_file
WHERE thread_id = $1 AND user_id = $2 AND is_deleted = FALSE
""",
thread_id, user_id
)
# 获取列表
rows = await conn.fetch(
"""
SELECT id, thread_id, user_id, file_name, file_path, file_size,
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
FROM chat_thread_file
WHERE thread_id = $1 AND user_id = $2 AND is_deleted = FALSE
ORDER BY created_at DESC
LIMIT $3 OFFSET $4
""",
thread_id, user_id, page_size, offset
)
files = [ChatThreadFile(**dict(row)) for row in rows]
return files, total
except Exception as e:
logger.error(f"获取文件列表失败: {e}")
raise Exception(f"获取文件列表失败: {str(e)}")
@staticmethod
async def get_all_files_by_thread(
conn: asyncpg.Connection,
thread_id: str
) -> List[ChatThreadFile]:
"""
获取会话的所有文件(用于删除会话时清理)
Args:
conn: 数据库连接
thread_id: 会话线程 ID
Returns:
List[ChatThreadFile]: 文件列表
"""
try:
rows = await conn.fetch(
"""
SELECT id, thread_id, user_id, file_name, file_path, file_size,
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
FROM chat_thread_file
WHERE thread_id = $1 AND is_deleted = FALSE
""",
thread_id
)
return [ChatThreadFile(**dict(row)) for row in rows]
except Exception as e:
logger.error(f"获取文件列表失败: {e}")
raise Exception(f"获取文件列表失败: {str(e)}")
@staticmethod
async def get_thread_all_vector_ids(
conn: asyncpg.Connection,
thread_id: str
) -> List[str]:
"""
获取会话的所有向量 ID用于删除向量
Args:
conn: 数据库连接
thread_id: 会话线程 ID
Returns:
List[str]: 向量 ID 列表
"""
try:
rows = await conn.fetch(
"""
SELECT vector_id FROM chat_thread_chunk
WHERE thread_id = $1 AND vector_id IS NOT NULL
""",
thread_id
)
return [row['vector_id'] for row in rows if row['vector_id']]
except Exception as e:
logger.error(f"获取向量 ID 列表失败: {e}")
return []
@staticmethod
async def get_file_vector_ids(
conn: asyncpg.Connection,
file_id: int
) -> List[str]:
"""
获取文件的所有向量 ID
Args:
conn: 数据库连接
file_id: 文件 ID
Returns:
List[str]: 向量 ID 列表
"""
try:
rows = await conn.fetch(
"""
SELECT vector_id FROM chat_thread_chunk
WHERE file_id = $1 AND vector_id IS NOT NULL
""",
file_id
)
return [row['vector_id'] for row in rows if row['vector_id']]
except Exception as e:
logger.error(f"获取向量 ID 列表失败: {e}")
return []
@staticmethod
async def get_file_chunks_from_db(
conn: asyncpg.Connection,
file_id: int
) -> List[dict]:
"""
从 PostgreSQL 获取文件的所有 chunks包括 summary
用于注入完整内容到 AI 上下文
Args:
conn: 数据库连接
file_id: 文件 ID
Returns:
List[dict]: [{"content": str, "summary": str, "chunk_index": int}]
"""
try:
rows = await conn.fetch(
"""
SELECT chunk_index, content, summary
FROM chat_thread_chunk
WHERE file_id = $1
ORDER BY chunk_index
""",
file_id
)
chunks = [
{
"chunk_index": row['chunk_index'],
"content": row['content'],
"summary": row['summary'] or ''
}
for row in rows
]
logger.info(f"从数据库获取文件chunks: file_id={file_id}, chunks数量={len(chunks)}")
return chunks
except Exception as e:
logger.error(f"从数据库获取文件chunks失败: {e}")
return []
@staticmethod
async def delete_file(
conn: asyncpg.Connection,
file_id: int,
user_id: int
) -> Tuple[bool, List[str]]:
"""
删除文件(软删除),同时返回向量 ID 列表
Args:
conn: 数据库连接
file_id: 文件 ID
user_id: 用户 ID用于权限验证
Returns:
Tuple[bool, List[str]]: (是否删除成功, 向量 ID 列表)
"""
try:
# 先获取向量 ID 列表
vector_ids = await ChatThreadFileService.get_file_vector_ids(conn, file_id)
# 检查文件是否存在且属于该用户
existing = await conn.fetchrow(
"""
SELECT id FROM chat_thread_file
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
""",
file_id, user_id
)
if not existing:
return False, []
# 软删除文件
await conn.execute(
"""
UPDATE chat_thread_file
SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP
WHERE id = $1
""",
file_id
)
# 删除文档块(物理删除,因为文件已删除)
await conn.execute(
"""
DELETE FROM chat_thread_chunk
WHERE file_id = $1
""",
file_id
)
logger.info(f"删除文件成功: file_id={file_id}, 向量数={len(vector_ids)}")
return True, vector_ids
except Exception as e:
logger.error(f"删除文件失败: {e}")
raise Exception(f"删除文件失败: {str(e)}")
@staticmethod
async def delete_thread_all_chunks(
conn: asyncpg.Connection,
thread_id: str
) -> int:
"""
删除会话的所有文档块(用于删除会话时清理)
Args:
conn: 数据库连接
thread_id: 会话线程 ID
Returns:
int: 删除的块数量
"""
try:
result = await conn.execute(
"""
DELETE FROM chat_thread_chunk
WHERE thread_id = $1
""",
thread_id
)
# 解析删除的行数
deleted_count = int(result.split()[-1]) if result else 0
logger.info(f"删除会话 {thread_id}{deleted_count} 个文档块")
return deleted_count
except Exception as e:
logger.error(f"删除文档块失败: {e}")
return 0