171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
"""
|
||
数据库连接管理模块
|
||
|
||
统一管理所有数据库连接池:
|
||
- asyncpg Pool: 用于一般的数据库操作
|
||
- psycopg AsyncConnectionPool: 用于 LangGraph Checkpointer
|
||
"""
|
||
from typing import Optional
|
||
import asyncio
|
||
|
||
import asyncpg
|
||
from psycopg_pool import AsyncConnectionPool
|
||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||
|
||
from core.config import settings
|
||
from core.graph_metadata import ensure_graph_metadata, reset_graph_metadata
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 全局数据库连接池
|
||
_asyncpg_pool: Optional[asyncpg.Pool] = None
|
||
_psycopg_pool: Optional[AsyncConnectionPool] = None
|
||
_checkpointer: Optional[AsyncPostgresSaver] = None
|
||
|
||
|
||
async def get_db_pool() -> asyncpg.Pool:
|
||
"""
|
||
获取或创建 asyncpg 数据库连接池
|
||
|
||
用于一般的数据库 CRUD 操作。
|
||
"""
|
||
global _asyncpg_pool
|
||
|
||
if _asyncpg_pool is None:
|
||
logger.info(f"初始化 asyncpg 数据库连接池: {settings.db_user}@{settings.db_host}:{settings.db_port}/{settings.db_name}")
|
||
|
||
max_retries = 3
|
||
retry_delay = 2 # 秒
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
_asyncpg_pool = await asyncpg.create_pool(
|
||
host=settings.db_host,
|
||
port=settings.db_port,
|
||
database=settings.db_name,
|
||
user=settings.db_user,
|
||
password=settings.db_password,
|
||
min_size=settings.db_pool_min_size,
|
||
max_size=settings.db_pool_max_size,
|
||
command_timeout=settings.db_command_timeout,
|
||
timeout=30, # 连接超时 30 秒
|
||
server_settings={
|
||
'application_name': 'huoyan-enterprise',
|
||
'jit': 'off' # 禁用 JIT 以提高稳定性
|
||
}
|
||
)
|
||
|
||
# 测试连接
|
||
async with _asyncpg_pool.acquire() as _conn:
|
||
await _conn.execute("SELECT 1")
|
||
await ensure_graph_metadata(_conn)
|
||
|
||
logger.info("asyncpg 数据库连接池初始化成功")
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"asyncpg 数据库连接池初始化失败 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||
|
||
if _asyncpg_pool is not None:
|
||
try:
|
||
await _asyncpg_pool.close()
|
||
except:
|
||
pass
|
||
_asyncpg_pool = None
|
||
|
||
if attempt < max_retries - 1:
|
||
logger.info(f"将在 {retry_delay} 秒后重试...")
|
||
await asyncio.sleep(retry_delay)
|
||
retry_delay *= 2 # 指数退避
|
||
else:
|
||
logger.error("数据库连接池初始化失败,已达到最大重试次数")
|
||
raise
|
||
|
||
return _asyncpg_pool
|
||
|
||
|
||
async def get_checkpointer() -> AsyncPostgresSaver:
|
||
"""
|
||
获取或创建 LangGraph Checkpointer
|
||
|
||
使用 psycopg AsyncConnectionPool,用于 LangGraph 的状态持久化。
|
||
"""
|
||
global _psycopg_pool, _checkpointer
|
||
|
||
if _checkpointer is None:
|
||
logger.info("初始化 psycopg 连接池和 Checkpointer...")
|
||
|
||
max_retries = 3
|
||
retry_delay = 2 # 秒
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
_psycopg_pool = AsyncConnectionPool(
|
||
conninfo=settings.db_uri_psycopg,
|
||
max_size=settings.checkpointer_pool_max_size,
|
||
open=False,
|
||
timeout=30, # 连接超时 30 秒
|
||
kwargs={
|
||
"autocommit": True,
|
||
"prepare_threshold": 0
|
||
},
|
||
)
|
||
await _psycopg_pool.open()
|
||
|
||
_checkpointer = AsyncPostgresSaver(_psycopg_pool)
|
||
await _checkpointer.setup()
|
||
|
||
logger.info("Checkpointer 初始化成功")
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"Checkpointer 初始化失败 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||
|
||
if _psycopg_pool is not None:
|
||
try:
|
||
await _psycopg_pool.close()
|
||
except:
|
||
pass
|
||
_psycopg_pool = None
|
||
_checkpointer = None
|
||
|
||
if attempt < max_retries - 1:
|
||
logger.info(f"将在 {retry_delay} 秒后重试...")
|
||
await asyncio.sleep(retry_delay)
|
||
retry_delay *= 2 # 指数退避
|
||
else:
|
||
logger.error("Checkpointer 初始化失败,已达到最大重试次数")
|
||
raise
|
||
|
||
return _checkpointer
|
||
|
||
|
||
async def close_db_pool():
|
||
"""关闭所有数据库连接池"""
|
||
global _asyncpg_pool, _psycopg_pool, _checkpointer
|
||
|
||
# 关闭 asyncpg 连接池
|
||
if _asyncpg_pool is not None:
|
||
logger.info("关闭 asyncpg 数据库连接池...")
|
||
await _asyncpg_pool.close()
|
||
_asyncpg_pool = None
|
||
reset_graph_metadata()
|
||
logger.info("asyncpg 数据库连接池已关闭")
|
||
|
||
# 关闭 psycopg 连接池
|
||
if _psycopg_pool is not None:
|
||
logger.info("关闭 psycopg 连接池...")
|
||
await _psycopg_pool.close()
|
||
_psycopg_pool = None
|
||
_checkpointer = None
|
||
logger.info("psycopg 连接池已关闭")
|
||
|
||
|
||
async def get_db_connection():
|
||
"""获取数据库连接(用于依赖注入)"""
|
||
pool = await get_db_pool()
|
||
async with pool.acquire() as connection:
|
||
yield connection
|
||
|