747 lines
26 KiB
Python
747 lines
26 KiB
Python
"""
|
||
知识加工服务
|
||
"""
|
||
import io
|
||
import json
|
||
import tempfile
|
||
import os
|
||
import uuid
|
||
from typing import Optional, List, Tuple
|
||
import asyncpg
|
||
from datetime import datetime
|
||
|
||
from models.knowledge_processing import (
|
||
KnowledgeProcessingTask,
|
||
TaskCreateRequest,
|
||
TaskType,
|
||
TaskStatus
|
||
)
|
||
from services.knowledge_base_file_service import KnowledgeBaseFileService
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 表格类文件扩展名
|
||
TABLE_EXTENSIONS = {'.xlsx', '.xls', '.csv'}
|
||
|
||
|
||
class KnowledgeProcessingService:
|
||
"""知识加工服务类"""
|
||
|
||
@staticmethod
|
||
async def create_task(
|
||
conn: asyncpg.Connection,
|
||
user_id: int,
|
||
kb_id: int,
|
||
task_data: TaskCreateRequest
|
||
) -> KnowledgeProcessingTask:
|
||
"""
|
||
创建知识加工任务
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
user_id: 用户 ID
|
||
kb_id: 知识库 ID
|
||
task_data: 任务创建数据
|
||
|
||
Returns:
|
||
KnowledgeProcessingTask: 创建的任务
|
||
|
||
Raises:
|
||
ValueError: 如果文件不存在或不属于该知识库
|
||
"""
|
||
try:
|
||
# 验证所有文件是否存在且属于该知识库
|
||
for file_id in task_data.file_ids:
|
||
file = await KnowledgeBaseFileService.get_file_by_id(conn, file_id, user_id)
|
||
if not file:
|
||
raise ValueError(f"文件 ID {file_id} 不存在")
|
||
if file.knowledge_base_id != kb_id:
|
||
raise ValueError(f"文件 ID {file_id} 不属于该知识库")
|
||
if file.status != "completed":
|
||
raise ValueError(f"文件 {file.file_name} 尚未处理完成,无法进行加工")
|
||
|
||
# 插入任务记录
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO knowledge_processing_task
|
||
(user_id, knowledge_base_id, task_name, instruction, file_ids, task_type, status)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||
RETURNING id, user_id, knowledge_base_id, task_name, instruction, file_ids,
|
||
task_type, status, result, error_message, created_at, updated_at,
|
||
started_at, completed_at
|
||
""",
|
||
user_id, kb_id, task_data.task_name, task_data.instruction,
|
||
task_data.file_ids, task_data.task_type.value, TaskStatus.PENDING.value
|
||
)
|
||
|
||
logger.info(f"用户 {user_id} 创建知识加工任务: {task_data.task_name}, 文件数: {len(task_data.file_ids)}")
|
||
return KnowledgeProcessingTask(**dict(row))
|
||
|
||
except ValueError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"创建知识加工任务失败: {e}")
|
||
raise Exception(f"创建知识加工任务失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def get_task_by_id(
|
||
conn: asyncpg.Connection,
|
||
task_id: int,
|
||
user_id: int
|
||
) -> Optional[KnowledgeProcessingTask]:
|
||
"""
|
||
根据 ID 获取任务
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
task_id: 任务 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
Optional[KnowledgeProcessingTask]: 任务对象
|
||
"""
|
||
try:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids,
|
||
task_type, status, result, result_file_url, error_message,
|
||
created_at, updated_at, started_at, completed_at
|
||
FROM knowledge_processing_task
|
||
WHERE id = $1 AND user_id = $2
|
||
""",
|
||
task_id, user_id
|
||
)
|
||
|
||
if row:
|
||
return KnowledgeProcessingTask(**dict(row))
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取任务失败: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
async def get_user_tasks(
|
||
conn: asyncpg.Connection,
|
||
user_id: int,
|
||
kb_id: Optional[int] = None,
|
||
status: Optional[str] = None,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> Tuple[List[KnowledgeProcessingTask], int]:
|
||
"""
|
||
获取用户的任务列表
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
user_id: 用户 ID
|
||
kb_id: 知识库 ID(可选,用于筛选)
|
||
status: 任务状态(可选,用于筛选)
|
||
page: 页码
|
||
page_size: 每页数量
|
||
|
||
Returns:
|
||
Tuple[List[KnowledgeProcessingTask], int]: (任务列表, 总数量)
|
||
"""
|
||
try:
|
||
offset = (page - 1) * page_size
|
||
|
||
# 构建查询条件
|
||
conditions = ["user_id = $1"]
|
||
params = [user_id]
|
||
param_index = 2
|
||
|
||
if kb_id is not None:
|
||
conditions.append(f"knowledge_base_id = ${param_index}")
|
||
params.append(kb_id)
|
||
param_index += 1
|
||
|
||
if status is not None:
|
||
conditions.append(f"status = ${param_index}")
|
||
params.append(status)
|
||
param_index += 1
|
||
|
||
where_clause = " AND ".join(conditions)
|
||
|
||
# 获取总数
|
||
total = await conn.fetchval(
|
||
f"""
|
||
SELECT COUNT(*) FROM knowledge_processing_task
|
||
WHERE {where_clause}
|
||
""",
|
||
*params
|
||
)
|
||
|
||
# 获取列表
|
||
params.extend([page_size, offset])
|
||
rows = await conn.fetch(
|
||
f"""
|
||
SELECT id, user_id, knowledge_base_id, task_name, instruction, file_ids,
|
||
task_type, status, result, result_file_url, error_message,
|
||
created_at, updated_at, started_at, completed_at
|
||
FROM knowledge_processing_task
|
||
WHERE {where_clause}
|
||
ORDER BY created_at DESC
|
||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||
""",
|
||
*params
|
||
)
|
||
|
||
tasks = [KnowledgeProcessingTask(**dict(row)) for row in rows]
|
||
return tasks, total
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取任务列表失败: {e}")
|
||
raise Exception(f"获取任务列表失败: {str(e)}")
|
||
|
||
@staticmethod
|
||
async def update_task_status(
|
||
conn: asyncpg.Connection,
|
||
task_id: int,
|
||
status: TaskStatus,
|
||
result: Optional[str] = None,
|
||
error_message: Optional[str] = None,
|
||
result_file_url: Optional[str] = None,
|
||
) -> bool:
|
||
"""
|
||
更新任务状态
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
task_id: 任务 ID
|
||
status: 新状态
|
||
result: 处理结果(可选)
|
||
error_message: 错误信息(可选)
|
||
|
||
Returns:
|
||
bool: 是否更新成功
|
||
"""
|
||
try:
|
||
# 根据状态设置时间戳
|
||
if status == TaskStatus.PROCESSING:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE knowledge_processing_task
|
||
SET status = $1, started_at = CURRENT_TIMESTAMP
|
||
WHERE id = $2
|
||
""",
|
||
status.value, task_id
|
||
)
|
||
elif status == TaskStatus.COMPLETED:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE knowledge_processing_task
|
||
SET status = $1, result = $2, result_file_url = $3, completed_at = CURRENT_TIMESTAMP
|
||
WHERE id = $4
|
||
""",
|
||
status.value, result, result_file_url, task_id
|
||
)
|
||
elif status == TaskStatus.FAILED:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE knowledge_processing_task
|
||
SET status = $1, error_message = $2, completed_at = CURRENT_TIMESTAMP
|
||
WHERE id = $3
|
||
""",
|
||
status.value, error_message, task_id
|
||
)
|
||
else:
|
||
await conn.execute(
|
||
"""
|
||
UPDATE knowledge_processing_task
|
||
SET status = $1
|
||
WHERE id = $2
|
||
""",
|
||
status.value, task_id
|
||
)
|
||
|
||
logger.info(f"任务 {task_id} 状态更新为: {status.value}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新任务状态失败: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
async def delete_task(
|
||
conn: asyncpg.Connection,
|
||
task_id: int,
|
||
user_id: int
|
||
) -> bool:
|
||
"""
|
||
删除任务(物理删除)
|
||
|
||
Args:
|
||
conn: 数据库连接
|
||
task_id: 任务 ID
|
||
user_id: 用户 ID
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
try:
|
||
result = await conn.execute(
|
||
"""
|
||
DELETE FROM knowledge_processing_task
|
||
WHERE id = $1 AND user_id = $2
|
||
""",
|
||
task_id, user_id
|
||
)
|
||
|
||
if result == "DELETE 1":
|
||
logger.info(f"用户 {user_id} 删除任务 {task_id}")
|
||
return True
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除任务失败: {e}")
|
||
return False
|
||
|
||
|
||
class KnowledgeProcessingExecutor:
|
||
"""知识加工执行器"""
|
||
|
||
@staticmethod
|
||
async def process_task(
|
||
conn: asyncpg.Connection,
|
||
task: KnowledgeProcessingTask
|
||
) -> Tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||
"""
|
||
执行知识加工任务
|
||
|
||
Returns:
|
||
Tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||
(是否成功, 结果JSON, 错误信息, 结果文件URL)
|
||
"""
|
||
try:
|
||
logger.info(f"开始处理任务 {task.id}: {task.task_name}, 类型: {task.task_type}")
|
||
|
||
# 1. 获取所有文件信息(含 file_path 供 OSS 下载)
|
||
file_records = []
|
||
for file_id in task.file_ids:
|
||
file_info = await conn.fetchrow(
|
||
"SELECT id, file_name, file_type, file_path FROM knowledge_base_file WHERE id = $1",
|
||
file_id
|
||
)
|
||
if not file_info:
|
||
logger.warning(f"文件 {file_id} 不存在")
|
||
continue
|
||
file_records.append(dict(file_info))
|
||
|
||
if not file_records:
|
||
return False, None, "没有找到有效的文件", None
|
||
|
||
# 2. 判断是否为表格合并任务(Excel/CSV 合并走专用逻辑)
|
||
all_table_files = all(
|
||
f".{r['file_type'].lower()}" in TABLE_EXTENSIONS for r in file_records
|
||
)
|
||
is_merge = task.task_type == TaskType.MERGE
|
||
|
||
if is_merge and all_table_files and len(file_records) >= 2:
|
||
logger.info(f"检测到表格合并任务,使用 pandas 实际合并文件")
|
||
result_json, file_url = await KnowledgeProcessingExecutor._process_table_merge(
|
||
task, file_records
|
||
)
|
||
logger.info(f"任务 {task.id} 表格合并完成,文件链接: {file_url}")
|
||
return True, result_json, None, file_url
|
||
|
||
# 3. 普通任务:通过 LLM 处理(需要读取文本 chunks)
|
||
file_contents = []
|
||
for record in file_records:
|
||
chunks = await KnowledgeBaseFileService.get_file_chunks_from_db(conn, record['id'])
|
||
if not chunks:
|
||
logger.warning(f"文件 {record['id']} 没有内容块")
|
||
continue
|
||
content = "\n\n".join([chunk['content'] for chunk in chunks])
|
||
summary = chunks[0].get('summary', '') if chunks else ''
|
||
file_contents.append({
|
||
'file_id': record['id'],
|
||
'file_name': record['file_name'],
|
||
'file_type': record['file_type'],
|
||
'content': content,
|
||
'summary': summary,
|
||
})
|
||
|
||
if not file_contents:
|
||
return False, None, "没有找到有效的文件内容", None
|
||
|
||
if task.task_type == TaskType.MERGE:
|
||
result = await KnowledgeProcessingExecutor._process_merge(task, file_contents)
|
||
elif task.task_type == TaskType.COMPARE:
|
||
result = await KnowledgeProcessingExecutor._process_compare(task, file_contents)
|
||
elif task.task_type == TaskType.SUMMARY:
|
||
result = await KnowledgeProcessingExecutor._process_summary(task, file_contents)
|
||
else:
|
||
result = await KnowledgeProcessingExecutor._process_custom(task, file_contents)
|
||
|
||
logger.info(f"任务 {task.id} 处理完成")
|
||
return True, result, None, None
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理任务失败: {e}")
|
||
import traceback
|
||
logger.error(f"错误堆栈: {traceback.format_exc()}")
|
||
return False, None, str(e), None
|
||
|
||
@staticmethod
|
||
async def _process_table_merge(
|
||
task: KnowledgeProcessingTask,
|
||
file_records: List[dict]
|
||
) -> Tuple[str, Optional[str]]:
|
||
"""
|
||
对 Excel / CSV 文件做真正的表格合并,生成新 Excel 并上传 OSS。
|
||
|
||
Returns:
|
||
(result_json, oss_file_url)
|
||
"""
|
||
import asyncio
|
||
import pandas as pd
|
||
from services.oss_service import get_oss_service
|
||
|
||
def _extract_oss_key(file_path: str, oss_service) -> str:
|
||
"""从完整 URL 或本地路径中提取 OSS Key"""
|
||
if file_path.startswith("http://") or file_path.startswith("https://"):
|
||
# 格式: https://{bucket}.{endpoint}/{key}
|
||
# 去掉协议和域名部分,保留 key
|
||
from urllib.parse import urlparse
|
||
parsed = urlparse(file_path)
|
||
# path 格式为 /kb_7/filename.csv,去掉开头的 /
|
||
return parsed.path.lstrip("/")
|
||
return file_path
|
||
|
||
def _do_merge() -> Tuple[bytes, str]:
|
||
"""在线程池中执行 pandas 合并(同步操作)"""
|
||
dfs = []
|
||
for record in file_records:
|
||
file_path = record['file_path'] # OSS URL 或本地路径
|
||
ext = f".{record['file_type'].lower()}"
|
||
|
||
oss = get_oss_service()
|
||
tmp_path = None
|
||
|
||
# 优先从 OSS 下载
|
||
if oss.enabled:
|
||
oss_key = _extract_oss_key(file_path, oss)
|
||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
|
||
tmp_path = tmp.name
|
||
try:
|
||
oss.download_file(oss_key, tmp_path)
|
||
read_path = tmp_path
|
||
logger.info(f"OSS 下载成功: {oss_key} -> {tmp_path}")
|
||
except Exception as e:
|
||
logger.warning(f"OSS 下载失败 (key={oss_key}): {e}")
|
||
# 如果是本地路径则直接读取
|
||
if os.path.isfile(file_path):
|
||
read_path = file_path
|
||
else:
|
||
raise ValueError(f"无法获取文件 {record['file_name']}:OSS 下载失败且本地文件不存在") from e
|
||
elif os.path.isfile(file_path):
|
||
read_path = file_path
|
||
else:
|
||
raise ValueError(f"OSS 未启用且本地文件不存在: {file_path}")
|
||
|
||
try:
|
||
if ext == '.csv':
|
||
df = pd.read_csv(read_path, encoding='utf-8')
|
||
else:
|
||
df = pd.read_excel(read_path)
|
||
# 增加来源列,方便区分
|
||
df['_来源文件'] = record['file_name']
|
||
dfs.append(df)
|
||
logger.info(f"读取文件 {record['file_name']},{len(df)} 行")
|
||
finally:
|
||
if tmp_path and os.path.exists(tmp_path):
|
||
os.remove(tmp_path)
|
||
|
||
if not dfs:
|
||
raise ValueError("所有文件读取失败,无法合并")
|
||
|
||
merged_df = pd.concat(dfs, ignore_index=True)
|
||
|
||
# 输出为 Excel 字节流
|
||
buf = io.BytesIO()
|
||
with pd.ExcelWriter(buf, engine='openpyxl') as writer:
|
||
merged_df.to_excel(writer, index=False, sheet_name='合并结果')
|
||
buf.seek(0)
|
||
excel_bytes = buf.read()
|
||
|
||
# 上传到 OSS
|
||
oss = get_oss_service()
|
||
file_name = f"merged_{uuid.uuid4().hex[:8]}.xlsx"
|
||
oss_key = f"processing_results/{file_name}"
|
||
file_url = None
|
||
if oss.enabled:
|
||
file_url = oss.upload_file_from_bytes(excel_bytes, oss_key, file_name)
|
||
logger.info(f"合并结果已上传 OSS: {file_url}")
|
||
else:
|
||
logger.warning("OSS 未启用,合并文件未上传")
|
||
|
||
return excel_bytes, file_url, len(merged_df), file_name
|
||
|
||
excel_bytes, file_url, row_count, output_name = await asyncio.to_thread(_do_merge)
|
||
|
||
result = {
|
||
"type": "table_merge",
|
||
"file_count": len(file_records),
|
||
"files": [{"file_id": r['id'], "file_name": r['file_name']} for r in file_records],
|
||
"merged_rows": row_count,
|
||
"output_file": output_name,
|
||
"download_url": file_url,
|
||
}
|
||
return json.dumps(result, ensure_ascii=False), file_url
|
||
|
||
@staticmethod
|
||
async def _process_merge(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
|
||
"""
|
||
处理文件合并任务
|
||
|
||
Args:
|
||
task: 任务对象
|
||
file_contents: 文件内容列表
|
||
|
||
Returns:
|
||
str: 合并结果(JSON格式)
|
||
"""
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
from core.llm_catalog import build_chat_model
|
||
|
||
logger.info(f"执行合并任务: {task.task_name}")
|
||
|
||
# 构建 prompt
|
||
files_text = ""
|
||
for idx, file_data in enumerate(file_contents, 1):
|
||
files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n"
|
||
if file_data['summary']:
|
||
files_text += f"摘要: {file_data['summary']}\n\n"
|
||
files_text += f"内容:\n{file_data['content']}\n"
|
||
files_text += "=" * 80
|
||
|
||
prompt = f"""你是一个文档处理助手。用户需要合并多个文件。
|
||
|
||
用户指令:{task.instruction}
|
||
|
||
{files_text}
|
||
|
||
请按照用户的指令,将上述文件合并成一个逻辑通顺的文档。注意:
|
||
1. 去除重复内容
|
||
2. 保持结构清晰
|
||
3. 确保内容连贯
|
||
4. 保留所有关键信息
|
||
|
||
请直接输出合并后的内容,不要添加额外的说明。"""
|
||
|
||
llm = build_chat_model(
|
||
provider="deepseek",
|
||
api_model="deepseek-chat",
|
||
streaming=False,
|
||
temperature=0.3,
|
||
)
|
||
|
||
messages = [
|
||
SystemMessage(content="你是一个专业的文档处理助手,擅长合并、对比和总结文档。"),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
|
||
response = await llm.ainvoke(messages)
|
||
merged_content = response.content
|
||
|
||
# 返回 JSON 格式的结果
|
||
result = {
|
||
"type": "merge",
|
||
"file_count": len(file_contents),
|
||
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
|
||
"merged_content": merged_content
|
||
}
|
||
|
||
return json.dumps(result, ensure_ascii=False)
|
||
|
||
@staticmethod
|
||
async def _process_compare(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
|
||
"""
|
||
处理文件对比任务
|
||
|
||
Args:
|
||
task: 任务对象
|
||
file_contents: 文件内容列表
|
||
|
||
Returns:
|
||
str: 对比结果(JSON格式)
|
||
"""
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
from core.llm_catalog import build_chat_model
|
||
|
||
logger.info(f"执行对比任务: {task.task_name}")
|
||
|
||
# 构建 prompt
|
||
files_text = ""
|
||
for idx, file_data in enumerate(file_contents, 1):
|
||
files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n"
|
||
if file_data['summary']:
|
||
files_text += f"摘要: {file_data['summary']}\n\n"
|
||
files_text += f"内容:\n{file_data['content']}\n"
|
||
files_text += "=" * 80
|
||
|
||
prompt = f"""你是一个文档对比分析助手。用户需要对比分析多个文件。
|
||
|
||
用户指令:{task.instruction}
|
||
|
||
{files_text}
|
||
|
||
请按照用户的指令,对上述文件进行对比分析。请从以下几个维度分析:
|
||
1. 相似之处:列出文件之间的共同点
|
||
2. 差异之处:列出文件之间的不同点
|
||
3. 独特内容:每个文件独有的内容
|
||
4. 综合分析:整体对比总结
|
||
|
||
请使用清晰的结构化格式输出结果。"""
|
||
|
||
llm = build_chat_model(
|
||
provider="deepseek",
|
||
api_model="deepseek-chat",
|
||
streaming=False,
|
||
temperature=0.3,
|
||
)
|
||
|
||
messages = [
|
||
SystemMessage(content="你是一个专业的文档对比分析助手,擅长发现文档之间的异同点。"),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
|
||
response = await llm.ainvoke(messages)
|
||
comparison_result = response.content
|
||
|
||
# 返回 JSON 格式的结果
|
||
result = {
|
||
"type": "compare",
|
||
"file_count": len(file_contents),
|
||
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
|
||
"comparison": comparison_result
|
||
}
|
||
|
||
return json.dumps(result, ensure_ascii=False)
|
||
|
||
@staticmethod
|
||
async def _process_summary(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
|
||
"""
|
||
处理文件总结任务
|
||
|
||
Args:
|
||
task: 任务对象
|
||
file_contents: 文件内容列表
|
||
|
||
Returns:
|
||
str: 总结结果(JSON格式)
|
||
"""
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
from core.llm_catalog import build_chat_model
|
||
|
||
logger.info(f"执行总结任务: {task.task_name}")
|
||
|
||
# 构建 prompt
|
||
files_text = ""
|
||
for idx, file_data in enumerate(file_contents, 1):
|
||
files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n"
|
||
if file_data['summary']:
|
||
files_text += f"摘要: {file_data['summary']}\n\n"
|
||
files_text += f"内容:\n{file_data['content']}\n"
|
||
files_text += "=" * 80
|
||
|
||
prompt = f"""你是一个文档总结助手。用户需要总结多个文件的内容。
|
||
|
||
用户指令:{task.instruction}
|
||
|
||
{files_text}
|
||
|
||
请按照用户的指令,对上述文件进行总结。请包含:
|
||
1. 每个文件的核心内容
|
||
2. 整体主题和要点
|
||
3. 关键信息提炼
|
||
4. 综合总结
|
||
|
||
请使用清晰的结构化格式输出结果。"""
|
||
|
||
llm = build_chat_model(
|
||
provider="deepseek",
|
||
api_model="deepseek-chat",
|
||
streaming=False,
|
||
temperature=0.3,
|
||
)
|
||
|
||
messages = [
|
||
SystemMessage(content="你是一个专业的文档总结助手,擅长提炼关键信息和核心要点。"),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
|
||
response = await llm.ainvoke(messages)
|
||
summary_result = response.content
|
||
|
||
# 返回 JSON 格式的结果
|
||
result = {
|
||
"type": "summary",
|
||
"file_count": len(file_contents),
|
||
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
|
||
"summary": summary_result
|
||
}
|
||
|
||
return json.dumps(result, ensure_ascii=False)
|
||
|
||
@staticmethod
|
||
async def _process_custom(task: KnowledgeProcessingTask, file_contents: List[dict]) -> str:
|
||
"""
|
||
处理自定义任务
|
||
|
||
Args:
|
||
task: 任务对象
|
||
file_contents: 文件内容列表
|
||
|
||
Returns:
|
||
str: 处理结果(JSON格式)
|
||
"""
|
||
from langchain_core.messages import HumanMessage, SystemMessage
|
||
from core.llm_catalog import build_chat_model
|
||
|
||
logger.info(f"执行自定义任务: {task.task_name}")
|
||
|
||
# 构建 prompt
|
||
files_text = ""
|
||
for idx, file_data in enumerate(file_contents, 1):
|
||
files_text += f"\n\n【文件{idx}: {file_data['file_name']}】\n"
|
||
if file_data['summary']:
|
||
files_text += f"摘要: {file_data['summary']}\n\n"
|
||
files_text += f"内容:\n{file_data['content']}\n"
|
||
files_text += "=" * 80
|
||
|
||
prompt = f"""你是一个文档处理助手。用户给出了以下文件和指令。
|
||
|
||
用户指令:{task.instruction}
|
||
|
||
{files_text}
|
||
|
||
请严格按照用户的指令执行处理,并输出结果。"""
|
||
|
||
llm = build_chat_model(
|
||
provider="deepseek",
|
||
api_model="deepseek-chat",
|
||
streaming=False,
|
||
temperature=0.5,
|
||
)
|
||
|
||
messages = [
|
||
SystemMessage(content="你是一个专业的文档处理助手,能够根据用户指令灵活处理各种文档任务。"),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
|
||
response = await llm.ainvoke(messages)
|
||
custom_result = response.content
|
||
|
||
# 返回 JSON 格式的结果
|
||
result = {
|
||
"type": "custom",
|
||
"file_count": len(file_contents),
|
||
"files": [{"file_id": f['file_id'], "file_name": f['file_name']} for f in file_contents],
|
||
"result": custom_result
|
||
}
|
||
|
||
return json.dumps(result, ensure_ascii=False)
|