""" 知识库服务 """ from typing import Any, Dict, List, Optional, Tuple import asyncpg from core.permissions import can_manage_kb, can_view_kb from models.knowledge_base import KnowledgeBase, KnowledgeBaseCreate, KnowledgeBaseUpdate from models.user import User from logger.logging import get_logger logger = get_logger(__name__) def _kb_model_dump(kb: KnowledgeBase) -> Dict[str, Any]: return kb.model_dump() if hasattr(kb, "model_dump") else kb.dict() _KB_FIELDS = """ id, user_id, enterprise_id, department_id, creator_id, visibility, name, description, created_at, updated_at, is_deleted, deleted_at """ class KnowledgeBaseService: """知识库服务类""" @staticmethod async def enrich_kb_for_response( conn: asyncpg.Connection, kb: KnowledgeBase, viewer: User, ) -> Dict[str, Any]: """补充创建者、部门名称及是否本人创建,用于 API 返回。""" data = _kb_model_dump(kb) row = await conn.fetchrow( """ SELECT u.username AS creator_username, COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name, d.name AS department_name FROM knowledge_base kb LEFT JOIN user_list u ON u.id = kb.creator_id LEFT JOIN department d ON d.id = kb.department_id WHERE kb.id = $1 """, kb.id, ) if row: data["creator_username"] = row["creator_username"] data["creator_display_name"] = row["creator_display_name"] data["department_name"] = row["department_name"] else: data["creator_username"] = None data["creator_display_name"] = None data["department_name"] = None cid = kb.creator_id data["is_mine"] = bool( viewer.id is not None and ( (cid is not None and cid == viewer.id) or (cid is None and kb.user_id == viewer.id) ) ) return data @staticmethod def _validate_visibility(v: str) -> str: if v not in ("private", "department", "enterprise"): raise ValueError("visibility 必须是 private、department 或 enterprise") return v @staticmethod async def create_knowledge_base( conn: asyncpg.Connection, user: User, kb_data: KnowledgeBaseCreate ) -> KnowledgeBase: """创建知识库(写入企业、部门、创建者与可见性)。""" user_id = user.id vis = KnowledgeBaseService._validate_visibility(kb_data.visibility) enterprise_id = user.enterprise_id if enterprise_id is None: raise ValueError("用户未关联企业,无法创建知识库") try: existing = await conn.fetchrow( """ SELECT id FROM knowledge_base WHERE user_id = $1 AND name = $2 AND is_deleted = FALSE """, user_id, kb_data.name, ) if existing: raise ValueError(f"知识库名称 '{kb_data.name}' 已存在") deleted_existing = await conn.fetchrow( """ SELECT id FROM knowledge_base WHERE user_id = $1 AND name = $2 AND is_deleted = TRUE """, user_id, kb_data.name, ) if deleted_existing: logger.info(f"发现已删除的同名知识库 ID: {deleted_existing['id']},将彻底删除") await conn.execute( """ DELETE FROM knowledge_base WHERE id = $1 AND user_id = $2 AND is_deleted = TRUE """, deleted_existing["id"], user_id, ) row = await conn.fetchrow( f""" INSERT INTO knowledge_base ( user_id, enterprise_id, department_id, creator_id, visibility, name, description ) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING {_KB_FIELDS.strip()} """, user_id, enterprise_id, user.department_id, user_id, vis, kb_data.name, kb_data.description, ) logger.info(f"用户 {user_id} 创建知识库: {kb_data.name}") return KnowledgeBase(**dict(row)) except ValueError: raise except asyncpg.UniqueViolationError as e: error_msg = str(e) if "uk_user_knowledge_base_name" in error_msg or "user_id" in error_msg.lower(): deleted_kb = await conn.fetchrow( """ SELECT id FROM knowledge_base WHERE user_id = $1 AND name = $2 AND is_deleted = TRUE """, user_id, kb_data.name, ) if deleted_kb: raise ValueError( f"知识库名称 '{kb_data.name}' 已被使用(已删除)。" f"请先彻底删除已删除的知识库,或使用其他名称。" ) raise ValueError(f"知识库名称 '{kb_data.name}' 已存在") logger.error(f"创建知识库时发生唯一约束冲突: {e}") raise Exception("创建知识库失败: 唯一约束冲突") except Exception as e: logger.error(f"创建知识库失败: {e}") raise Exception(f"创建知识库失败: {str(e)}") @staticmethod async def fetch_knowledge_base_by_id( conn: asyncpg.Connection, kb_id: int, ) -> Optional[KnowledgeBase]: """按主键读取未删除的知识库(不做权限过滤)。""" row = await conn.fetchrow( f""" SELECT {_KB_FIELDS.strip()} FROM knowledge_base WHERE id = $1 AND is_deleted = FALSE """, kb_id, ) if row: return KnowledgeBase(**dict(row)) return None @staticmethod async def get_knowledge_base_by_id( conn: asyncpg.Connection, kb_id: int, user: User, ) -> Optional[KnowledgeBase]: """获取知识库(企业版:按可见性与角色过滤)。""" kb = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id) if kb is None: return None if not can_view_kb(user, kb): return None return kb @staticmethod async def list_visible_knowledge_bases( conn: asyncpg.Connection, user: User, page: int = 1, page_size: int = 20, ) -> Tuple[List[Dict[str, Any]], int]: """列出当前用户可见的知识库(企业版 SQL 过滤),含创建者/部门 JOIN。""" offset = (page - 1) * page_size enterprise_id = user.enterprise_id if enterprise_id is None: return [], 0 role = user.role or "employee" dept_id = user.department_id uid = user.id where_sql = """ kb.is_deleted = FALSE AND kb.enterprise_id = $1 AND ( $2::text = 'admin' OR kb.creator_id = $3 OR ($2::text = 'leader' AND kb.department_id IS NOT NULL AND kb.department_id = $4) OR (kb.visibility = 'department' AND kb.department_id IS NOT NULL AND kb.department_id = $4) OR (kb.visibility = 'enterprise') ) """ total = await conn.fetchval( f""" SELECT COUNT(*) FROM knowledge_base kb WHERE {where_sql} """, enterprise_id, role, uid, dept_id, ) rows = await conn.fetch( f""" SELECT kb.id, kb.user_id, kb.enterprise_id, kb.department_id, kb.creator_id, kb.visibility, kb.name, kb.description, kb.created_at, kb.updated_at, u.username AS creator_username, COALESCE(NULLIF(TRIM(u.display_name), ''), u.username) AS creator_display_name, d.name AS department_name FROM knowledge_base kb LEFT JOIN user_list u ON u.id = kb.creator_id LEFT JOIN department d ON d.id = kb.department_id WHERE {where_sql} ORDER BY kb.created_at DESC LIMIT $5 OFFSET $6 """, enterprise_id, role, uid, dept_id, page_size, offset, ) items: List[Dict[str, Any]] = [] for r in rows: d = dict(r) cid = d.get("creator_id") d["is_mine"] = bool( uid is not None and ( (cid is not None and cid == uid) or (cid is None and d.get("user_id") == uid) ) ) items.append(d) return items, int(total or 0) @staticmethod async def update_knowledge_base( conn: asyncpg.Connection, kb_id: int, user: User, kb_data: KnowledgeBaseUpdate, ) -> Optional[KnowledgeBase]: """更新知识库(仅创建者或企业管理员)。""" existing = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id) if existing is None: return None if not can_manage_kb(user, existing): return None update_fields: List[str] = [] params: List = [] param_index = 1 if kb_data.name is not None: conflict = await conn.fetchrow( """ SELECT id FROM knowledge_base WHERE user_id = $1 AND name = $2 AND id != $3 AND is_deleted = FALSE """, existing.user_id, kb_data.name, kb_id, ) if conflict: raise ValueError(f"知识库名称 '{kb_data.name}' 已存在") update_fields.append(f"name = ${param_index}") params.append(kb_data.name) param_index += 1 if kb_data.description is not None: update_fields.append(f"description = ${param_index}") params.append(kb_data.description) param_index += 1 if kb_data.visibility is not None: KnowledgeBaseService._validate_visibility(kb_data.visibility) update_fields.append(f"visibility = ${param_index}") params.append(kb_data.visibility) param_index += 1 if not update_fields: return existing params.append(kb_id) query = f""" UPDATE knowledge_base SET {', '.join(update_fields)} WHERE id = ${param_index} AND is_deleted = FALSE RETURNING {_KB_FIELDS.strip()} """ row = await conn.fetchrow(query, *params) if row: logger.info(f"用户 {user.id} 更新知识库 {kb_id}") return KnowledgeBase(**dict(row)) return None @staticmethod async def delete_knowledge_base( conn: asyncpg.Connection, kb_id: int, user: User, ) -> bool: """软删除知识库(仅创建者或企业管理员)。""" existing = await KnowledgeBaseService.fetch_knowledge_base_by_id(conn, kb_id) if existing is None: return False if not can_manage_kb(user, existing): return False result = await conn.execute( """ UPDATE knowledge_base SET is_deleted = TRUE, deleted_at = CURRENT_TIMESTAMP WHERE id = $1 AND is_deleted = FALSE """, kb_id, ) if result == "UPDATE 1": logger.info(f"用户 {user.id} 删除知识库 {kb_id}") return True return False