125 lines
3.2 KiB
Python
125 lines
3.2 KiB
Python
"""
|
||
安全相关工具:JWT、密码加密等
|
||
"""
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Optional, Dict, Any
|
||
|
||
import bcrypt
|
||
from jose import JWTError, jwt
|
||
|
||
from core.config import settings
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# JWT 配置(从统一配置获取)
|
||
SECRET_KEY = settings.jwt_secret_key
|
||
ALGORITHM = settings.jwt_algorithm
|
||
ACCESS_TOKEN_EXPIRE_MINUTES = settings.jwt_expire_minutes
|
||
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
"""
|
||
验证密码
|
||
|
||
Args:
|
||
plain_password: 明文密码
|
||
hashed_password: 哈希后的密码
|
||
|
||
Returns:
|
||
bool: 密码是否匹配
|
||
"""
|
||
try:
|
||
# bcrypt 限制密码最大长度为 72 字节
|
||
# 如果密码超过 72 字节,需要截断(与哈希时保持一致)
|
||
password_bytes = plain_password.encode('utf-8')
|
||
if len(password_bytes) > 72:
|
||
password_bytes = password_bytes[:72]
|
||
|
||
# 使用 bcrypt 直接验证
|
||
return bcrypt.checkpw(password_bytes, hashed_password.encode('utf-8'))
|
||
except Exception as e:
|
||
logger.error(f"密码验证失败: {e}")
|
||
return False
|
||
|
||
|
||
def get_password_hash(password: str) -> str:
|
||
"""
|
||
对密码进行哈希加密
|
||
|
||
Args:
|
||
password: 明文密码
|
||
|
||
Returns:
|
||
str: 哈希后的密码
|
||
"""
|
||
# bcrypt 限制密码最大长度为 72 字节
|
||
# 如果密码超过 72 字节,需要截断
|
||
password_bytes = password.encode('utf-8')
|
||
if len(password_bytes) > 72:
|
||
password_bytes = password_bytes[:72]
|
||
|
||
# 使用 bcrypt 直接哈希,使用默认的 rounds (12)
|
||
salt = bcrypt.gensalt()
|
||
hashed = bcrypt.hashpw(password_bytes, salt)
|
||
return hashed.decode('utf-8')
|
||
|
||
|
||
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||
"""
|
||
创建 JWT access token
|
||
|
||
Args:
|
||
data: 要编码到 token 中的数据
|
||
expires_delta: token 过期时间,如果为 None 则使用默认值
|
||
|
||
Returns:
|
||
str: JWT token
|
||
"""
|
||
to_encode = data.copy()
|
||
|
||
if expires_delta:
|
||
expire = datetime.now(timezone.utc) + expires_delta
|
||
else:
|
||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
|
||
to_encode.update({"exp": expire})
|
||
|
||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||
return encoded_jwt
|
||
|
||
|
||
def decode_access_token(token: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
解码 JWT token
|
||
|
||
Args:
|
||
token: JWT token
|
||
|
||
Returns:
|
||
Optional[Dict[str, Any]]: 解码后的数据,如果 token 无效则返回 None
|
||
"""
|
||
try:
|
||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||
return payload
|
||
except JWTError as e:
|
||
logger.warning(f"JWT 解码失败: {e}")
|
||
return None
|
||
|
||
|
||
def create_token_for_user(user_id: int, username: str) -> str:
|
||
"""
|
||
为用户创建 token
|
||
|
||
Args:
|
||
user_id: 用户 ID
|
||
username: 用户名
|
||
|
||
Returns:
|
||
str: JWT token
|
||
"""
|
||
return create_access_token(
|
||
data={"sub": str(user_id), "username": username}
|
||
)
|
||
|