huoyan-enterprise/backend/core/graph_metadata.py

169 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图谱元数据表名graphs / star_graph与 chat_threads 知识图谱外键列名兼容。
"""
from __future__ import annotations
import asyncio
from typing import Final, Optional
import asyncpg
from logger.logging import get_logger
logger = get_logger(__name__)
_ALLOWED_TABLES: Final[frozenset[str]] = frozenset({"graphs", "star_graph"})
_ALLOWED_KG_COLS: Final[frozenset[str]] = frozenset({"knowledge_graph_id", "novel_graph_id"})
_lock = asyncio.Lock()
_ready: bool = False
GRAPH_TABLE: str = "graphs"
# None = 未探测或库表无知识图谱外键列ensure_graph_metadata 后会设为实际列名或保持 None
CHAT_THREAD_KG_COLUMN: Optional[str] = None
# chat_threads 是否存在 ip 列(应用 INSERT 会话时写入;无此列则省略,避免 INSERT 失败导致会话列表永远为空)
CHAT_THREADS_HAS_IP_COLUMN: bool = False
# chat_threads 是否存在 llm_provider / llm_model记录会话最近选用模型见 migrations/add_chat_threads_llm_columns.sql
CHAT_THREADS_HAS_LLM_COLUMNS: bool = False
def graph_table_sql() -> str:
if GRAPH_TABLE not in _ALLOWED_TABLES:
raise RuntimeError(f"invalid GRAPH_TABLE: {GRAPH_TABLE!r}")
return GRAPH_TABLE
def chat_thread_kg_column_sql() -> str:
"""返回 chat_threads 上绑定图谱的列名;若库中无该列则抛错(仅用于确需写入该列的路径)。"""
if CHAT_THREAD_KG_COLUMN is None:
raise RuntimeError(
"chat_threads 缺少 knowledge_graph_id / novel_graph_id 列,请执行 migrations/knowledge_graph_and_processing.sql"
)
if CHAT_THREAD_KG_COLUMN not in _ALLOWED_KG_COLS:
raise RuntimeError(f"invalid CHAT_THREAD_KG_COLUMN: {CHAT_THREAD_KG_COLUMN!r}")
return CHAT_THREAD_KG_COLUMN
def chat_thread_kg_select_fragment_sql() -> str:
"""用于 SELECT 列表:无物理列时返回 NULL避免引用不存在的列导致会话列表等接口 500。"""
if CHAT_THREAD_KG_COLUMN is None:
return "NULL::integer AS knowledge_graph_id"
if CHAT_THREAD_KG_COLUMN not in _ALLOWED_KG_COLS:
raise RuntimeError(f"invalid CHAT_THREAD_KG_COLUMN: {CHAT_THREAD_KG_COLUMN!r}")
return f"{CHAT_THREAD_KG_COLUMN} AS knowledge_graph_id"
def chat_threads_has_kg_column() -> bool:
return CHAT_THREAD_KG_COLUMN is not None
def chat_threads_has_ip_column() -> bool:
return CHAT_THREADS_HAS_IP_COLUMN
def chat_threads_has_llm_columns() -> bool:
return CHAT_THREADS_HAS_LLM_COLUMNS
def chat_thread_llm_select_fragment_sql() -> str:
"""SELECT 列表片段:无列时返回 NULL避免未迁移库 500。"""
if CHAT_THREADS_HAS_LLM_COLUMNS:
return "llm_provider, llm_model"
return "NULL::varchar AS llm_provider, NULL::varchar AS llm_model"
async def ensure_graph_metadata(conn: asyncpg.Connection) -> None:
"""首次连接数据库时解析表名与 chat_threads 列名(仅白名单)。"""
global _ready, GRAPH_TABLE, CHAT_THREAD_KG_COLUMN, CHAT_THREADS_HAS_IP_COLUMN, CHAT_THREADS_HAS_LLM_COLUMNS
if _ready:
return
async with _lock:
if _ready:
return
has_g = await conn.fetchval(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = 'graphs'
)
"""
)
has_s = await conn.fetchval(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = 'star_graph'
)
"""
)
if has_g:
GRAPH_TABLE = "graphs"
elif has_s:
GRAPH_TABLE = "star_graph"
logger.info("图谱元数据表使用 PostgreSQL 表名 star_graph建议统一为 graphs")
else:
GRAPH_TABLE = "graphs"
logger.warning("未找到 public.graphs 或 public.star_graph请先执行数据库迁移")
has_kg = await conn.fetchval(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'chat_threads'
AND column_name = 'knowledge_graph_id'
)
"""
)
has_ng = await conn.fetchval(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'chat_threads'
AND column_name = 'novel_graph_id'
)
"""
)
if has_kg:
CHAT_THREAD_KG_COLUMN = "knowledge_graph_id"
elif has_ng:
CHAT_THREAD_KG_COLUMN = "novel_graph_id"
logger.info("chat_threads 使用列 novel_graph_id可迁移为 knowledge_graph_id")
else:
CHAT_THREAD_KG_COLUMN = None
logger.warning(
"chat_threads 未找到 knowledge_graph_id / novel_graph_id会话列表仍可查询图谱绑定需执行迁移"
)
_has_ip = await conn.fetchval(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'chat_threads'
AND column_name = 'ip'
)
"""
)
CHAT_THREADS_HAS_IP_COLUMN = bool(_has_ip)
_has_llm = await conn.fetchval(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'chat_threads'
AND column_name = 'llm_provider'
)
"""
)
CHAT_THREADS_HAS_LLM_COLUMNS = bool(_has_llm)
_ready = True
def reset_graph_metadata() -> None:
global _ready, GRAPH_TABLE, CHAT_THREAD_KG_COLUMN, CHAT_THREADS_HAS_IP_COLUMN, CHAT_THREADS_HAS_LLM_COLUMNS
_ready = False
GRAPH_TABLE = "graphs"
CHAT_THREAD_KG_COLUMN = None
CHAT_THREADS_HAS_IP_COLUMN = False
CHAT_THREADS_HAS_LLM_COLUMNS = False