526 lines
16 KiB
Python
526 lines
16 KiB
Python
"""
|
||
聊天对话文件服务
|
||
"""
|
||
import os
|
||
import json
|
||
from typing import Optional, List, Tuple
|
||
from pathlib import Path
|
||
import asyncpg
|
||
from datetime import datetime
|
||
|
||
from models.chat_thread_file import ChatThreadFile, ChatThreadChunk
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class ChatThreadFileService:
|
||
"""聊天对话文件服务类"""
|
||
|
||
@staticmethod
|
||
async def create_file_record(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
user_id: int,
|
||
file_name: str,
|
||
file_path: str,
|
||
file_size: int,
|
||
file_type: str = "pdf"
|
||
) -> ChatThreadFile:
|
||
"""
|
||
创建文件记录
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
user_id: 用户 ID
|
||
file_name: 文件名
|
||
file_path: 文件路径
|
||
file_size: 文件大小
|
||
file_type: 文件类型
|
||
|
||
Returns:
|
||
ChatThreadFile: 创建的文件记录
|
||
"""
|
||
try:
|
||
# 检查文件名是否已存在(同一 thread_id 下)
|
||
existing = await conn.fetchrow(
|
||
"""
|
||
SELECT id FROM chat_thread_file
|
||
WHERE thread_id = $1 AND file_name = $2 AND is_deleted = FALSE
|
||
""",
|
||
thread_id, file_name
|
||
)
|
||
|
||
if existing:
|
||
raise ValueError(f"文件 '{file_name}' 已存在于该对话中")
|
||
|
||
# 插入文件记录
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO chat_thread_file
|
||
(thread_id, user_id, file_name, file_path, file_size, file_type, status)
|
||
VALUES ($1, $2, $3, $4, $5, $6, 'processing')
|
||
RETURNING id, thread_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
""",
|
||
thread_id, user_id, file_name, file_path, file_size, file_type
|
||
)
|
||
|
||
logger.info(f"创建文件记录: {file_name}, thread_id: {thread_id}")
|
||
return ChatThreadFile(**dict(row))
|
||
|
||
except ValueError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"创建文件记录失败: {e}")
|
||
raise Exception(f"创建文件记录失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def update_file_status(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
status: str,
|
||
chunk_count: int = 0
|
||
) -> bool:
|
||
"""
|
||
更新文件状态
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
status: 状态(processing/completed/failed)
|
||
chunk_count: 分块数量
|
||
|
||
Returns:
|
||
bool: 是否更新成功
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
UPDATE chat_thread_file
|
||
SET status = $1, chunk_count = $2
|
||
WHERE id = $3
|
||
""",
|
||
status, chunk_count, file_id
|
||
)
|
||
|
||
return result == "UPDATE 1"
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新文件状态失败: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
async def save_chunks(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
thread_id: str,
|
||
chunks: List[Tuple[int, str, dict, str]],
|
||
summary: Optional[str] = None
|
||
) -> int:
|
||
"""
|
||
批量保存文档块
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
thread_id: 会话线程 ID
|
||
chunks: 文档块列表 [(chunk_index, content, metadata, vector_id), ...]
|
||
summary: 文件摘要(可选)
|
||
|
||
Returns:
|
||
int: 保存的块数量
|
||
"""
|
||
try:
|
||
# 批量插入(每个chunk都保存summary,便于独立检索)
|
||
records = [
|
||
(file_id, thread_id, chunk_index, content, json.dumps(metadata), vector_id, summary)
|
||
for chunk_index, content, metadata, vector_id in chunks
|
||
]
|
||
|
||
await conn.executemany(
|
||
"""
|
||
INSERT INTO chat_thread_chunk
|
||
(file_id, thread_id, chunk_index, content, metadata, vector_id, summary)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||
""",
|
||
records
|
||
)
|
||
|
||
logger.info(f"保存 {len(chunks)} 个文档块,文件 ID: {file_id}, 摘要: {'已保存' if summary else '无'}")
|
||
return len(chunks)
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存文档块失败: {e}")
|
||
raise Exception(f"保存文档块失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_file_by_id(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
user_id: int
|
||
) -> Optional[ChatThreadFile]:
|
||
"""
|
||
根据 ID 获取文件
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
user_id: 用户 ID(用于权限验证)
|
||
|
||
Returns:
|
||
Optional[ChatThreadFile]: 文件对象,如果不存在则返回 None
|
||
"""
|
||
try:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
SELECT id, thread_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
FROM chat_thread_file
|
||
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
""",
|
||
file_id, user_id
|
||
)
|
||
|
||
if row:
|
||
return ChatThreadFile(**dict(row))
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件失败: {e}")
|
||
raise Exception(f"获取文件失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_recent_files_with_summary(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
limit: int = 10
|
||
) -> List[dict]:
|
||
"""
|
||
获取会话中最近上传的文件及其摘要(无时间限制)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
limit: 限制返回数量
|
||
|
||
Returns:
|
||
List[dict]: 文件列表,包含摘要信息 [{"file_name": "xxx", "file_type": "png", "summary": "xxx"}, ...]
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT
|
||
f.file_name,
|
||
f.file_type,
|
||
c.summary
|
||
FROM chat_thread_file f
|
||
LEFT JOIN chat_thread_chunk c ON f.id = c.file_id AND c.chunk_index = 0
|
||
WHERE f.thread_id = $1
|
||
AND f.is_deleted = FALSE
|
||
AND f.status = 'completed'
|
||
AND c.summary IS NOT NULL
|
||
AND c.summary != ''
|
||
ORDER BY f.created_at DESC
|
||
LIMIT $2
|
||
""",
|
||
thread_id, limit
|
||
)
|
||
|
||
result = []
|
||
for row in rows:
|
||
result.append({
|
||
"file_name": row['file_name'],
|
||
"file_type": row['file_type'],
|
||
"summary": row['summary']
|
||
})
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件摘要失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def get_files_by_thread(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str,
|
||
user_id: int,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> Tuple[List[ChatThreadFile], int]:
|
||
"""
|
||
获取会话的文件列表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
user_id: 用户 ID(用于权限验证)
|
||
page: 页码(从 1 开始)
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
Tuple[List[ChatThreadFile], int]: (文件列表, 总数量)
|
||
"""
|
||
try:
|
||
# 计算偏移量
|
||
offset = (page - 1) * page_size
|
||
|
||
# 获取总数
|
||
total = await conn.fetchval(
|
||
"""
|
||
SELECT COUNT(*) FROM chat_thread_file
|
||
WHERE thread_id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
""",
|
||
thread_id, user_id
|
||
)
|
||
|
||
# 获取列表
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id, thread_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
FROM chat_thread_file
|
||
WHERE thread_id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
ORDER BY created_at DESC
|
||
LIMIT $3 OFFSET $4
|
||
""",
|
||
thread_id, user_id, page_size, offset
|
||
)
|
||
|
||
files = [ChatThreadFile(**dict(row)) for row in rows]
|
||
return files, total
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件列表失败: {e}")
|
||
raise Exception(f"获取文件列表失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_all_files_by_thread(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str
|
||
) -> List[ChatThreadFile]:
|
||
"""
|
||
获取会话的所有文件(用于删除会话时清理)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
List[ChatThreadFile]: 文件列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id, thread_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
FROM chat_thread_file
|
||
WHERE thread_id = $1 AND is_deleted = FALSE
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
return [ChatThreadFile(**dict(row)) for row in rows]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件列表失败: {e}")
|
||
raise Exception(f"获取文件列表失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_thread_all_vector_ids(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str
|
||
) -> List[str]:
|
||
"""
|
||
获取会话的所有向量 ID(用于删除向量)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
List[str]: 向量 ID 列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT vector_id FROM chat_thread_chunk
|
||
WHERE thread_id = $1 AND vector_id IS NOT NULL
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
return [row['vector_id'] for row in rows if row['vector_id']]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取向量 ID 列表失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def get_file_vector_ids(
|
||
conn: asyncpg.Connection,
|
||
file_id: int
|
||
) -> List[str]:
|
||
"""
|
||
获取文件的所有向量 ID
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
List[str]: 向量 ID 列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT vector_id FROM chat_thread_chunk
|
||
WHERE file_id = $1 AND vector_id IS NOT NULL
|
||
""",
|
||
file_id
|
||
)
|
||
|
||
return [row['vector_id'] for row in rows if row['vector_id']]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取向量 ID 列表失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def get_file_chunks_from_db(
|
||
conn: asyncpg.Connection,
|
||
file_id: int
|
||
) -> List[dict]:
|
||
"""
|
||
从 PostgreSQL 获取文件的所有 chunks(包括 summary)
|
||
用于注入完整内容到 AI 上下文
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
List[dict]: [{"content": str, "summary": str, "chunk_index": int}]
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT chunk_index, content, summary
|
||
FROM chat_thread_chunk
|
||
WHERE file_id = $1
|
||
ORDER BY chunk_index
|
||
""",
|
||
file_id
|
||
)
|
||
|
||
chunks = [
|
||
{
|
||
"chunk_index": row['chunk_index'],
|
||
"content": row['content'],
|
||
"summary": row['summary'] or ''
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
logger.info(f"从数据库获取文件chunks: file_id={file_id}, chunks数量={len(chunks)}")
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"从数据库获取文件chunks失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def delete_file(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
user_id: int
|
||
) -> Tuple[bool, List[str]]:
|
||
"""
|
||
删除文件(软删除),同时返回向量 ID 列表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
user_id: 用户 ID(用于权限验证)
|
||
|
||
Returns:
|
||
Tuple[bool, List[str]]: (是否删除成功, 向量 ID 列表)
|
||
"""
|
||
try:
|
||
# 先获取向量 ID 列表
|
||
vector_ids = await ChatThreadFileService.get_file_vector_ids(conn, file_id)
|
||
|
||
# 检查文件是否存在且属于该用户
|
||
existing = await conn.fetchrow(
|
||
"""
|
||
SELECT id FROM chat_thread_file
|
||
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
""",
|
||
file_id, user_id
|
||
)
|
||
|
||
if not existing:
|
||
return False, []
|
||
|
||
# 软删除文件
|
||
await conn.execute(
|
||
"""
|
||
UPDATE chat_thread_file
|
||
SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP
|
||
WHERE id = $1
|
||
""",
|
||
file_id
|
||
)
|
||
|
||
# 删除文档块(物理删除,因为文件已删除)
|
||
await conn.execute(
|
||
"""
|
||
DELETE FROM chat_thread_chunk
|
||
WHERE file_id = $1
|
||
""",
|
||
file_id
|
||
)
|
||
|
||
logger.info(f"删除文件成功: file_id={file_id}, 向量数={len(vector_ids)}")
|
||
return True, vector_ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除文件失败: {e}")
|
||
raise Exception(f"删除文件失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def delete_thread_all_chunks(
|
||
conn: asyncpg.Connection,
|
||
thread_id: str
|
||
) -> int:
|
||
"""
|
||
删除会话的所有文档块(用于删除会话时清理)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
int: 删除的块数量
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
DELETE FROM chat_thread_chunk
|
||
WHERE thread_id = $1
|
||
""",
|
||
thread_id
|
||
)
|
||
|
||
# 解析删除的行数
|
||
deleted_count = int(result.split()[-1]) if result else 0
|
||
logger.info(f"删除会话 {thread_id} 的 {deleted_count} 个文档块")
|
||
return deleted_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除文档块失败: {e}")
|
||
return 0
|
||
|