huoyan-enterprise/backend/services/knowledge_base_service.py

361 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识库服务
"""
from typing import Any, Dict, List, Optional, Tuple
import asyncpg
from core.permissions import can_manage_kb, can_view_kb, get_managed_dept_ids
from models.knowledge_base import KnowledgeBase, KnowledgeBaseCreate, KnowledgeBaseUpdate
from models.user import User
from logger.logging import get_logger
logger = get_logger(__name__)
def _kb_model_dump(kb: KnowledgeBase) -> Dict[str, Any]:
return kb.model_dump() if hasattr(kb, "model_dump") else kb.dict()
_KB_FIELDS = """
id, user_id, enterprise_id, department_id, creator_id, visibility,
name, description, created_at, updated_at, is_deleted, deleted_at
"""
class KnowledgeBaseService:
"""知识库服务类"""
@staticmethod
async def enrich_kb_for_response(
conn: asyncpg.Connection,
kb: KnowledgeBase,
viewer: User,
) -> Dict[str, Any]:
"""补充创建者、部门名称及是否本人创建,用于 API 返回。"""
data = _kb_model_dump(kb)
row = await conn.fetchrow(
"""
SELECT u.username AS creator_username,
COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name,
d.name AS department_name
FROM knowledge_base kb
LEFT JOIN user_list u ON u.id = kb.creator_id
LEFT JOIN department d ON d.id = kb.department_id
WHERE kb.id = $1
""",
kb.id,
)
if row:
data["creator_username"] = row["creator_username"]
data["creator_display_name"] = row["creator_display_name"]
data["department_name"] = row["department_name"]
else:
data["creator_username"] = None
data["creator_display_name"] = None
data["department_name"] = None
cid = kb.creator_id
data["is_mine"] = bool(
viewer.id is not None
and (
(cid is not None and cid == viewer.id)
or (cid is None and kb.user_id == viewer.id)
)
)
return data
@staticmethod
def _validate_visibility(v: str) -> str:
if v not in ("private", "department", "enterprise"):
raise ValueError("visibility 必须是 private、department 或 enterprise")
return v
@staticmethod
async def create_knowledge_base(
conn: asyncpg.Connection,
user: User,
kb_data: KnowledgeBaseCreate
) -> KnowledgeBase:
"""创建知识库(写入企业、部门、创建者与可见性)。"""
user_id = user.id
vis = KnowledgeBaseService._validate_visibility(kb_data.visibility)
enterprise_id = user.enterprise_id
if enterprise_id is None:
raise ValueError("用户未关联企业,无法创建知识库")
try:
existing = await conn.fetchrow(
"""
SELECT id FROM knowledge_base
WHERE user_id = $1 AND name = $2 AND is_deleted = FALSE
""",
user_id,
kb_data.name,
)
if existing:
raise ValueError(f"知识库名称 '{kb_data.name}' 已存在")
deleted_existing = await conn.fetchrow(
"""
SELECT id FROM knowledge_base
WHERE user_id = $1 AND name = $2 AND is_deleted = TRUE
""",
user_id,
kb_data.name,
)
if deleted_existing:
logger.info(f"发现已删除的同名知识库 ID: {deleted_existing['id']},将彻底删除")
await conn.execute(
"""
DELETE FROM knowledge_base
WHERE id = $1 AND user_id = $2 AND is_deleted = TRUE
""",
deleted_existing["id"],
user_id,
)
row = await conn.fetchrow(
f"""
INSERT INTO knowledge_base (
user_id, enterprise_id, department_id, creator_id, visibility,
name, description
)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING {_KB_FIELDS.strip()}
""",
user_id,
enterprise_id,
user.department_id,
user_id,
vis,
kb_data.name,
kb_data.description,
)
logger.info(f"用户 {user_id} 创建知识库: {kb_data.name}")
return KnowledgeBase(**dict(row))
except ValueError:
raise
except asyncpg.UniqueViolationError as e:
error_msg = str(e)
if "uk_user_knowledge_base_name" in error_msg or "user_id" in error_msg.lower():
deleted_kb = await conn.fetchrow(
"""
SELECT id FROM knowledge_base
WHERE user_id = $1 AND name = $2 AND is_deleted = TRUE
""",
user_id,
kb_data.name,
)
if deleted_kb:
raise ValueError(
f"知识库名称 '{kb_data.name}' 已被使用(已删除)。"
f"请先彻底删除已删除的知识库,或使用其他名称。"
)
raise ValueError(f"知识库名称 '{kb_data.name}' 已存在")
logger.error(f"创建知识库时发生唯一约束冲突: {e}")
raise Exception("创建知识库失败: 唯一约束冲突")
except Exception as e:
logger.error(f"创建知识库失败: {e}")
raise Exception(f"创建知识库失败: {str(e)}")
@staticmethod
async def fetch_knowledge_base_by_id(
conn: asyncpg.Connection,
kb_id: int,
) -> Optional[KnowledgeBase]:
"""按主键读取未删除的知识库(不做权限过滤)。"""
row = await conn.fetchrow(
f"""
SELECT {_KB_FIELDS.strip()}
FROM knowledge_base
WHERE id = $1 AND is_deleted = FALSE
""",
kb_id,
)
if row:
return KnowledgeBase(**dict(row))
return None
@staticmethod
async def get_knowledge_base_by_id(
conn: asyncpg.Connection,
kb_id: int,
user: User,
) -> Optional[KnowledgeBase]:
"""获取知识库(企业版:按可见性与角色过滤)。"""
kb = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id)
if kb is None:
return None
if not await can_view_kb(conn, user, kb):
return None
return kb
@staticmethod
async def list_visible_knowledge_bases(
conn: asyncpg.Connection,
user: User,
page: int = 1,
page_size: int = 20,
) -> Tuple[List[Dict[str, Any]], int]:
"""列出当前用户可见的知识库(企业版 SQL 过滤),含创建者/部门 JOIN。"""
offset = (page - 1) * page_size
enterprise_id = user.enterprise_id
if enterprise_id is None:
return [], 0
role = user.role or "employee"
dept_id = user.department_id
uid = user.id
# leader 需要获取本部门及所有子孙部门 ID用于列表过滤
managed_dept_ids: List[int] = []
if role == "leader" and dept_id is not None:
managed_dept_ids = await get_managed_dept_ids(conn, user)
where_sql = """
kb.is_deleted = FALSE
AND kb.enterprise_id = $1
AND (
$2::text = 'admin'
OR kb.creator_id = $3
OR ($2::text = 'leader' AND kb.department_id IS NOT NULL AND kb.department_id = ANY($4::int[]))
OR (kb.visibility = 'department' AND kb.department_id IS NOT NULL AND kb.department_id = $5)
OR (kb.visibility = 'enterprise')
)
"""
total = await conn.fetchval(
f"""
SELECT COUNT(*) FROM knowledge_base kb
WHERE {where_sql}
""",
enterprise_id,
role,
uid,
managed_dept_ids,
dept_id,
)
rows = await conn.fetch(
f"""
SELECT kb.id, kb.user_id, kb.enterprise_id, kb.department_id, kb.creator_id, kb.visibility,
kb.name, kb.description, kb.created_at, kb.updated_at,
u.username AS creator_username,
COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name,
d.name AS department_name
FROM knowledge_base kb
LEFT JOIN user_list u ON u.id = kb.creator_id
LEFT JOIN department d ON d.id = kb.department_id
WHERE {where_sql}
ORDER BY kb.created_at DESC
LIMIT $6 OFFSET $7
""",
enterprise_id,
role,
uid,
managed_dept_ids,
dept_id,
page_size,
offset,
)
items: List[Dict[str, Any]] = []
for r in rows:
d = dict(r)
cid = d.get("creator_id")
d["is_mine"] = bool(
uid is not None
and (
(cid is not None and cid == uid)
or (cid is None and d.get("user_id") == uid)
)
)
items.append(d)
return items, int(total or 0)
@staticmethod
async def update_knowledge_base(
conn: asyncpg.Connection,
kb_id: int,
user: User,
kb_data: KnowledgeBaseUpdate,
) -> Optional[KnowledgeBase]:
"""更新知识库(仅创建者或企业管理员)。"""
existing = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id)
if existing is None:
return None
if not can_manage_kb(user, existing):
return None
update_fields: List[str] = []
params: List = []
param_index = 1
if kb_data.name is not None:
conflict = await conn.fetchrow(
"""
SELECT id FROM knowledge_base
WHERE user_id = $1 AND name = $2 AND id != $3 AND is_deleted = FALSE
""",
existing.user_id,
kb_data.name,
kb_id,
)
if conflict:
raise ValueError(f"知识库名称 '{kb_data.name}' 已存在")
update_fields.append(f"name = ${param_index}")
params.append(kb_data.name)
param_index += 1
if kb_data.description is not None:
update_fields.append(f"description = ${param_index}")
params.append(kb_data.description)
param_index += 1
if kb_data.visibility is not None:
KnowledgeBaseService._validate_visibility(kb_data.visibility)
update_fields.append(f"visibility = ${param_index}")
params.append(kb_data.visibility)
param_index += 1
if not update_fields:
return existing
params.append(kb_id)
query = f"""
UPDATE knowledge_base
SET {', '.join(update_fields)}
WHERE id = ${param_index} AND is_deleted = FALSE
RETURNING {_KB_FIELDS.strip()}
"""
row = await conn.fetchrow(query, *params)
if row:
logger.info(f"用户 {user.id} 更新知识库 {kb_id}")
return KnowledgeBase(**dict(row))
return None
@staticmethod
async def delete_knowledge_base(
conn: asyncpg.Connection,
kb_id: int,
user: User,
) -> bool:
"""软删除知识库(仅创建者或企业管理员)。"""
existing = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id)
if existing is None:
return False
if not can_manage_kb(user, existing):
return False
result = await conn.execute(
"""
UPDATE knowledge_base
SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP
WHERE id = $1 AND is_deleted = FALSE
""",
kb_id,
)
if result == "UPDATE 1":
logger.info(f"用户 {user.id} 删除知识库 {kb_id}")
return True
return False