288 lines
9.5 KiB
Python
288 lines
9.5 KiB
Python
"""后台管理员用户管理"""
|
|
from datetime import datetime, timezone
|
|
from typing import Optional, Tuple, List, Any, Dict
|
|
import asyncpg
|
|
|
|
from core.security import get_password_hash
|
|
from models.user import User
|
|
from admin.schemas import AdminUserCreate, AdminUserUpdate
|
|
from logger.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_VALID_ROLES = frozenset({"admin", "leader", "employee"})
|
|
|
|
|
|
def _validate_role(role: str) -> str:
|
|
if role not in _VALID_ROLES:
|
|
raise ValueError("role 必须是 admin、leader 或 employee")
|
|
return role
|
|
|
|
|
|
class AdminUserService:
|
|
@staticmethod
|
|
async def list_users(
|
|
conn: asyncpg.Connection,
|
|
enterprise_id: int,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
username: Optional[str] = None,
|
|
email: Optional[str] = None,
|
|
phone: Optional[str] = None,
|
|
display_name: Optional[str] = None,
|
|
department_id: Optional[int] = None,
|
|
) -> Tuple[List[dict], int]:
|
|
offset = (page - 1) * page_size
|
|
|
|
conds = ["enterprise_id = $1"]
|
|
params: List[Any] = [enterprise_id]
|
|
i = 2
|
|
|
|
if department_id is not None:
|
|
conds.append(f"department_id = ${i}")
|
|
params.append(department_id)
|
|
i += 1
|
|
|
|
uq = (username or "").strip()
|
|
if uq:
|
|
conds.append(f"username ILIKE ${i}")
|
|
params.append(f"%{uq}%")
|
|
i += 1
|
|
|
|
eq = (email or "").strip()
|
|
if eq:
|
|
conds.append(f"email ILIKE ${i}")
|
|
params.append(f"%{eq}%")
|
|
i += 1
|
|
|
|
pq = (phone or "").strip()
|
|
if pq:
|
|
conds.append(f"phone ILIKE ${i}")
|
|
params.append(f"%{pq}%")
|
|
i += 1
|
|
|
|
dq = (display_name or "").strip()
|
|
if dq:
|
|
conds.append(f"COALESCE(display_name, '') ILIKE ${i}")
|
|
params.append(f"%{dq}%")
|
|
i += 1
|
|
|
|
where_sql = " AND ".join(conds)
|
|
lim_ph = i
|
|
off_ph = i + 1
|
|
params.extend([page_size, offset])
|
|
|
|
total = await conn.fetchval(
|
|
f"SELECT COUNT(*) FROM user_list WHERE {where_sql}",
|
|
*params[:-2],
|
|
)
|
|
rows = await conn.fetch(
|
|
f"""
|
|
SELECT id, username, email, phone, display_name, enterprise_id, department_id,
|
|
role, is_active, is_first_login, allow_kb_upload, created_at, last_login_at
|
|
FROM user_list
|
|
WHERE {where_sql}
|
|
ORDER BY id DESC
|
|
LIMIT ${lim_ph} OFFSET ${off_ph}
|
|
""",
|
|
*params,
|
|
)
|
|
return [dict(r) for r in rows], int(total or 0)
|
|
|
|
@staticmethod
|
|
async def get_user(
|
|
conn: asyncpg.Connection,
|
|
enterprise_id: int,
|
|
user_id: int,
|
|
) -> Optional[dict]:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT id, username, email, phone, display_name, enterprise_id, department_id,
|
|
role, is_active, is_first_login, allow_kb_upload, created_at, last_login_at
|
|
FROM user_list
|
|
WHERE id = $1 AND enterprise_id = $2
|
|
""",
|
|
user_id,
|
|
enterprise_id,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
@staticmethod
|
|
async def create_user(
|
|
conn: asyncpg.Connection,
|
|
enterprise_id: int,
|
|
data: AdminUserCreate,
|
|
) -> dict:
|
|
_validate_role(data.role)
|
|
exists = await conn.fetchval(
|
|
"SELECT 1 FROM user_list WHERE username = $1",
|
|
data.username,
|
|
)
|
|
if exists:
|
|
raise ValueError("用户名已存在")
|
|
exists_email = await conn.fetchval(
|
|
"SELECT 1 FROM user_list WHERE email = $1",
|
|
str(data.email),
|
|
)
|
|
if exists_email:
|
|
raise ValueError("邮箱已被使用")
|
|
hashed = get_password_hash(data.password)
|
|
row = await conn.fetchrow(
|
|
"""
|
|
INSERT INTO user_list (
|
|
username, email, phone, hashed_password, display_name,
|
|
enterprise_id, department_id, role, is_first_login,
|
|
is_active, created_at, updated_at
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, TRUE, $10, $10)
|
|
RETURNING id, username, email, phone, display_name, enterprise_id, department_id,
|
|
role, is_active, is_first_login, allow_kb_upload, created_at, last_login_at
|
|
""",
|
|
data.username,
|
|
str(data.email),
|
|
data.phone,
|
|
hashed,
|
|
data.display_name or data.username,
|
|
enterprise_id,
|
|
data.department_id,
|
|
data.role,
|
|
True,
|
|
datetime.now(timezone.utc),
|
|
)
|
|
return dict(row)
|
|
|
|
@staticmethod
|
|
async def update_user(
|
|
conn: asyncpg.Connection,
|
|
admin: User,
|
|
user_id: int,
|
|
data: AdminUserUpdate,
|
|
) -> Optional[dict]:
|
|
target = await conn.fetchrow(
|
|
"SELECT * FROM user_list WHERE id = $1 AND enterprise_id = $2",
|
|
user_id,
|
|
admin.enterprise_id,
|
|
)
|
|
if not target:
|
|
return None
|
|
|
|
updates: Dict[str, Any] = data.model_dump(exclude_unset=True)
|
|
if user_id == admin.id and updates.get("is_active") is False:
|
|
raise ValueError("不能禁用当前登录账号")
|
|
if not updates:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT id, username, email, phone, display_name, enterprise_id, department_id,
|
|
role, is_active, is_first_login, allow_kb_upload, created_at, last_login_at
|
|
FROM user_list WHERE id = $1 AND enterprise_id = $2
|
|
""",
|
|
user_id,
|
|
admin.enterprise_id,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
if "role" in updates and updates["role"] is not None:
|
|
_validate_role(updates["role"])
|
|
if target["role"] == "admin" and updates["role"] != "admin":
|
|
n_admins = await conn.fetchval(
|
|
"""
|
|
SELECT COUNT(*) FROM user_list
|
|
WHERE enterprise_id = $1 AND role = 'admin' AND is_active = TRUE
|
|
""",
|
|
admin.enterprise_id,
|
|
)
|
|
if int(n_admins or 0) <= 1:
|
|
raise ValueError("至少需要保留一名企业管理员")
|
|
|
|
if "email" in updates and updates["email"] is not None:
|
|
conflict = await conn.fetchval(
|
|
"SELECT id FROM user_list WHERE email = $1 AND id != $2",
|
|
str(updates["email"]),
|
|
user_id,
|
|
)
|
|
if conflict:
|
|
raise ValueError("邮箱已被使用")
|
|
|
|
if "password" in updates:
|
|
pwd = updates.pop("password")
|
|
updates["hashed_password"] = get_password_hash(pwd)
|
|
|
|
fields: List[str] = []
|
|
params: List[Any] = []
|
|
allowed = (
|
|
"email",
|
|
"phone",
|
|
"display_name",
|
|
"department_id",
|
|
"role",
|
|
"is_active",
|
|
"hashed_password",
|
|
"allow_kb_upload",
|
|
)
|
|
for key, val in updates.items():
|
|
if key not in allowed:
|
|
continue
|
|
fields.append(f"{key} = ${len(params) + 1}")
|
|
params.append(val)
|
|
|
|
if not fields:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT id, username, email, phone, display_name, enterprise_id, department_id,
|
|
role, is_active, is_first_login, allow_kb_upload, created_at, last_login_at
|
|
FROM user_list WHERE id = $1 AND enterprise_id = $2
|
|
""",
|
|
user_id,
|
|
admin.enterprise_id,
|
|
)
|
|
return dict(row) if row else None
|
|
|
|
wid = len(params) + 1
|
|
we = len(params) + 2
|
|
params.extend([user_id, admin.enterprise_id])
|
|
q = f"""
|
|
UPDATE user_list
|
|
SET {", ".join(fields)}, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = ${wid} AND enterprise_id = ${we}
|
|
RETURNING id, username, email, phone, display_name, enterprise_id, department_id,
|
|
role, is_active, is_first_login, allow_kb_upload, created_at, last_login_at
|
|
"""
|
|
row = await conn.fetchrow(q, *params)
|
|
return dict(row) if row else None
|
|
|
|
@staticmethod
|
|
async def delete_user(
|
|
conn: asyncpg.Connection,
|
|
admin: User,
|
|
user_id: int,
|
|
) -> bool:
|
|
"""从企业中删除用户(物理删除;若外键限制失败由路由层捕获)。"""
|
|
if user_id == admin.id:
|
|
raise ValueError("不能删除当前登录账号")
|
|
|
|
target = await conn.fetchrow(
|
|
"SELECT role FROM user_list WHERE id = $1 AND enterprise_id = $2",
|
|
user_id,
|
|
admin.enterprise_id,
|
|
)
|
|
if not target:
|
|
return False
|
|
|
|
if target["role"] == "admin":
|
|
n_admins = await conn.fetchval(
|
|
"""
|
|
SELECT COUNT(*) FROM user_list
|
|
WHERE enterprise_id = $1 AND role = 'admin' AND is_active = TRUE
|
|
""",
|
|
admin.enterprise_id,
|
|
)
|
|
if int(n_admins or 0) <= 1:
|
|
raise ValueError("至少需要保留一名企业管理员")
|
|
|
|
result = await conn.execute(
|
|
"DELETE FROM user_list WHERE id = $1 AND enterprise_id = $2",
|
|
user_id,
|
|
admin.enterprise_id,
|
|
)
|
|
return result == "DELETE 1"
|