huoyan-enterprise/backend/services/knowledge_graph_service.py

195 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识图谱元数据:列表/详情与知识库一致的可见性与 RBAC。
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import asyncpg
from core.graph_metadata import graph_table_sql
from core.permissions import can_manage_graph, can_view_graph, get_managed_dept_ids
from models.graph_metadata import GraphRecord
from models.user import User
from logger.logging import get_logger
logger = get_logger(__name__)
class KnowledgeGraphService:
@staticmethod
def _validate_visibility(v: str) -> str:
if v not in ("private", "department", "enterprise"):
raise ValueError("visibility 必须是 private、department 或 enterprise")
return v
@staticmethod
def _row_to_graph_record(row: Dict[str, Any]) -> GraphRecord:
return GraphRecord(
id=int(row["id"]),
user_id=int(row["user_id"]),
enterprise_id=row.get("enterprise_id"),
department_id=row.get("department_id"),
creator_id=row.get("creator_id"),
visibility=(row.get("visibility") or "private"),
)
@staticmethod
async def enrich_graph_for_response(
conn: asyncpg.Connection,
raw: Dict[str, Any],
viewer: User,
) -> Dict[str, Any]:
"""补充创建者、部门、是否本人、是否可管理。"""
data = dict(raw)
t = graph_table_sql()
gid = raw.get("id")
row = await conn.fetchrow(
f"""
SELECT u.username AS creator_username,
COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name,
d.name AS department_name
FROM {t} g
LEFT JOIN user_list u ON u.id = g.creator_id
LEFT JOIN department d ON d.id = g.department_id
WHERE g.id = $1
""",
gid,
)
if row:
data["creator_username"] = row["creator_username"]
data["creator_display_name"] = row["creator_display_name"]
data["department_name"] = row["department_name"]
else:
data["creator_username"] = None
data["creator_display_name"] = None
data["department_name"] = None
gr = KnowledgeGraphService._row_to_graph_record(data)
cid = gr.creator_id
uid = viewer.id
data["is_mine"] = bool(
uid is not None
and (
(cid is not None and cid == uid)
or (cid is None and int(data.get("user_id") or 0) == uid)
)
)
data["can_manage"] = can_manage_graph(viewer, gr)
return data
@staticmethod
async def list_visible_graphs(
conn: asyncpg.Connection,
user: User,
page: int = 1,
page_size: int = 20,
) -> Tuple[List[Dict[str, Any]], int]:
t = graph_table_sql()
enterprise_id = user.enterprise_id
if enterprise_id is None:
return [], 0
offset = (page - 1) * page_size
role = user.role or "employee"
dept_id = user.department_id
uid = user.id
# leader 需要获取本部门及所有子孙部门 ID与知识库列表保持一致
managed_dept_ids: List[int] = []
if role == "leader" and dept_id is not None:
managed_dept_ids = await get_managed_dept_ids(conn, user)
where_sql = """
g.enterprise_id = $1
AND (
$2::text = 'admin'
OR g.creator_id = $3
OR ($2::text = 'leader' AND g.department_id IS NOT NULL AND g.department_id = ANY($4::int[]))
OR (g.visibility = 'department' AND g.department_id IS NOT NULL AND g.department_id = $5)
OR (g.visibility = 'enterprise')
)
"""
total = await conn.fetchval(
f"""
SELECT COUNT(*) FROM {t} g
WHERE {where_sql}
""",
enterprise_id,
role,
uid,
managed_dept_ids,
dept_id,
)
rows = await conn.fetch(
f"""
SELECT g.id, g.user_id, g.enterprise_id, g.department_id, g.creator_id, g.visibility,
g.name, g.description, g.csv_file_name, g.node_count, g.edge_count, g.neo4j_graph_id,
g.graph_type, g.build_status, g.build_error, g.rag_chunk_count,
g.created_at, g.updated_at,
u.username AS creator_username,
COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name,
d.name AS department_name
FROM {t} g
LEFT JOIN user_list u ON u.id = g.creator_id
LEFT JOIN department d ON d.id = g.department_id
WHERE {where_sql}
ORDER BY g.created_at DESC
LIMIT $6 OFFSET $7
""",
enterprise_id,
role,
uid,
managed_dept_ids,
dept_id,
page_size,
offset,
)
items: List[Dict[str, Any]] = []
for r in rows:
d = dict(r)
gr = KnowledgeGraphService._row_to_graph_record(d)
cid = gr.creator_id
d["is_mine"] = bool(
uid is not None
and (
(cid is not None and cid == uid)
or (cid is None and d.get("user_id") == uid)
)
)
d["can_manage"] = can_manage_graph(user, gr)
items.append(d)
return items, int(total or 0)
@staticmethod
async def fetch_graph_by_id(conn: asyncpg.Connection, graph_pk: int) -> Optional[Dict[str, Any]]:
t = graph_table_sql()
row = await conn.fetchrow(
f"""
SELECT * FROM {t}
WHERE id = $1
""",
graph_pk,
)
return dict(row) if row else None
@staticmethod
async def get_graph_for_viewer(
conn: asyncpg.Connection,
graph_pk: int,
user: User,
) -> Optional[Dict[str, Any]]:
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_pk)
if raw is None:
return None
try:
gr = KnowledgeGraphService._row_to_graph_record(raw)
except Exception:
return None
if not await can_view_graph(conn, user, gr):
return None
return await KnowledgeGraphService.enrich_graph_for_response(conn, raw, user)