413 lines
12 KiB
Python
413 lines
12 KiB
Python
"""
|
||
认证相关 API 路由
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||
import asyncpg
|
||
|
||
from core.dependencies import get_db, get_current_user
|
||
from core.config import settings
|
||
from core.security import create_token_for_user
|
||
from models.user import (
|
||
User, UserCreate, UserLogin, UserResponse, TokenResponse,
|
||
PhoneRegisterRequest, PhoneLoginRequest, SendSmsCodeRequest, WechatLoginRequest
|
||
)
|
||
from services.user_service import UserService
|
||
from services.sms_service import SmsService
|
||
from services.wechat_service import WechatService
|
||
from services.captcha_service import CaptchaService
|
||
from utils.helpers import BaseResponse
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 创建认证路由
|
||
auth_router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||
|
||
|
||
def get_client_ip(request: Request) -> str:
|
||
"""
|
||
获取客户端 IP 地址
|
||
|
||
Args:
|
||
request: FastAPI Request 对象
|
||
|
||
Returns:
|
||
str: 客户端 IP 地址
|
||
"""
|
||
# 优先从 X-Forwarded-For 获取(代理/负载均衡场景)
|
||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||
if forwarded_for:
|
||
return forwarded_for.split(",")[0].strip()
|
||
|
||
# 从 X-Real-IP 获取
|
||
real_ip = request.headers.get("X-Real-IP")
|
||
if real_ip:
|
||
return real_ip
|
||
|
||
# 直接从 client 获取
|
||
if request.client:
|
||
return request.client.host
|
||
|
||
return "unknown"
|
||
|
||
|
||
@auth_router.get("/captcha/generate", summary="生成图形验证码")
|
||
async def generate_captcha(request: Request):
|
||
"""
|
||
生成图形验证码
|
||
|
||
Returns:
|
||
{
|
||
"captcha_id": str, # 验证码唯一标识
|
||
"image": str, # Base64 编码的图片(data URL 格式)
|
||
"expires_in": int # 过期时间(秒)
|
||
}
|
||
"""
|
||
# 获取客户端 IP
|
||
client_ip = get_client_ip(request)
|
||
|
||
# 检查 IP 是否被封禁
|
||
if await CaptchaService.check_ban(client_ip):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="操作过于频繁,请10分钟后再试"
|
||
)
|
||
|
||
# 检查请求频率限制
|
||
if await CaptchaService.check_rate_limit(client_ip):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||
detail="请求过于频繁,请稍后再试"
|
||
)
|
||
|
||
# 生成验证码
|
||
try:
|
||
result = await CaptchaService.generate_captcha(client_ip)
|
||
return result
|
||
except RuntimeError as e:
|
||
# 字体加载失败的特定错误
|
||
logger.error(f"验证码字体加载失败 [IP: {client_ip}]: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="验证码服务暂时不可用,请联系管理员"
|
||
)
|
||
except Exception as e:
|
||
# 其他未预期的错误
|
||
logger.exception(f"生成验证码失败 [IP: {client_ip}]: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="验证码生成失败,请稍后重试"
|
||
)
|
||
|
||
|
||
@auth_router.post("/register", response_model=TokenResponse, summary="用户注册")
|
||
async def register(
|
||
user_data: UserCreate,
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""
|
||
用户注册接口
|
||
|
||
Args:
|
||
user_data: 用户注册信息
|
||
conn: 数据库连接
|
||
|
||
Returns:
|
||
TokenResponse: 包含 token 和用户信息的响应
|
||
"""
|
||
if not settings.enable_public_register:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="当前未开放自助注册,请联系管理员开通账号",
|
||
)
|
||
# 检查用户名是否已存在
|
||
existing_user = await UserService.get_user_by_username(conn, user_data.username)
|
||
if existing_user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="用户名已存在"
|
||
)
|
||
|
||
# 检查邮箱是否已存在
|
||
existing_email = await UserService.get_user_by_email(conn, user_data.email)
|
||
if existing_email:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="邮箱已被注册"
|
||
)
|
||
|
||
# 创建用户
|
||
try:
|
||
user = await UserService.create_user(conn, user_data)
|
||
except Exception as e:
|
||
logger.exception(f"创建用户失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="创建用户失败"
|
||
)
|
||
|
||
# 生成 token
|
||
access_token = create_token_for_user(user.id, user.username)
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
user=UserResponse(**user.dict())
|
||
)
|
||
|
||
|
||
@auth_router.post("/login", response_model=TokenResponse, summary="用户登录")
|
||
async def login(
|
||
login_data: UserLogin,
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""
|
||
用户登录接口
|
||
|
||
Args:
|
||
login_data: 用户登录信息
|
||
conn: 数据库连接
|
||
|
||
Returns:
|
||
TokenResponse: 包含 token 和用户信息的响应
|
||
"""
|
||
# 验证用户
|
||
user = await UserService.authenticate_user(
|
||
conn,
|
||
login_data.username,
|
||
login_data.password
|
||
)
|
||
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户名或密码错误"
|
||
)
|
||
|
||
# 生成 token
|
||
access_token = create_token_for_user(user.id, user.username)
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
user=UserResponse(**user.dict())
|
||
)
|
||
|
||
|
||
@auth_router.get("/me", response_model=UserResponse, summary="获取当前用户信息")
|
||
async def get_me(current_user: User = Depends(get_current_user)):
|
||
"""
|
||
获取当前登录用户信息
|
||
|
||
Args:
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
UserResponse: 用户信息
|
||
"""
|
||
return UserResponse(**current_user.dict())
|
||
|
||
|
||
# ==================== 手机号注册/登录接口 ====================
|
||
|
||
@auth_router.post("/sms/send", response_model=BaseResponse, summary="发送短信验证码")
|
||
async def send_sms_code(request: SendSmsCodeRequest, http_request: Request):
|
||
"""
|
||
发送短信验证码(需要先验证图形验证码)
|
||
|
||
Args:
|
||
request: 包含手机号、场景、图形验证码 ID 和验证码
|
||
http_request: FastAPI Request 对象,用于获取 IP
|
||
|
||
Returns:
|
||
BaseResponse: 发送结果
|
||
"""
|
||
# 获取客户端 IP
|
||
client_ip = get_client_ip(http_request)
|
||
|
||
# 验证图形验证码
|
||
is_valid = await CaptchaService.verify_captcha(request.captcha_id, request.captcha_code)
|
||
|
||
if not is_valid:
|
||
# 记录验证失败
|
||
await CaptchaService.record_fail(client_ip)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="图形验证码错误或已过期"
|
||
)
|
||
|
||
# 图形验证码验证成功,发送短信验证码
|
||
result = await SmsService.send_code(request.phone, request.scene)
|
||
|
||
if not result["success"]:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=result["message"]
|
||
)
|
||
|
||
return BaseResponse(code=200, msg=result["message"])
|
||
|
||
|
||
@auth_router.post("/phone/register", response_model=TokenResponse, summary="手机号注册")
|
||
async def phone_register(
|
||
request: PhoneRegisterRequest,
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""手机号注册"""
|
||
if not settings.enable_public_register:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="当前未开放自助注册,请联系管理员开通账号",
|
||
)
|
||
# 验证短信验证码
|
||
if not await SmsService.verify_code(request.phone, request.code, "register"):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="验证码错误或已过期"
|
||
)
|
||
|
||
# 检查手机号是否已注册
|
||
existing_user = await UserService.get_user_by_phone(conn, request.phone)
|
||
if existing_user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="该手机号已注册"
|
||
)
|
||
|
||
# 创建用户
|
||
try:
|
||
user = await UserService.create_user_by_phone(
|
||
conn,
|
||
phone=request.phone,
|
||
password=request.password,
|
||
username=request.username
|
||
)
|
||
except ValueError as e:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=str(e)
|
||
)
|
||
except Exception as e:
|
||
logger.exception(f"创建用户失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="创建用户失败"
|
||
)
|
||
|
||
# 生成 token
|
||
access_token = create_token_for_user(user.id, user.username)
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
user=UserResponse(**user.dict())
|
||
)
|
||
|
||
|
||
@auth_router.post("/phone/login", response_model=TokenResponse, summary="手机号登录")
|
||
async def phone_login(
|
||
request: PhoneLoginRequest,
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""
|
||
手机号登录(未注册自动注册)
|
||
|
||
支持两种方式:
|
||
1. 手机号 + 验证码(未注册自动注册)
|
||
2. 手机号 + 密码
|
||
"""
|
||
user = None
|
||
|
||
if request.code:
|
||
# 验证码登录
|
||
if not await SmsService.verify_code(request.phone, request.code, "login"):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="验证码错误或已过期"
|
||
)
|
||
user = await UserService.get_user_by_phone(conn, request.phone)
|
||
if not user:
|
||
# 未注册,自动创建用户(不设置密码)
|
||
user = await UserService.create_user_by_phone_without_password(conn, request.phone)
|
||
logger.info(f"手机号自动注册: phone={request.phone}")
|
||
else:
|
||
await UserService.update_last_login(conn, user.id)
|
||
elif request.password:
|
||
# 密码登录
|
||
user = await UserService.authenticate_by_phone_password(
|
||
conn, request.phone, request.password
|
||
)
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户不存在或密码错误"
|
||
)
|
||
else:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="请提供验证码或密码"
|
||
)
|
||
|
||
# 生成 token
|
||
access_token = create_token_for_user(user.id, user.username)
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
user=UserResponse(**user.dict())
|
||
)
|
||
|
||
|
||
# ==================== 微信小程序登录接口 ====================
|
||
|
||
@auth_router.post("/wechat/login", response_model=TokenResponse, summary="微信小程序登录")
|
||
async def wechat_login(
|
||
request: WechatLoginRequest,
|
||
conn: asyncpg.Connection = Depends(get_db)
|
||
):
|
||
"""
|
||
微信小程序登录
|
||
|
||
支持账号合并:如果传入 phone_code,会获取用户手机号,
|
||
若该手机号已有账号则自动绑定,实现多登录方式共享账号。
|
||
"""
|
||
# 获取微信 session
|
||
session_data = await WechatService.code2session(request.code)
|
||
|
||
if not session_data or not session_data.get("openid"):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="微信登录失败"
|
||
)
|
||
|
||
# 如果传入 phone_code,获取用户手机号用于账号合并
|
||
phone = None
|
||
if request.phone_code:
|
||
phone = await WechatService.get_phone_number(request.phone_code)
|
||
if phone:
|
||
logger.info(f"微信登录获取到手机号: {phone[:3]}****{phone[-4:]}")
|
||
|
||
# 创建或更新用户(支持账号合并)
|
||
try:
|
||
user = await UserService.create_or_update_wechat_user(
|
||
conn,
|
||
openid=session_data["openid"],
|
||
unionid=session_data.get("unionid"),
|
||
phone=phone
|
||
)
|
||
except Exception as e:
|
||
logger.exception(f"创建或更新微信用户失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="创建或更新用户失败"
|
||
)
|
||
|
||
# 生成 token
|
||
access_token = create_token_for_user(user.id, user.username)
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
user=UserResponse(**user.dict())
|
||
)
|
||
|
||
|
||
# 导出路由
|
||
__all__ = ["auth_router"]
|
||
|