""" 数据库连接管理模块 统一管理所有数据库连接池: - asyncpg Pool: 用于一般的数据库操作 - psycopg AsyncConnectionPool: 用于 LangGraph Checkpointer """ from typing import Optional import asyncio import asyncpg from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from core.config import settings from core.graph_metadata import ensure_graph_metadata, reset_graph_metadata from logger.logging import get_logger logger = get_logger(__name__) # 全局数据库连接池 _asyncpg_pool: Optional[asyncpg.Pool] = None _psycopg_pool: Optional[AsyncConnectionPool] = None _checkpointer: Optional[AsyncPostgresSaver] = None async def get_db_pool() -> asyncpg.Pool: """ 获取或创建 asyncpg 数据库连接池 用于一般的数据库 CRUD 操作。 """ global _asyncpg_pool if _asyncpg_pool is None: logger.info(f"初始化 asyncpg 数据库连接池: {settings.db_user}@{settings.db_host}:{settings.db_port}/{settings.db_name}") max_retries = 3 retry_delay = 2 # 秒 for attempt in range(max_retries): try: _asyncpg_pool = await asyncpg.create_pool( host=settings.db_host, port=settings.db_port, database=settings.db_name, user=settings.db_user, password=settings.db_password, min_size=settings.db_pool_min_size, max_size=settings.db_pool_max_size, command_timeout=settings.db_command_timeout, timeout=30, # 连接超时 30 秒 server_settings={ 'application_name': 'huoyan-enterprise', 'jit': 'off' # 禁用 JIT 以提高稳定性 } ) # 测试连接 async with _asyncpg_pool.acquire() as _conn: await _conn.execute("SELECT 1") await ensure_graph_metadata(_conn) logger.info("asyncpg 数据库连接池初始化成功") break except Exception as e: logger.error(f"asyncpg 数据库连接池初始化失败 (尝试 {attempt + 1}/{max_retries}): {e}") if _asyncpg_pool is not None: try: await _asyncpg_pool.close() except: pass _asyncpg_pool = None if attempt < max_retries - 1: logger.info(f"将在 {retry_delay} 秒后重试...") await asyncio.sleep(retry_delay) retry_delay *= 2 # 指数退避 else: logger.error("数据库连接池初始化失败,已达到最大重试次数") raise return _asyncpg_pool async def get_checkpointer() -> AsyncPostgresSaver: """ 获取或创建 LangGraph Checkpointer 使用 psycopg AsyncConnectionPool,用于 LangGraph 的状态持久化。 """ global _psycopg_pool, _checkpointer if _checkpointer is None: logger.info("初始化 psycopg 连接池和 Checkpointer...") max_retries = 3 retry_delay = 2 # 秒 for attempt in range(max_retries): try: _psycopg_pool = AsyncConnectionPool( conninfo=settings.db_uri_psycopg, max_size=settings.checkpointer_pool_max_size, open=False, timeout=30, # 连接超时 30 秒 kwargs={ "autocommit": True, "prepare_threshold": 0 }, ) await _psycopg_pool.open() _checkpointer = AsyncPostgresSaver(_psycopg_pool) await _checkpointer.setup() logger.info("Checkpointer 初始化成功") break except Exception as e: logger.error(f"Checkpointer 初始化失败 (尝试 {attempt + 1}/{max_retries}): {e}") if _psycopg_pool is not None: try: await _psycopg_pool.close() except: pass _psycopg_pool = None _checkpointer = None if attempt < max_retries - 1: logger.info(f"将在 {retry_delay} 秒后重试...") await asyncio.sleep(retry_delay) retry_delay *= 2 # 指数退避 else: logger.error("Checkpointer 初始化失败,已达到最大重试次数") raise return _checkpointer async def close_db_pool(): """关闭所有数据库连接池""" global _asyncpg_pool, _psycopg_pool, _checkpointer # 关闭 asyncpg 连接池 if _asyncpg_pool is not None: logger.info("关闭 asyncpg 数据库连接池...") await _asyncpg_pool.close() _asyncpg_pool = None reset_graph_metadata() logger.info("asyncpg 数据库连接池已关闭") # 关闭 psycopg 连接池 if _psycopg_pool is not None: logger.info("关闭 psycopg 连接池...") await _psycopg_pool.close() _psycopg_pool = None _checkpointer = None logger.info("psycopg 连接池已关闭") async def get_db_connection(): """获取数据库连接(用于依赖注入)""" pool = await get_db_pool() async with pool.acquire() as connection: yield connection