702 lines
25 KiB
Python
702 lines
25 KiB
Python
"""
|
||
资料文本 → 分块 → LLM 实体关系抽取 → Neo4j 三元组导入
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import io
|
||
import json
|
||
import os
|
||
import re
|
||
import tempfile
|
||
import zipfile
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
|
||
from core.config import settings
|
||
from core.llm_catalog import build_chat_model
|
||
from services import neo4j_service
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 知识图谱抽取上限:每块约 900 字、重叠 120,每块 1 次 DeepSeek 调用(串行)。
|
||
# 8 万字 ≈ 100 次调用,后台约 5–10 分钟可完成;50 万字需 600+ 次,基本不可行。
|
||
MAX_INPUT_CHARS = 80_000
|
||
MAX_KG_EXTRACT_CHUNKS = 100
|
||
CHUNK_SIZE = 900
|
||
CHUNK_OVERLAP = 120
|
||
|
||
MIN_MEANINGFUL_TEXT_LEN = 30
|
||
MAX_PDF_VISION_PAGES = 50
|
||
|
||
NOVEL_ALLOWED_EXTENSIONS = frozenset({
|
||
".txt", ".pdf", ".docx",
|
||
".png", ".jpg", ".jpeg", ".bmp", ".webp", ".gif",
|
||
})
|
||
IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".bmp", ".webp", ".gif"})
|
||
|
||
KG_VISION_PROMPT_IMAGE = (
|
||
"详细描述图片中的内容:场景、人物、物体、图表及所有可见文字(逐字提取)。"
|
||
"用通顺中文输出,便于后续做实体与关系抽取。"
|
||
)
|
||
KG_VISION_PROMPT_PAGE = (
|
||
"这是纸质文档的一页扫描图。请尽量还原页内全部文字(标题、正文、表格、脚注等),"
|
||
"并简要说明版面结构。用中文输出。"
|
||
)
|
||
|
||
|
||
def _collapse_blank_lines(text: str) -> str:
|
||
text = re.sub(r"[ \t\r\f\v]+", " ", text)
|
||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||
return text.strip()
|
||
|
||
|
||
def _text_from_txt(raw: bytes) -> str:
|
||
try:
|
||
s = raw.decode("utf-8")
|
||
except UnicodeDecodeError:
|
||
s = raw.decode("gb18030", errors="replace")
|
||
return _collapse_blank_lines(s)
|
||
|
||
|
||
def _text_from_pdf(raw: bytes) -> str:
|
||
from pypdf import PdfReader
|
||
|
||
buf = io.BytesIO(raw)
|
||
try:
|
||
reader = PdfReader(buf)
|
||
except Exception as e:
|
||
raise ValueError(f"无法读取 PDF:{e}") from e
|
||
|
||
parts: list[str] = []
|
||
for page in reader.pages:
|
||
try:
|
||
t = page.extract_text()
|
||
except Exception:
|
||
t = ""
|
||
if t and t.strip():
|
||
parts.append(t)
|
||
text = "\n".join(parts)
|
||
text = _collapse_blank_lines(text)
|
||
|
||
if len(text) < 30:
|
||
try:
|
||
import fitz # PyMuPDF
|
||
except ImportError:
|
||
return text
|
||
try:
|
||
doc = fitz.open(stream=raw, filetype="pdf")
|
||
alt: list[str] = []
|
||
for i in range(doc.page_count):
|
||
alt.append(doc.load_page(i).get_text() or "")
|
||
doc.close()
|
||
text = _collapse_blank_lines("\n".join(alt))
|
||
except Exception as e:
|
||
logger.warning("PyMuPDF 回退提取失败: {}", e)
|
||
|
||
return text
|
||
|
||
|
||
def _text_from_docx(raw: bytes) -> str:
|
||
try:
|
||
from docx import Document
|
||
except ImportError as e:
|
||
raise ValueError("服务端未安装 python-docx,无法解析 Word 文档") from e
|
||
|
||
try:
|
||
doc = Document(io.BytesIO(raw))
|
||
except Exception as e:
|
||
raise ValueError(f"无法读取 Word 文档(.docx):{e}") from e
|
||
|
||
parts: list[str] = []
|
||
for p in doc.paragraphs:
|
||
if p.text and p.text.strip():
|
||
parts.append(p.text.strip())
|
||
for table in doc.tables:
|
||
for row in table.rows:
|
||
for cell in row.cells:
|
||
if cell.text and cell.text.strip():
|
||
parts.append(cell.text.strip())
|
||
return _collapse_blank_lines("\n".join(parts))
|
||
|
||
|
||
def _text_meaningful(text: str) -> bool:
|
||
return bool(text and len(text.strip()) >= MIN_MEANINGFUL_TEXT_LEN)
|
||
|
||
|
||
def _guess_extension(filename: str | None, raw: bytes) -> str:
|
||
fn = (filename or "").lower()
|
||
ext = Path(fn).suffix.lower()
|
||
if ext in NOVEL_ALLOWED_EXTENSIONS:
|
||
return ext
|
||
if raw[:4] == b"%PDF":
|
||
return ".pdf"
|
||
if len(raw) > 4 and raw[:2] == b"PK":
|
||
if fn.endswith(".docx") or "docx" in fn:
|
||
return ".docx"
|
||
try:
|
||
zf = zipfile.ZipFile(io.BytesIO(raw))
|
||
names = zf.namelist()
|
||
zf.close()
|
||
if any(n.startswith("word/") for n in names):
|
||
return ".docx"
|
||
except zipfile.BadZipFile:
|
||
pass
|
||
if len(raw) >= 8 and raw[:8] == b"\x89PNG\r\n\x1a\n":
|
||
return ".png"
|
||
if len(raw) >= 3 and raw[:3] == b"\xff\xd8\xff":
|
||
return ".jpg"
|
||
if len(raw) >= 6 and raw[:6] in (b"GIF87a", b"GIF89a"):
|
||
return ".gif"
|
||
if len(raw) >= 2 and raw[:2] == b"BM":
|
||
return ".bmp"
|
||
if len(raw) >= 12 and raw[:4] == b"RIFF" and raw[8:12] == b"WEBP":
|
||
return ".webp"
|
||
if ext in ("", ".text"):
|
||
return ".txt"
|
||
raise ValueError(
|
||
"不支持的文件格式。支持:.txt、.pdf、.docx 及常见图片(.png/.jpg/.jpeg/.bmp/.webp/.gif)"
|
||
)
|
||
|
||
|
||
def _primary_extract(ext: str, raw: bytes) -> str:
|
||
if ext == ".txt":
|
||
return _text_from_txt(raw)
|
||
if ext == ".pdf":
|
||
return _text_from_pdf(raw)
|
||
if ext == ".docx":
|
||
return _text_from_docx(raw)
|
||
if ext in IMAGE_EXTENSIONS:
|
||
return ""
|
||
raise ValueError("不支持的文件格式")
|
||
|
||
|
||
def _temp_suffix(ext: str) -> str:
|
||
if ext in (".jpg", ".jpeg"):
|
||
return ".jpg"
|
||
return ext
|
||
|
||
|
||
def _pdf_ocr_with_vector(raw: bytes) -> str:
|
||
from services.vector_service import get_vector_service
|
||
|
||
vs = get_vector_service()
|
||
fd, path = tempfile.mkstemp(suffix=".pdf")
|
||
os.close(fd)
|
||
try:
|
||
with open(path, "wb") as f:
|
||
f.write(raw)
|
||
docs = vs._process_pdf_with_ocr(path)
|
||
if not docs:
|
||
return ""
|
||
return _collapse_blank_lines(docs[0].page_content)
|
||
finally:
|
||
try:
|
||
os.unlink(path)
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
def _docx_enhanced_with_vector(raw: bytes) -> str:
|
||
from services.vector_service import get_vector_service
|
||
|
||
vs = get_vector_service()
|
||
fd, path = tempfile.mkstemp(suffix=".docx")
|
||
os.close(fd)
|
||
img_paths: list[str] = []
|
||
try:
|
||
with open(path, "wb") as f:
|
||
f.write(raw)
|
||
docs, img_paths = vs._process_docx_with_images(path)
|
||
if not docs:
|
||
return ""
|
||
return _collapse_blank_lines(docs[0].page_content)
|
||
finally:
|
||
for p in img_paths:
|
||
try:
|
||
if os.path.isfile(p):
|
||
os.unlink(p)
|
||
except OSError:
|
||
pass
|
||
try:
|
||
os.unlink(path)
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
def _image_ocr_with_vector(raw: bytes, ext: str) -> str:
|
||
from services.vector_service import get_vector_service
|
||
|
||
vs = get_vector_service()
|
||
suf = _temp_suffix(ext)
|
||
fd, path = tempfile.mkstemp(suffix=suf)
|
||
os.close(fd)
|
||
try:
|
||
with open(path, "wb") as f:
|
||
f.write(raw)
|
||
docs = vs._process_image_ocr(path)
|
||
if not docs:
|
||
return ""
|
||
return _collapse_blank_lines(docs[0].page_content)
|
||
finally:
|
||
try:
|
||
os.unlink(path)
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
def _mime_for_ext(ext: str) -> str:
|
||
return {
|
||
".png": "image/png",
|
||
".jpg": "image/jpeg",
|
||
".jpeg": "image/jpeg",
|
||
".gif": "image/gif",
|
||
".bmp": "image/bmp",
|
||
".webp": "image/webp",
|
||
}.get(ext.lower(), "image/jpeg")
|
||
|
||
|
||
async def _pdf_pages_vision(raw: bytes) -> str:
|
||
from services.vision_service import VisionService
|
||
|
||
def rasterize() -> list[bytes]:
|
||
import fitz
|
||
|
||
doc = fitz.open(stream=raw, filetype="pdf")
|
||
out: list[bytes] = []
|
||
try:
|
||
n = min(doc.page_count, MAX_PDF_VISION_PAGES)
|
||
for i in range(n):
|
||
page = doc.load_page(i)
|
||
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
|
||
out.append(pix.tobytes("png"))
|
||
finally:
|
||
doc.close()
|
||
return out
|
||
|
||
try:
|
||
pages = await asyncio.to_thread(rasterize)
|
||
except Exception as e:
|
||
logger.warning("知识图谱:PDF 渲染为图片失败(视觉回退跳过): {}", e)
|
||
return ""
|
||
if not pages:
|
||
return ""
|
||
|
||
sem = asyncio.Semaphore(3)
|
||
|
||
async def one_page(idx: int, png_bytes: bytes) -> tuple[int, str]:
|
||
async with sem:
|
||
t = await VisionService.get_image_description_from_bytes(
|
||
png_bytes, prompt=KG_VISION_PROMPT_PAGE, mime_hint="image/png"
|
||
)
|
||
return idx, t or ""
|
||
|
||
ordered = await asyncio.gather(*(one_page(i, b) for i, b in enumerate(pages)))
|
||
parts: list[str] = []
|
||
for idx, t in sorted(ordered, key=lambda x: x[0]):
|
||
if t.strip():
|
||
parts.append(f"[第 {idx + 1} 页]\n{t.strip()}")
|
||
return "\n\n".join(parts)
|
||
|
||
|
||
async def _image_ocr_plus_vision(raw: bytes, ext: str) -> str:
|
||
from services.vision_service import VisionService
|
||
|
||
ocr_txt = ""
|
||
try:
|
||
ocr_txt = await asyncio.to_thread(_image_ocr_with_vector, raw, ext)
|
||
except Exception as e:
|
||
logger.warning("知识图谱:图片 OCR 失败或未配置 OCR: {}", e)
|
||
|
||
vision_txt = ""
|
||
if settings.dashscope_api_key:
|
||
try:
|
||
vision_txt = await VisionService.get_image_description_from_bytes(
|
||
raw, prompt=KG_VISION_PROMPT_IMAGE, mime_hint=_mime_for_ext(ext)
|
||
)
|
||
except Exception as e:
|
||
logger.warning("知识图谱:视觉模型失败: {}", e)
|
||
|
||
if ocr_txt.strip() and vision_txt.strip():
|
||
return _collapse_blank_lines(f"【视觉理解】\n{vision_txt}\n\n【OCR 文字】\n{ocr_txt}")
|
||
if vision_txt.strip():
|
||
return _collapse_blank_lines(vision_txt)
|
||
if ocr_txt.strip():
|
||
return _collapse_blank_lines(ocr_txt)
|
||
return ""
|
||
|
||
|
||
def _cannot_extract_message() -> str:
|
||
return (
|
||
"未能从文件中提取到足够文本。请配置阿里云 OCR(OCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET)"
|
||
"和/或通义视觉(DASHSCOPE_API_KEY),或换用可复制文字的 PDF / 文本文件。"
|
||
)
|
||
|
||
|
||
async def extract_knowledge_document_text(filename: str | None, raw: bytes) -> str:
|
||
"""
|
||
知识图谱上传:从字节流提取全文。顺序为常规解析 → Vector OCR(与知识库一致)→ 通义 VL 页面/图片理解。
|
||
"""
|
||
if not raw:
|
||
raise ValueError("文件内容为空")
|
||
|
||
ext = _guess_extension(filename, raw)
|
||
|
||
if ext == ".txt":
|
||
text = _primary_extract(ext, raw)
|
||
if not _text_meaningful(text):
|
||
raise ValueError("文本文件内容过短或为空")
|
||
return text
|
||
|
||
if ext in IMAGE_EXTENSIONS:
|
||
merged = await _image_ocr_plus_vision(raw, ext)
|
||
if not _text_meaningful(merged):
|
||
raise ValueError(_cannot_extract_message())
|
||
return merged
|
||
|
||
text = _primary_extract(ext, raw)
|
||
if _text_meaningful(text):
|
||
return text
|
||
|
||
if ext == ".pdf":
|
||
ocr_text = await asyncio.to_thread(_pdf_ocr_with_vector, raw)
|
||
if _text_meaningful(ocr_text):
|
||
logger.info("知识图谱:PDF 使用 Vector OCR 提取成功")
|
||
return ocr_text
|
||
if settings.dashscope_api_key:
|
||
vision_text = await _pdf_pages_vision(raw)
|
||
if _text_meaningful(vision_text):
|
||
logger.info("知识图谱:PDF 使用通义视觉按页提取成功")
|
||
return vision_text
|
||
raise ValueError(_cannot_extract_message())
|
||
|
||
if ext == ".docx":
|
||
enhanced = await asyncio.to_thread(_docx_enhanced_with_vector, raw)
|
||
if _text_meaningful(enhanced):
|
||
logger.info("知识图谱:DOCX 使用增强提取(正文+内嵌图 OCR)成功")
|
||
return enhanced
|
||
raise ValueError(_cannot_extract_message())
|
||
|
||
raise ValueError("不支持的文件格式")
|
||
|
||
|
||
def extract_knowledge_plain_text(filename: str | None, raw: bytes) -> str:
|
||
"""
|
||
仅做常规文本层解析(无 OCR / 视觉)。知识图谱接口应使用 extract_knowledge_document_text。
|
||
"""
|
||
if not raw:
|
||
raise ValueError("文件内容为空")
|
||
ext = _guess_extension(filename, raw)
|
||
if ext in IMAGE_EXTENSIONS:
|
||
raise ValueError("图片请使用 extract_knowledge_document_text(含 OCR/视觉)")
|
||
text = _primary_extract(ext, raw)
|
||
if not (text or "").strip():
|
||
raise ValueError("未能从文件中提取到文本,若为扫描版 PDF 请先 OCR 后再上传")
|
||
return text
|
||
|
||
|
||
def split_novel_text(text: str) -> list[str]:
|
||
text = text.strip()
|
||
if not text:
|
||
return []
|
||
splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=CHUNK_SIZE,
|
||
chunk_overlap=CHUNK_OVERLAP,
|
||
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""],
|
||
)
|
||
return splitter.split_text(text)
|
||
|
||
|
||
def _parse_triplet_json(content: str) -> list[dict[str, Any]]:
|
||
raw = content.strip()
|
||
m = re.search(r"\[[\s\S]*\]", raw)
|
||
if m:
|
||
raw = m.group(0)
|
||
try:
|
||
data = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
return []
|
||
if not isinstance(data, list):
|
||
return []
|
||
out: list[dict[str, Any]] = []
|
||
for item in data:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
subj = item.get("subject") or item.get("head") or item.get("s")
|
||
obj = item.get("object") or item.get("tail") or item.get("o")
|
||
rel = item.get("relation") or item.get("predicate") or item.get("p")
|
||
note = item.get("note") or item.get("evidence") or ""
|
||
if subj is None or obj is None:
|
||
continue
|
||
out.append({
|
||
"subject": str(subj).strip(),
|
||
"object": str(obj).strip(),
|
||
"relation_type": str(rel).strip() if rel else "相关",
|
||
"note": str(note).strip() if note else "",
|
||
})
|
||
return out
|
||
|
||
|
||
def _triplet_llm():
|
||
return build_chat_model(
|
||
provider="deepseek",
|
||
api_model="deepseek-chat",
|
||
streaming=False,
|
||
temperature=0.2,
|
||
)
|
||
|
||
|
||
async def extract_triplets_from_chunk(chunk: str, chunk_index: int) -> list[dict[str, Any]]:
|
||
if not settings.deepseek_api_key:
|
||
raise ValueError("未配置 DEEPSEEK_API_KEY,无法抽取实体关系")
|
||
|
||
llm = _triplet_llm()
|
||
prompt = f"""你是知识图谱构建专家。阅读下列文本片段(可能是中文或英文),抽取其中**实体之间的关系三元组**。
|
||
|
||
## 实体定义(重要!)
|
||
实体必须是**具体的名词性对象**,包括:
|
||
- 人物:具体的人名(如「贾宝玉」「James」「Bronny」)
|
||
- 组织:公司、机构、团队名称(如「荣国府」「Apple Inc.」「NASA」)
|
||
- 地点:具体的地名、场所(如「大观园」「Beijing」「New York」)
|
||
- 物品:具体的物体、产品名称(如「通灵宝玉」「iPhone」)
|
||
- 概念:重要的抽象概念、系统、模块名(如「知识库系统」「User Management Module」)
|
||
|
||
## 严禁作为实体的内容
|
||
- ❌ 动作短语:「离去」「到来」「leaving」「arriving」
|
||
- ❌ 泛指代词:「他」「她」「父亲」「母亲」「he」「she」「father」「mother」(除非是专有称呼)
|
||
- ❌ 描述性短语:「甄士隐离去」「John's departure」
|
||
- ❌ 动词短语:「听闻此信」「heard the news」
|
||
|
||
## 关系定义(重要!)
|
||
关系应描述**实体之间的静态联系**,不是动作,包括:
|
||
- 人际关系:夫妻/spouse、父子/father-son、母女/mother-daughter、师徒/mentor-disciple、朋友/friend
|
||
- 社会关系:雇佣/employed_by、所属/belongs_to、管理/manages、合作/cooperates_with
|
||
- 位置关系:位于/located_in、毗邻/adjacent_to、包含/contains、居住于/resides_in
|
||
- 属性关系:拥有/owns、制造/manufactures、创建/created_by
|
||
|
||
## 示例(正确)
|
||
中文原文:"封氏是甄士隐的嫡妻"
|
||
✅ {{"subject": "甄士隐", "relation": "夫妻", "object": "封氏", "note": "封氏是甄士隐的嫡妻"}}
|
||
|
||
中文原文:"封肃是甄士隐的岳父"
|
||
✅ {{"subject": "封肃", "relation": "岳父", "object": "甄士隐"}}
|
||
|
||
英文原文:"Bronny is LeBron James's son"
|
||
✅ {{"subject": "LeBron James", "relation": "father-son", "object": "Bronny", "note": "Bronny is LeBron James's son"}}
|
||
|
||
英文原文:"Apple Inc. is headquartered in Cupertino"
|
||
✅ {{"subject": "Apple Inc.", "relation": "located_in", "object": "Cupertino"}}
|
||
|
||
## 示例(错误)
|
||
❌ {{"subject": "封氏", "relation": "听闻", "object": "甄士隐离去"}} // "甄士隐离去"不是实体,"听闻"是动作
|
||
❌ {{"subject": "封氏", "relation": "依靠", "object": "父亲"}} // "父亲"是泛指,不是具体实体
|
||
❌ {{"subject": "John", "relation": "left", "object": "office"}} // "left"是动作,不是关系
|
||
|
||
## 输出要求
|
||
1. 只输出一个 JSON 数组,不要 Markdown、不要解释文字
|
||
2. 数组中每个元素包含:subject(主体实体名)、relation(关系类型,简洁表达)、object(客体实体名)
|
||
3. 可选字段 note:原文证据(≤50字符)
|
||
4. 使用原文中的具体名称,确保 subject 和 object 都是上述定义的实体
|
||
5. relation 用原文语言表达(中文文本用中文关系,英文文本用英文关系)
|
||
6. 若本段没有符合要求的实体关系,输出空数组 []
|
||
|
||
【文本片段 #{chunk_index}】
|
||
{chunk}
|
||
"""
|
||
messages = [
|
||
SystemMessage(content="你是知识图谱构建专家。只输出合法 JSON 数组,严格遵守实体和关系定义,键名使用英文 subject/relation/object/note。"),
|
||
HumanMessage(content=prompt),
|
||
]
|
||
response = await llm.ainvoke(messages)
|
||
return _parse_triplet_json(response.content)
|
||
|
||
|
||
FALLBACK_TEXT_CAP = 20_000
|
||
|
||
|
||
async def extract_triplets_fallback_manual(text: str) -> list[dict[str, Any]]:
|
||
"""
|
||
当分块抽取全部为空时,用一篇截断正文做一次汇总抽取。
|
||
"""
|
||
if not settings.deepseek_api_key:
|
||
return []
|
||
|
||
body = text.strip()
|
||
if len(body) > FALLBACK_TEXT_CAP:
|
||
body = body[:FALLBACK_TEXT_CAP] + "\n\n...(正文过长,已截断;若关系主要在后续章节,可考虑拆分为多文件上传)"
|
||
|
||
llm = _triplet_llm()
|
||
prompt = f"""你是知识图谱构建专家。请从下列文本(可能是中文或英文)中抽取**尽量多**的实体关系三元组。
|
||
|
||
## 实体定义(重要!)
|
||
实体必须是**具体的名词性对象**:
|
||
- 人物:具体的人名(如「贾宝玉」「James」「Bronny」)
|
||
- 组织:公司、机构、团队名称(如「荣国府」「Apple Inc.」)
|
||
- 地点:具体的地名、场所(如「大观园」「Beijing」「New York」)
|
||
- 物品:具体的物体、产品名称(如「通灵宝玉」「iPhone」)
|
||
- 概念:重要的系统、模块、功能名
|
||
|
||
## 严禁作为实体
|
||
❌ 动作短语:「离去」「到来」「leaving」「arriving」
|
||
❌ 泛指代词:「父亲」「母亲」「he」「she」「father」「mother」
|
||
❌ 描述性短语:「甄士隐离去」「John's departure」
|
||
|
||
## 关系定义(重要!)
|
||
关系应描述**实体之间的静态联系**,不是动作:
|
||
- 人际关系:夫妻/spouse、父子/father-son、母女/mother-daughter、师徒/mentor
|
||
- 社会关系:雇佣/employed_by、所属/belongs_to、管理/manages
|
||
- 位置关系:位于/located_in、毗邻/adjacent_to、居住于/resides_in
|
||
- 属性关系:拥有/owns、制造/manufactures、创建/created_by
|
||
|
||
## 输出要求
|
||
1. 只输出一个 JSON 数组,不要 Markdown
|
||
2. 每项包含:subject(主体实体名)、relation(关系类型,简洁表达)、object(客体实体名)
|
||
3. 可选字段 note:原文证据(≤50字符)
|
||
4. 使用原文中的具体名称,确保 subject 和 object 都是上述定义的实体
|
||
5. relation 用原文语言表达(中文文本用中文关系,英文文本用英文关系)
|
||
6. 不要编造原文没有的实体
|
||
7. 至少尝试抽取若干条;若全文无任何结构信息,才输出 []
|
||
|
||
【资料正文】
|
||
{body}
|
||
"""
|
||
messages = [
|
||
SystemMessage(content="你是知识图谱构建专家。只输出合法 JSON 数组,严格遵守实体和关系定义,键名 subject/relation/object/note。"),
|
||
HumanMessage(content=prompt),
|
||
]
|
||
response = await llm.ainvoke(messages)
|
||
return _parse_triplet_json(response.content)
|
||
|
||
|
||
def _is_valid_entity(name: str) -> bool:
|
||
"""
|
||
检查是否为有效实体名称。
|
||
过滤掉明显的动作短语、泛指代词等(支持中英文)。
|
||
"""
|
||
name = name.strip()
|
||
if not name:
|
||
return False
|
||
|
||
name_lower = name.lower()
|
||
|
||
# 过滤泛指代词(中文)
|
||
invalid_generic_zh = {"他", "她", "它", "他们", "她们", "我", "你", "我们", "你们",
|
||
"父亲", "母亲", "儿子", "女儿", "兄弟", "姐妹", "爷爷", "奶奶"}
|
||
if name in invalid_generic_zh:
|
||
return False
|
||
|
||
# 过滤泛指代词(英文)
|
||
invalid_generic_en = {"he", "she", "it", "they", "i", "you", "we",
|
||
"father", "mother", "son", "daughter", "brother", "sister",
|
||
"grandfather", "grandmother", "him", "her", "his", "their"}
|
||
if name_lower in invalid_generic_en:
|
||
return False
|
||
|
||
# 过滤明显的动作短语(中文动词)
|
||
action_verbs_zh = ["离去", "到来", "哭泣", "听闻", "看见", "说道", "笑道",
|
||
"走来", "回来", "进来", "出去", "过来", "起来", "下去"]
|
||
if any(verb in name for verb in action_verbs_zh):
|
||
return False
|
||
|
||
# 过滤明显的动作短语(英文动词)
|
||
action_verbs_en = ["leaving", "arriving", "crying", "hearing", "seeing", "saying",
|
||
"coming", "going", "walking", "running", "departure", "arrival"]
|
||
if any(verb in name_lower for verb in action_verbs_en):
|
||
return False
|
||
|
||
# 实体名称不应过长(可能是描述性短语)
|
||
# 英文实体名称可以稍长一些(考虑空格)
|
||
max_len = 30 if any(c.isascii() and c.isalpha() for c in name) else 20
|
||
if len(name) > max_len:
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
def merge_triplets(chunks: list[list[dict[str, Any]]]) -> list[dict[str, Any]]:
|
||
seen: set[tuple[str, str, str]] = set()
|
||
merged: list[dict[str, Any]] = []
|
||
for group in chunks:
|
||
for t in group:
|
||
s = (t.get("subject") or "").strip()
|
||
o = (t.get("object") or "").strip()
|
||
r = (t.get("relation_type") or "").strip() or "相关"
|
||
|
||
# 基本验证
|
||
if not s or not o or s == o:
|
||
continue
|
||
|
||
# 实体有效性验证
|
||
if not _is_valid_entity(s) or not _is_valid_entity(o):
|
||
continue
|
||
|
||
key = (s, o, r)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
merged.append({
|
||
"subject": s[:200],
|
||
"object": o[:200],
|
||
"relation_type": r[:120],
|
||
"note": (t.get("note") or "")[:500],
|
||
})
|
||
return merged
|
||
|
||
|
||
def validate_knowledge_graph_text_length(text: str) -> None:
|
||
"""上传/构建前校验正文长度,避免超长文本进入图谱抽取流水线。"""
|
||
char_count = len(text)
|
||
if char_count > MAX_INPUT_CHARS:
|
||
raise ValueError(
|
||
f"提取的正文过长(约 {char_count:,} 字),知识图谱构建上限为 {MAX_INPUT_CHARS:,} 字。"
|
||
"请拆分为多个文件分别上传,或使用知识库功能处理长文档问答。"
|
||
)
|
||
|
||
|
||
async def extract_and_import_knowledge_graph(text: str, graph_id: str) -> dict[str, Any]:
|
||
"""
|
||
对整篇文本分块调用 LLM,合并三元组后写入 Neo4j。
|
||
"""
|
||
validate_knowledge_graph_text_length(text)
|
||
|
||
chunks = split_novel_text(text)
|
||
if not chunks:
|
||
raise ValueError("文本为空")
|
||
|
||
if len(chunks) > MAX_KG_EXTRACT_CHUNKS:
|
||
raise ValueError(
|
||
f"文本分块过多({len(chunks)} 块,上限 {MAX_KG_EXTRACT_CHUNKS} 块),"
|
||
f"请将正文控制在约 {MAX_INPUT_CHARS:,} 字以内后重试。"
|
||
)
|
||
|
||
logger.info("知识图谱:共 {} 个文本块", len(chunks))
|
||
batch_results: list[list[dict[str, Any]]] = []
|
||
for i, ch in enumerate(chunks):
|
||
triplets = await extract_triplets_from_chunk(ch, i + 1)
|
||
logger.info("块 {}/{} 抽取到 {} 条关系", i + 1, len(chunks), len(triplets))
|
||
batch_results.append(triplets)
|
||
if i > 0 and i % 5 == 0:
|
||
await asyncio.sleep(0)
|
||
|
||
merged = merge_triplets(batch_results)
|
||
if not merged:
|
||
logger.warning("知识图谱:分块抽取无结果,尝试说明文档/产品手册汇总抽取")
|
||
fb = await extract_triplets_fallback_manual(text)
|
||
merged = merge_triplets([fb])
|
||
|
||
loop = asyncio.get_event_loop()
|
||
stats = await loop.run_in_executor(
|
||
None, lambda: neo4j_service.import_knowledge_graph_triplets(merged, graph_id)
|
||
)
|
||
if stats["node_count"] == 0:
|
||
logger.warning(
|
||
"知识图谱 graph_id={}:未写入任何关系节点,仍可使用向量检索(RAG)回答;"
|
||
"Neo4j 关系查询工具可能无数据。",
|
||
graph_id,
|
||
)
|
||
return stats
|