759 lines
29 KiB
Python
759 lines
29 KiB
Python
"""
|
||
工具模块
|
||
|
||
定义各种 AI 工具函数,包括网络搜索、文生图、文生视频、RAG 检索等。
|
||
"""
|
||
import os
|
||
import time
|
||
import uuid
|
||
import requests
|
||
from typing import Literal, Optional
|
||
from openai import OpenAI
|
||
from langchain.tools import tool
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import json
|
||
|
||
from tavily import TavilyClient
|
||
from core.config import settings
|
||
from utils.datetime_utils import format_beijing_time_for_agent
|
||
from logger.logging import get_logger
|
||
from services.vector_service import get_vector_service
|
||
from services.oss_service import get_oss_service
|
||
|
||
# 获取日志记录器
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
# 初始化 Tavily 客户端
|
||
tavily_client = TavilyClient(api_key=settings.tavily_api_key)
|
||
|
||
|
||
@tool
|
||
def get_current_time() -> str:
|
||
"""
|
||
获取当前中国北京时间(东八区)。
|
||
|
||
当用户询问现在几点、今天日期、星期几、或需要当前时间作为参考时调用此工具。
|
||
"""
|
||
return format_beijing_time_for_agent()
|
||
|
||
|
||
def internet_search(
|
||
query: str,
|
||
max_results: int = 5,
|
||
topic: Literal["general", "news", "finance"] = "general",
|
||
include_raw_content: bool = False,
|
||
):
|
||
"""Run a web search"""
|
||
return tavily_client.search(
|
||
query,
|
||
max_results=max_results,
|
||
include_raw_content=include_raw_content,
|
||
topic=topic,
|
||
)
|
||
|
||
|
||
def _download_and_upload_image_to_oss(image_url: str, image_index: int) -> tuple[str, Optional[str], float, float]:
|
||
"""
|
||
下载单张图片并上传到 OSS
|
||
|
||
Args:
|
||
image_url: 原始图片 URL
|
||
image_index: 图片索引(用于日志)
|
||
|
||
Returns:
|
||
tuple: (原始URL, OSS URL, 下载耗时, 上传耗时)
|
||
"""
|
||
upload_start_time = time.time()
|
||
try:
|
||
logger.info(f"开始下载图片 {image_index}:{image_url}")
|
||
|
||
# 下载图片内容
|
||
download_start = time.time()
|
||
response = requests.get(image_url, timeout=300) # 5分钟超时
|
||
response.raise_for_status()
|
||
image_content = response.content
|
||
download_time = time.time() - download_start
|
||
logger.info(f"图片 {image_index} 下载完成,耗时:{download_time:.2f} 秒,大小:{len(image_content) / 1024 / 1024:.2f} MB")
|
||
|
||
# 生成 OSS 对象名称
|
||
timestamp = int(time.time())
|
||
unique_id = str(uuid.uuid4())[:8]
|
||
# 根据图片内容判断文件扩展名
|
||
content_type = response.headers.get('Content-Type', 'image/png')
|
||
if 'jpeg' in content_type or 'jpg' in content_type:
|
||
ext = 'jpg'
|
||
elif 'png' in content_type:
|
||
ext = 'png'
|
||
elif 'webp' in content_type:
|
||
ext = 'webp'
|
||
else:
|
||
ext = 'png' # 默认使用 png
|
||
oss_object_name = f"images/{timestamp}_{unique_id}_{image_index}.{ext}"
|
||
|
||
# 上传到 OSS
|
||
upload_start = time.time()
|
||
oss_service = get_oss_service()
|
||
oss_url = oss_service.upload_file_from_bytes(
|
||
file_content=image_content,
|
||
oss_object_name=oss_object_name,
|
||
file_name=f"generated_image_{image_index}.{ext}"
|
||
)
|
||
upload_time = time.time() - upload_start
|
||
|
||
total_time = time.time() - upload_start_time
|
||
|
||
if oss_url:
|
||
logger.info(f"✅ 图片 {image_index} 已上传到 OSS:{oss_url}")
|
||
logger.info(f"📊 图片 {image_index} 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒")
|
||
return image_url, oss_url, download_time, upload_time
|
||
else:
|
||
logger.warning(f"⚠️ 图片 {image_index} OSS 上传失败,使用原始 URL")
|
||
logger.warning(f"📊 图片 {image_index} 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒(上传失败)")
|
||
return image_url, None, download_time, upload_time
|
||
|
||
except Exception as e:
|
||
error_msg = f"❌ 图片 {image_index} 上传到 OSS 失败:{str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return image_url, None, 0.0, 0.0
|
||
|
||
|
||
def _normalize_openai_image_size(size: str) -> str:
|
||
"""将「宽*高」转为 OpenAI 兼容接口要求的「宽x高」。"""
|
||
return (size or "").strip().replace("*", "x")
|
||
|
||
|
||
@tool
|
||
def text_to_image(
|
||
prompt: str,
|
||
negative_prompt: str = "",
|
||
size: str = "1024x1024",
|
||
n: int = 1,
|
||
prompt_extend: bool = True,
|
||
watermark: bool = False,
|
||
seed: Optional[int] = None,
|
||
) -> str:
|
||
"""
|
||
文生图工具:根据文本描述生成高质量图片。
|
||
|
||
通过 OpenAI 兼容接口 ``/v1/images/generations`` 调用通义千问文生图模型(如同步返回的 Qwen-Image 系列)。
|
||
生成的图片会自动上传到 OSS,返回永久可访问的 URL。
|
||
|
||
使用场景:
|
||
- 用户说「生成一张…的图片」「画一个…」「创建…的图像」
|
||
- 需要配图、插图或可视化某个概念/场景
|
||
|
||
参数说明:
|
||
prompt: 正向提示词(建议详细描述主体、场景、风格、光线与构图)。
|
||
negative_prompt: 反向提示词;不需要可留空。
|
||
size: 分辨率,OpenAI 标准 ``宽x高``(也可用 ``宽*高``,会自动转为 ``x``)。
|
||
qwen-image-2.0 系列常用:``2048x2048``(默认 1:1)、``2688x1536``(16:9)、
|
||
``1536x2688``(9:16)、``2368x1728``(4:3)、``1728x2368``(3:4)。
|
||
plus / qwen-image 系列可参考:``1664x928``、``1472x1104``、``1328x1328`` 等。
|
||
n: 生成张数;qwen-image-2.0 系列多为 1–6,其它系列常为 1。
|
||
prompt_extend: 是否开启提示词智能改写(True 时模型会润色提示词)。
|
||
watermark: 是否在图像上加 Qwen-Image 水印。
|
||
seed: 随机种子 [0, 2147483647];不传则由服务端随机。
|
||
|
||
返回值:
|
||
Markdown 文本,包含可供渲染的图片链接(OSS 永久 URL,失败时回退临时 URL)。
|
||
|
||
注意:
|
||
平台返回的原始图像 URL 通常约 24 小时有效;本工具会尽快转存 OSS。
|
||
"""
|
||
try:
|
||
api_key = (os.getenv("ZL_DASHSCOPE_API_KEY") or "").strip()
|
||
base_url = (os.getenv("ZL_DASHSCOPE_API_BASE") or "").strip().rstrip("/")
|
||
if not api_key:
|
||
return "错误:未配置 DASHSCOPE_API_KEY 环境变量"
|
||
if not base_url:
|
||
return "错误:未配置 DASHSCOPE_API_BASE 环境变量"
|
||
|
||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||
model_image = "qwen-image-2.0"
|
||
|
||
size_norm = _normalize_openai_image_size(size)
|
||
n_req = max(1, min(int(n), 6))
|
||
|
||
extra_body: dict = {
|
||
"prompt_extend": prompt_extend,
|
||
"watermark": watermark,
|
||
}
|
||
np = (negative_prompt or "").strip()
|
||
if np:
|
||
extra_body["negative_prompt"] = np
|
||
if seed is not None:
|
||
extra_body["seed"] = seed
|
||
|
||
logger.info(
|
||
f"开始生成图片(OpenAI 兼容 images/generations),model={model_image}, n={n_req}, size={size_norm}"
|
||
)
|
||
|
||
response = client.images.generate(
|
||
model=model_image,
|
||
prompt=prompt,
|
||
size=size_norm,
|
||
n=n_req,
|
||
extra_body=extra_body,
|
||
)
|
||
|
||
image_urls: list[str] = []
|
||
for item in response.data or []:
|
||
url = getattr(item, "url", None)
|
||
if url:
|
||
image_urls.append(url)
|
||
|
||
if not image_urls:
|
||
return "图片生成完成,但未获取到图片URL"
|
||
|
||
logger.info(f"图片生成成功,共 {len(image_urls)} 张图片,开始上传到 OSS")
|
||
|
||
# 使用多线程下载并上传图片到 OSS
|
||
oss_urls = []
|
||
total_start_time = time.time()
|
||
|
||
if len(image_urls) == 1:
|
||
# 单张图片,直接处理
|
||
_, oss_url, _, _ = _download_and_upload_image_to_oss(image_urls[0], 1)
|
||
oss_urls.append(oss_url if oss_url else image_urls[0])
|
||
else:
|
||
# 多张图片,使用多线程并行处理
|
||
with ThreadPoolExecutor(max_workers=min(len(image_urls), 5)) as executor:
|
||
# 提交所有任务
|
||
future_to_index = {
|
||
executor.submit(_download_and_upload_image_to_oss, url, idx + 1): idx
|
||
for idx, url in enumerate(image_urls)
|
||
}
|
||
|
||
# 收集结果(保持顺序)
|
||
results = [None] * len(image_urls)
|
||
for future in as_completed(future_to_index):
|
||
idx = future_to_index[future]
|
||
try:
|
||
results[idx] = future.result()
|
||
except Exception as e:
|
||
logger.error(f"图片 {idx + 1} 处理异常:{e}", exc_info=True)
|
||
results[idx] = (image_urls[idx], None, 0.0, 0.0)
|
||
|
||
# 按顺序提取 OSS URL
|
||
for original_url, oss_url, _, _ in results:
|
||
oss_urls.append(oss_url if oss_url else original_url)
|
||
|
||
total_time = time.time() - total_start_time
|
||
logger.info(f"✅ 所有图片处理完成,总耗时:{total_time:.2f} 秒")
|
||
|
||
# 构建返回信息(使用 markdown 格式以便前端正确显示图片)
|
||
result_text = f"图片生成成功!共生成 {len(oss_urls)} 张图片,以下是图片连接,请使用 markdown 格式渲染这些图片。\n\n"
|
||
for idx, url in enumerate(oss_urls, 1):
|
||
# 使用 markdown 图片语法,这样前端可以正确渲染
|
||
result_text += f"{url}\n\n"
|
||
|
||
return result_text
|
||
|
||
except Exception as e:
|
||
error_msg = f"生成图片时发生错误: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return error_msg
|
||
|
||
|
||
from typing import Optional
|
||
from langchain_core.tools import tool
|
||
import logging
|
||
import uuid
|
||
import requests
|
||
from services.oss_service import get_oss_service
|
||
|
||
|
||
@tool
|
||
def text_to_video(
|
||
prompt: str,
|
||
negative_prompt: str = "",
|
||
size: str = "1280x720",
|
||
duration: int = 5,
|
||
) -> str:
|
||
"""
|
||
文生视频工具:根据文本描述生成动态视频。
|
||
|
||
通过 OpenAI 兼容接口(``POST /v1/videos`` / ``GET /v1/videos/{id}``)调用万相文生视频模型
|
||
``wan2.6-t2v``。任务为异步队列:queued → in_progress → completed / failed;
|
||
SDK 使用 ``create_and_poll`` 自动轮询直至结束。
|
||
生成的视频会自动下载并上传到 OSS,返回永久可访问的播放地址。
|
||
|
||
使用场景:
|
||
- 用户说「生成一个…的视频」「做一个短视频」「需要动态画面」
|
||
- 产品演示、营销素材、社交媒体短视频脚本可视化
|
||
|
||
参数说明:
|
||
prompt: 视频内容描述(中英文均可,建议写清主体、动作、镜头与风格)。
|
||
negative_prompt: 不希望出现的内容;若填写则附加到提示中约束生成方向。
|
||
size: 分辨率,OpenAI 标准 ``宽x高``(也可用 ``宽*高``,会自动转为 ``x``)。
|
||
例如 ``1280x720``、``1920x1080``、``720x1280``(竖屏)等。
|
||
duration: 时长(秒),平台支持 **2–15** 秒,超出范围会自动裁剪到该区间。
|
||
|
||
返回值:
|
||
包含 ``<video>`` 的 HTML 片段,可直接播放;优先使用 OSS URL。
|
||
平台返回的原始视频链接通常约 24 小时内有效,请及时依赖 OSS 结果。
|
||
|
||
注意事项:
|
||
- 生成耗时常为数十秒至数分钟,请耐心等待轮询完成
|
||
- 提示词越具体,成片越可控
|
||
"""
|
||
try:
|
||
api_key = (os.getenv("ZL_DASHSCOPE_API_KEY") or "").strip()
|
||
base_url = (os.getenv("ZL_DASHSCOPE_API_BASE") or "").strip().rstrip("/")
|
||
if not api_key:
|
||
return "错误:未配置 DASHSCOPE_API_KEY 环境变量"
|
||
if not base_url:
|
||
return "错误:未配置 DASHSCOPE_API_BASE 环境变量"
|
||
|
||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||
model_video = "wan2.6-t2v"
|
||
|
||
size_norm = _normalize_openai_image_size(size)
|
||
seconds = max(2, min(int(duration), 15))
|
||
|
||
full_prompt = prompt
|
||
np = (negative_prompt or "").strip()
|
||
if np:
|
||
full_prompt = f"{prompt}\n(避免出现以下内容:{np})"
|
||
|
||
resolution = "1080P" if "1920" in size_norm else "720P"
|
||
extra_body = {"resolution": resolution, "duration": seconds}
|
||
|
||
start = time.time()
|
||
logger.info(
|
||
f"提交文生视频任务(videos API):model={model_video}, size={size_norm}, seconds={seconds}"
|
||
)
|
||
|
||
result = client.videos.create_and_poll(
|
||
model=model_video,
|
||
prompt=full_prompt,
|
||
size=size_norm,
|
||
seconds=seconds,
|
||
poll_interval_ms=5000,
|
||
extra_body=extra_body,
|
||
)
|
||
|
||
status = getattr(result, "status", None)
|
||
video_url = getattr(result, "url", None) or ""
|
||
|
||
if status != "completed" or not video_url:
|
||
err_obj = getattr(result, "error", None)
|
||
err_msg = getattr(err_obj, "message", None) if err_obj else None
|
||
if not err_msg:
|
||
err_msg = str(status) if status else "未知状态"
|
||
logger.error(f"视频任务未成功:status={status}, error={err_msg}")
|
||
return f"视频生成失败:{err_msg}"
|
||
|
||
logger.info(f"视频任务完成,耗时 {time.time() - start:.1f}s,开始下载并上传 OSS:{video_url}")
|
||
|
||
try:
|
||
upload_start_time = time.time()
|
||
download_start = time.time()
|
||
response = requests.get(video_url, timeout=300)
|
||
response.raise_for_status()
|
||
video_content = response.content
|
||
download_time = time.time() - download_start
|
||
logger.info(
|
||
f"视频下载完成,耗时:{download_time:.2f} 秒,大小:{len(video_content) / 1024 / 1024:.2f} MB"
|
||
)
|
||
|
||
timestamp = int(time.time())
|
||
unique_id = str(uuid.uuid4())[:8]
|
||
oss_object_name = f"videos/{timestamp}_{unique_id}.mp4"
|
||
|
||
upload_start = time.time()
|
||
oss_service = get_oss_service()
|
||
oss_url = oss_service.upload_file_from_bytes(
|
||
file_content=video_content,
|
||
oss_object_name=oss_object_name,
|
||
file_name="generated_video.mp4",
|
||
)
|
||
upload_time = time.time() - upload_start
|
||
total_time = time.time() - upload_start_time
|
||
|
||
if oss_url:
|
||
logger.info(f"✅ 视频已上传到 OSS:{oss_url}")
|
||
logger.info(
|
||
f"📊 上传统计 - 下载:{download_time:.2f}s,上传:{upload_time:.2f}s,总计:{total_time:.2f}s"
|
||
)
|
||
video_url = oss_url
|
||
else:
|
||
logger.warning("⚠️ OSS 上传失败,使用平台返回的临时视频 URL")
|
||
except Exception as upload_error:
|
||
logger.error(f"❌ 上传视频到 OSS 失败:{upload_error}", exc_info=True)
|
||
logger.warning("⚠️ 使用原始临时视频 URL")
|
||
|
||
logger.info(f"✅ 视频生成流程结束:{video_url}")
|
||
return f"""<video controls width="100%" style="max-width: 600px;">
|
||
<source src="{video_url}" type="video/mp4">
|
||
您的浏览器不支持视频播放
|
||
</video>"""
|
||
|
||
except Exception as e:
|
||
error_msg = f"❌ 生成视频异常:{str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return error_msg
|
||
|
||
|
||
@tool
|
||
def text_to_poster(
|
||
title: str,
|
||
sub_title: str = "",
|
||
body_text: str = "",
|
||
prompt_text_zh: str = "",
|
||
prompt_text_en: str = "",
|
||
size: str = "2048x2048",
|
||
) -> str:
|
||
"""
|
||
创意海报生成工具:根据标题、副标题和正文生成海报图。
|
||
|
||
与 ``text_to_image`` 相同,通过 OpenAI 兼容接口 ``/v1/images/generations``
|
||
调用通义千问文生图模型生成海报;结果上传到 OSS 后以 Markdown 图片链接返回。
|
||
海报场景默认 **打水印**、开启提示词扩展,并附带画质相关反向提示。
|
||
|
||
使用场景:
|
||
- 「生成一张…海报」「活动宣传图」「营销配图」等
|
||
|
||
参数说明:
|
||
title: 主标题。
|
||
sub_title: 副标题。
|
||
body_text: 正文(过长时在 prompt 内截取约 200 字摘要)。
|
||
prompt_text_zh / prompt_text_en: 视觉风格描述;均未填则用默认「高质量海报风格」。
|
||
size: 分辨率 ``宽x高``(也可用 ``宽*高``);正方形海报常用 ``2048x2048``,
|
||
亦可选 ``2688x1536``、``1536x2688`` 等与 ``text_to_image`` 一致的取值。
|
||
|
||
返回值:
|
||
Markdown 文本:海报图片 URL(OSS)及标题信息。
|
||
|
||
注意:
|
||
原始生成 URL 有效期有限,请以 OSS 链接为准。
|
||
"""
|
||
try:
|
||
api_key = (os.getenv("ZL_DASHSCOPE_API_KEY") or "").strip()
|
||
base_url = (os.getenv("ZL_DASHSCOPE_API_BASE") or "").strip().rstrip("/")
|
||
if not api_key:
|
||
return "错误:未配置 DASHSCOPE_API_KEY 环境变量"
|
||
if not base_url:
|
||
return "错误:未配置 DASHSCOPE_API_BASE 环境变量"
|
||
|
||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||
model_image = "qwen-image-2.0"
|
||
|
||
logger.info(
|
||
f"开始生成创意海报(OpenAI 兼容 images/generations),title: {title}, "
|
||
f"sub_title: {sub_title}, body_text: {(body_text[:50] + '...') if len(body_text) > 50 else body_text}"
|
||
)
|
||
|
||
prompt_parts = ["创意海报设计,宣传海报,专业排版,醒目吸睛,文字清晰可读"]
|
||
|
||
if title:
|
||
prompt_parts.append(f"主标题:{title}")
|
||
if sub_title:
|
||
prompt_parts.append(f"副标题:{sub_title}")
|
||
if body_text:
|
||
body_summary = body_text.replace("\n", " ")[:200]
|
||
if len(body_text) > 200:
|
||
body_summary += "..."
|
||
prompt_parts.append(f"正文内容:{body_summary}")
|
||
|
||
if prompt_text_zh:
|
||
prompt_parts.append(f"视觉风格:{prompt_text_zh}")
|
||
elif prompt_text_en:
|
||
prompt_parts.append(f"Visual style: {prompt_text_en}")
|
||
else:
|
||
prompt_parts.append("精美设计,高质量海报风格")
|
||
|
||
prompt = ",".join(prompt_parts)
|
||
logger.info(f"海报生成 prompt: {prompt[:200]}...")
|
||
|
||
size_norm = _normalize_openai_image_size(size)
|
||
negative = "低分辨率,低画质,画面模糊,文字扭曲,构图混乱,画面过饱和"
|
||
extra_body: dict = {
|
||
"prompt_extend": True,
|
||
"watermark": True,
|
||
"negative_prompt": negative,
|
||
}
|
||
|
||
response = client.images.generate(
|
||
model=model_image,
|
||
prompt=prompt,
|
||
size=size_norm,
|
||
n=1,
|
||
extra_body=extra_body,
|
||
)
|
||
|
||
image_urls: list[str] = []
|
||
for item in response.data or []:
|
||
url = getattr(item, "url", None)
|
||
if url:
|
||
image_urls.append(url)
|
||
|
||
if not image_urls:
|
||
return "海报生成完成,但未获取到图片URL"
|
||
|
||
image_url = image_urls[0]
|
||
logger.info(f"海报生成成功,图片URL: {image_url},开始上传到 OSS")
|
||
|
||
_, oss_url, _, _ = _download_and_upload_image_to_oss(image_url, 1)
|
||
final_url = oss_url if oss_url else image_url
|
||
|
||
result_text = f"创意海报生成成功!\n\n{final_url}\n\n"
|
||
result_text += f"**标题**:{title}\n"
|
||
if sub_title:
|
||
result_text += f"**副标题**:{sub_title}\n"
|
||
if body_text:
|
||
result_text += f"**正文**:{body_text}\n"
|
||
|
||
logger.info("✅ 海报生成完成")
|
||
return result_text
|
||
|
||
except Exception as e:
|
||
error_msg = f"❌ 生成海报异常:{str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return error_msg
|
||
|
||
|
||
def create_rag_retrieve_tool(thread_id: str):
|
||
"""
|
||
创建 RAG 检索工具(用于对话文件)
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
tool: RAG 检索工具
|
||
"""
|
||
vector_service = get_vector_service()
|
||
|
||
@tool(response_format="content_and_artifact")
|
||
def retrieve_context_from_files(query: str):
|
||
"""
|
||
从用户上传的文件中检索相关信息来帮助回答问题。
|
||
|
||
当用户的问题涉及到上传的文件内容时,使用此工具检索相关文档片段。
|
||
例如:用户上传了PDF文件后,询问文件中的具体内容、数据、概念等。
|
||
|
||
Args:
|
||
query: 用户的查询问题
|
||
|
||
Returns:
|
||
tuple: (检索到的文档内容字符串, 检索结果列表)
|
||
"""
|
||
try:
|
||
# 使用向量服务搜索相似文档
|
||
results = vector_service.search_similar_in_thread(
|
||
thread_id=thread_id,
|
||
query=query,
|
||
k=5 # 返回最相关的5个文档片段
|
||
)
|
||
|
||
if not results:
|
||
return "未在文件中找到相关信息。", []
|
||
|
||
# 格式化检索结果
|
||
content_parts = []
|
||
for idx, result in enumerate(results, 1):
|
||
content = result.get("content", "")
|
||
metadata = result.get("metadata", {})
|
||
score = result.get("score", 0)
|
||
|
||
# 构建来源信息
|
||
source_info = []
|
||
if metadata:
|
||
if "source" in metadata:
|
||
source_info.append(f"来源: {metadata['source']}")
|
||
if "page" in metadata:
|
||
source_info.append(f"页码: {metadata['page']}")
|
||
|
||
source_str = f" ({', '.join(source_info)})" if source_info else ""
|
||
|
||
content_parts.append(
|
||
f"[文档片段 {idx}]{source_str}\n{content}\n"
|
||
)
|
||
|
||
content = "\n".join(content_parts)
|
||
return content, results
|
||
|
||
except Exception as e:
|
||
logger.error(f"RAG 检索失败: {e}")
|
||
return f"检索文件内容时出错: {str(e)}", []
|
||
|
||
return retrieve_context_from_files
|
||
|
||
|
||
def create_kb_rag_retrieve_tool(knowledge_base_id: int):
|
||
"""
|
||
创建知识库 RAG 检索工具
|
||
|
||
Args:
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
tool: 知识库 RAG 检索工具
|
||
"""
|
||
vector_service = get_vector_service()
|
||
|
||
@tool(response_format="content_and_artifact")
|
||
def retrieve_context_from_knowledge_base(query: str):
|
||
"""
|
||
从知识库中检索相关信息来帮助回答问题。
|
||
|
||
当用户的问题涉及到知识库中的内容时,使用此工具检索相关文档片段。
|
||
知识库包含了用户预先上传和整理的文件内容。
|
||
|
||
Args:
|
||
query: 用户的查询问题
|
||
|
||
Returns:
|
||
tuple: (检索到的文档内容字符串, 检索结果列表)
|
||
"""
|
||
try:
|
||
# 使用向量服务搜索知识库中的相似文档
|
||
results = vector_service.search_similar(
|
||
knowledge_base_id=knowledge_base_id,
|
||
query=query,
|
||
k=5 # 返回最相关的5个文档片段
|
||
)
|
||
|
||
if not results:
|
||
return "未在知识库中找到相关信息。", []
|
||
|
||
# 格式化检索结果
|
||
content_parts = []
|
||
for idx, result in enumerate(results, 1):
|
||
content = result.get("content", "")
|
||
metadata = result.get("metadata", {})
|
||
score = result.get("score", 0)
|
||
|
||
# 构建来源信息
|
||
source_info = []
|
||
if metadata:
|
||
if "source" in metadata:
|
||
source_info.append(f"来源: {metadata['source']}")
|
||
if "page" in metadata:
|
||
source_info.append(f"页码: {metadata['page']}")
|
||
|
||
source_str = f" ({', '.join(source_info)})" if source_info else ""
|
||
|
||
content_parts.append(
|
||
f"[知识库文档片段 {idx}]{source_str}\n{content}\n"
|
||
)
|
||
|
||
content = "\n".join(content_parts)
|
||
return content, results
|
||
|
||
except Exception as e:
|
||
logger.error(f"知识库 RAG 检索失败: {e}")
|
||
return f"检索知识库内容时出错: {str(e)}", []
|
||
|
||
return retrieve_context_from_knowledge_base
|
||
|
||
|
||
def create_knowledge_graph_rag_retrieve_tool(knowledge_graph_pk: int):
|
||
"""
|
||
创建「知识图谱」绑定的正文向量检索工具(与 Neo4j 实体关系互补)。
|
||
"""
|
||
vector_service = get_vector_service()
|
||
|
||
@tool(response_format="content_and_artifact")
|
||
def retrieve_context_from_knowledge_graph(query: str):
|
||
"""
|
||
从用户选中的知识图谱资料正文中检索相关片段,用于回答细节、原文依据等问题。
|
||
|
||
当问题涉及资料内容、叙述、对话、描写而非仅关系网络时,应使用本工具。
|
||
|
||
Args:
|
||
query: 检索查询(可与用户问题同义改写)
|
||
|
||
Returns:
|
||
tuple: (检索到的文本片段拼接字符串, 检索结果列表)
|
||
"""
|
||
try:
|
||
results = vector_service.search_similar_knowledge_graph(
|
||
knowledge_graph_pk=knowledge_graph_pk,
|
||
query=query,
|
||
k=5,
|
||
)
|
||
if not results:
|
||
return "未在该知识图谱资料正文中找到相关片段。", []
|
||
|
||
content_parts = []
|
||
for idx, result in enumerate(results, 1):
|
||
content = result.get("content", "")
|
||
metadata = result.get("metadata", {}) or {}
|
||
chunk_i = metadata.get("chunk_index", "")
|
||
prefix = f"[资料原文片段 {idx}]"
|
||
if chunk_i != "":
|
||
prefix += f" (块 #{chunk_i})"
|
||
content_parts.append(f"{prefix}\n{content}\n")
|
||
|
||
return "\n".join(content_parts), results
|
||
except Exception as e:
|
||
logger.error(f"知识图谱 RAG 检索失败: {e}")
|
||
return f"检索资料正文时出错: {str(e)}", []
|
||
|
||
return retrieve_context_from_knowledge_graph
|
||
|
||
|
||
def _format_knowledge_graph_neo4j_result(result: dict) -> str:
|
||
"""将 Neo4j search_knowledge_graph 的返回结果转为给模型阅读的文本。"""
|
||
msg = result.get("message")
|
||
if msg:
|
||
return msg
|
||
seeds = result.get("seeds") or []
|
||
elements = result.get("elements") or []
|
||
if not seeds and not elements:
|
||
return "未在知识图谱中找到与关键词匹配的实体或关系。"
|
||
lines: list[str] = []
|
||
if seeds:
|
||
lines.append(f"关键词命中的实体: {', '.join(seeds)}")
|
||
edges: list[str] = []
|
||
for el in elements:
|
||
d = el.get("data") or {}
|
||
if "source" not in d:
|
||
continue
|
||
rel = (d.get("label") or d.get("type") or "关系").strip()
|
||
note = (d.get("note") or "").strip()
|
||
suf = f"(备注: {note})" if note else ""
|
||
edges.append(f"- {d['source']} —[{rel}]→ {d['target']}{suf}")
|
||
if edges:
|
||
lines.append("关系边(来自图数据库,Person/RELATION):")
|
||
lines.extend(edges[:100])
|
||
elif seeds:
|
||
lines.append("已命中实体,但未检索到相连的关系边;可尝试增大 hops 或更换关键词。")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def create_knowledge_graph_neo4j_search_tool(neo4j_graph_id: str):
|
||
"""
|
||
创建基于 Neo4j 的实体/关系查询工具(与正文向量检索互补)。
|
||
"""
|
||
from services.neo4j_service import search_knowledge_graph
|
||
|
||
@tool
|
||
def query_knowledge_graph_relations(entity_keyword: str, hops: int = 2) -> str:
|
||
"""
|
||
在当前绑定的知识图谱(Neo4j)中按关键词查找人物/实体,并返回其关联关系。
|
||
|
||
当用户询问「某人是谁」「某人和谁的关系」「亲属/子女/上下级/合作」等**实体关系**时,应优先使用本工具。
|
||
entity_keyword 为人名或实体名(可只填部分,如姓氏或名);若无结果可换关键词再试。
|
||
hops 为关系扩展深度:1 仅直接关系,2 为两跳内(默认),最大 3。
|
||
|
||
Args:
|
||
entity_keyword: 要查找的实体关键词
|
||
hops: 关系跳数,1–3
|
||
"""
|
||
try:
|
||
kw = (entity_keyword or "").strip()
|
||
if not kw:
|
||
return "请提供非空的实体关键词。"
|
||
h = max(1, min(int(hops), 3))
|
||
result = search_knowledge_graph(neo4j_graph_id, kw, hops=h)
|
||
return _format_knowledge_graph_neo4j_result(result)
|
||
except Exception as e:
|
||
logger.error(f"知识图谱 Neo4j 查询失败: {e}", exc_info=True)
|
||
return f"知识图谱关系查询失败: {e}"
|
||
|
||
return query_knowledge_graph_relations
|