huoyan-enterprise/backend/core/dependencies.py

234 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI 依赖项
"""
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import asyncpg
from core.database import get_db_pool
from core.security import decode_access_token
from models.user import User
from services.user_service import UserService
from logger.logging import get_logger
logger = get_logger(__name__)
# HTTP Bearer 认证方案
security = HTTPBearer()
async def get_db() -> asyncpg.Connection:
"""获取数据库连接(依赖注入)"""
pool = await get_db_pool()
async with pool.acquire() as connection:
yield connection
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
conn: asyncpg.Connection = Depends(get_db)
) -> User:
"""
获取当前登录用户(必须登录)
Args:
credentials: HTTP Bearer 认证凭证
conn: 数据库连接
Returns:
User: 当前登录的用户
Raises:
HTTPException: 如果 token 无效或用户不存在
"""
token = credentials.credentials
# 解码 token
payload = decode_access_token(token)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
# 从 payload 中获取用户 ID
user_id_str = payload.get("sub")
if user_id_str is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
try:
user_id = int(user_id_str)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
# 从数据库获取用户
user = await UserService.get_user_by_id(conn, user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
headers={"WWW-Authenticate": "Bearer"},
)
# 检查用户是否激活
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用",
)
return user
async def get_current_admin_user(
current_user: User = Depends(get_current_user),
) -> User:
"""仅企业管理员role=admin可访问后台管理接口。"""
if getattr(current_user, "role", None) != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要企业管理员权限",
)
return current_user
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)),
conn: asyncpg.Connection = Depends(get_db)
) -> Optional[User]:
"""
获取当前登录用户(可选,不强制登录)
Args:
credentials: HTTP Bearer 认证凭证(可选)
conn: 数据库连接
Returns:
Optional[User]: 当前登录的用户,如果未登录则返回 None
"""
if credentials is None:
return None
try:
token = credentials.credentials
# 解码 token
payload = decode_access_token(token)
if payload is None:
return None
# 从 payload 中获取用户 ID
user_id_str = payload.get("sub")
if user_id_str is None:
return None
try:
user_id = int(user_id_str)
except ValueError:
return None
# 从数据库获取用户
user = await UserService.get_user_by_id(conn, user_id)
if user is None or not user.is_active:
return None
return user
except Exception as e:
logger.warning(f"获取当前用户时发生错误: {e}")
return None
# 审核服务单例实例
_moderation_service: Optional["ModerationService"] = None
async def get_moderation_service():
"""
获取或创建审核服务实例(依赖注入)
实现单例模式,复用 ModerationService 实例以提高性能。
行为:
- 如果 MODERATION_ENABLED 为 False返回 NoOpModerationService空操作实现
- 如果 MODERATION_ENABLED 为 True验证凭证并返回 ModerationService 实例
- 使用全局变量缓存服务实例,避免重复创建
Returns:
ModerationService 或 NoOpModerationService: 审核服务实例
Raises:
RuntimeError: 如果审核已启用但凭证配置缺失
Example:
>>> @router.post("/chat/completion")
>>> async def chat_completion(
>>> moderation_service = Depends(get_moderation_service)
>>> ):
>>> result = await moderation_service.moderate_text(text)
"""
global _moderation_service
# 导入配置和服务(延迟导入避免循环依赖)
from core.config import get_settings
from services.moderation_service import ModerationService, NoOpModerationService
settings = get_settings()
# 如果审核功能被禁用,返回空操作服务
if not settings.moderation_enabled:
logger.info("审核功能已禁用 - 返回 NoOpModerationService")
return NoOpModerationService()
# 如果服务实例尚未创建,创建新实例
if _moderation_service is None:
# 验证必需的凭证配置
if not settings.aliyun_access_key_id:
error_msg = (
"审核服务配置错误: ALIYUN_ACCESS_KEY_ID 未设置。"
"请在 .env 文件中配置 ALIYUN_ACCESS_KEY_ID"
"或设置 MODERATION_ENABLED=false 禁用审核功能。"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
if not settings.aliyun_access_key_secret:
error_msg = (
"审核服务配置错误: ALIYUN_ACCESS_KEY_SECRET 未设置。"
"请在 .env 文件中配置 ALIYUN_ACCESS_KEY_SECRET"
"或设置 MODERATION_ENABLED=false 禁用审核功能。"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
# 创建审核服务实例
_moderation_service = ModerationService(
access_key_id=settings.aliyun_access_key_id,
access_key_secret=settings.aliyun_access_key_secret,
region=settings.aliyun_moderation_region,
timeout=settings.moderation_timeout_seconds,
service_type=settings.moderation_service_type,
image_service_type=settings.image_moderation_service_type
)
logger.info(
f"审核服务实例已创建(增强版)- 区域: {settings.aliyun_moderation_region}, "
f"文本服务类型: {settings.moderation_service_type}, "
f"图片服务类型: {settings.image_moderation_service_type}, "
f"超时: {settings.moderation_timeout_seconds}"
)
return _moderation_service