169 lines
6.0 KiB
Python
169 lines
6.0 KiB
Python
"""
|
||
图谱元数据表名(graphs / star_graph)与 chat_threads 知识图谱外键列名兼容。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from typing import Final, Optional
|
||
|
||
import asyncpg
|
||
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
_ALLOWED_TABLES: Final[frozenset[str]] = frozenset({"graphs", "star_graph"})
|
||
_ALLOWED_KG_COLS: Final[frozenset[str]] = frozenset({"knowledge_graph_id", "novel_graph_id"})
|
||
|
||
_lock = asyncio.Lock()
|
||
_ready: bool = False
|
||
|
||
GRAPH_TABLE: str = "graphs"
|
||
# None = 未探测(或库表无知识图谱外键列);ensure_graph_metadata 后会设为实际列名或保持 None
|
||
CHAT_THREAD_KG_COLUMN: Optional[str] = None
|
||
# chat_threads 是否存在 ip 列(应用 INSERT 会话时写入;无此列则省略,避免 INSERT 失败导致会话列表永远为空)
|
||
CHAT_THREADS_HAS_IP_COLUMN: bool = False
|
||
# chat_threads 是否存在 llm_provider / llm_model(记录会话最近选用模型;见 migrations/add_chat_threads_llm_columns.sql)
|
||
CHAT_THREADS_HAS_LLM_COLUMNS: bool = False
|
||
|
||
|
||
def graph_table_sql() -> str:
|
||
if GRAPH_TABLE not in _ALLOWED_TABLES:
|
||
raise RuntimeError(f"invalid GRAPH_TABLE: {GRAPH_TABLE!r}")
|
||
return GRAPH_TABLE
|
||
|
||
|
||
def chat_thread_kg_column_sql() -> str:
|
||
"""返回 chat_threads 上绑定图谱的列名;若库中无该列则抛错(仅用于确需写入该列的路径)。"""
|
||
if CHAT_THREAD_KG_COLUMN is None:
|
||
raise RuntimeError(
|
||
"chat_threads 缺少 knowledge_graph_id / novel_graph_id 列,请执行 migrations/knowledge_graph_and_processing.sql"
|
||
)
|
||
if CHAT_THREAD_KG_COLUMN not in _ALLOWED_KG_COLS:
|
||
raise RuntimeError(f"invalid CHAT_THREAD_KG_COLUMN: {CHAT_THREAD_KG_COLUMN!r}")
|
||
return CHAT_THREAD_KG_COLUMN
|
||
|
||
|
||
def chat_thread_kg_select_fragment_sql() -> str:
|
||
"""用于 SELECT 列表:无物理列时返回 NULL,避免引用不存在的列导致会话列表等接口 500。"""
|
||
if CHAT_THREAD_KG_COLUMN is None:
|
||
return "NULL::integer AS knowledge_graph_id"
|
||
if CHAT_THREAD_KG_COLUMN not in _ALLOWED_KG_COLS:
|
||
raise RuntimeError(f"invalid CHAT_THREAD_KG_COLUMN: {CHAT_THREAD_KG_COLUMN!r}")
|
||
return f"{CHAT_THREAD_KG_COLUMN} AS knowledge_graph_id"
|
||
|
||
|
||
def chat_threads_has_kg_column() -> bool:
|
||
return CHAT_THREAD_KG_COLUMN is not None
|
||
|
||
|
||
def chat_threads_has_ip_column() -> bool:
|
||
return CHAT_THREADS_HAS_IP_COLUMN
|
||
|
||
|
||
def chat_threads_has_llm_columns() -> bool:
|
||
return CHAT_THREADS_HAS_LLM_COLUMNS
|
||
|
||
|
||
def chat_thread_llm_select_fragment_sql() -> str:
|
||
"""SELECT 列表片段:无列时返回 NULL,避免未迁移库 500。"""
|
||
if CHAT_THREADS_HAS_LLM_COLUMNS:
|
||
return "llm_provider, llm_model"
|
||
return "NULL::varchar AS llm_provider, NULL::varchar AS llm_model"
|
||
|
||
|
||
async def ensure_graph_metadata(conn: asyncpg.Connection) -> None:
|
||
"""首次连接数据库时解析表名与 chat_threads 列名(仅白名单)。"""
|
||
global _ready, GRAPH_TABLE, CHAT_THREAD_KG_COLUMN, CHAT_THREADS_HAS_IP_COLUMN, CHAT_THREADS_HAS_LLM_COLUMNS
|
||
if _ready:
|
||
return
|
||
async with _lock:
|
||
if _ready:
|
||
return
|
||
has_g = await conn.fetchval(
|
||
"""
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.tables
|
||
WHERE table_schema = 'public' AND table_name = 'graphs'
|
||
)
|
||
"""
|
||
)
|
||
has_s = await conn.fetchval(
|
||
"""
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.tables
|
||
WHERE table_schema = 'public' AND table_name = 'star_graph'
|
||
)
|
||
"""
|
||
)
|
||
if has_g:
|
||
GRAPH_TABLE = "graphs"
|
||
elif has_s:
|
||
GRAPH_TABLE = "star_graph"
|
||
logger.info("图谱元数据表使用 PostgreSQL 表名 star_graph,建议统一为 graphs")
|
||
else:
|
||
GRAPH_TABLE = "graphs"
|
||
logger.warning("未找到 public.graphs 或 public.star_graph,请先执行数据库迁移")
|
||
|
||
has_kg = await conn.fetchval(
|
||
"""
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.columns
|
||
WHERE table_schema = 'public' AND table_name = 'chat_threads'
|
||
AND column_name = 'knowledge_graph_id'
|
||
)
|
||
"""
|
||
)
|
||
has_ng = await conn.fetchval(
|
||
"""
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.columns
|
||
WHERE table_schema = 'public' AND table_name = 'chat_threads'
|
||
AND column_name = 'novel_graph_id'
|
||
)
|
||
"""
|
||
)
|
||
if has_kg:
|
||
CHAT_THREAD_KG_COLUMN = "knowledge_graph_id"
|
||
elif has_ng:
|
||
CHAT_THREAD_KG_COLUMN = "novel_graph_id"
|
||
logger.info("chat_threads 使用列 novel_graph_id,可迁移为 knowledge_graph_id")
|
||
else:
|
||
CHAT_THREAD_KG_COLUMN = None
|
||
logger.warning(
|
||
"chat_threads 未找到 knowledge_graph_id / novel_graph_id;会话列表仍可查询,图谱绑定需执行迁移"
|
||
)
|
||
|
||
_has_ip = await conn.fetchval(
|
||
"""
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.columns
|
||
WHERE table_schema = 'public' AND table_name = 'chat_threads'
|
||
AND column_name = 'ip'
|
||
)
|
||
"""
|
||
)
|
||
CHAT_THREADS_HAS_IP_COLUMN = bool(_has_ip)
|
||
|
||
_has_llm = await conn.fetchval(
|
||
"""
|
||
SELECT EXISTS (
|
||
SELECT 1 FROM information_schema.columns
|
||
WHERE table_schema = 'public' AND table_name = 'chat_threads'
|
||
AND column_name = 'llm_provider'
|
||
)
|
||
"""
|
||
)
|
||
CHAT_THREADS_HAS_LLM_COLUMNS = bool(_has_llm)
|
||
|
||
_ready = True
|
||
|
||
|
||
def reset_graph_metadata() -> None:
|
||
global _ready, GRAPH_TABLE, CHAT_THREAD_KG_COLUMN, CHAT_THREADS_HAS_IP_COLUMN, CHAT_THREADS_HAS_LLM_COLUMNS
|
||
_ready = False
|
||
GRAPH_TABLE = "graphs"
|
||
CHAT_THREAD_KG_COLUMN = None
|
||
CHAT_THREADS_HAS_IP_COLUMN = False
|
||
CHAT_THREADS_HAS_LLM_COLUMNS = False
|