350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""
|
||
知识图谱 API:上传资料文本 → 异步抽取实体关系 → Neo4j + 向量检索
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
import asyncpg
|
||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Query, UploadFile
|
||
|
||
from core.config import settings
|
||
from core.database import get_db_pool
|
||
from core.dependencies import get_current_user, get_db
|
||
from core.graph_metadata import graph_table_sql
|
||
from core.permissions import can_manage_graph, can_view_graph
|
||
from models.graph_metadata import GraphRecord
|
||
from models.user import User
|
||
from services.knowledge_graph_service import KnowledgeGraphService
|
||
from services import neo4j_service
|
||
from services.novel_kg_service import (
|
||
extract_and_import_knowledge_graph,
|
||
extract_knowledge_document_text,
|
||
)
|
||
from utils.helpers import BaseResponse
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
knowledge_graph_router = APIRouter(prefix="/api/knowledge-graph", tags=["知识图谱"])
|
||
|
||
MAX_UPLOAD_BYTES = 15 * 1024 * 1024
|
||
|
||
|
||
async def _knowledge_graph_build_task(record_id: int, neo4j_gid: str, text: str) -> None:
|
||
pool = await get_db_pool()
|
||
try:
|
||
async with pool.acquire() as conn:
|
||
t = graph_table_sql()
|
||
await conn.execute(
|
||
f"""
|
||
UPDATE {t}
|
||
SET build_status = 'processing', build_error = NULL, updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = $1
|
||
""",
|
||
record_id,
|
||
)
|
||
stats = await extract_and_import_knowledge_graph(text, neo4j_gid)
|
||
|
||
rag_chunks = 0
|
||
try:
|
||
from services.vector_service import get_vector_service
|
||
|
||
def _index():
|
||
vs = get_vector_service()
|
||
return vs.index_knowledge_graph_text(record_id, text)
|
||
|
||
rag_chunks = await asyncio.to_thread(_index)
|
||
except Exception as rag_err:
|
||
logger.warning("知识图谱向量化失败(仍可查看关系图): {}", rag_err)
|
||
|
||
async with pool.acquire() as conn:
|
||
t = graph_table_sql()
|
||
await conn.execute(
|
||
f"""
|
||
UPDATE {t}
|
||
SET build_status = 'completed',
|
||
node_count = $2,
|
||
edge_count = $3,
|
||
rag_chunk_count = $4,
|
||
build_error = NULL,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = $1
|
||
""",
|
||
record_id,
|
||
stats["node_count"],
|
||
stats["edge_count"],
|
||
rag_chunks,
|
||
)
|
||
logger.info(
|
||
"知识图谱构建完成 id={} neo4j={} rag_chunks={}",
|
||
record_id,
|
||
neo4j_gid,
|
||
rag_chunks,
|
||
)
|
||
except Exception as e:
|
||
logger.exception("知识图谱构建失败 id={}", record_id)
|
||
try:
|
||
neo4j_service.delete_knowledge_graph(neo4j_gid)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
from services.vector_service import get_vector_service
|
||
|
||
get_vector_service().delete_knowledge_graph_collection(record_id)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
async with pool.acquire() as conn:
|
||
t = graph_table_sql()
|
||
await conn.execute(
|
||
f"""
|
||
UPDATE {t}
|
||
SET build_status = 'failed',
|
||
build_error = $2,
|
||
updated_at = CURRENT_TIMESTAMP
|
||
WHERE id = $1
|
||
""",
|
||
record_id,
|
||
str(e)[:4000],
|
||
)
|
||
except Exception:
|
||
logger.exception("写入构建失败状态时出错")
|
||
|
||
|
||
@knowledge_graph_router.post("", response_model=BaseResponse, summary="上传资料文件并创建知识图谱")
|
||
async def create_knowledge_graph(
|
||
background_tasks: BackgroundTasks,
|
||
name: str = Query(..., description="图谱名称"),
|
||
description: Optional[str] = Query(None, description="图谱描述"),
|
||
visibility: str = Query("private", description="private | department | enterprise"),
|
||
file: UploadFile = File(
|
||
...,
|
||
description="支持 .txt / .pdf / .docx / 图片;扫描件可走 OCR 与通义视觉提取",
|
||
),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db),
|
||
):
|
||
if not settings.deepseek_api_key:
|
||
raise HTTPException(status_code=503, detail="服务端未配置 DEEPSEEK_API_KEY,无法抽取实体关系")
|
||
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail="请上传文件")
|
||
|
||
raw = await file.read()
|
||
if len(raw) > MAX_UPLOAD_BYTES:
|
||
raise HTTPException(status_code=400, detail="文件过大,请控制在 15MB 以内")
|
||
|
||
try:
|
||
text = await extract_knowledge_document_text(file.filename, raw)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||
|
||
graph_id = str(uuid.uuid4())
|
||
safe_name = file.filename[:255] if file.filename else "document.txt"
|
||
|
||
try:
|
||
vis = KnowledgeGraphService._validate_visibility(visibility)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||
|
||
enterprise_id = current_user.enterprise_id
|
||
if enterprise_id is None:
|
||
raise HTTPException(status_code=400, detail="用户未关联企业,无法创建知识图谱")
|
||
|
||
try:
|
||
t = graph_table_sql()
|
||
row = await conn.fetchrow(
|
||
f"""
|
||
INSERT INTO {t} (
|
||
user_id, enterprise_id, department_id, creator_id, visibility,
|
||
name, description, csv_file_name,
|
||
node_count, edge_count, neo4j_graph_id,
|
||
build_status
|
||
)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 0, 0, $9, 'pending')
|
||
RETURNING *
|
||
""",
|
||
current_user.id,
|
||
enterprise_id,
|
||
current_user.department_id,
|
||
current_user.id,
|
||
vis,
|
||
name.strip(),
|
||
description,
|
||
safe_name,
|
||
graph_id,
|
||
)
|
||
except Exception as e:
|
||
logger.exception("保存知识图谱元数据失败")
|
||
raise HTTPException(status_code=500, detail=f"创建图谱记录失败:{e}") from e
|
||
|
||
record_id = row["id"]
|
||
background_tasks.add_task(_knowledge_graph_build_task, record_id, graph_id, text)
|
||
|
||
enriched = await KnowledgeGraphService.enrich_graph_for_response(conn, dict(row), current_user)
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="已接收文本,正在后台抽取关系并写入图谱",
|
||
data=enriched,
|
||
)
|
||
|
||
|
||
@knowledge_graph_router.get("", response_model=BaseResponse, summary="获取知识图谱列表")
|
||
async def list_knowledge_graphs(
|
||
page: int = Query(1, ge=1),
|
||
size: int = Query(20, ge=1, le=100),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db),
|
||
):
|
||
items, total = await KnowledgeGraphService.list_visible_graphs(conn, current_user, page, size)
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="success",
|
||
data={
|
||
"items": items,
|
||
"total": total,
|
||
"page": page,
|
||
"size": size,
|
||
},
|
||
)
|
||
|
||
|
||
@knowledge_graph_router.get("/{graph_pk}/info", response_model=BaseResponse, summary="获取知识图谱详情")
|
||
async def get_knowledge_graph_info(
|
||
graph_pk: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db),
|
||
):
|
||
data = await KnowledgeGraphService.get_graph_for_viewer(conn, graph_pk, current_user)
|
||
if not data:
|
||
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
|
||
return BaseResponse(code=200, msg="success", data=data)
|
||
|
||
|
||
@knowledge_graph_router.delete("/{graph_pk}", response_model=BaseResponse, summary="删除知识图谱")
|
||
async def delete_knowledge_graph_ep(
|
||
graph_pk: int,
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db),
|
||
):
|
||
t = graph_table_sql()
|
||
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_pk)
|
||
if not raw:
|
||
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
|
||
gr = GraphRecord(
|
||
id=int(raw["id"]),
|
||
user_id=int(raw["user_id"]),
|
||
enterprise_id=raw.get("enterprise_id"),
|
||
department_id=raw.get("department_id"),
|
||
creator_id=raw.get("creator_id"),
|
||
visibility=raw.get("visibility") or "private",
|
||
)
|
||
if not can_manage_graph(current_user, gr):
|
||
raise HTTPException(status_code=403, detail="无权删除该知识图谱")
|
||
row = {"neo4j_graph_id": raw["neo4j_graph_id"]}
|
||
|
||
try:
|
||
neo4j_service.delete_knowledge_graph(row["neo4j_graph_id"])
|
||
except Exception as e:
|
||
logger.warning("删除 Neo4j 知识图谱数据失败(继续删元数据): {}", e)
|
||
|
||
try:
|
||
from services.vector_service import get_vector_service
|
||
|
||
get_vector_service().delete_knowledge_graph_collection(graph_pk)
|
||
except Exception as e:
|
||
logger.warning("删除知识图谱向量库失败: {}", e)
|
||
|
||
await conn.execute(
|
||
f"DELETE FROM {t} WHERE id = $1",
|
||
graph_pk,
|
||
)
|
||
return BaseResponse(code=200, msg="图谱已删除")
|
||
|
||
|
||
async def _fetch_graph_or_404(conn: asyncpg.Connection, graph_pk: int, user: User):
|
||
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_pk)
|
||
if not raw:
|
||
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
|
||
gr = GraphRecord(
|
||
id=int(raw["id"]),
|
||
user_id=int(raw["user_id"]),
|
||
enterprise_id=raw.get("enterprise_id"),
|
||
department_id=raw.get("department_id"),
|
||
creator_id=raw.get("creator_id"),
|
||
visibility=raw.get("visibility") or "private",
|
||
)
|
||
if not await can_view_graph(conn, user, gr):
|
||
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
|
||
return raw
|
||
|
||
|
||
@knowledge_graph_router.get("/{graph_pk}/data", response_model=BaseResponse, summary="获取 Cytoscape 图数据")
|
||
async def get_knowledge_graph_data_ep(
|
||
graph_pk: int,
|
||
limit: int = Query(200, ge=10, le=1000),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db),
|
||
):
|
||
raw = await _fetch_graph_or_404(conn, graph_pk, current_user)
|
||
row = {"neo4j_graph_id": raw["neo4j_graph_id"], "build_status": raw.get("build_status")}
|
||
if row["build_status"] != "completed":
|
||
raise HTTPException(status_code=409, detail="图谱尚未构建完成,请稍后再试")
|
||
|
||
try:
|
||
elements = neo4j_service.get_knowledge_graph_data(row["neo4j_graph_id"], limit=limit)
|
||
except Exception as e:
|
||
logger.exception("查询知识图谱数据失败")
|
||
raise HTTPException(status_code=500, detail=f"查询失败:{e}") from e
|
||
|
||
return BaseResponse(code=200, msg="success", data={"elements": elements})
|
||
|
||
|
||
@knowledge_graph_router.get("/{graph_pk}/search", response_model=BaseResponse, summary="按实体名搜索子图")
|
||
async def search_knowledge_graph_ep(
|
||
graph_pk: int,
|
||
q: str = Query(..., description="实体名称关键词"),
|
||
hops: int = Query(1, ge=1, le=3),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db),
|
||
):
|
||
raw = await _fetch_graph_or_404(conn, graph_pk, current_user)
|
||
row = {"neo4j_graph_id": raw["neo4j_graph_id"], "build_status": raw.get("build_status")}
|
||
if row["build_status"] != "completed":
|
||
raise HTTPException(status_code=409, detail="图谱尚未构建完成")
|
||
|
||
try:
|
||
result = neo4j_service.search_knowledge_graph(row["neo4j_graph_id"], keyword=q, hops=hops)
|
||
except Exception as e:
|
||
logger.exception("搜索知识图谱失败")
|
||
raise HTTPException(status_code=500, detail=f"搜索失败:{e}") from e
|
||
|
||
return BaseResponse(code=200, msg="success", data=result)
|
||
|
||
|
||
@knowledge_graph_router.get("/{graph_pk}/expand", response_model=BaseResponse, summary="展开节点邻居")
|
||
async def expand_knowledge_graph_node_ep(
|
||
graph_pk: int,
|
||
node: str = Query(..., description="实体名称"),
|
||
hops: int = Query(1, ge=1, le=3),
|
||
current_user: User = Depends(get_current_user),
|
||
conn: asyncpg.Connection = Depends(get_db),
|
||
):
|
||
raw = await _fetch_graph_or_404(conn, graph_pk, current_user)
|
||
row = {"neo4j_graph_id": raw["neo4j_graph_id"], "build_status": raw.get("build_status")}
|
||
if row["build_status"] != "completed":
|
||
raise HTTPException(status_code=409, detail="图谱尚未构建完成")
|
||
|
||
try:
|
||
elements = neo4j_service.expand_knowledge_graph_node(
|
||
row["neo4j_graph_id"], node_name=node, hops=hops
|
||
)
|
||
except Exception as e:
|
||
logger.exception("展开节点失败")
|
||
raise HTTPException(status_code=500, detail=f"展开失败:{e}") from e
|
||
|
||
return BaseResponse(code=200, msg="success", data={"elements": elements})
|