350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""
|
||
Neo4j:知识图谱(:Person + RELATION,按 graph_id 隔离)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from neo4j import GraphDatabase
|
||
|
||
from core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _get_driver():
|
||
"""创建并返回 Neo4j 驱动"""
|
||
return GraphDatabase.driver(
|
||
settings.neo4j_uri,
|
||
auth=(settings.neo4j_user, settings.neo4j_password),
|
||
)
|
||
|
||
|
||
def check_neo4j_health() -> dict[str, Any]:
|
||
"""检查 Neo4j 连接状态"""
|
||
try:
|
||
driver = _get_driver()
|
||
driver.verify_connectivity()
|
||
with driver.session() as session:
|
||
person_n = session.run("MATCH (n:Person) RETURN count(n) AS c").single()["c"]
|
||
driver.close()
|
||
return {"status": "ok", "person_nodes": int(person_n)}
|
||
except Exception as e:
|
||
logger.warning("Neo4j health check failed: {}", e)
|
||
return {"status": "degraded", "error": str(e)}
|
||
|
||
|
||
# ----- 知识图谱(文本抽取):Person 节点 + RELATION 边(按 graph_id 隔离) -----
|
||
|
||
|
||
def _import_knowledge_graph_batch(tx, batch: list[dict], graph_id: str):
|
||
tx.run(
|
||
"""
|
||
UNWIND $rows AS row
|
||
MERGE (a:Person {name: row.subject, graph_id: $graph_id})
|
||
MERGE (b:Person {name: row.object, graph_id: $graph_id})
|
||
MERGE (a)-[r:RELATION {type: row.relation_type, note: row.note, graph_id: $graph_id}]->(b)
|
||
""",
|
||
rows=batch,
|
||
graph_id=graph_id,
|
||
)
|
||
|
||
|
||
def import_knowledge_graph_triplets(rows: list[dict], graph_id: str) -> dict[str, Any]:
|
||
"""
|
||
将实体关系三元组导入 Neo4j。每行需含 subject, relation_type, object,可选 note。
|
||
若无可导入三元组,仍会清空该 graph_id 下旧数据并返回 0 节点/边(便于「仅向量检索」类资料)。
|
||
"""
|
||
norm: list[dict] = []
|
||
for r in rows or []:
|
||
s = (r.get("subject") or "").strip()
|
||
o = (r.get("object") or "").strip()
|
||
rel = (r.get("relation_type") or r.get("relation") or "").strip() or "相关"
|
||
note = (r.get("note") or "").strip()
|
||
if not s or not o or s == o:
|
||
continue
|
||
norm.append({"subject": s, "object": o, "relation_type": rel[:120], "note": note[:500]})
|
||
|
||
driver = _get_driver()
|
||
driver.verify_connectivity()
|
||
batch_size = 80
|
||
try:
|
||
with driver.session() as session:
|
||
session.run(
|
||
"MATCH (n:Person {graph_id: $graph_id}) DETACH DELETE n",
|
||
graph_id=graph_id,
|
||
)
|
||
if not norm:
|
||
return {
|
||
"graph_id": graph_id,
|
||
"node_count": 0,
|
||
"edge_count": 0,
|
||
"rows": 0,
|
||
}
|
||
for i in range(0, len(norm), batch_size):
|
||
batch = norm[i : i + batch_size]
|
||
session.execute_write(_import_knowledge_graph_batch, batch, graph_id)
|
||
|
||
node_count = session.run(
|
||
"MATCH (n:Person {graph_id: $graph_id}) RETURN count(n) AS c",
|
||
graph_id=graph_id,
|
||
).single()["c"]
|
||
edge_count = session.run(
|
||
"MATCH ()-[r:RELATION {graph_id: $graph_id}]->() RETURN count(r) AS c",
|
||
graph_id=graph_id,
|
||
).single()["c"]
|
||
finally:
|
||
driver.close()
|
||
|
||
return {
|
||
"graph_id": graph_id,
|
||
"node_count": int(node_count),
|
||
"edge_count": int(edge_count),
|
||
"rows": len(norm),
|
||
}
|
||
|
||
|
||
def delete_knowledge_graph(graph_id: str) -> None:
|
||
driver = _get_driver()
|
||
try:
|
||
with driver.session() as session:
|
||
session.run(
|
||
"MATCH (n:Person {graph_id: $graph_id}) DETACH DELETE n",
|
||
graph_id=graph_id,
|
||
)
|
||
finally:
|
||
driver.close()
|
||
|
||
|
||
def _knowledge_graph_node_color() -> str:
|
||
return "#5BB5A2"
|
||
|
||
|
||
def get_knowledge_graph_data(graph_id: str, limit: int = 200) -> list[dict]:
|
||
driver = _get_driver()
|
||
elements: list[dict] = []
|
||
seen_nodes: set[str] = set()
|
||
seen_edges: set[tuple[str, str]] = set()
|
||
color = _knowledge_graph_node_color()
|
||
|
||
try:
|
||
with driver.session() as session:
|
||
result = session.run(
|
||
"""
|
||
MATCH (a:Person {graph_id: $graph_id})-[r:RELATION {graph_id: $graph_id}]->(b:Person)
|
||
RETURN a, r, b
|
||
LIMIT $limit
|
||
""",
|
||
graph_id=graph_id,
|
||
limit=min(limit * 3, 1000),
|
||
)
|
||
for record in result:
|
||
a, rel, b = record["a"], record["r"], record["b"]
|
||
aid, bid = a["name"], b["name"]
|
||
for nid in [aid, bid]:
|
||
if nid not in seen_nodes:
|
||
seen_nodes.add(nid)
|
||
elements.append({
|
||
"data": {
|
||
"id": nid,
|
||
"label": nid,
|
||
"name": nid,
|
||
"color": color,
|
||
"degree": 0,
|
||
}
|
||
})
|
||
edge_key = (aid, bid)
|
||
if edge_key not in seen_edges:
|
||
seen_edges.add(edge_key)
|
||
elements.append({
|
||
"data": {
|
||
"id": f"{aid}->{bid}",
|
||
"source": aid,
|
||
"target": bid,
|
||
"label": (rel.get("type") or "")[:100],
|
||
"type": rel.get("type", ""),
|
||
"note": rel.get("note", ""),
|
||
}
|
||
})
|
||
|
||
if seen_nodes:
|
||
degree_result = session.run(
|
||
"""
|
||
MATCH (s:Person {graph_id: $graph_id})
|
||
WHERE s.name IN $names
|
||
OPTIONAL MATCH (s)-[r:RELATION {graph_id: $graph_id}]-()
|
||
WITH s.name AS name, count(r) AS degree
|
||
RETURN name, degree
|
||
""",
|
||
graph_id=graph_id,
|
||
names=list(seen_nodes),
|
||
)
|
||
degree_map = {r["name"]: r["degree"] for r in degree_result}
|
||
for el in elements:
|
||
if "source" not in el["data"] and el["data"]["id"] in degree_map:
|
||
el["data"]["degree"] = degree_map[el["data"]["id"]]
|
||
finally:
|
||
driver.close()
|
||
|
||
return elements
|
||
|
||
|
||
def search_knowledge_graph(graph_id: str, keyword: str, hops: int = 1) -> dict[str, Any]:
|
||
driver = _get_driver()
|
||
elements: list[dict] = []
|
||
seen_nodes: set[str] = set()
|
||
seen_edges: set[tuple[str, str]] = set()
|
||
color = _knowledge_graph_node_color()
|
||
|
||
try:
|
||
with driver.session() as session:
|
||
result = session.run(
|
||
"""
|
||
MATCH (n:Person {graph_id: $graph_id})
|
||
WHERE toLower(n.name) CONTAINS toLower($keyword)
|
||
RETURN n.name AS name
|
||
LIMIT 20
|
||
""",
|
||
graph_id=graph_id,
|
||
keyword=keyword.strip(),
|
||
)
|
||
seed_names = [r["name"] for r in result if r["name"]]
|
||
|
||
if not seed_names:
|
||
return {"elements": [], "seeds": [], "message": "未找到匹配实体"}
|
||
|
||
result = session.run(
|
||
f"""
|
||
MATCH path = (start:Person {{graph_id: $graph_id}})-[:RELATION*1..{hops}]-(end:Person {{graph_id: $graph_id}})
|
||
WHERE start.name IN $seeds
|
||
UNWIND relationships(path) AS rel
|
||
WITH startNode(rel) AS a, endNode(rel) AS b, rel
|
||
WHERE a.graph_id = $graph_id AND b.graph_id = $graph_id
|
||
RETURN a, rel, b
|
||
LIMIT 500
|
||
""",
|
||
graph_id=graph_id,
|
||
seeds=seed_names,
|
||
)
|
||
for record in result:
|
||
a, rel, b = record["a"], record["rel"], record["b"]
|
||
aid, bid = a["name"], b["name"]
|
||
for nid in [aid, bid]:
|
||
if nid not in seen_nodes:
|
||
seen_nodes.add(nid)
|
||
elements.append({
|
||
"data": {
|
||
"id": nid,
|
||
"label": nid,
|
||
"name": nid,
|
||
"color": color,
|
||
"degree": 0,
|
||
}
|
||
})
|
||
edge_key = (aid, bid)
|
||
if edge_key not in seen_edges:
|
||
seen_edges.add(edge_key)
|
||
elements.append({
|
||
"data": {
|
||
"id": f"{aid}->{bid}",
|
||
"source": aid,
|
||
"target": bid,
|
||
"label": (rel.get("type") or "")[:100],
|
||
"type": rel.get("type", ""),
|
||
"note": rel.get("note", ""),
|
||
}
|
||
})
|
||
|
||
if seen_nodes:
|
||
degree_result = session.run(
|
||
"""
|
||
MATCH (s:Person {graph_id: $graph_id})
|
||
WHERE s.name IN $names
|
||
OPTIONAL MATCH (s)-[r:RELATION {graph_id: $graph_id}]-()
|
||
WITH s.name AS name, count(r) AS degree
|
||
RETURN name, degree
|
||
""",
|
||
graph_id=graph_id,
|
||
names=list(seen_nodes),
|
||
)
|
||
degree_map = {r["name"]: r["degree"] for r in degree_result}
|
||
for el in elements:
|
||
if "source" not in el["data"] and el["data"]["id"] in degree_map:
|
||
el["data"]["degree"] = degree_map[el["data"]["id"]]
|
||
finally:
|
||
driver.close()
|
||
|
||
return {"elements": elements, "seeds": seed_names}
|
||
|
||
|
||
def expand_knowledge_graph_node(graph_id: str, node_name: str, hops: int = 1) -> list[dict]:
|
||
driver = _get_driver()
|
||
elements: list[dict] = []
|
||
seen_nodes: set[str] = set()
|
||
seen_edges: set[tuple[str, str]] = set()
|
||
color = _knowledge_graph_node_color()
|
||
|
||
try:
|
||
with driver.session() as session:
|
||
result = session.run(
|
||
f"""
|
||
MATCH path = (start:Person {{name: $node, graph_id: $graph_id}})-[:RELATION*1..{hops}]-(end:Person {{graph_id: $graph_id}})
|
||
UNWIND relationships(path) AS rel
|
||
WITH startNode(rel) AS a, endNode(rel) AS b, rel
|
||
RETURN a, rel, b
|
||
LIMIT 300
|
||
""",
|
||
node=node_name.strip(),
|
||
graph_id=graph_id,
|
||
)
|
||
for record in result:
|
||
a, rel, b = record["a"], record["rel"], record["b"]
|
||
aid, bid = a["name"], b["name"]
|
||
for nid in [aid, bid]:
|
||
if nid not in seen_nodes:
|
||
seen_nodes.add(nid)
|
||
elements.append({
|
||
"data": {
|
||
"id": nid,
|
||
"label": nid,
|
||
"name": nid,
|
||
"color": color,
|
||
"degree": 0,
|
||
}
|
||
})
|
||
edge_key = (aid, bid)
|
||
if edge_key not in seen_edges:
|
||
seen_edges.add(edge_key)
|
||
elements.append({
|
||
"data": {
|
||
"id": f"{aid}->{bid}",
|
||
"source": aid,
|
||
"target": bid,
|
||
"label": (rel.get("type") or "")[:100],
|
||
"type": rel.get("type", ""),
|
||
"note": rel.get("note", ""),
|
||
}
|
||
})
|
||
|
||
if seen_nodes:
|
||
degree_result = session.run(
|
||
"""
|
||
MATCH (s:Person {graph_id: $graph_id})
|
||
WHERE s.name IN $names
|
||
OPTIONAL MATCH (s)-[r:RELATION {graph_id: $graph_id}]-()
|
||
WITH s.name AS name, count(r) AS degree
|
||
RETURN name, degree
|
||
""",
|
||
graph_id=graph_id,
|
||
names=list(seen_nodes),
|
||
)
|
||
degree_map = {r["name"]: r["degree"] for r in degree_result}
|
||
for el in elements:
|
||
if "source" not in el["data"] and el["data"]["id"] in degree_map:
|
||
el["data"]["degree"] = degree_map[el["data"]["id"]]
|
||
finally:
|
||
driver.close()
|
||
|
||
return elements
|