huoyan-enterprise/backend/services/knowledge_base_file_service.py

558 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识库文件服务
"""
import os
import json
from typing import Optional, List, Tuple
from pathlib import Path
import asyncpg
from datetime import datetime
from models.knowledge_base_file import KnowledgeBaseFile, KnowledgeBaseChunk
from logger.logging import get_logger
logger = get_logger(__name__)
class KnowledgeBaseFileService:
"""知识库文件服务类"""
@staticmethod
async def create_file_record(
conn: asyncpg.Connection,
knowledge_base_id: int,
user_id: int,
file_name: str,
file_path: str,
file_size: int,
file_type: str = "pdf"
) -> KnowledgeBaseFile:
"""
创建文件记录
Args:
conn: 数据库连接
knowledge_base_id: 知识库 ID
user_id: 用户 ID
file_name: 文件名
file_path: 文件路径
file_size: 文件大小
file_type: 文件类型
Returns:
KnowledgeBaseFile: 创建的文件记录
"""
try:
# 检查文件名是否已存在
existing = await conn.fetchrow(
"""
SELECT id FROM knowledge_base_file
WHERE knowledge_base_id = $1 AND file_name = $2 AND is_deleted = FALSE
""",
knowledge_base_id, file_name
)
if existing:
raise ValueError(f"文件 '{file_name}' 已存在于该知识库中")
# 插入文件记录
row = await conn.fetchrow(
"""
INSERT INTO knowledge_base_file
(knowledge_base_id, user_id, file_name, file_path, file_size, file_type, status)
VALUES ($1, $2, $3, $4, $5, $6, 'processing')
RETURNING id, knowledge_base_id, user_id, file_name, file_path, file_size,
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
""",
knowledge_base_id, user_id, file_name, file_path, file_size, file_type
)
logger.info(f"创建文件记录: {file_name}, 知识库 ID: {knowledge_base_id}")
return KnowledgeBaseFile(**dict(row))
except ValueError:
raise
except Exception as e:
logger.error(f"创建文件记录失败: {e}")
raise Exception(f"创建文件记录失败: {str(e)}")
@staticmethod
async def update_file_status(
conn: asyncpg.Connection,
file_id: int,
status: str,
chunk_count: int = 0
) -> bool:
"""
更新文件状态
Args:
conn: 数据库连接
file_id: 文件 ID
status: 状态processing/completed/failed
chunk_count: 分块数量
Returns:
bool: 是否更新成功
"""
try:
result = await conn.execute(
"""
UPDATE knowledge_base_file
SET status = $1, chunk_count = $2
WHERE id = $3
""",
status, chunk_count, file_id
)
return result == "UPDATE 1"
except Exception as e:
logger.error(f"更新文件状态失败: {e}")
return False
@staticmethod
async def save_chunks(
conn: asyncpg.Connection,
file_id: int,
knowledge_base_id: int,
chunks: List[Tuple[int, str, dict, str]],
summary: Optional[str] = None
) -> int:
"""
批量保存文档块
Args:
conn: 数据库连接
file_id: 文件 ID
knowledge_base_id: 知识库 ID
chunks: 文档块列表 [(chunk_index, content, metadata, vector_id), ...]
summary: 文件摘要(可选)
Returns:
int: 保存的块数量
"""
try:
# 批量插入每个chunk都保存summary便于独立检索
records = [
(file_id, knowledge_base_id, chunk_index, content, json.dumps(metadata), vector_id, summary)
for chunk_index, content, metadata, vector_id in chunks
]
await conn.executemany(
"""
INSERT INTO knowledge_base_chunk
(file_id, knowledge_base_id, chunk_index, content, metadata, vector_id, summary)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
records
)
logger.info(f"保存 {len(chunks)} 个文档块,文件 ID: {file_id}, 摘要: {'已保存' if summary else ''}")
return len(chunks)
except Exception as e:
logger.error(f"保存文档块失败: {e}")
raise Exception(f"保存文档块失败: {str(e)}")
@staticmethod
async def get_file_by_id(
conn: asyncpg.Connection,
file_id: int,
user_id: int,
) -> Optional[KnowledgeBaseFile]:
"""
根据 ID 获取文件(不再强制 user_id 过滤,权限校验由路由层完成)。
user_id 参数保留以兼容旧调用签名,不再作为查询条件。
"""
try:
row = await conn.fetchrow(
"""
SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size,
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
FROM knowledge_base_file
WHERE id = $1 AND is_deleted = FALSE
""",
file_id,
)
if row:
return KnowledgeBaseFile(**dict(row))
return None
except Exception as e:
logger.error(f"获取文件失败: {e}")
return None
@staticmethod
async def get_files_by_kb(
conn: asyncpg.Connection,
knowledge_base_id: int,
user_id: int,
page: int = 1,
page_size: int = 20,
) -> Tuple[List[dict], int]:
"""
获取知识库的文件列表(返回该 KB 内所有文件的原始 dict
KB 访问权限应由调用方(路由层 _check_kb_access提前完成校验。
文件可见性随 KB 可见性走:能看到 KB 即可看到库内全部文件,
使同部门员工互相查看对方上传的文件成为可能。
user_id 参数保留以兼容旧调用签名,不再作为过滤条件。
"""
try:
offset = (page - 1) * page_size
total = await conn.fetchval(
"""
SELECT COUNT(*) FROM knowledge_base_file
WHERE knowledge_base_id = $1 AND is_deleted = FALSE
""",
knowledge_base_id,
)
rows = await conn.fetch(
"""
SELECT f.id, f.knowledge_base_id, f.user_id, f.file_name, f.file_path,
f.file_size, f.file_type, f.status, f.chunk_count,
f.created_at, f.updated_at, f.is_deleted, f.deleted_at,
COALESCE(NULLIF(TRIM(u.display_name),''), u.username) AS uploader_name
FROM knowledge_base_file f
LEFT JOIN user_list u ON u.id = f.user_id
WHERE f.knowledge_base_id = $1 AND f.is_deleted = FALSE
ORDER BY f.created_at DESC
LIMIT $2 OFFSET $3
""",
knowledge_base_id, page_size, offset,
)
return [dict(r) for r in rows], int(total or 0)
except Exception as e:
logger.error(f"获取文件列表失败: {e}")
raise Exception(f"获取文件列表失败: {str(e)}")
@staticmethod
async def get_file_vector_ids(
conn: asyncpg.Connection,
file_id: int
) -> List[str]:
"""
获取文件的所有向量 ID
Args:
conn: 数据库连接
file_id: 文件 ID
Returns:
List[str]: 向量 ID 列表
"""
try:
rows = await conn.fetch(
"""
SELECT vector_id FROM knowledge_base_chunk
WHERE file_id = $1 AND vector_id IS NOT NULL
""",
file_id
)
return [row['vector_id'] for row in rows if row['vector_id']]
except Exception as e:
logger.error(f"获取文件向量 ID 失败: {e}")
return []
@staticmethod
async def get_all_files_by_kb(
conn: asyncpg.Connection,
knowledge_base_id: int
) -> List[KnowledgeBaseFile]:
"""
获取知识库的所有文件(包括已删除的)
Args:
conn: 数据库连接
knowledge_base_id: 知识库 ID
Returns:
List[KnowledgeBaseFile]: 文件列表
"""
try:
rows = await conn.fetch(
"""
SELECT id, knowledge_base_id, user_id, file_name, file_path, file_size,
file_type, status, chunk_count, created_at, updated_at, is_deleted, deleted_at
FROM knowledge_base_file
WHERE knowledge_base_id = $1
""",
knowledge_base_id
)
return [KnowledgeBaseFile(**dict(row)) for row in rows]
except Exception as e:
logger.error(f"获取知识库所有文件失败: {e}")
return []
@staticmethod
async def get_kb_all_vector_ids(
conn: asyncpg.Connection,
knowledge_base_id: int
) -> List[str]:
"""
获取知识库的所有向量 ID
Args:
conn: 数据库连接
knowledge_base_id: 知识库 ID
Returns:
List[str]: 向量 ID 列表
"""
try:
rows = await conn.fetch(
"""
SELECT vector_id FROM knowledge_base_chunk
WHERE knowledge_base_id = $1 AND vector_id IS NOT NULL
""",
knowledge_base_id
)
return [row['vector_id'] for row in rows if row['vector_id']]
except Exception as e:
logger.error(f"获取知识库向量 ID 失败: {e}")
return []
@staticmethod
async def delete_file_chunks(
conn: asyncpg.Connection,
file_id: int
) -> int:
"""
删除文件的所有文档块
Args:
conn: 数据库连接
file_id: 文件 ID
Returns:
int: 删除的块数量
"""
try:
result = await conn.execute(
"""
DELETE FROM knowledge_base_chunk
WHERE file_id = $1
""",
file_id
)
# 解析删除的行数
deleted_count = int(result.split()[-1]) if result.startswith("DELETE") else 0
logger.info(f"删除文件 {file_id}{deleted_count} 个文档块")
return deleted_count
except Exception as e:
logger.error(f"删除文档块失败: {e}")
return 0
@staticmethod
async def delete_kb_all_chunks(
conn: asyncpg.Connection,
knowledge_base_id: int
) -> int:
"""
删除知识库的所有文档块
Args:
conn: 数据库连接
knowledge_base_id: 知识库 ID
Returns:
int: 删除的块数量
"""
try:
result = await conn.execute(
"""
DELETE FROM knowledge_base_chunk
WHERE knowledge_base_id = $1
""",
knowledge_base_id
)
# 解析删除的行数
deleted_count = int(result.split()[-1]) if result.startswith("DELETE") else 0
logger.info(f"删除知识库 {knowledge_base_id}{deleted_count} 个文档块")
return deleted_count
except Exception as e:
logger.error(f"删除知识库文档块失败: {e}")
return 0
@staticmethod
async def get_recent_files_with_summary(
conn: asyncpg.Connection,
knowledge_base_id: int,
limit: int = 5
) -> List[dict]:
"""
获取知识库中最近上传的文件及其摘要(无时间限制)
Args:
conn: 数据库连接
knowledge_base_id: 知识库 ID
limit: 返回文件数量
Returns:
List[dict]: 文件列表 [{"file_name": str, "summary": str}]
"""
try:
rows = await conn.fetch(
"""
SELECT DISTINCT ON (kbf.id)
kbf.id,
kbf.file_name,
kbc.summary
FROM knowledge_base_file kbf
LEFT JOIN knowledge_base_chunk kbc ON kbf.id = kbc.file_id
WHERE kbf.knowledge_base_id = $1
AND kbf.is_deleted = FALSE
AND kbf.status = 'completed'
ORDER BY kbf.id, kbf.created_at DESC
LIMIT $2
""",
knowledge_base_id, limit
)
result = [
{
"file_id": row['id'],
"file_name": row['file_name'],
"summary": row['summary'] or ""
}
for row in rows
]
logger.info(f"获取知识库 {knowledge_base_id}{len(result)} 个文件及摘要(无时间限制)")
return result
except Exception as e:
logger.error(f"获取知识库文件摘要失败: {e}")
return []
@staticmethod
async def get_file_chunks_from_db(
conn: asyncpg.Connection,
file_id: int
) -> List[dict]:
"""
从 PostgreSQL 获取文件的所有 chunks包括 summary
用于注入完整内容到 AI 上下文
Args:
conn: 数据库连接
file_id: 文件 ID
Returns:
List[dict]: [{"content": str, "summary": str, "chunk_index": int}]
"""
try:
rows = await conn.fetch(
"""
SELECT chunk_index, content, summary
FROM knowledge_base_chunk
WHERE file_id = $1
ORDER BY chunk_index
""",
file_id
)
chunks = [
{
"chunk_index": row['chunk_index'],
"content": row['content'],
"summary": row['summary'] or ''
}
for row in rows
]
logger.info(f"从数据库获取知识库文件chunks: file_id={file_id}, chunks数量={len(chunks)}")
return chunks
except Exception as e:
logger.error(f"从数据库获取知识库文件chunks失败: {e}")
return []
@staticmethod
async def delete_file(
conn: asyncpg.Connection,
file_id: int,
user_id: int,
bypass_owner_check: bool = False,
) -> Tuple[bool, List[str]]:
"""
删除文件(软删除),同时删除文件的所有文档块。
bypass_owner_check=True 时跳过 user_id 校验(供 admin/leader 路径使用,
调用方须已在路由层完成 can_delete_file 权限判断)。
"""
try:
if bypass_owner_check:
file_record = await conn.fetchrow(
"""
SELECT id, knowledge_base_id, file_name
FROM knowledge_base_file
WHERE id = $1 AND is_deleted = FALSE
""",
file_id,
)
else:
file_record = await conn.fetchrow(
"""
SELECT id, knowledge_base_id, file_name
FROM knowledge_base_file
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
""",
file_id, user_id,
)
if not file_record:
return False, []
vector_ids = await KnowledgeBaseFileService.get_file_vector_ids(conn, file_id)
deleted_chunks = await KnowledgeBaseFileService.delete_file_chunks(conn, file_id)
if bypass_owner_check:
result = await conn.execute(
"""
UPDATE knowledge_base_file
SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP
WHERE id = $1 AND is_deleted = FALSE
""",
file_id,
)
else:
result = await conn.execute(
"""
UPDATE knowledge_base_file
SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP
WHERE id = $1 AND user_id = $2 AND is_deleted = FALSE
""",
file_id, user_id,
)
if result == "UPDATE 1":
logger.info(
f"删除文件 ID: {file_id}, 文件名: {file_record['file_name']}, "
f"文档块数: {deleted_chunks}, 向量数: {len(vector_ids)}"
)
return True, vector_ids
return False, []
except Exception as e:
logger.error(f"删除文件失败: {e}")
return False, []