huoyan-enterprise/backend/api/knowledge_graph_router.py

350 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识图谱 API上传资料文本 → 异步抽取实体关系 → Neo4j + 向量检索
"""
from __future__ import annotations
import asyncio
import uuid
from typing import Optional
import asyncpg
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Query, UploadFile
from core.config import settings
from core.database import get_db_pool
from core.dependencies import get_current_user, get_db
from core.graph_metadata import graph_table_sql
from core.permissions import can_manage_graph, can_view_graph
from models.graph_metadata import GraphRecord
from models.user import User
from services.knowledge_graph_service import KnowledgeGraphService
from services import neo4j_service
from services.novel_kg_service import (
extract_and_import_knowledge_graph,
extract_knowledge_document_text,
)
from utils.helpers import BaseResponse
from logger.logging import get_logger
logger = get_logger(__name__)
knowledge_graph_router = APIRouter(prefix="/api/knowledge-graph", tags=["知识图谱"])
MAX_UPLOAD_BYTES = 15 * 1024 * 1024
async def _knowledge_graph_build_task(record_id: int, neo4j_gid: str, text: str) -> None:
pool = await get_db_pool()
try:
async with pool.acquire() as conn:
t = graph_table_sql()
await conn.execute(
f"""
UPDATE {t}
SET build_status = 'processing', build_error = NULL, updated_at = CURRENT_TIMESTAMP
WHERE id = $1
""",
record_id,
)
stats = await extract_and_import_knowledge_graph(text, neo4j_gid)
rag_chunks = 0
try:
from services.vector_service import get_vector_service
def _index():
vs = get_vector_service()
return vs.index_knowledge_graph_text(record_id, text)
rag_chunks = await asyncio.to_thread(_index)
except Exception as rag_err:
logger.warning("知识图谱向量化失败(仍可查看关系图): {}", rag_err)
async with pool.acquire() as conn:
t = graph_table_sql()
await conn.execute(
f"""
UPDATE {t}
SET build_status = 'completed',
node_count = $2,
edge_count = $3,
rag_chunk_count = $4,
build_error = NULL,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
""",
record_id,
stats["node_count"],
stats["edge_count"],
rag_chunks,
)
logger.info(
"知识图谱构建完成 id={} neo4j={} rag_chunks={}",
record_id,
neo4j_gid,
rag_chunks,
)
except Exception as e:
logger.exception("知识图谱构建失败 id={}", record_id)
try:
neo4j_service.delete_knowledge_graph(neo4j_gid)
except Exception:
pass
try:
from services.vector_service import get_vector_service
get_vector_service().delete_knowledge_graph_collection(record_id)
except Exception:
pass
try:
async with pool.acquire() as conn:
t = graph_table_sql()
await conn.execute(
f"""
UPDATE {t}
SET build_status = 'failed',
build_error = $2,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
""",
record_id,
str(e)[:4000],
)
except Exception:
logger.exception("写入构建失败状态时出错")
@knowledge_graph_router.post("", response_model=BaseResponse, summary="上传资料文件并创建知识图谱")
async def create_knowledge_graph(
background_tasks: BackgroundTasks,
name: str = Query(..., description="图谱名称"),
description: Optional[str] = Query(None, description="图谱描述"),
visibility: str = Query("private", description="private | department | enterprise"),
file: UploadFile = File(
...,
description="支持 .txt / .pdf / .docx / 图片;扫描件可走 OCR 与通义视觉提取",
),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db),
):
if not settings.deepseek_api_key:
raise HTTPException(status_code=503, detail="服务端未配置 DEEPSEEK_API_KEY无法抽取实体关系")
if not file.filename:
raise HTTPException(status_code=400, detail="请上传文件")
raw = await file.read()
if len(raw) > MAX_UPLOAD_BYTES:
raise HTTPException(status_code=400, detail="文件过大,请控制在 15MB 以内")
try:
text = await extract_knowledge_document_text(file.filename, raw)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
graph_id = str(uuid.uuid4())
safe_name = file.filename[:255] if file.filename else "document.txt"
try:
vis = KnowledgeGraphService._validate_visibility(visibility)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
enterprise_id = current_user.enterprise_id
if enterprise_id is None:
raise HTTPException(status_code=400, detail="用户未关联企业,无法创建知识图谱")
try:
t = graph_table_sql()
row = await conn.fetchrow(
f"""
INSERT INTO {t} (
user_id, enterprise_id, department_id, creator_id, visibility,
name, description, csv_file_name,
node_count, edge_count, neo4j_graph_id,
build_status
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 0, 0, $9, 'pending')
RETURNING *
""",
current_user.id,
enterprise_id,
current_user.department_id,
current_user.id,
vis,
name.strip(),
description,
safe_name,
graph_id,
)
except Exception as e:
logger.exception("保存知识图谱元数据失败")
raise HTTPException(status_code=500, detail=f"创建图谱记录失败:{e}") from e
record_id = row["id"]
background_tasks.add_task(_knowledge_graph_build_task, record_id, graph_id, text)
enriched = await KnowledgeGraphService.enrich_graph_for_response(conn, dict(row), current_user)
return BaseResponse(
code=200,
msg="已接收文本,正在后台抽取关系并写入图谱",
data=enriched,
)
@knowledge_graph_router.get("", response_model=BaseResponse, summary="获取知识图谱列表")
async def list_knowledge_graphs(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db),
):
items, total = await KnowledgeGraphService.list_visible_graphs(conn, current_user, page, size)
return BaseResponse(
code=200,
msg="success",
data={
"items": items,
"total": total,
"page": page,
"size": size,
},
)
@knowledge_graph_router.get("/{graph_pk}/info", response_model=BaseResponse, summary="获取知识图谱详情")
async def get_knowledge_graph_info(
graph_pk: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db),
):
data = await KnowledgeGraphService.get_graph_for_viewer(conn, graph_pk, current_user)
if not data:
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
return BaseResponse(code=200, msg="success", data=data)
@knowledge_graph_router.delete("/{graph_pk}", response_model=BaseResponse, summary="删除知识图谱")
async def delete_knowledge_graph_ep(
graph_pk: int,
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db),
):
t = graph_table_sql()
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_pk)
if not raw:
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
gr = GraphRecord(
id=int(raw["id"]),
user_id=int(raw["user_id"]),
enterprise_id=raw.get("enterprise_id"),
department_id=raw.get("department_id"),
creator_id=raw.get("creator_id"),
visibility=raw.get("visibility") or "private",
)
if not can_manage_graph(current_user, gr):
raise HTTPException(status_code=403, detail="无权删除该知识图谱")
row = {"neo4j_graph_id": raw["neo4j_graph_id"]}
try:
neo4j_service.delete_knowledge_graph(row["neo4j_graph_id"])
except Exception as e:
logger.warning("删除 Neo4j 知识图谱数据失败(继续删元数据): {}", e)
try:
from services.vector_service import get_vector_service
get_vector_service().delete_knowledge_graph_collection(graph_pk)
except Exception as e:
logger.warning("删除知识图谱向量库失败: {}", e)
await conn.execute(
f"DELETE FROM {t} WHERE id = $1",
graph_pk,
)
return BaseResponse(code=200, msg="图谱已删除")
async def _fetch_graph_or_404(conn: asyncpg.Connection, graph_pk: int, user: User):
raw = await KnowledgeGraphService.fetch_graph_by_id(conn, graph_pk)
if not raw:
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
gr = GraphRecord(
id=int(raw["id"]),
user_id=int(raw["user_id"]),
enterprise_id=raw.get("enterprise_id"),
department_id=raw.get("department_id"),
creator_id=raw.get("creator_id"),
visibility=raw.get("visibility") or "private",
)
if not await can_view_graph(conn, user, gr):
raise HTTPException(status_code=404, detail="图谱不存在或无权访问")
return raw
@knowledge_graph_router.get("/{graph_pk}/data", response_model=BaseResponse, summary="获取 Cytoscape 图数据")
async def get_knowledge_graph_data_ep(
graph_pk: int,
limit: int = Query(200, ge=10, le=1000),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db),
):
raw = await _fetch_graph_or_404(conn, graph_pk, current_user)
row = {"neo4j_graph_id": raw["neo4j_graph_id"], "build_status": raw.get("build_status")}
if row["build_status"] != "completed":
raise HTTPException(status_code=409, detail="图谱尚未构建完成,请稍后再试")
try:
elements = neo4j_service.get_knowledge_graph_data(row["neo4j_graph_id"], limit=limit)
except Exception as e:
logger.exception("查询知识图谱数据失败")
raise HTTPException(status_code=500, detail=f"查询失败:{e}") from e
return BaseResponse(code=200, msg="success", data={"elements": elements})
@knowledge_graph_router.get("/{graph_pk}/search", response_model=BaseResponse, summary="按实体名搜索子图")
async def search_knowledge_graph_ep(
graph_pk: int,
q: str = Query(..., description="实体名称关键词"),
hops: int = Query(1, ge=1, le=3),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db),
):
raw = await _fetch_graph_or_404(conn, graph_pk, current_user)
row = {"neo4j_graph_id": raw["neo4j_graph_id"], "build_status": raw.get("build_status")}
if row["build_status"] != "completed":
raise HTTPException(status_code=409, detail="图谱尚未构建完成")
try:
result = neo4j_service.search_knowledge_graph(row["neo4j_graph_id"], keyword=q, hops=hops)
except Exception as e:
logger.exception("搜索知识图谱失败")
raise HTTPException(status_code=500, detail=f"搜索失败:{e}") from e
return BaseResponse(code=200, msg="success", data=result)
@knowledge_graph_router.get("/{graph_pk}/expand", response_model=BaseResponse, summary="展开节点邻居")
async def expand_knowledge_graph_node_ep(
graph_pk: int,
node: str = Query(..., description="实体名称"),
hops: int = Query(1, ge=1, le=3),
current_user: User = Depends(get_current_user),
conn: asyncpg.Connection = Depends(get_db),
):
raw = await _fetch_graph_or_404(conn, graph_pk, current_user)
row = {"neo4j_graph_id": raw["neo4j_graph_id"], "build_status": raw.get("build_status")}
if row["build_status"] != "completed":
raise HTTPException(status_code=409, detail="图谱尚未构建完成")
try:
elements = neo4j_service.expand_knowledge_graph_node(
row["neo4j_graph_id"], node_name=node, hops=hops
)
except Exception as e:
logger.exception("展开节点失败")
raise HTTPException(status_code=500, detail=f"展开失败:{e}") from e
return BaseResponse(code=200, msg="success", data={"elements": elements})