huoyan-enterprise/backend/services/novel_kg_service.py

702 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
资料文本 → 分块 → LLM 实体关系抽取 → Neo4j 三元组导入
"""
from __future__ import annotations
import asyncio
import io
import json
import os
import re
import tempfile
import zipfile
from pathlib import Path
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_text_splitters import RecursiveCharacterTextSplitter
from core.config import settings
from core.llm_catalog import build_chat_model
from services import neo4j_service
from logger.logging import get_logger
logger = get_logger(__name__)
# 知识图谱抽取上限:每块约 900 字、重叠 120每块 1 次 DeepSeek 调用(串行)。
# 8 万字 ≈ 100 次调用,后台约 510 分钟可完成50 万字需 600+ 次,基本不可行。
MAX_INPUT_CHARS = 80_000
MAX_KG_EXTRACT_CHUNKS = 100
CHUNK_SIZE = 900
CHUNK_OVERLAP = 120
MIN_MEANINGFUL_TEXT_LEN = 30
MAX_PDF_VISION_PAGES = 50
NOVEL_ALLOWED_EXTENSIONS = frozenset({
".txt", ".pdf", ".docx",
".png", ".jpg", ".jpeg", ".bmp", ".webp", ".gif",
})
IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg", ".bmp", ".webp", ".gif"})
KG_VISION_PROMPT_IMAGE = (
"详细描述图片中的内容:场景、人物、物体、图表及所有可见文字(逐字提取)。"
"用通顺中文输出,便于后续做实体与关系抽取。"
)
KG_VISION_PROMPT_PAGE = (
"这是纸质文档的一页扫描图。请尽量还原页内全部文字(标题、正文、表格、脚注等),"
"并简要说明版面结构。用中文输出。"
)
def _collapse_blank_lines(text: str) -> str:
text = re.sub(r"[ \t\r\f\v]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
def _text_from_txt(raw: bytes) -> str:
try:
s = raw.decode("utf-8")
except UnicodeDecodeError:
s = raw.decode("gb18030", errors="replace")
return _collapse_blank_lines(s)
def _text_from_pdf(raw: bytes) -> str:
from pypdf import PdfReader
buf = io.BytesIO(raw)
try:
reader = PdfReader(buf)
except Exception as e:
raise ValueError(f"无法读取 PDF{e}") from e
parts: list[str] = []
for page in reader.pages:
try:
t = page.extract_text()
except Exception:
t = ""
if t and t.strip():
parts.append(t)
text = "\n".join(parts)
text = _collapse_blank_lines(text)
if len(text) < 30:
try:
import fitz # PyMuPDF
except ImportError:
return text
try:
doc = fitz.open(stream=raw, filetype="pdf")
alt: list[str] = []
for i in range(doc.page_count):
alt.append(doc.load_page(i).get_text() or "")
doc.close()
text = _collapse_blank_lines("\n".join(alt))
except Exception as e:
logger.warning("PyMuPDF 回退提取失败: {}", e)
return text
def _text_from_docx(raw: bytes) -> str:
try:
from docx import Document
except ImportError as e:
raise ValueError("服务端未安装 python-docx无法解析 Word 文档") from e
try:
doc = Document(io.BytesIO(raw))
except Exception as e:
raise ValueError(f"无法读取 Word 文档(.docx{e}") from e
parts: list[str] = []
for p in doc.paragraphs:
if p.text and p.text.strip():
parts.append(p.text.strip())
for table in doc.tables:
for row in table.rows:
for cell in row.cells:
if cell.text and cell.text.strip():
parts.append(cell.text.strip())
return _collapse_blank_lines("\n".join(parts))
def _text_meaningful(text: str) -> bool:
return bool(text and len(text.strip()) >= MIN_MEANINGFUL_TEXT_LEN)
def _guess_extension(filename: str | None, raw: bytes) -> str:
fn = (filename or "").lower()
ext = Path(fn).suffix.lower()
if ext in NOVEL_ALLOWED_EXTENSIONS:
return ext
if raw[:4] == b"%PDF":
return ".pdf"
if len(raw) > 4 and raw[:2] == b"PK":
if fn.endswith(".docx") or "docx" in fn:
return ".docx"
try:
zf = zipfile.ZipFile(io.BytesIO(raw))
names = zf.namelist()
zf.close()
if any(n.startswith("word/") for n in names):
return ".docx"
except zipfile.BadZipFile:
pass
if len(raw) >= 8 and raw[:8] == b"\x89PNG\r\n\x1a\n":
return ".png"
if len(raw) >= 3 and raw[:3] == b"\xff\xd8\xff":
return ".jpg"
if len(raw) >= 6 and raw[:6] in (b"GIF87a", b"GIF89a"):
return ".gif"
if len(raw) >= 2 and raw[:2] == b"BM":
return ".bmp"
if len(raw) >= 12 and raw[:4] == b"RIFF" and raw[8:12] == b"WEBP":
return ".webp"
if ext in ("", ".text"):
return ".txt"
raise ValueError(
"不支持的文件格式。支持:.txt、.pdf、.docx 及常见图片(.png/.jpg/.jpeg/.bmp/.webp/.gif"
)
def _primary_extract(ext: str, raw: bytes) -> str:
if ext == ".txt":
return _text_from_txt(raw)
if ext == ".pdf":
return _text_from_pdf(raw)
if ext == ".docx":
return _text_from_docx(raw)
if ext in IMAGE_EXTENSIONS:
return ""
raise ValueError("不支持的文件格式")
def _temp_suffix(ext: str) -> str:
if ext in (".jpg", ".jpeg"):
return ".jpg"
return ext
def _pdf_ocr_with_vector(raw: bytes) -> str:
from services.vector_service import get_vector_service
vs = get_vector_service()
fd, path = tempfile.mkstemp(suffix=".pdf")
os.close(fd)
try:
with open(path, "wb") as f:
f.write(raw)
docs = vs._process_pdf_with_ocr(path)
if not docs:
return ""
return _collapse_blank_lines(docs[0].page_content)
finally:
try:
os.unlink(path)
except OSError:
pass
def _docx_enhanced_with_vector(raw: bytes) -> str:
from services.vector_service import get_vector_service
vs = get_vector_service()
fd, path = tempfile.mkstemp(suffix=".docx")
os.close(fd)
img_paths: list[str] = []
try:
with open(path, "wb") as f:
f.write(raw)
docs, img_paths = vs._process_docx_with_images(path)
if not docs:
return ""
return _collapse_blank_lines(docs[0].page_content)
finally:
for p in img_paths:
try:
if os.path.isfile(p):
os.unlink(p)
except OSError:
pass
try:
os.unlink(path)
except OSError:
pass
def _image_ocr_with_vector(raw: bytes, ext: str) -> str:
from services.vector_service import get_vector_service
vs = get_vector_service()
suf = _temp_suffix(ext)
fd, path = tempfile.mkstemp(suffix=suf)
os.close(fd)
try:
with open(path, "wb") as f:
f.write(raw)
docs = vs._process_image_ocr(path)
if not docs:
return ""
return _collapse_blank_lines(docs[0].page_content)
finally:
try:
os.unlink(path)
except OSError:
pass
def _mime_for_ext(ext: str) -> str:
return {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".bmp": "image/bmp",
".webp": "image/webp",
}.get(ext.lower(), "image/jpeg")
async def _pdf_pages_vision(raw: bytes) -> str:
from services.vision_service import VisionService
def rasterize() -> list[bytes]:
import fitz
doc = fitz.open(stream=raw, filetype="pdf")
out: list[bytes] = []
try:
n = min(doc.page_count, MAX_PDF_VISION_PAGES)
for i in range(n):
page = doc.load_page(i)
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
out.append(pix.tobytes("png"))
finally:
doc.close()
return out
try:
pages = await asyncio.to_thread(rasterize)
except Exception as e:
logger.warning("知识图谱PDF 渲染为图片失败(视觉回退跳过): {}", e)
return ""
if not pages:
return ""
sem = asyncio.Semaphore(3)
async def one_page(idx: int, png_bytes: bytes) -> tuple[int, str]:
async with sem:
t = await VisionService.get_image_description_from_bytes(
png_bytes, prompt=KG_VISION_PROMPT_PAGE, mime_hint="image/png"
)
return idx, t or ""
ordered = await asyncio.gather(*(one_page(i, b) for i, b in enumerate(pages)))
parts: list[str] = []
for idx, t in sorted(ordered, key=lambda x: x[0]):
if t.strip():
parts.append(f"[第 {idx + 1} 页]\n{t.strip()}")
return "\n\n".join(parts)
async def _image_ocr_plus_vision(raw: bytes, ext: str) -> str:
from services.vision_service import VisionService
ocr_txt = ""
try:
ocr_txt = await asyncio.to_thread(_image_ocr_with_vector, raw, ext)
except Exception as e:
logger.warning("知识图谱:图片 OCR 失败或未配置 OCR: {}", e)
vision_txt = ""
if settings.dashscope_api_key:
try:
vision_txt = await VisionService.get_image_description_from_bytes(
raw, prompt=KG_VISION_PROMPT_IMAGE, mime_hint=_mime_for_ext(ext)
)
except Exception as e:
logger.warning("知识图谱:视觉模型失败: {}", e)
if ocr_txt.strip() and vision_txt.strip():
return _collapse_blank_lines(f"【视觉理解】\n{vision_txt}\n\n【OCR 文字】\n{ocr_txt}")
if vision_txt.strip():
return _collapse_blank_lines(vision_txt)
if ocr_txt.strip():
return _collapse_blank_lines(ocr_txt)
return ""
def _cannot_extract_message() -> str:
return (
"未能从文件中提取到足够文本。请配置阿里云 OCROCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET"
"和/或通义视觉DASHSCOPE_API_KEY或换用可复制文字的 PDF / 文本文件。"
)
async def extract_knowledge_document_text(filename: str | None, raw: bytes) -> str:
"""
知识图谱上传:从字节流提取全文。顺序为常规解析 → Vector OCR与知识库一致→ 通义 VL 页面/图片理解。
"""
if not raw:
raise ValueError("文件内容为空")
ext = _guess_extension(filename, raw)
if ext == ".txt":
text = _primary_extract(ext, raw)
if not _text_meaningful(text):
raise ValueError("文本文件内容过短或为空")
return text
if ext in IMAGE_EXTENSIONS:
merged = await _image_ocr_plus_vision(raw, ext)
if not _text_meaningful(merged):
raise ValueError(_cannot_extract_message())
return merged
text = _primary_extract(ext, raw)
if _text_meaningful(text):
return text
if ext == ".pdf":
ocr_text = await asyncio.to_thread(_pdf_ocr_with_vector, raw)
if _text_meaningful(ocr_text):
logger.info("知识图谱PDF 使用 Vector OCR 提取成功")
return ocr_text
if settings.dashscope_api_key:
vision_text = await _pdf_pages_vision(raw)
if _text_meaningful(vision_text):
logger.info("知识图谱PDF 使用通义视觉按页提取成功")
return vision_text
raise ValueError(_cannot_extract_message())
if ext == ".docx":
enhanced = await asyncio.to_thread(_docx_enhanced_with_vector, raw)
if _text_meaningful(enhanced):
logger.info("知识图谱DOCX 使用增强提取(正文+内嵌图 OCR成功")
return enhanced
raise ValueError(_cannot_extract_message())
raise ValueError("不支持的文件格式")
def extract_knowledge_plain_text(filename: str | None, raw: bytes) -> str:
"""
仅做常规文本层解析(无 OCR / 视觉)。知识图谱接口应使用 extract_knowledge_document_text。
"""
if not raw:
raise ValueError("文件内容为空")
ext = _guess_extension(filename, raw)
if ext in IMAGE_EXTENSIONS:
raise ValueError("图片请使用 extract_knowledge_document_text含 OCR/视觉)")
text = _primary_extract(ext, raw)
if not (text or "").strip():
raise ValueError("未能从文件中提取到文本,若为扫描版 PDF 请先 OCR 后再上传")
return text
def split_novel_text(text: str) -> list[str]:
text = text.strip()
if not text:
return []
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n\n", "\n", "", "", "", "", "", " ", ""],
)
return splitter.split_text(text)
def _parse_triplet_json(content: str) -> list[dict[str, Any]]:
raw = content.strip()
m = re.search(r"\[[\s\S]*\]", raw)
if m:
raw = m.group(0)
try:
data = json.loads(raw)
except json.JSONDecodeError:
return []
if not isinstance(data, list):
return []
out: list[dict[str, Any]] = []
for item in data:
if not isinstance(item, dict):
continue
subj = item.get("subject") or item.get("head") or item.get("s")
obj = item.get("object") or item.get("tail") or item.get("o")
rel = item.get("relation") or item.get("predicate") or item.get("p")
note = item.get("note") or item.get("evidence") or ""
if subj is None or obj is None:
continue
out.append({
"subject": str(subj).strip(),
"object": str(obj).strip(),
"relation_type": str(rel).strip() if rel else "相关",
"note": str(note).strip() if note else "",
})
return out
def _triplet_llm():
return build_chat_model(
provider="deepseek",
api_model="deepseek-chat",
streaming=False,
temperature=0.2,
)
async def extract_triplets_from_chunk(chunk: str, chunk_index: int) -> list[dict[str, Any]]:
if not settings.deepseek_api_key:
raise ValueError("未配置 DEEPSEEK_API_KEY无法抽取实体关系")
llm = _triplet_llm()
prompt = f"""你是知识图谱构建专家。阅读下列文本片段(可能是中文或英文),抽取其中**实体之间的关系三元组**。
## 实体定义(重要!)
实体必须是**具体的名词性对象**,包括:
- 人物具体的人名如「贾宝玉」「James」「Bronny」
- 组织公司、机构、团队名称如「荣国府」「Apple Inc.」「NASA」
- 地点具体的地名、场所如「大观园」「Beijing」「New York」
- 物品具体的物体、产品名称如「通灵宝玉」「iPhone」
- 概念重要的抽象概念、系统、模块名如「知识库系统」「User Management Module」
## 严禁作为实体的内容
- ❌ 动作短语「离去」「到来」「leaving」「arriving」
- ❌ 泛指代词「他」「她」「父亲」「母亲」「he」「she」「father」「mother」除非是专有称呼
- ❌ 描述性短语「甄士隐离去」「John's departure」
- ❌ 动词短语「听闻此信」「heard the news」
## 关系定义(重要!)
关系应描述**实体之间的静态联系**,不是动作,包括:
- 人际关系:夫妻/spouse、父子/father-son、母女/mother-daughter、师徒/mentor-disciple、朋友/friend
- 社会关系:雇佣/employed_by、所属/belongs_to、管理/manages、合作/cooperates_with
- 位置关系:位于/located_in、毗邻/adjacent_to、包含/contains、居住于/resides_in
- 属性关系:拥有/owns、制造/manufactures、创建/created_by
## 示例(正确)
中文原文:"封氏是甄士隐的嫡妻"
{{"subject": "甄士隐", "relation": "夫妻", "object": "封氏", "note": "封氏是甄士隐的嫡妻"}}
中文原文:"封肃是甄士隐的岳父"
{{"subject": "封肃", "relation": "岳父", "object": "甄士隐"}}
英文原文:"Bronny is LeBron James's son"
{{"subject": "LeBron James", "relation": "father-son", "object": "Bronny", "note": "Bronny is LeBron James's son"}}
英文原文:"Apple Inc. is headquartered in Cupertino"
{{"subject": "Apple Inc.", "relation": "located_in", "object": "Cupertino"}}
## 示例(错误)
{{"subject": "封氏", "relation": "听闻", "object": "甄士隐离去"}} // "甄士隐离去"不是实体,"听闻"是动作
{{"subject": "封氏", "relation": "依靠", "object": "父亲"}} // "父亲"是泛指,不是具体实体
{{"subject": "John", "relation": "left", "object": "office"}} // "left"是动作,不是关系
## 输出要求
1. 只输出一个 JSON 数组,不要 Markdown、不要解释文字
2. 数组中每个元素包含subject主体实体名、relation关系类型简洁表达、object客体实体名
3. 可选字段 note原文证据≤50字符
4. 使用原文中的具体名称,确保 subject 和 object 都是上述定义的实体
5. relation 用原文语言表达(中文文本用中文关系,英文文本用英文关系)
6. 若本段没有符合要求的实体关系,输出空数组 []
【文本片段 #{chunk_index}
{chunk}
"""
messages = [
SystemMessage(content="你是知识图谱构建专家。只输出合法 JSON 数组,严格遵守实体和关系定义,键名使用英文 subject/relation/object/note。"),
HumanMessage(content=prompt),
]
response = await llm.ainvoke(messages)
return _parse_triplet_json(response.content)
FALLBACK_TEXT_CAP = 20_000
async def extract_triplets_fallback_manual(text: str) -> list[dict[str, Any]]:
"""
当分块抽取全部为空时,用一篇截断正文做一次汇总抽取。
"""
if not settings.deepseek_api_key:
return []
body = text.strip()
if len(body) > FALLBACK_TEXT_CAP:
body = body[:FALLBACK_TEXT_CAP] + "\n\n...(正文过长,已截断;若关系主要在后续章节,可考虑拆分为多文件上传)"
llm = _triplet_llm()
prompt = f"""你是知识图谱构建专家。请从下列文本(可能是中文或英文)中抽取**尽量多**的实体关系三元组。
## 实体定义(重要!)
实体必须是**具体的名词性对象**
- 人物具体的人名如「贾宝玉」「James」「Bronny」
- 组织公司、机构、团队名称如「荣国府」「Apple Inc.」)
- 地点具体的地名、场所如「大观园」「Beijing」「New York」
- 物品具体的物体、产品名称如「通灵宝玉」「iPhone」
- 概念:重要的系统、模块、功能名
## 严禁作为实体
❌ 动作短语「离去」「到来」「leaving」「arriving」
❌ 泛指代词「父亲」「母亲」「he」「she」「father」「mother」
❌ 描述性短语「甄士隐离去」「John's departure」
## 关系定义(重要!)
关系应描述**实体之间的静态联系**,不是动作:
- 人际关系:夫妻/spouse、父子/father-son、母女/mother-daughter、师徒/mentor
- 社会关系:雇佣/employed_by、所属/belongs_to、管理/manages
- 位置关系:位于/located_in、毗邻/adjacent_to、居住于/resides_in
- 属性关系:拥有/owns、制造/manufactures、创建/created_by
## 输出要求
1. 只输出一个 JSON 数组,不要 Markdown
2. 每项包含subject主体实体名、relation关系类型简洁表达、object客体实体名
3. 可选字段 note原文证据≤50字符
4. 使用原文中的具体名称,确保 subject 和 object 都是上述定义的实体
5. relation 用原文语言表达(中文文本用中文关系,英文文本用英文关系)
6. 不要编造原文没有的实体
7. 至少尝试抽取若干条;若全文无任何结构信息,才输出 []
【资料正文】
{body}
"""
messages = [
SystemMessage(content="你是知识图谱构建专家。只输出合法 JSON 数组,严格遵守实体和关系定义,键名 subject/relation/object/note。"),
HumanMessage(content=prompt),
]
response = await llm.ainvoke(messages)
return _parse_triplet_json(response.content)
def _is_valid_entity(name: str) -> bool:
"""
检查是否为有效实体名称。
过滤掉明显的动作短语、泛指代词等(支持中英文)。
"""
name = name.strip()
if not name:
return False
name_lower = name.lower()
# 过滤泛指代词(中文)
invalid_generic_zh = {"", "", "", "他们", "她们", "", "", "我们", "你们",
"父亲", "母亲", "儿子", "女儿", "兄弟", "姐妹", "爷爷", "奶奶"}
if name in invalid_generic_zh:
return False
# 过滤泛指代词(英文)
invalid_generic_en = {"he", "she", "it", "they", "i", "you", "we",
"father", "mother", "son", "daughter", "brother", "sister",
"grandfather", "grandmother", "him", "her", "his", "their"}
if name_lower in invalid_generic_en:
return False
# 过滤明显的动作短语(中文动词)
action_verbs_zh = ["离去", "到来", "哭泣", "听闻", "看见", "说道", "笑道",
"走来", "回来", "进来", "出去", "过来", "起来", "下去"]
if any(verb in name for verb in action_verbs_zh):
return False
# 过滤明显的动作短语(英文动词)
action_verbs_en = ["leaving", "arriving", "crying", "hearing", "seeing", "saying",
"coming", "going", "walking", "running", "departure", "arrival"]
if any(verb in name_lower for verb in action_verbs_en):
return False
# 实体名称不应过长(可能是描述性短语)
# 英文实体名称可以稍长一些(考虑空格)
max_len = 30 if any(c.isascii() and c.isalpha() for c in name) else 20
if len(name) > max_len:
return False
return True
def merge_triplets(chunks: list[list[dict[str, Any]]]) -> list[dict[str, Any]]:
seen: set[tuple[str, str, str]] = set()
merged: list[dict[str, Any]] = []
for group in chunks:
for t in group:
s = (t.get("subject") or "").strip()
o = (t.get("object") or "").strip()
r = (t.get("relation_type") or "").strip() or "相关"
# 基本验证
if not s or not o or s == o:
continue
# 实体有效性验证
if not _is_valid_entity(s) or not _is_valid_entity(o):
continue
key = (s, o, r)
if key in seen:
continue
seen.add(key)
merged.append({
"subject": s[:200],
"object": o[:200],
"relation_type": r[:120],
"note": (t.get("note") or "")[:500],
})
return merged
def validate_knowledge_graph_text_length(text: str) -> None:
"""上传/构建前校验正文长度,避免超长文本进入图谱抽取流水线。"""
char_count = len(text)
if char_count > MAX_INPUT_CHARS:
raise ValueError(
f"提取的正文过长(约 {char_count:,} 字),知识图谱构建上限为 {MAX_INPUT_CHARS:,} 字。"
"请拆分为多个文件分别上传,或使用知识库功能处理长文档问答。"
)
async def extract_and_import_knowledge_graph(text: str, graph_id: str) -> dict[str, Any]:
"""
对整篇文本分块调用 LLM合并三元组后写入 Neo4j。
"""
validate_knowledge_graph_text_length(text)
chunks = split_novel_text(text)
if not chunks:
raise ValueError("文本为空")
if len(chunks) > MAX_KG_EXTRACT_CHUNKS:
raise ValueError(
f"文本分块过多({len(chunks)} 块,上限 {MAX_KG_EXTRACT_CHUNKS} 块),"
f"请将正文控制在约 {MAX_INPUT_CHARS:,} 字以内后重试。"
)
logger.info("知识图谱:共 {} 个文本块", len(chunks))
batch_results: list[list[dict[str, Any]]] = []
for i, ch in enumerate(chunks):
triplets = await extract_triplets_from_chunk(ch, i + 1)
logger.info("{}/{} 抽取到 {} 条关系", i + 1, len(chunks), len(triplets))
batch_results.append(triplets)
if i > 0 and i % 5 == 0:
await asyncio.sleep(0)
merged = merge_triplets(batch_results)
if not merged:
logger.warning("知识图谱:分块抽取无结果,尝试说明文档/产品手册汇总抽取")
fb = await extract_triplets_fallback_manual(text)
merged = merge_triplets([fb])
loop = asyncio.get_event_loop()
stats = await loop.run_in_executor(
None, lambda: neo4j_service.import_knowledge_graph_triplets(merged, graph_id)
)
if stats["node_count"] == 0:
logger.warning(
"知识图谱 graph_id={}未写入任何关系节点仍可使用向量检索RAG回答"
"Neo4j 关系查询工具可能无数据。",
graph_id,
)
return stats