""" 知识加工服务 """ import io import json import tempfile import os import uuid from typing import Optional, List, Tuple import asyncpg from datetime import datetime from models.knowledge_processing import ( KnowledgeProcessingTask, TaskCreateRequest, TaskType, TaskStatus ) from services.knowledge_base_file_service import KnowledgeBaseFileService from logger.logging import get_logger logger = get_logger(__name__) # 表格类文件扩展名 TABLE_EXTENSIONS = {'.xlsx', '.xls', '.csv'} class KnowledgeProcessingService: """知识加工服务类""" @staticmethod async def create_task( conn: asyncpg.Connection, user_id: int, kb_id: int, task_data: TaskCreateRequest ) -> KnowledgeProcessingTask: """ 创建知识加工任务 Args: conn: 数据库连接 user_id: 用户 ID kb_id: 知识库 ID task_data: 任务创建数据 Returns: KnowledgeProcessingTask: 创建的任务 Raises: ValueError: 如果文件不存在或不属于该知识库 """ try: # 验证所有文件是否存在且属于该知识库 for file_id in task_data.file_ids: file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, user_id) if not file: raise ValueError(f"文件 ID {file_id} 不存在") if file.knowledge_base_id != kb_id: raise ValueError(f"文件 ID {file_id} 不属于该知识库") if file.status != "completed": raise ValueError(f"文件 {file.file_name} 尚未处理完成,无法进行加工") # 插入任务记录 row = await conn.fetchrow( """ INSERT INTO knowledge_processing_task (user_id, knowledge_base_id, task_name, instruction, file_ids, task_type, status) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id, user_id, knowledge_base_id, task_name, instruction, file_ids, task_type, status, result, error_message, created_at, updated_at, started_at, completed_at """, user_id, kb_id, task_data.task_name, task_data.instruction, task_data.file_ids, task_data.task_type.value, TaskStatus.PENDING.value ) logger.info(f"用户 {user_id} 创建知识加工任务: {task_data.task_name}, 文件数: {len(task_data.file_ids)}") return KnowledgeProcessingTask(**dict(row)) except ValueError: raise except Exception as e: logger.error(f"创建知识加工任务失败: {e}") raise Exception(f"创建知识加工任务失败: {str(e)}") @staticmethod async def get_task_by_id( conn: asyncpg.Connection, task_id: int, user_id: int ) -> Optional[KnowledgeProcessingTask]: """ 根据 ID 获取任务 Args: conn: 数据库连接 task_id: 任务 ID user_id: 用户 ID Returns: Optional[KnowledgeProcessingTask]: 任务对象 """ try: row = await conn.fetchrow( """ SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids, task_type, status, result, result_file_url, error_message, created_at, updated_at, started_at, completed_at FROM knowledge_processing_task WHERE id = $1 AND user_id = $2 """, task_id, user_id ) if row: return KnowledgeProcessingTask(**dict(row)) return None except Exception as e: logger.error(f"获取任务失败: {e}") return None @staticmethod async def get_user_tasks( conn: asyncpg.Connection, user_id: int, kb_id: Optional[int] = None, status: Optional[str] = None, page: int = 1, page_size: int = 20 ) -> Tuple[List[KnowledgeProcessingTask], int]: """ 获取用户的任务列表 Args: conn: 数据库连接 user_id: 用户 ID kb_id: 知识库 ID(可选,用于筛选) status: 任务状态(可选,用于筛选) page: 页码 page_size: 每页数量 Returns: Tuple[List[KnowledgeProcessingTask], int]: (任务列表, 总数量) """ try: offset = (page - 1) * page_size # 构建查询条件 conditions = ["user_id = $1"] params = [user_id] param_index = 2 if kb_id is not None: conditions.append(f"knowledge_base_id = ${param_index}") params.append(kb_id) param_index += 1 if status is not None: conditions.append(f"status = ${param_index}") params.append(status) param_index += 1 where_clause = " AND ".join(conditions) # 获取总数 total = await conn.fetchval( f""" SELECT COUNT(*) FROM knowledge_processing_task WHERE {where_clause} """, *params ) # 获取列表 params.extend([page_size, offset]) rows = await conn.fetch( f""" SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids, task_type, status, result, result_file_url, error_message, created_at, updated_at, started_at, completed_at FROM knowledge_processing_task WHERE {where_clause} ORDER BY created_at DESC LIMIT ${param_index} OFFSET ${param_index + 1} """, *params ) tasks = [KnowledgeProcessingTask(**dict(row)) for row in rows] return tasks, total except Exception as e: logger.error(f"获取任务列表失败: {e}") raise Exception(f"获取任务列表失败: {str(e)}") @staticmethod async def update_task_status( conn: asyncpg.Connection, task_id: int, status: TaskStatus, result: Optional[str] = None, error_message: Optional[str] = None, result_file_url: Optional[str] = None, ) -> bool: """ 更新任务状态 Args: conn: 数据库连接 task_id: 任务 ID status: 新状态 result: 处理结果(可选) error_message: 错误信息(可选) Returns: bool: 是否更新成功 """ try: # 根据状态设置时间戳 if status == TaskStatus.PROCESSING: await conn.execute( """ UPDATE knowledge_processing_task SET status = $1, started_at = CURRENT_TIMESTAMP WHERE id = $2 """, status.value, task_id ) elif status == TaskStatus.COMPLETED: await conn.execute( """ UPDATE knowledge_processing_task SET status = $1, result = $2, result_file_url = $3, completed_at = CURRENT_TIMESTAMP WHERE id = $4 """, status.value, result, result_file_url, task_id ) elif status == TaskStatus.FAILED: await conn.execute( """ UPDATE knowledge_processing_task SET status = $1, error_message = $2, completed_at = CURRENT_TIMESTAMP WHERE id = $3 """, status.value, error_message, task_id ) else: await conn.execute( """ UPDATE knowledge_processing_task SET status = $1 WHERE id = $2 """, status.value, task_id ) logger.info(f"任务 {task_id} 状态更新为: {status.value}") return True except Exception as e: logger.error(f"更新任务状态失败: {e}") return False @staticmethod async def delete_task( conn: asyncpg.Connection, task_id: int, user_id: int ) -> bool: """ 删除任务(物理删除) Args: conn: 数据库连接 task_id: 任务 ID user_id: 用户 ID Returns: bool: 是否删除成功 """ try: result = await conn.execute( """ DELETE FROM knowledge_processing_task WHERE id = $1 AND user_id = $2 """, task_id, user_id ) if result == "DELETE 1": logger.info(f"用户 {user_id} 删除任务 {task_id}") return True return False except Exception as e: logger.error(f"删除任务失败: {e}") return False class KnowledgeProcessingExecutor: """知识加工执行器""" @staticmethod async def process_task( conn: asyncpg.Connection, task: KnowledgeProcessingTask ) -> Tuple[bool, Optional[str], Optional[str], Optional[str]]: """ 执行知识加工任务 Returns: Tuple[bool, Optional[str], Optional[str], Optional[str]]: (是否成功, 结果JSON, 错误信息, 结果文件URL) """ try: logger.info(f"开始处理任务 {task.id}: {task.task_name}, 类型: {task.task_type}") # 1. 获取所有文件信息(含 file_path 供 OSS 下载) file_records = [] for file_id in task.file_ids: file_info = await conn.fetchrow( "SELECT id, file_name, file_type, file_path FROM knowledge_base_file WHERE id = $1", file_id ) if not file_info: logger.warning(f"文件 {file_id} 不存在") continue file_records.append(dict(file_info)) if not file_records: return False, None, "没有找到有效的文件", None # 2. 判断是否为表格合并任务(Excel/CSV 合并走专用逻辑) all_table_files = all( f".{r['file_type'].lower()}" in TABLE_EXTENSIONS for r in file_records ) is_merge = task.task_type == TaskType.MERGE if is_merge and all_table_files and len(file_records) >= 2: logger.info(f"检测到表格合并任务,使用 pandas 实际合并文件") result_json, file_url = await KnowledgeProcessingExecutor._process_table_merge( task, file_records ) logger.info(f"任务 {task.id} 表格合并完成,文件链接: {file_url}") return True, result_json, None, file_url # 3. 普通任务:通过 LLM 处理(需要读取文本 chunks) file_contents = [] for record in file_records: chunks = await KnowledgeBaseFileService.get_file_chunks_from_db(conn, record['id']) if not chunks: logger.warning(f"文件 {record['id']} 没有内容块") continue content = "\n\n".join([chunk['content'] for chunk in chunks]) summary = chunks[0].get('summary', '') if chunks else '' file_contents.append({ 'file_id': record['id'], 'file_name': record['file_name'], 'file_type': record['file_type'], 'content': content, 'summary': summary, }) if not file_contents: return False, None, "没有找到有效的文件内容", None if task.task_type == TaskType.MERGE: result = await KnowledgeProcessingExecutor._process_merge(task, file_contents) elif task.task_type == TaskType.COMPARE: result = await KnowledgeProcessingExecutor._process_compare(task, file_contents) elif task.task_type == TaskType.SUMMARY: result = await KnowledgeProcessingExecutor._process_summary(task, file_contents) else: result = await KnowledgeProcessingExecutor._process_custom(task, file_contents) logger.info(f"任务 {task.id} 处理完成") return True, result, None, None except Exception as e: logger.error(f"处理任务失败: {e}") import traceback logger.error(f"错误堆栈: {traceback.format_exc()}") return False, None, str(e), None @staticmethod async def _process_table_merge( task: KnowledgeProcessingTask, file_records: List[dict] ) -> Tuple[str, Optional[str]]: """ 对 Excel / CSV 文件做真正的表格合并,生成新 Excel 并上传 OSS。 Returns: (result_json, oss_file_url) """ import asyncio import pandas as pd from services.oss_service import get_oss_service def _extract_oss_key(file_path: str, oss_service) -> str: """从完整 URL 或本地路径中提取 OSS Key""" if file_path.startswith("http://") or file_path.startswith("https://"): # 格式: https://{bucket}.{endpoint}/{key} # 去掉协议和域名部分,保留 key from urllib.parse import urlparse parsed = urlparse(file_path) # path 格式为 /kb_7/filename.csv,去掉开头的 / return parsed.path.lstrip("/") return file_path def _do_merge() -> Tuple[bytes, str]: """在线程池中执行 pandas 合并(同步操作)""" dfs = [] for record in file_records: file_path = record['file_path'] # OSS URL 或本地路径 ext = f".{record['file_type'].lower()}" oss = get_oss_service() tmp_path = None # 优先从 OSS 下载 if oss.enabled: oss_key = _extract_oss_key(file_path, oss) with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: tmp_path = tmp.name try: oss.download_file(oss_key, tmp_path) read_path = tmp_path logger.info(f"OSS 下载成功: {oss_key} -> {tmp_path}") except Exception as e: logger.warning(f"OSS 下载失败 (key={oss_key}): {e}") # 如果是本地路径则直接读取 if os.path.isfile(file_path): read_path = file_path else: raise ValueError(f"无法获取文件 {record['file_name']}:OSS 下载失败且本地文件不存在") from e elif os.path.isfile(file_path): read_path = file_path else: raise ValueError(f"OSS 未启用且本地文件不存在: {file_path}") try: if ext == '.csv': df = pd.read_csv(read_path, encoding='utf-8') else: df = pd.read_excel(read_path) # 增加来源列,方便区分 df['_来源文件'] = record['file_name'] dfs.append(df) logger.info(f"读取文件 {record['file_name']},{len(df)} 行") finally: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) if not dfs: raise ValueError("所有文件读取失败,无法合并") merged_df = pd.concat(dfs, ignore_index=True) # 输出为 Excel 字节流 buf = io.BytesIO() with pd.ExcelWriter(buf, engine='openpyxl') as writer: merged_df.to_excel(writer, index=False, sheet_name='合并结果') buf.seek(0) excel_bytes = buf.read() # 上传到 OSS oss = get_oss_service() file_name = f"merged_{uuid.uuid4().hex[:8]}.xlsx" oss_key = f"processing_results/{file_name}" file_url = None if oss.enabled: file_url = oss.upload_file_from_bytes(excel_bytes, oss_key, file_name) logger.info(f"合并结果已上传 OSS: {file_url}") else: logger.warning("OSS 未启用,合并文件未上传") return excel_bytes, file_url, len(merged_df), file_name excel_bytes, file_url, row_count, output_name = await asyncio.to_thread(_do_merge) result = { "type": "table_merge", "file_count": len(file_records), "files": [{"file_id": r['id'], "file_name": r['file_name']} for r in file_records], "merged_rows": row_count, "output_file": output_name, "download_url": file_url, } return json.dumps(result, ensure_ascii=False), file_url @staticmethod async def _process_merge(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str: """ 处理文件合并任务 Args: task: 任务对象 file_contents: 文件内容列表 Returns: str: 合并结果(JSON格式) """ from langchain_core.messages import HumanMessage, SystemMessage from core.llm_catalog import build_chat_model logger.info(f"执行合并任务: {task.task_name}") # 构建 prompt files_text = "" for idx, file_data in enumerate(file_contents, 1): files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n" if file_data['summary']: files_text += f"摘要: {file_data['summary']}\n\n" files_text += f"内容:\n{file_data['content']}\n" files_text += "=" * 80 prompt = f"""你是一个文档处理助手。用户需要合并多个文件。 用户指令:{task.instruction} {files_text} 请按照用户的指令,将上述文件合并成一个逻辑通顺的文档。注意: 1. 去除重复内容 2. 保持结构清晰 3. 确保内容连贯 4. 保留所有关键信息 请直接输出合并后的内容,不要添加额外的说明。""" llm = build_chat_model( provider="deepseek", api_model="deepseek-chat", streaming=False, temperature=0.3, ) messages = [ SystemMessage(content="你是一个专业的文档处理助手,擅长合并、对比和总结文档。"), HumanMessage(content=prompt) ] response = await llm.ainvoke(messages) merged_content = response.content # 返回 JSON 格式的结果 result = { "type": "merge", "file_count": len(file_contents), "files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents], "merged_content": merged_content } return json.dumps(result, ensure_ascii=False) @staticmethod async def _process_compare(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str: """ 处理文件对比任务 Args: task: 任务对象 file_contents: 文件内容列表 Returns: str: 对比结果(JSON格式) """ from langchain_core.messages import HumanMessage, SystemMessage from core.llm_catalog import build_chat_model logger.info(f"执行对比任务: {task.task_name}") # 构建 prompt files_text = "" for idx, file_data in enumerate(file_contents, 1): files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n" if file_data['summary']: files_text += f"摘要: {file_data['summary']}\n\n" files_text += f"内容:\n{file_data['content']}\n" files_text += "=" * 80 prompt = f"""你是一个文档对比分析助手。用户需要对比分析多个文件。 用户指令:{task.instruction} {files_text} 请按照用户的指令,对上述文件进行对比分析。请从以下几个维度分析: 1. 相似之处:列出文件之间的共同点 2. 差异之处:列出文件之间的不同点 3. 独特内容:每个文件独有的内容 4. 综合分析:整体对比总结 请使用清晰的结构化格式输出结果。""" llm = build_chat_model( provider="deepseek", api_model="deepseek-chat", streaming=False, temperature=0.3, ) messages = [ SystemMessage(content="你是一个专业的文档对比分析助手,擅长发现文档之间的异同点。"), HumanMessage(content=prompt) ] response = await llm.ainvoke(messages) comparison_result = response.content # 返回 JSON 格式的结果 result = { "type": "compare", "file_count": len(file_contents), "files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents], "comparison": comparison_result } return json.dumps(result, ensure_ascii=False) @staticmethod async def _process_summary(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str: """ 处理文件总结任务 Args: task: 任务对象 file_contents: 文件内容列表 Returns: str: 总结结果(JSON格式) """ from langchain_core.messages import HumanMessage, SystemMessage from core.llm_catalog import build_chat_model logger.info(f"执行总结任务: {task.task_name}") # 构建 prompt files_text = "" for idx, file_data in enumerate(file_contents, 1): files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n" if file_data['summary']: files_text += f"摘要: {file_data['summary']}\n\n" files_text += f"内容:\n{file_data['content']}\n" files_text += "=" * 80 prompt = f"""你是一个文档总结助手。用户需要总结多个文件的内容。 用户指令:{task.instruction} {files_text} 请按照用户的指令,对上述文件进行总结。请包含: 1. 每个文件的核心内容 2. 整体主题和要点 3. 关键信息提炼 4. 综合总结 请使用清晰的结构化格式输出结果。""" llm = build_chat_model( provider="deepseek", api_model="deepseek-chat", streaming=False, temperature=0.3, ) messages = [ SystemMessage(content="你是一个专业的文档总结助手,擅长提炼关键信息和核心要点。"), HumanMessage(content=prompt) ] response = await llm.ainvoke(messages) summary_result = response.content # 返回 JSON 格式的结果 result = { "type": "summary", "file_count": len(file_contents), "files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents], "summary": summary_result } return json.dumps(result, ensure_ascii=False) @staticmethod async def _process_custom(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str: """ 处理自定义任务 Args: task: 任务对象 file_contents: 文件内容列表 Returns: str: 处理结果(JSON格式) """ from langchain_core.messages import HumanMessage, SystemMessage from core.llm_catalog import build_chat_model logger.info(f"执行自定义任务: {task.task_name}") # 构建 prompt files_text = "" for idx, file_data in enumerate(file_contents, 1): files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n" if file_data['summary']: files_text += f"摘要: {file_data['summary']}\n\n" files_text += f"内容:\n{file_data['content']}\n" files_text += "=" * 80 prompt = f"""你是一个文档处理助手。用户给出了以下文件和指令。 用户指令:{task.instruction} {files_text} 请严格按照用户的指令执行处理,并输出结果。""" llm = build_chat_model( provider="deepseek", api_model="deepseek-chat", streaming=False, temperature=0.5, ) messages = [ SystemMessage(content="你是一个专业的文档处理助手,能够根据用户指令灵活处理各种文档任务。"), HumanMessage(content=prompt) ] response = await llm.ainvoke(messages) custom_result = response.content # 返回 JSON 格式的结果 result = { "type": "custom", "file_count": len(file_contents), "files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents], "result": custom_result } return json.dumps(result, ensure_ascii=False)