361 lines
12 KiB
Python
361 lines
12 KiB
Python
"""
|
||
知识库服务
|
||
"""
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
import asyncpg
|
||
|
||
from core.permissions import can_manage_kb, can_view_kb, get_managed_dept_ids
|
||
from models.knowledge_base import KnowledgeBase, KnowledgeBaseCreate, KnowledgeBaseUpdate
|
||
from models.user import User
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _kb_model_dump(kb: KnowledgeBase) -> Dict[str, Any]:
|
||
return kb.model_dump() if hasattr(kb, "model_dump") else kb.dict()
|
||
|
||
_KB_FIELDS = """
|
||
id, user_id, enterprise_id, department_id, creator_id, visibility,
|
||
name, description, created_at, updated_at, is_deleted, deleted_at
|
||
"""
|
||
|
||
|
||
class KnowledgeBaseService:
|
||
"""知识库服务类"""
|
||
|
||
@staticmethod
|
||
async def enrich_kb_for_response(
|
||
conn: asyncpg.Connection,
|
||
kb: KnowledgeBase,
|
||
viewer: User,
|
||
) -> Dict[str, Any]:
|
||
"""补充创建者、部门名称及是否本人创建,用于 API 返回。"""
|
||
data = _kb_model_dump(kb)
|
||
row = await conn.fetchrow(
|
||
"""
|
||
SELECT u.username AS creator_username,
|
||
COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name,
|
||
d.name AS department_name
|
||
FROM knowledge_base kb
|
||
LEFT JOIN user_list u ON u.id = kb.creator_id
|
||
LEFT JOIN department d ON d.id = kb.department_id
|
||
WHERE kb.id = $1
|
||
""",
|
||
kb.id,
|
||
)
|
||
if row:
|
||
data["creator_username"] = row["creator_username"]
|
||
data["creator_display_name"] = row["creator_display_name"]
|
||
data["department_name"] = row["department_name"]
|
||
else:
|
||
data["creator_username"] = None
|
||
data["creator_display_name"] = None
|
||
data["department_name"] = None
|
||
cid = kb.creator_id
|
||
data["is_mine"] = bool(
|
||
viewer.id is not None
|
||
and (
|
||
(cid is not None and cid == viewer.id)
|
||
or (cid is None and kb.user_id == viewer.id)
|
||
)
|
||
)
|
||
return data
|
||
|
||
@staticmethod
|
||
def _validate_visibility(v: str) -> str:
|
||
if v not in ("private", "department", "enterprise"):
|
||
raise ValueError("visibility 必须是 private、department 或 enterprise")
|
||
return v
|
||
|
||
@staticmethod
|
||
async def create_knowledge_base(
|
||
conn: asyncpg.Connection,
|
||
user: User,
|
||
kb_data: KnowledgeBaseCreate
|
||
) -> KnowledgeBase:
|
||
"""创建知识库(写入企业、部门、创建者与可见性)。"""
|
||
user_id = user.id
|
||
vis = KnowledgeBaseService._validate_visibility(kb_data.visibility)
|
||
enterprise_id = user.enterprise_id
|
||
if enterprise_id is None:
|
||
raise ValueError("用户未关联企业,无法创建知识库")
|
||
|
||
try:
|
||
existing = await conn.fetchrow(
|
||
"""
|
||
SELECT id FROM knowledge_base
|
||
WHERE user_id = $1 AND name = $2 AND is_deleted = FALSE
|
||
""",
|
||
user_id,
|
||
kb_data.name,
|
||
)
|
||
if existing:
|
||
raise ValueError(f"知识库名称 '{kb_data.name}' 已存在")
|
||
|
||
deleted_existing = await conn.fetchrow(
|
||
"""
|
||
SELECT id FROM knowledge_base
|
||
WHERE user_id = $1 AND name = $2 AND is_deleted = TRUE
|
||
""",
|
||
user_id,
|
||
kb_data.name,
|
||
)
|
||
if deleted_existing:
|
||
logger.info(f"发现已删除的同名知识库 ID: {deleted_existing['id']},将彻底删除")
|
||
await conn.execute(
|
||
"""
|
||
DELETE FROM knowledge_base
|
||
WHERE id = $1 AND user_id = $2 AND is_deleted = TRUE
|
||
""",
|
||
deleted_existing["id"],
|
||
user_id,
|
||
)
|
||
|
||
row = await conn.fetchrow(
|
||
f"""
|
||
INSERT INTO knowledge_base (
|
||
user_id, enterprise_id, department_id, creator_id, visibility,
|
||
name, description
|
||
)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||
RETURNING {_KB_FIELDS.strip()}
|
||
""",
|
||
user_id,
|
||
enterprise_id,
|
||
user.department_id,
|
||
user_id,
|
||
vis,
|
||
kb_data.name,
|
||
kb_data.description,
|
||
)
|
||
|
||
logger.info(f"用户 {user_id} 创建知识库: {kb_data.name}")
|
||
return KnowledgeBase(**dict(row))
|
||
except ValueError:
|
||
raise
|
||
except asyncpg.UniqueViolationError as e:
|
||
error_msg = str(e)
|
||
if "uk_user_knowledge_base_name" in error_msg or "user_id" in error_msg.lower():
|
||
deleted_kb = await conn.fetchrow(
|
||
"""
|
||
SELECT id FROM knowledge_base
|
||
WHERE user_id = $1 AND name = $2 AND is_deleted = TRUE
|
||
""",
|
||
user_id,
|
||
kb_data.name,
|
||
)
|
||
if deleted_kb:
|
||
raise ValueError(
|
||
f"知识库名称 '{kb_data.name}' 已被使用(已删除)。"
|
||
f"请先彻底删除已删除的知识库,或使用其他名称。"
|
||
)
|
||
raise ValueError(f"知识库名称 '{kb_data.name}' 已存在")
|
||
logger.error(f"创建知识库时发生唯一约束冲突: {e}")
|
||
raise Exception("创建知识库失败: 唯一约束冲突")
|
||
except Exception as e:
|
||
logger.error(f"创建知识库失败: {e}")
|
||
raise Exception(f"创建知识库失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def fetch_knowledge_base_by_id(
|
||
conn: asyncpg.Connection,
|
||
kb_id: int,
|
||
) -> Optional[KnowledgeBase]:
|
||
"""按主键读取未删除的知识库(不做权限过滤)。"""
|
||
row = await conn.fetchrow(
|
||
f"""
|
||
SELECT {_KB_FIELDS.strip()}
|
||
FROM knowledge_base
|
||
WHERE id = $1 AND is_deleted = FALSE
|
||
""",
|
||
kb_id,
|
||
)
|
||
if row:
|
||
return KnowledgeBase(**dict(row))
|
||
return None
|
||
|
||
@staticmethod
|
||
async def get_knowledge_base_by_id(
|
||
conn: asyncpg.Connection,
|
||
kb_id: int,
|
||
user: User,
|
||
) -> Optional[KnowledgeBase]:
|
||
"""获取知识库(企业版:按可见性与角色过滤)。"""
|
||
kb = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id)
|
||
if kb is None:
|
||
return None
|
||
if not await can_view_kb(conn, user, kb):
|
||
return None
|
||
return kb
|
||
|
||
@staticmethod
|
||
async def list_visible_knowledge_bases(
|
||
conn: asyncpg.Connection,
|
||
user: User,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
) -> Tuple[List[Dict[str, Any]], int]:
|
||
"""列出当前用户可见的知识库(企业版 SQL 过滤),含创建者/部门 JOIN。"""
|
||
offset = (page - 1) * page_size
|
||
enterprise_id = user.enterprise_id
|
||
if enterprise_id is None:
|
||
return [], 0
|
||
|
||
role = user.role or "employee"
|
||
dept_id = user.department_id
|
||
uid = user.id
|
||
|
||
# leader 需要获取本部门及所有子孙部门 ID,用于列表过滤
|
||
managed_dept_ids: List[int] = []
|
||
if role == "leader" and dept_id is not None:
|
||
managed_dept_ids = await get_managed_dept_ids(conn, user)
|
||
|
||
where_sql = """
|
||
kb.is_deleted = FALSE
|
||
AND kb.enterprise_id = $1
|
||
AND (
|
||
$2::text = 'admin'
|
||
OR kb.creator_id = $3
|
||
OR ($2::text = 'leader' AND kb.department_id IS NOT NULL AND kb.department_id = ANY($4::int[]))
|
||
OR (kb.visibility = 'department' AND kb.department_id IS NOT NULL AND kb.department_id = $5)
|
||
OR (kb.visibility = 'enterprise')
|
||
)
|
||
"""
|
||
|
||
total = await conn.fetchval(
|
||
f"""
|
||
SELECT COUNT(*) FROM knowledge_base kb
|
||
WHERE {where_sql}
|
||
""",
|
||
enterprise_id,
|
||
role,
|
||
uid,
|
||
managed_dept_ids,
|
||
dept_id,
|
||
)
|
||
|
||
rows = await conn.fetch(
|
||
f"""
|
||
SELECT kb.id, kb.user_id, kb.enterprise_id, kb.department_id, kb.creator_id, kb.visibility,
|
||
kb.name, kb.description, kb.created_at, kb.updated_at,
|
||
u.username AS creator_username,
|
||
COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name,
|
||
d.name AS department_name
|
||
FROM knowledge_base kb
|
||
LEFT JOIN user_list u ON u.id = kb.creator_id
|
||
LEFT JOIN department d ON d.id = kb.department_id
|
||
WHERE {where_sql}
|
||
ORDER BY kb.created_at DESC
|
||
LIMIT $6 OFFSET $7
|
||
""",
|
||
enterprise_id,
|
||
role,
|
||
uid,
|
||
managed_dept_ids,
|
||
dept_id,
|
||
page_size,
|
||
offset,
|
||
)
|
||
|
||
items: List[Dict[str, Any]] = []
|
||
for r in rows:
|
||
d = dict(r)
|
||
cid = d.get("creator_id")
|
||
d["is_mine"] = bool(
|
||
uid is not None
|
||
and (
|
||
(cid is not None and cid == uid)
|
||
or (cid is None and d.get("user_id") == uid)
|
||
)
|
||
)
|
||
items.append(d)
|
||
return items, int(total or 0)
|
||
|
||
@staticmethod
|
||
async def update_knowledge_base(
|
||
conn: asyncpg.Connection,
|
||
kb_id: int,
|
||
user: User,
|
||
kb_data: KnowledgeBaseUpdate,
|
||
) -> Optional[KnowledgeBase]:
|
||
"""更新知识库(仅创建者或企业管理员)。"""
|
||
existing = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id)
|
||
if existing is None:
|
||
return None
|
||
if not can_manage_kb(user, existing):
|
||
return None
|
||
|
||
update_fields: List[str] = []
|
||
params: List = []
|
||
param_index = 1
|
||
|
||
if kb_data.name is not None:
|
||
conflict = await conn.fetchrow(
|
||
"""
|
||
SELECT id FROM knowledge_base
|
||
WHERE user_id = $1 AND name = $2 AND id != $3 AND is_deleted = FALSE
|
||
""",
|
||
existing.user_id,
|
||
kb_data.name,
|
||
kb_id,
|
||
)
|
||
if conflict:
|
||
raise ValueError(f"知识库名称 '{kb_data.name}' 已存在")
|
||
update_fields.append(f"name = ${param_index}")
|
||
params.append(kb_data.name)
|
||
param_index += 1
|
||
|
||
if kb_data.description is not None:
|
||
update_fields.append(f"description = ${param_index}")
|
||
params.append(kb_data.description)
|
||
param_index += 1
|
||
|
||
if kb_data.visibility is not None:
|
||
KnowledgeBaseService._validate_visibility(kb_data.visibility)
|
||
update_fields.append(f"visibility = ${param_index}")
|
||
params.append(kb_data.visibility)
|
||
param_index += 1
|
||
|
||
if not update_fields:
|
||
return existing
|
||
|
||
params.append(kb_id)
|
||
query = f"""
|
||
UPDATE knowledge_base
|
||
SET {', '.join(update_fields)}
|
||
WHERE id = ${param_index} AND is_deleted = FALSE
|
||
RETURNING {_KB_FIELDS.strip()}
|
||
"""
|
||
row = await conn.fetchrow(query, *params)
|
||
if row:
|
||
logger.info(f"用户 {user.id} 更新知识库 {kb_id}")
|
||
return KnowledgeBase(**dict(row))
|
||
return None
|
||
|
||
@staticmethod
|
||
async def delete_knowledge_base(
|
||
conn: asyncpg.Connection,
|
||
kb_id: int,
|
||
user: User,
|
||
) -> bool:
|
||
"""软删除知识库(仅创建者或企业管理员)。"""
|
||
existing = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id)
|
||
if existing is None:
|
||
return False
|
||
if not can_manage_kb(user, existing):
|
||
return False
|
||
|
||
result = await conn.execute(
|
||
"""
|
||
UPDATE knowledge_base
|
||
SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP
|
||
WHERE id = $1 AND is_deleted = FALSE
|
||
""",
|
||
kb_id,
|
||
)
|
||
if result == "UPDATE 1":
|
||
logger.info(f"用户 {user.id} 删除知识库 {kb_id}")
|
||
return True
|
||
return False
|