huoyan-enterprise/backend/core/database.py

171 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库连接管理模块
统一管理所有数据库连接池:
- asyncpg Pool: 用于一般的数据库操作
- psycopg AsyncConnectionPool: 用于 LangGraph Checkpointer
"""
from typing import Optional
import asyncio
import asyncpg
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from core.config import settings
from core.graph_metadata import ensure_graph_metadata, reset_graph_metadata
from logger.logging import get_logger
logger = get_logger(__name__)
# 全局数据库连接池
_asyncpg_pool: Optional[asyncpg.Pool] = None
_psycopg_pool: Optional[AsyncConnectionPool] = None
_checkpointer: Optional[AsyncPostgresSaver] = None
async def get_db_pool() -> asyncpg.Pool:
"""
获取或创建 asyncpg 数据库连接池
用于一般的数据库 CRUD 操作。
"""
global _asyncpg_pool
if _asyncpg_pool is None:
logger.info(f"初始化 asyncpg 数据库连接池: {settings.db_user}@{settings.db_host}:{settings.db_port}/{settings.db_name}")
max_retries = 3
retry_delay = 2 # 秒
for attempt in range(max_retries):
try:
_asyncpg_pool = await asyncpg.create_pool(
host=settings.db_host,
port=settings.db_port,
database=settings.db_name,
user=settings.db_user,
password=settings.db_password,
min_size=settings.db_pool_min_size,
max_size=settings.db_pool_max_size,
command_timeout=settings.db_command_timeout,
timeout=30, # 连接超时 30 秒
server_settings={
'application_name': 'huoyan-enterprise',
'jit': 'off' # 禁用 JIT 以提高稳定性
}
)
# 测试连接
async with _asyncpg_pool.acquire() as _conn:
await _conn.execute("SELECT 1")
await ensure_graph_metadata(_conn)
logger.info("asyncpg 数据库连接池初始化成功")
break
except Exception as e:
logger.error(f"asyncpg 数据库连接池初始化失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if _asyncpg_pool is not None:
try:
await _asyncpg_pool.close()
except:
pass
_asyncpg_pool = None
if attempt < max_retries - 1:
logger.info(f"将在 {retry_delay} 秒后重试...")
await asyncio.sleep(retry_delay)
retry_delay *= 2 # 指数退避
else:
logger.error("数据库连接池初始化失败,已达到最大重试次数")
raise
return _asyncpg_pool
async def get_checkpointer() -> AsyncPostgresSaver:
"""
获取或创建 LangGraph Checkpointer
使用 psycopg AsyncConnectionPool用于 LangGraph 的状态持久化。
"""
global _psycopg_pool, _checkpointer
if _checkpointer is None:
logger.info("初始化 psycopg 连接池和 Checkpointer...")
max_retries = 3
retry_delay = 2 # 秒
for attempt in range(max_retries):
try:
_psycopg_pool = AsyncConnectionPool(
conninfo=settings.db_uri_psycopg,
max_size=settings.checkpointer_pool_max_size,
open=False,
timeout=30, # 连接超时 30 秒
kwargs={
"autocommit": True,
"prepare_threshold": 0
},
)
await _psycopg_pool.open()
_checkpointer = AsyncPostgresSaver(_psycopg_pool)
await _checkpointer.setup()
logger.info("Checkpointer 初始化成功")
break
except Exception as e:
logger.error(f"Checkpointer 初始化失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if _psycopg_pool is not None:
try:
await _psycopg_pool.close()
except:
pass
_psycopg_pool = None
_checkpointer = None
if attempt < max_retries - 1:
logger.info(f"将在 {retry_delay} 秒后重试...")
await asyncio.sleep(retry_delay)
retry_delay *= 2 # 指数退避
else:
logger.error("Checkpointer 初始化失败,已达到最大重试次数")
raise
return _checkpointer
async def close_db_pool():
"""关闭所有数据库连接池"""
global _asyncpg_pool, _psycopg_pool, _checkpointer
# 关闭 asyncpg 连接池
if _asyncpg_pool is not None:
logger.info("关闭 asyncpg 数据库连接池...")
await _asyncpg_pool.close()
_asyncpg_pool = None
reset_graph_metadata()
logger.info("asyncpg 数据库连接池已关闭")
# 关闭 psycopg 连接池
if _psycopg_pool is not None:
logger.info("关闭 psycopg 连接池...")
await _psycopg_pool.close()
_psycopg_pool = None
_checkpointer = None
logger.info("psycopg 连接池已关闭")
async def get_db_connection():
"""获取数据库连接(用于依赖注入)"""
pool = await get_db_pool()
async with pool.acquire() as connection:
yield connection