huoyan-enterprise/backend/services/knowledge_processing_servic...

747 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识加工服务
"""
import io
import json
import tempfile
import os
import uuid
from typing import Optional, List, Tuple
import asyncpg
from datetime import datetime
from models.knowledge_processing import (
KnowledgeProcessingTask,
TaskCreateRequest,
TaskType,
TaskStatus
)
from services.knowledge_base_file_service import KnowledgeBaseFileService
from logger.logging import get_logger
logger = get_logger(__name__)
# 表格类文件扩展名
TABLE_EXTENSIONS = {'.xlsx', '.xls', '.csv'}
class KnowledgeProcessingService:
"""知识加工服务类"""
@staticmethod
async def create_task(
conn: asyncpg.Connection,
user_id: int,
kb_id: int,
task_data: TaskCreateRequest
) -> KnowledgeProcessingTask:
"""
创建知识加工任务
Args:
conn: 数据库连接
user_id: 用户 ID
kb_id: 知识库 ID
task_data: 任务创建数据
Returns:
KnowledgeProcessingTask: 创建的任务
Raises:
ValueError: 如果文件不存在或不属于该知识库
"""
try:
# 验证所有文件是否存在且属于该知识库
for file_id in task_data.file_ids:
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, user_id)
if not file:
raise ValueError(f"文件 ID {file_id} 不存在")
if file.knowledge_base_id != kb_id:
raise ValueError(f"文件 ID {file_id} 不属于该知识库")
if file.status != "completed":
raise ValueError(f"文件 {file.file_name} 尚未处理完成,无法进行加工")
# 插入任务记录
row = await conn.fetchrow(
"""
INSERT INTO knowledge_processing_task
(user_id, knowledge_base_id, task_name, instruction, file_ids, task_type, status)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, user_id, knowledge_base_id, task_name, instruction, file_ids,
task_type, status, result, error_message, created_at, updated_at,
started_at, completed_at
""",
user_id, kb_id, task_data.task_name, task_data.instruction,
task_data.file_ids, task_data.task_type.value, TaskStatus.PENDING.value
)
logger.info(f"用户 {user_id} 创建知识加工任务: {task_data.task_name}, 文件数: {len(task_data.file_ids)}")
return KnowledgeProcessingTask(**dict(row))
except ValueError:
raise
except Exception as e:
logger.error(f"创建知识加工任务失败: {e}")
raise Exception(f"创建知识加工任务失败: {str(e)}")
@staticmethod
async def get_task_by_id(
conn: asyncpg.Connection,
task_id: int,
user_id: int
) -> Optional[KnowledgeProcessingTask]:
"""
根据 ID 获取任务
Args:
conn: 数据库连接
task_id: 任务 ID
user_id: 用户 ID
Returns:
Optional[KnowledgeProcessingTask]: 任务对象
"""
try:
row = await conn.fetchrow(
"""
SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids,
task_type, status, result, result_file_url, error_message,
created_at, updated_at, started_at, completed_at
FROM knowledge_processing_task
WHERE id = $1 AND user_id = $2
""",
task_id, user_id
)
if row:
return KnowledgeProcessingTask(**dict(row))
return None
except Exception as e:
logger.error(f"获取任务失败: {e}")
return None
@staticmethod
async def get_user_tasks(
conn: asyncpg.Connection,
user_id: int,
kb_id: Optional[int] = None,
status: Optional[str] = None,
page: int = 1,
page_size: int = 20
) -> Tuple[List[KnowledgeProcessingTask], int]:
"""
获取用户的任务列表
Args:
conn: 数据库连接
user_id: 用户 ID
kb_id: 知识库 ID可选用于筛选
status: 任务状态(可选,用于筛选)
page: 页码
page_size: 每页数量
Returns:
Tuple[List[KnowledgeProcessingTask], int]: (任务列表, 总数量)
"""
try:
offset = (page - 1) * page_size
# 构建查询条件
conditions = ["user_id = $1"]
params = [user_id]
param_index = 2
if kb_id is not None:
conditions.append(f"knowledge_base_id = ${param_index}")
params.append(kb_id)
param_index += 1
if status is not None:
conditions.append(f"status = ${param_index}")
params.append(status)
param_index += 1
where_clause = " AND ".join(conditions)
# 获取总数
total = await conn.fetchval(
f"""
SELECT COUNT(*) FROM knowledge_processing_task
WHERE {where_clause}
""",
*params
)
# 获取列表
params.extend([page_size, offset])
rows = await conn.fetch(
f"""
SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids,
task_type, status, result, result_file_url, error_message,
created_at, updated_at, started_at, completed_at
FROM knowledge_processing_task
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT ${param_index} OFFSET ${param_index + 1}
""",
*params
)
tasks = [KnowledgeProcessingTask(**dict(row)) for row in rows]
return tasks, total
except Exception as e:
logger.error(f"获取任务列表失败: {e}")
raise Exception(f"获取任务列表失败: {str(e)}")
@staticmethod
async def update_task_status(
conn: asyncpg.Connection,
task_id: int,
status: TaskStatus,
result: Optional[str] = None,
error_message: Optional[str] = None,
result_file_url: Optional[str] = None,
) -> bool:
"""
更新任务状态
Args:
conn: 数据库连接
task_id: 任务 ID
status: 新状态
result: 处理结果(可选)
error_message: 错误信息(可选)
Returns:
bool: 是否更新成功
"""
try:
# 根据状态设置时间戳
if status == TaskStatus.PROCESSING:
await conn.execute(
"""
UPDATE knowledge_processing_task
SET status = $1, started_at = CURRENT_TIMESTAMP
WHERE id = $2
""",
status.value, task_id
)
elif status == TaskStatus.COMPLETED:
await conn.execute(
"""
UPDATE knowledge_processing_task
SET status = $1, result = $2, result_file_url = $3, completed_at = CURRENT_TIMESTAMP
WHERE id = $4
""",
status.value, result, result_file_url, task_id
)
elif status == TaskStatus.FAILED:
await conn.execute(
"""
UPDATE knowledge_processing_task
SET status = $1, error_message = $2, completed_at = CURRENT_TIMESTAMP
WHERE id = $3
""",
status.value, error_message, task_id
)
else:
await conn.execute(
"""
UPDATE knowledge_processing_task
SET status = $1
WHERE id = $2
""",
status.value, task_id
)
logger.info(f"任务 {task_id} 状态更新为: {status.value}")
return True
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
return False
@staticmethod
async def delete_task(
conn: asyncpg.Connection,
task_id: int,
user_id: int
) -> bool:
"""
删除任务(物理删除)
Args:
conn: 数据库连接
task_id: 任务 ID
user_id: 用户 ID
Returns:
bool: 是否删除成功
"""
try:
result = await conn.execute(
"""
DELETE FROM knowledge_processing_task
WHERE id = $1 AND user_id = $2
""",
task_id, user_id
)
if result == "DELETE 1":
logger.info(f"用户 {user_id} 删除任务 {task_id}")
return True
return False
except Exception as e:
logger.error(f"删除任务失败: {e}")
return False
class KnowledgeProcessingExecutor:
"""知识加工执行器"""
@staticmethod
async def process_task(
conn: asyncpg.Connection,
task: KnowledgeProcessingTask
) -> Tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""
执行知识加工任务
Returns:
Tuple[bool, Optional[str], Optional[str], Optional[str]]:
(是否成功, 结果JSON, 错误信息, 结果文件URL)
"""
try:
logger.info(f"开始处理任务 {task.id}: {task.task_name}, 类型: {task.task_type}")
# 1. 获取所有文件信息(含 file_path 供 OSS 下载)
file_records = []
for file_id in task.file_ids:
file_info = await conn.fetchrow(
"SELECT id, file_name, file_type, file_path FROM knowledge_base_file WHERE id = $1",
file_id
)
if not file_info:
logger.warning(f"文件 {file_id} 不存在")
continue
file_records.append(dict(file_info))
if not file_records:
return False, None, "没有找到有效的文件", None
# 2. 判断是否为表格合并任务Excel/CSV 合并走专用逻辑)
all_table_files = all(
f".{r['file_type'].lower()}" in TABLE_EXTENSIONS for r in file_records
)
is_merge = task.task_type == TaskType.MERGE
if is_merge and all_table_files and len(file_records) >= 2:
logger.info(f"检测到表格合并任务,使用 pandas 实际合并文件")
result_json, file_url = await KnowledgeProcessingExecutor._process_table_merge(
task, file_records
)
logger.info(f"任务 {task.id} 表格合并完成,文件链接: {file_url}")
return True, result_json, None, file_url
# 3. 普通任务:通过 LLM 处理(需要读取文本 chunks
file_contents = []
for record in file_records:
chunks = await KnowledgeBaseFileService.get_file_chunks_from_db(conn, record['id'])
if not chunks:
logger.warning(f"文件 {record['id']} 没有内容块")
continue
content = "\n\n".join([chunk['content'] for chunk in chunks])
summary = chunks[0].get('summary', '') if chunks else ''
file_contents.append({
'file_id': record['id'],
'file_name': record['file_name'],
'file_type': record['file_type'],
'content': content,
'summary': summary,
})
if not file_contents:
return False, None, "没有找到有效的文件内容", None
if task.task_type == TaskType.MERGE:
result = await KnowledgeProcessingExecutor._process_merge(task, file_contents)
elif task.task_type == TaskType.COMPARE:
result = await KnowledgeProcessingExecutor._process_compare(task, file_contents)
elif task.task_type == TaskType.SUMMARY:
result = await KnowledgeProcessingExecutor._process_summary(task, file_contents)
else:
result = await KnowledgeProcessingExecutor._process_custom(task, file_contents)
logger.info(f"任务 {task.id} 处理完成")
return True, result, None, None
except Exception as e:
logger.error(f"处理任务失败: {e}")
import traceback
logger.error(f"错误堆栈: {traceback.format_exc()}")
return False, None, str(e), None
@staticmethod
async def _process_table_merge(
task: KnowledgeProcessingTask,
file_records: List[dict]
) -> Tuple[str, Optional[str]]:
"""
对 Excel / CSV 文件做真正的表格合并,生成新 Excel 并上传 OSS。
Returns:
(result_json, oss_file_url)
"""
import asyncio
import pandas as pd
from services.oss_service import get_oss_service
def _extract_oss_key(file_path: str, oss_service) -> str:
"""从完整 URL 或本地路径中提取 OSS Key"""
if file_path.startswith("http://") or file_path.startswith("https://"):
# 格式: https://{bucket}.{endpoint}/{key}
# 去掉协议和域名部分,保留 key
from urllib.parse import urlparse
parsed = urlparse(file_path)
# path 格式为 /kb_7/filename.csv去掉开头的 /
return parsed.path.lstrip("/")
return file_path
def _do_merge() -> Tuple[bytes, str]:
"""在线程池中执行 pandas 合并(同步操作)"""
dfs = []
for record in file_records:
file_path = record['file_path'] # OSS URL 或本地路径
ext = f".{record['file_type'].lower()}"
oss = get_oss_service()
tmp_path = None
# 优先从 OSS 下载
if oss.enabled:
oss_key = _extract_oss_key(file_path, oss)
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
tmp_path = tmp.name
try:
oss.download_file(oss_key, tmp_path)
read_path = tmp_path
logger.info(f"OSS 下载成功: {oss_key} -> {tmp_path}")
except Exception as e:
logger.warning(f"OSS 下载失败 (key={oss_key}): {e}")
# 如果是本地路径则直接读取
if os.path.isfile(file_path):
read_path = file_path
else:
raise ValueError(f"无法获取文件 {record['file_name']}OSS 下载失败且本地文件不存在") from e
elif os.path.isfile(file_path):
read_path = file_path
else:
raise ValueError(f"OSS 未启用且本地文件不存在: {file_path}")
try:
if ext == '.csv':
df = pd.read_csv(read_path, encoding='utf-8')
else:
df = pd.read_excel(read_path)
# 增加来源列,方便区分
df['_来源文件'] = record['file_name']
dfs.append(df)
logger.info(f"读取文件 {record['file_name']}{len(df)}")
finally:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
if not dfs:
raise ValueError("所有文件读取失败,无法合并")
merged_df = pd.concat(dfs, ignore_index=True)
# 输出为 Excel 字节流
buf = io.BytesIO()
with pd.ExcelWriter(buf, engine='openpyxl') as writer:
merged_df.to_excel(writer, index=False, sheet_name='合并结果')
buf.seek(0)
excel_bytes = buf.read()
# 上传到 OSS
oss = get_oss_service()
file_name = f"merged_{uuid.uuid4().hex[:8]}.xlsx"
oss_key = f"processing_results/{file_name}"
file_url = None
if oss.enabled:
file_url = oss.upload_file_from_bytes(excel_bytes, oss_key, file_name)
logger.info(f"合并结果已上传 OSS: {file_url}")
else:
logger.warning("OSS 未启用,合并文件未上传")
return excel_bytes, file_url, len(merged_df), file_name
excel_bytes, file_url, row_count, output_name = await asyncio.to_thread(_do_merge)
result = {
"type": "table_merge",
"file_count": len(file_records),
"files": [{"file_id": r['id'], "file_name": r['file_name']} for r in file_records],
"merged_rows": row_count,
"output_file": output_name,
"download_url": file_url,
}
return json.dumps(result, ensure_ascii=False), file_url
@staticmethod
async def _process_merge(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
"""
处理文件合并任务
Args:
task: 任务对象
file_contents: 文件内容列表
Returns:
str: 合并结果JSON格式
"""
from langchain_core.messages import HumanMessage, SystemMessage
from core.llm_catalog import build_chat_model
logger.info(f"执行合并任务: {task.task_name}")
# 构建 prompt
files_text = ""
for idx, file_data in enumerate(file_contents, 1):
files_text += f"\n\n【文件{idx}: {file_data['file_name']}\n"
if file_data['summary']:
files_text += f"摘要: {file_data['summary']}\n\n"
files_text += f"内容:\n{file_data['content']}\n"
files_text += "=" * 80
prompt = f"""你是一个文档处理助手。用户需要合并多个文件。
用户指令:{task.instruction}
{files_text}
请按照用户的指令,将上述文件合并成一个逻辑通顺的文档。注意:
1. 去除重复内容
2. 保持结构清晰
3. 确保内容连贯
4. 保留所有关键信息
请直接输出合并后的内容,不要添加额外的说明。"""
llm = build_chat_model(
provider="deepseek",
api_model="deepseek-chat",
streaming=False,
temperature=0.3,
)
messages = [
SystemMessage(content="你是一个专业的文档处理助手,擅长合并、对比和总结文档。"),
HumanMessage(content=prompt)
]
response = await llm.ainvoke(messages)
merged_content = response.content
# 返回 JSON 格式的结果
result = {
"type": "merge",
"file_count": len(file_contents),
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
"merged_content": merged_content
}
return json.dumps(result, ensure_ascii=False)
@staticmethod
async def _process_compare(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
"""
处理文件对比任务
Args:
task: 任务对象
file_contents: 文件内容列表
Returns:
str: 对比结果JSON格式
"""
from langchain_core.messages import HumanMessage, SystemMessage
from core.llm_catalog import build_chat_model
logger.info(f"执行对比任务: {task.task_name}")
# 构建 prompt
files_text = ""
for idx, file_data in enumerate(file_contents, 1):
files_text += f"\n\n【文件{idx}: {file_data['file_name']}\n"
if file_data['summary']:
files_text += f"摘要: {file_data['summary']}\n\n"
files_text += f"内容:\n{file_data['content']}\n"
files_text += "=" * 80
prompt = f"""你是一个文档对比分析助手。用户需要对比分析多个文件。
用户指令:{task.instruction}
{files_text}
请按照用户的指令,对上述文件进行对比分析。请从以下几个维度分析:
1. 相似之处:列出文件之间的共同点
2. 差异之处:列出文件之间的不同点
3. 独特内容:每个文件独有的内容
4. 综合分析:整体对比总结
请使用清晰的结构化格式输出结果。"""
llm = build_chat_model(
provider="deepseek",
api_model="deepseek-chat",
streaming=False,
temperature=0.3,
)
messages = [
SystemMessage(content="你是一个专业的文档对比分析助手,擅长发现文档之间的异同点。"),
HumanMessage(content=prompt)
]
response = await llm.ainvoke(messages)
comparison_result = response.content
# 返回 JSON 格式的结果
result = {
"type": "compare",
"file_count": len(file_contents),
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
"comparison": comparison_result
}
return json.dumps(result, ensure_ascii=False)
@staticmethod
async def _process_summary(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
"""
处理文件总结任务
Args:
task: 任务对象
file_contents: 文件内容列表
Returns:
str: 总结结果JSON格式
"""
from langchain_core.messages import HumanMessage, SystemMessage
from core.llm_catalog import build_chat_model
logger.info(f"执行总结任务: {task.task_name}")
# 构建 prompt
files_text = ""
for idx, file_data in enumerate(file_contents, 1):
files_text += f"\n\n【文件{idx}: {file_data['file_name']}\n"
if file_data['summary']:
files_text += f"摘要: {file_data['summary']}\n\n"
files_text += f"内容:\n{file_data['content']}\n"
files_text += "=" * 80
prompt = f"""你是一个文档总结助手。用户需要总结多个文件的内容。
用户指令:{task.instruction}
{files_text}
请按照用户的指令,对上述文件进行总结。请包含:
1. 每个文件的核心内容
2. 整体主题和要点
3. 关键信息提炼
4. 综合总结
请使用清晰的结构化格式输出结果。"""
llm = build_chat_model(
provider="deepseek",
api_model="deepseek-chat",
streaming=False,
temperature=0.3,
)
messages = [
SystemMessage(content="你是一个专业的文档总结助手,擅长提炼关键信息和核心要点。"),
HumanMessage(content=prompt)
]
response = await llm.ainvoke(messages)
summary_result = response.content
# 返回 JSON 格式的结果
result = {
"type": "summary",
"file_count": len(file_contents),
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
"summary": summary_result
}
return json.dumps(result, ensure_ascii=False)
@staticmethod
async def _process_custom(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
"""
处理自定义任务
Args:
task: 任务对象
file_contents: 文件内容列表
Returns:
str: 处理结果JSON格式
"""
from langchain_core.messages import HumanMessage, SystemMessage
from core.llm_catalog import build_chat_model
logger.info(f"执行自定义任务: {task.task_name}")
# 构建 prompt
files_text = ""
for idx, file_data in enumerate(file_contents, 1):
files_text += f"\n\n【文件{idx}: {file_data['file_name']}\n"
if file_data['summary']:
files_text += f"摘要: {file_data['summary']}\n\n"
files_text += f"内容:\n{file_data['content']}\n"
files_text += "=" * 80
prompt = f"""你是一个文档处理助手。用户给出了以下文件和指令。
用户指令:{task.instruction}
{files_text}
请严格按照用户的指令执行处理,并输出结果。"""
llm = build_chat_model(
provider="deepseek",
api_model="deepseek-chat",
streaming=False,
temperature=0.5,
)
messages = [
SystemMessage(content="你是一个专业的文档处理助手,能够根据用户指令灵活处理各种文档任务。"),
HumanMessage(content=prompt)
]
response = await llm.ainvoke(messages)
custom_result = response.content
# 返回 JSON 格式的结果
result = {
"type": "custom",
"file_count": len(file_contents),
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
"result": custom_result
}
return json.dumps(result, ensure_ascii=False)