559 lines
18 KiB
Python
559 lines
18 KiB
Python
"""
|
||
知识库文件服务
|
||
"""
|
||
import os
|
||
import json
|
||
from typing import Optional, List, Tuple
|
||
from pathlib import Path
|
||
import asyncpg
|
||
from datetime import datetime
|
||
|
||
from models.knowledge_base_file import KnowledgeBaseFile, KnowledgeBaseChunk
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class KnowledgeBaseFileService:
|
||
"""知识库文件服务类"""
|
||
|
||
@staticmethod
|
||
async def create_file_record(
|
||
conn: asyncpg.Connection,
|
||
knowledge_base_id: int,
|
||
user_id: int,
|
||
file_name: str,
|
||
file_path: str,
|
||
file_size: int,
|
||
file_type: str = "pdf"
|
||
) -> KnowledgeBaseFile:
|
||
"""
|
||
创建文件记录
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
knowledge_base_id: 知识库 ID
|
||
user_id: 用户 ID
|
||
file_name: 文件名
|
||
file_path: 文件路径
|
||
file_size: 文件大小
|
||
file_type: 文件类型
|
||
|
||
Returns:
|
||
KnowledgeBaseFile: 创建的文件记录
|
||
"""
|
||
try:
|
||
# 检查文件名是否已存在
|
||
existing = await conn.fetchrow(
|
||
"""
|
||
SELECT id FROM knowledge_base_file
|
||
WHERE knowledge_base_id = $1 AND file_name = $2 AND is_deleted = FALSE
|
||
""",
|
||
knowledge_base_id, file_name
|
||
)
|
||
|
||
if existing:
|
||
raise ValueError(f"文件 '{file_name}' 已存在于该知识库中")
|
||
|
||
# 插入文件记录
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO knowledge_base_file
|
||
(knowledge_base_id, user_id, file_name, file_path, file_size, file_type, status)
|
||
VALUES ($1, $2, $3, $4, $5, $6, 'processing')
|
||
RETURNING id, knowledge_base_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
""",
|
||
knowledge_base_id, user_id, file_name, file_path, file_size, file_type
|
||
)
|
||
|
||
logger.info(f"创建文件记录: {file_name}, 知识库 ID: {knowledge_base_id}")
|
||
return KnowledgeBaseFile(**dict(row))
|
||
|
||
except ValueError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"创建文件记录失败: {e}")
|
||
raise Exception(f"创建文件记录失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def update_file_status(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
status: str,
|
||
chunk_count: int = 0
|
||
) -> bool:
|
||
"""
|
||
更新文件状态
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
status: 状态(processing/completed/failed)
|
||
chunk_count: 分块数量
|
||
|
||
Returns:
|
||
bool: 是否更新成功
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
UPDATE knowledge_base_file
|
||
SET status = $1, chunk_count = $2
|
||
WHERE id = $3
|
||
""",
|
||
status, chunk_count, file_id
|
||
)
|
||
|
||
return result == "UPDATE 1"
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新文件状态失败: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
async def save_chunks(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
knowledge_base_id: int,
|
||
chunks: List[Tuple[int, str, dict, str]],
|
||
summary: Optional[str] = None
|
||
) -> int:
|
||
"""
|
||
批量保存文档块
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
knowledge_base_id: 知识库 ID
|
||
chunks: 文档块列表 [(chunk_index, content, metadata, vector_id), ...]
|
||
summary: 文件摘要(可选)
|
||
|
||
Returns:
|
||
int: 保存的块数量
|
||
"""
|
||
try:
|
||
# 批量插入(每个chunk都保存summary,便于独立检索)
|
||
records = [
|
||
(file_id, knowledge_base_id, chunk_index, content, json.dumps(metadata), vector_id, summary)
|
||
for chunk_index, content, metadata, vector_id in chunks
|
||
]
|
||
|
||
await conn.executemany(
|
||
"""
|
||
INSERT INTO knowledge_base_chunk
|
||
(file_id, knowledge_base_id, chunk_index, content, metadata, vector_id, summary)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||
""",
|
||
records
|
||
)
|
||
|
||
logger.info(f"保存 {len(chunks)} 个文档块,文件 ID: {file_id}, 摘要: {'已保存' if summary else '无'}")
|
||
return len(chunks)
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存文档块失败: {e}")
|
||
raise Exception(f"保存文档块失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_file_by_id(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
user_id: int
|
||
) -> Optional[KnowledgeBaseFile]:
|
||
"""
|
||
根据 ID 获取文件
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
Optional[KnowledgeBaseFile]: 文件对象
|
||
"""
|
||
try:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
FROM knowledge_base_file
|
||
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
""",
|
||
file_id, user_id
|
||
)
|
||
|
||
if row:
|
||
return KnowledgeBaseFile(**dict(row))
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件失败: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
async def get_files_by_kb(
|
||
conn: asyncpg.Connection,
|
||
knowledge_base_id: int,
|
||
user_id: int,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> Tuple[List[KnowledgeBaseFile], int]:
|
||
"""
|
||
获取知识库的文件列表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
knowledge_base_id: 知识库 ID
|
||
user_id: 用户 ID
|
||
page: 页码
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
Tuple[List[KnowledgeBaseFile], int]: (文件列表, 总数量)
|
||
"""
|
||
try:
|
||
offset = (page - 1) * page_size
|
||
|
||
# 获取总数
|
||
total = await conn.fetchval(
|
||
"""
|
||
SELECT COUNT(*) FROM knowledge_base_file
|
||
WHERE knowledge_base_id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
""",
|
||
knowledge_base_id, user_id
|
||
)
|
||
|
||
# 获取列表
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
FROM knowledge_base_file
|
||
WHERE knowledge_base_id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
ORDER BY created_at DESC
|
||
LIMIT $3 OFFSET $4
|
||
""",
|
||
knowledge_base_id, user_id, page_size, offset
|
||
)
|
||
|
||
files = [KnowledgeBaseFile(**dict(row)) for row in rows]
|
||
return files, total
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件列表失败: {e}")
|
||
raise Exception(f"获取文件列表失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_file_vector_ids(
|
||
conn: asyncpg.Connection,
|
||
file_id: int
|
||
) -> List[str]:
|
||
"""
|
||
获取文件的所有向量 ID
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
List[str]: 向量 ID 列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT vector_id FROM knowledge_base_chunk
|
||
WHERE file_id = $1 AND vector_id IS NOT NULL
|
||
""",
|
||
file_id
|
||
)
|
||
|
||
return [row['vector_id'] for row in rows if row['vector_id']]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件向量 ID 失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def get_all_files_by_kb(
|
||
conn: asyncpg.Connection,
|
||
knowledge_base_id: int
|
||
) -> List[KnowledgeBaseFile]:
|
||
"""
|
||
获取知识库的所有文件(包括已删除的)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
List[KnowledgeBaseFile]: 文件列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size,
|
||
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
|
||
FROM knowledge_base_file
|
||
WHERE knowledge_base_id = $1
|
||
""",
|
||
knowledge_base_id
|
||
)
|
||
|
||
return [KnowledgeBaseFile(**dict(row)) for row in rows]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取知识库所有文件失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def get_kb_all_vector_ids(
|
||
conn: asyncpg.Connection,
|
||
knowledge_base_id: int
|
||
) -> List[str]:
|
||
"""
|
||
获取知识库的所有向量 ID
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
List[str]: 向量 ID 列表
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT vector_id FROM knowledge_base_chunk
|
||
WHERE knowledge_base_id = $1 AND vector_id IS NOT NULL
|
||
""",
|
||
knowledge_base_id
|
||
)
|
||
|
||
return [row['vector_id'] for row in rows if row['vector_id']]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取知识库向量 ID 失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def delete_file_chunks(
|
||
conn: asyncpg.Connection,
|
||
file_id: int
|
||
) -> int:
|
||
"""
|
||
删除文件的所有文档块
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
int: 删除的块数量
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
DELETE FROM knowledge_base_chunk
|
||
WHERE file_id = $1
|
||
""",
|
||
file_id
|
||
)
|
||
|
||
# 解析删除的行数
|
||
deleted_count = int(result.split()[-1]) if result.startswith("DELETE") else 0
|
||
logger.info(f"删除文件 {file_id} 的 {deleted_count} 个文档块")
|
||
return deleted_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除文档块失败: {e}")
|
||
return 0
|
||
|
||
@staticmethod
|
||
async def delete_kb_all_chunks(
|
||
conn: asyncpg.Connection,
|
||
knowledge_base_id: int
|
||
) -> int:
|
||
"""
|
||
删除知识库的所有文档块
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
int: 删除的块数量
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
DELETE FROM knowledge_base_chunk
|
||
WHERE knowledge_base_id = $1
|
||
""",
|
||
knowledge_base_id
|
||
)
|
||
|
||
# 解析删除的行数
|
||
deleted_count = int(result.split()[-1]) if result.startswith("DELETE") else 0
|
||
logger.info(f"删除知识库 {knowledge_base_id} 的 {deleted_count} 个文档块")
|
||
return deleted_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除知识库文档块失败: {e}")
|
||
return 0
|
||
|
||
@staticmethod
|
||
async def get_recent_files_with_summary(
|
||
conn: asyncpg.Connection,
|
||
knowledge_base_id: int,
|
||
limit: int = 5
|
||
) -> List[dict]:
|
||
"""
|
||
获取知识库中最近上传的文件及其摘要(无时间限制)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
knowledge_base_id: 知识库 ID
|
||
limit: 返回文件数量
|
||
|
||
Returns:
|
||
List[dict]: 文件列表 [{"file_name": str, "summary": str}]
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT DISTINCT ON (kbf.id)
|
||
kbf.id,
|
||
kbf.file_name,
|
||
kbc.summary
|
||
FROM knowledge_base_file kbf
|
||
LEFT JOIN knowledge_base_chunk kbc ON kbf.id = kbc.file_id
|
||
WHERE kbf.knowledge_base_id = $1
|
||
AND kbf.is_deleted = FALSE
|
||
AND kbf.status = 'completed'
|
||
ORDER BY kbf.id, kbf.created_at DESC
|
||
LIMIT $2
|
||
""",
|
||
knowledge_base_id, limit
|
||
)
|
||
|
||
result = [
|
||
{
|
||
"file_id": row['id'],
|
||
"file_name": row['file_name'],
|
||
"summary": row['summary'] or ""
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
logger.info(f"获取知识库 {knowledge_base_id} 的 {len(result)} 个文件及摘要(无时间限制)")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取知识库文件摘要失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def get_file_chunks_from_db(
|
||
conn: asyncpg.Connection,
|
||
file_id: int
|
||
) -> List[dict]:
|
||
"""
|
||
从 PostgreSQL 获取文件的所有 chunks(包括 summary)
|
||
用于注入完整内容到 AI 上下文
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
List[dict]: [{"content": str, "summary": str, "chunk_index": int}]
|
||
"""
|
||
try:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT chunk_index, content, summary
|
||
FROM knowledge_base_chunk
|
||
WHERE file_id = $1
|
||
ORDER BY chunk_index
|
||
""",
|
||
file_id
|
||
)
|
||
|
||
chunks = [
|
||
{
|
||
"chunk_index": row['chunk_index'],
|
||
"content": row['content'],
|
||
"summary": row['summary'] or ''
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
logger.info(f"从数据库获取知识库文件chunks: file_id={file_id}, chunks数量={len(chunks)}")
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"从数据库获取知识库文件chunks失败: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
async def delete_file(
|
||
conn: asyncpg.Connection,
|
||
file_id: int,
|
||
user_id: int
|
||
) -> Tuple[bool, List[str]]:
|
||
"""
|
||
删除文件(软删除)
|
||
同时删除文件的所有文档块
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
file_id: 文件 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
Tuple[bool, List[str]]: (是否删除成功, 向量 ID 列表)
|
||
"""
|
||
try:
|
||
# 先检查文件是否存在且属于该用户
|
||
file_record = await conn.fetchrow(
|
||
"""
|
||
SELECT id, knowledge_base_id, file_name
|
||
FROM knowledge_base_file
|
||
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
""",
|
||
file_id, user_id
|
||
)
|
||
|
||
if not file_record:
|
||
return False, []
|
||
|
||
# 获取文件的向量 ID 列表(在删除 chunk 之前获取)
|
||
vector_ids = await KnowledgeBaseFileService.get_file_vector_ids(conn, file_id)
|
||
|
||
# 删除文件的所有文档块(物理删除)
|
||
deleted_chunks = await KnowledgeBaseFileService.delete_file_chunks(conn, file_id)
|
||
|
||
# 执行软删除文件记录
|
||
result = await conn.execute(
|
||
"""
|
||
UPDATE knowledge_base_file
|
||
SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP
|
||
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
|
||
""",
|
||
file_id, user_id
|
||
)
|
||
|
||
if result == "UPDATE 1":
|
||
logger.info(
|
||
f"删除文件 ID: {file_id}, 文件名: {file_record['file_name']}, "
|
||
f"文档块数: {deleted_chunks}, 向量数: {len(vector_ids)}"
|
||
)
|
||
return True, vector_ids
|
||
return False, []
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除文件失败: {e}")
|
||
return False, []
|
||
|