huoyan-enterprise/backend/tools/tools.py

821 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
工具模块
定义各种 AI 工具函数包括网络搜索、文生图、文生视频、RAG 检索等。
"""
import time
import uuid
import requests
from typing import Literal, Optional
from http import HTTPStatus
from langchain.tools import tool
from dashscope import ImageSynthesis, VideoSynthesis
import dashscope
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
from tavily import TavilyClient
from core.config import settings
from core import llm_env
from utils.datetime_utils import format_beijing_time_for_agent
from logger.logging import get_logger
from services.vector_service import get_vector_service
from services.oss_service import get_oss_service
# 获取日志记录器
logger = get_logger(__name__)
def _dashscope_http_api_base() -> str:
"""``dashscope`` 原生 SDK 使用的 HTTP 根路径(与 OpenAI 兼容 ``DASHSCOPE_API_BASE`` 可能不同)。"""
return llm_env.dashscope_native_http_api_base().strip().rstrip("/")
# 初始化 Tavily 客户端
tavily_client = TavilyClient(api_key=settings.tavily_api_key)
@tool
def get_current_time() -> str:
"""
获取当前中国北京时间(东八区)。
当用户询问现在几点、今天日期、星期几、或需要当前时间作为参考时调用此工具。
"""
return format_beijing_time_for_agent()
def internet_search(
query: str,
max_results: int = 5,
topic: Literal["general", "news", "finance"] = "general",
include_raw_content: bool = False,
):
"""Run a web search"""
return tavily_client.search(
query,
max_results=max_results,
include_raw_content=include_raw_content,
topic=topic,
)
def _download_and_upload_image_to_oss(image_url: str, image_index: int) -> tuple[str, Optional[str], float, float]:
"""
下载单张图片并上传到 OSS
Args:
image_url: 原始图片 URL
image_index: 图片索引(用于日志)
Returns:
tuple: (原始URL, OSS URL, 下载耗时, 上传耗时)
"""
upload_start_time = time.time()
try:
logger.info(f"开始下载图片 {image_index}{image_url}")
# 下载图片内容
download_start = time.time()
response = requests.get(image_url, timeout=300) # 5分钟超时
response.raise_for_status()
image_content = response.content
download_time = time.time() - download_start
logger.info(f"图片 {image_index} 下载完成,耗时:{download_time:.2f} 秒,大小:{len(image_content) / 1024 / 1024:.2f} MB")
# 生成 OSS 对象名称
timestamp = int(time.time())
unique_id = str(uuid.uuid4())[:8]
# 根据图片内容判断文件扩展名
content_type = response.headers.get('Content-Type', 'image/png')
if 'jpeg' in content_type or 'jpg' in content_type:
ext = 'jpg'
elif 'png' in content_type:
ext = 'png'
elif 'webp' in content_type:
ext = 'webp'
else:
ext = 'png' # 默认使用 png
oss_object_name = f"images/{timestamp}_{unique_id}_{image_index}.{ext}"
# 上传到 OSS
upload_start = time.time()
oss_service = get_oss_service()
oss_url = oss_service.upload_file_from_bytes(
file_content=image_content,
oss_object_name=oss_object_name,
file_name=f"generated_image_{image_index}.{ext}"
)
upload_time = time.time() - upload_start
total_time = time.time() - upload_start_time
if oss_url:
logger.info(f"✅ 图片 {image_index} 已上传到 OSS{oss_url}")
logger.info(f"📊 图片 {image_index} 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f}")
return image_url, oss_url, download_time, upload_time
else:
logger.warning(f"⚠️ 图片 {image_index} OSS 上传失败,使用原始 URL")
logger.warning(f"📊 图片 {image_index} 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒(上传失败)")
return image_url, None, download_time, upload_time
except Exception as e:
error_msg = f"❌ 图片 {image_index} 上传到 OSS 失败:{str(e)}"
logger.error(error_msg, exc_info=True)
return image_url, None, 0.0, 0.0
@tool
def text_to_image(
prompt: str,
negative_prompt: str = "",
size: str = "1280*720",
n: int = 1,
)->str:
"""
文生图工具:根据文本描述生成高质量图片。
当用户需要生成图片、创建图像、制作插图、设计视觉内容时,使用此工具。
该工具使用阿里云百炼平台的 AI 图像生成模型,可以根据文字描述生成相应的图片。
生成的图片会自动上传到 OSS 存储,返回永久可访问的 URL。
使用场景:
- 用户说"生成一张...的图片""画一个...""创建...的图像"
- 需要为文章、演示文稿、社交媒体创建配图
- 用户想要可视化某个概念、场景、物体或人物
- 需要生成多个不同风格的图片供选择
参数说明:
prompt (必需): 详细描述想要生成的图片内容。应该包含:
- 主体对象(人物、动物、物品等)
- 场景和环境(背景、地点、氛围)
- 风格和艺术效果(写实、卡通、油画、水彩等)
- 颜色和光线(明亮、昏暗、暖色调等)
- 构图和视角(正面、侧面、俯视、特写等)
示例:"一只可爱的橘色小猫坐在窗台上,阳光透过窗户洒在它身上,背景是温馨的客厅,写实风格"
negative_prompt (可选): 描述不希望在图片中出现的内容,用于排除不想要的元素。
示例:"模糊,低质量,文字,水印,变形,多余的手指"
size (可选): 图片尺寸,格式为 "宽*高"。支持的官方尺寸:
- "1280*1280" - 1:1 正方形(适合头像、图标、社交媒体头像)
- "800*1200" - 2:3 竖向(适合手机壁纸、竖版海报)
- "1200*800" - 3:2 横向(适合横向展示、横幅)
- "960*1280" - 3:4 竖向(适合手机屏幕、竖版内容)
- "1280*960" - 4:3 横向(适合传统显示器比例、横版内容)
- "720*1280" - 9:16 竖向(适合手机竖屏、短视频封面)
- "1280*720" - 16:9 横向(默认,适合宽屏显示器、视频封面、网页横幅)
- "1344*576" - 21:9 超宽屏(适合电影比例、超宽屏展示)
默认值:"1280*720"
n (可选): 生成图片的数量,范围 1-4。生成多张图片时会并行处理以提高效率。
默认值1
返回值:
返回包含图片的 Markdown 格式字符串,图片会自动显示在对话中。
如果生成多张图片,会按顺序展示所有图片。
注意事项:
- 生成图片需要一定时间,请耐心等待
- 提示词越详细,生成的图片质量越好
- 生成多张图片时,总耗时会更长,但会并行处理以提高效率
- 如果用户没有明确指定尺寸,使用默认尺寸即可
"""
try:
api_key = settings.dashscope_api_key
if not api_key:
return "错误:未配置 DASHSCOPE_API_KEY 环境变量"
dashscope.base_http_api_url = _dashscope_http_api_base()
logger.info(f"开始生成图片prompt: {prompt}, n={n}")
# 创建异步任务
rsp = ImageSynthesis.call(api_key=api_key,
model="wan2.2-t2i-flash",
prompt=prompt,
n=n,
size=size,
negative_prompt=negative_prompt,
prompt_extend=True,
watermark=True)
print(f'response: {rsp}')
if rsp.status_code != HTTPStatus.OK:
print(f'同步调用失败, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}')
return "图片生成失败"
# 提取图片 URL
image_urls = []
if rsp.output and rsp.output.results:
for result in rsp.output.results:
if hasattr(result, 'url') and result.url:
image_urls.append(result.url)
if not image_urls:
return "图片生成完成但未获取到图片URL"
logger.info(f"图片生成成功,共 {len(image_urls)} 张图片,开始上传到 OSS")
# 使用多线程下载并上传图片到 OSS
oss_urls = []
total_start_time = time.time()
if len(image_urls) == 1:
# 单张图片,直接处理
_, oss_url, _, _ = _download_and_upload_image_to_oss(image_urls[0], 1)
oss_urls.append(oss_url if oss_url else image_urls[0])
else:
# 多张图片,使用多线程并行处理
with ThreadPoolExecutor(max_workers=min(len(image_urls), 5)) as executor:
# 提交所有任务
future_to_index = {
executor.submit(_download_and_upload_image_to_oss, url, idx + 1): idx
for idx, url in enumerate(image_urls)
}
# 收集结果(保持顺序)
results = [None] * len(image_urls)
for future in as_completed(future_to_index):
idx = future_to_index[future]
try:
results[idx] = future.result()
except Exception as e:
logger.error(f"图片 {idx + 1} 处理异常:{e}", exc_info=True)
results[idx] = (image_urls[idx], None, 0.0, 0.0)
# 按顺序提取 OSS URL
for original_url, oss_url, _, _ in results:
oss_urls.append(oss_url if oss_url else original_url)
total_time = time.time() - total_start_time
logger.info(f"✅ 所有图片处理完成,总耗时:{total_time:.2f}")
# 构建返回信息(使用 markdown 格式以便前端正确显示图片)
result_text = f"图片生成成功!共生成 {len(oss_urls)} 张图片,以下是图片连接,请使用 markdown 格式渲染这些图片。\n\n"
for idx, url in enumerate(oss_urls, 1):
# 使用 markdown 图片语法,这样前端可以正确渲染
result_text += f"{url}\n\n"
return result_text
except Exception as e:
error_msg = f"生成图片时发生错误: {str(e)}"
logger.error(error_msg, exc_info=True)
return error_msg
from typing import Optional
from langchain_core.tools import tool
import os
import logging
import uuid
import requests
from dashscope import VideoSynthesis
from http import HTTPStatus
from services.oss_service import get_oss_service
@tool
def text_to_video(
prompt: str,
negative_prompt: str = "",
size: str = "832*480",
duration: int = 5,
) -> str:
"""
文生视频工具:根据文本描述生成动态视频。
当用户需要生成视频、创建动画、制作短视频、需要动态视觉内容时,使用此工具。
该工具使用阿里云百炼平台的 AI 视频生成模型,可以根据文字描述生成相应的视频。
生成的视频会自动上传到 OSS 存储,返回永久可访问的 URL。
使用场景:
- 用户说"生成一个...的视频""创建一个...的动画""制作...的短视频"
- 需要为产品演示、营销推广创建动态视频内容
- 社交媒体短视频内容生成(抖音、快手、小红书等)
- 需要展示动态场景、运动过程、变化效果
- 用户想要可视化动态概念或过程
参数说明:
prompt (必需): 详细描述想要生成的视频内容。应该包含:
- 主体对象和动作(什么在做什么)
- 场景和环境(背景、地点、氛围)
- 运动方式和动态效果(移动、旋转、变化等)
- 风格和视觉效果(写实、动画、电影感等)
- 颜色和光线(明亮、昏暗、暖色调等)
示例:"一只橘色小猫在窗台上玩耍,阳光透过窗户洒在它身上,它好奇地看向窗外,背景是温馨的客厅,写实风格,画面流畅自然"
negative_prompt (可选): 描述不希望在视频中出现的内容,用于排除不想要的元素。
示例:"模糊,低质量,文字,水印,画面抖动,不自然的运动,变形"
size (可选): 视频尺寸,格式为 "宽*高"。支持的尺寸:
- "832*480" - 标准横向(默认,适合通用视频)
- "1280*720" - 高清横向(适合高质量视频)
- "720*1280" - 竖向(适合手机竖屏视频、短视频平台)
默认值:"832*480"
duration (可选): 视频时长,单位为秒。支持的时长:
- 5 秒(默认,适合短视频)
- 10 秒(适合中等长度视频)
- 15 秒(适合较长视频)
默认值5
返回值:
返回包含视频的 HTML 格式字符串,视频会自动显示在对话中,用户可以直接播放。
视频已上传到 OSS返回的是永久可访问的 URL。
注意事项:
- 视频生成需要较长时间(通常需要几十秒到几分钟),请耐心等待
- 提示词越详细,生成的视频质量越好
- 视频生成后会自动下载并上传到 OSS这个过程可能需要额外时间
- 如果用户没有明确指定尺寸和时长,使用默认值即可
- 视频生成是异步过程,完成后会返回可播放的视频链接
"""
try:
# 地域与 ``DASHSCOPE_API_BASE`` 一致新加坡等请改环境变量https://dashscope-intl.aliyuncs.com/api/v1
dashscope.base_http_api_url = _dashscope_http_api_base()
api_key = settings.dashscope_api_key
# call sync api, will return the result
start = time.time()
print('开始时间-->',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
rsp = VideoSynthesis.call(api_key=api_key,
model='wan2.2-t2v-plus',
prompt=prompt,
size="832*480",
duration=5,
negative_prompt=negative_prompt,
# audio=True,
prompt_extend=True,
watermark=True)
print("请求结果:",rsp)
video_url = ""
result = ""
if rsp.status_code == HTTPStatus.OK:
print("请求链接地址:",rsp.output.video_url)
print("结束时间-->",time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
print("耗时-->",time.time()-start,"")
video_url = rsp.output.video_url
# 下载视频并上传到 OSS
try:
upload_start_time = time.time()
logger.info(f"开始下载视频:{video_url}")
# 下载视频内容
download_start = time.time()
response = requests.get(video_url, timeout=300) # 5分钟超时
response.raise_for_status()
video_content = response.content
download_time = time.time() - download_start
logger.info(f"视频下载完成,耗时:{download_time:.2f} 秒,大小:{len(video_content) / 1024 / 1024:.2f} MB")
# 生成 OSS 对象名称
timestamp = int(time.time())
unique_id = str(uuid.uuid4())[:8]
oss_object_name = f"videos/{timestamp}_{unique_id}.mp4"
# 上传到 OSS
upload_start = time.time()
oss_service = get_oss_service()
oss_url = oss_service.upload_file_from_bytes(
file_content=video_content,
oss_object_name=oss_object_name,
file_name="generated_video.mp4"
)
upload_time = time.time() - upload_start
# 计算总耗时
total_time = time.time() - upload_start_time
if oss_url:
logger.info(f"✅ 视频已上传到 OSS{oss_url}")
logger.info(f"📊 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f}")
# 使用 OSS URL 替换临时 URL
video_url = oss_url
else:
logger.warning("⚠️ OSS 上传失败,使用原始临时 URL")
logger.warning(f"📊 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒(上传失败)")
# 如果 OSS 上传失败,继续使用原始 URL
except Exception as upload_error:
logger.error(f"❌ 上传视频到 OSS 失败:{upload_error}", exc_info=True)
# 如果上传失败,继续使用原始临时 URL
logger.warning("⚠️ 使用原始临时视频 URL")
result = f"""<video controls width="100%" style="max-width: 600px;">
<source src="{video_url}" type="video/mp4">
您的浏览器不支持视频播放
</video>"""
else:
print('视频请求失败, status_code: %s, code: %s, message: %s' %
(rsp.status_code, rsp.code, rsp.message))
result = "视频生成失败"
logger.info(f"✅ 视频生成完成:{video_url}")
return result
except Exception as e:
error_msg = f"❌ 生成视频异常:{str(e)}"
logger.error(error_msg, exc_info=True)
return error_msg
@tool
def text_to_poster(
title: str,
sub_title: str = "",
body_text: str = "",
prompt_text_zh: str = "",
prompt_text_en: str = "",
size: str = "1280*1280",
) -> str:
"""
创意海报生成工具:根据标题、副标题和正文内容生成创意海报。
当用户需要生成海报、创建宣传图、制作营销图片、需要创意设计时,使用此工具。
该工具使用阿里云百炼平台的文生图(万相)模型,根据海报文案生成相应的创意海报图片。
生成的海报会自动上传到 OSS 存储,返回永久可访问的 URL。海报右下角带有「AI生成」水印。
使用场景:
- 用户说"生成一张...的海报""创建一个...的宣传图""制作...的创意海报"
- 需要为活动、产品、品牌创建宣传海报
- 社交媒体营销图片生成(微信、微博、小红书等)
- 需要展示标题、副标题和正文内容的创意设计
- 用户想要可视化某个主题或概念的海报
参数说明:
title (必需): 海报的主标题。应该简洁有力,能够吸引注意力。
示例:"春季新品发布""限时优惠活动""品牌宣传"
sub_title (可选): 海报的副标题。用于补充说明主标题或提供更多信息。
示例:"全场8折起""限时3天""专业团队打造"
body_text (可选): 海报的正文内容。可以包含详细说明、活动规则、联系方式等。
示例:"活动时间2024年3月1日-3月31日\n活动地点:全国门店\n咨询热线400-xxx-xxxx"
prompt_text_zh (可选): 中文提示文本,用于描述海报的视觉风格和设计元素。
示例:"小朋友画的可爱的龙,白色背景""温馨的节日氛围,红色和金色主题"
如果未提供,将根据标题和副标题自动生成。
prompt_text_en (可选): 英文提示文本,用于描述海报的视觉风格和设计元素。
示例:"Children draw a lovely dragon, white background""Warm festive atmosphere, red and gold theme"
如果未提供,将根据标题和副标题自动生成。
注意prompt_text_zh 和 prompt_text_en 至少需要设置其中一个。
size string (可选)
输出图像的分辨率,格式为宽*高。
默认值为 1280*1280。
总像素在 [1280*1280, 1440*1440] 之间且宽高比范围为 [1:4, 4:1]。例如768*2700符合要求。
示例值1280*1280。
常见比例推荐的分辨率
1:11280*1280
3:41104*1472
4:31472*1104
9:16960*1696
16:91696*960
返回值:
返回包含海报图片的 Markdown 格式字符串,海报会自动显示在对话中。
注意事项:
- 生成海报需要一定时间,请耐心等待
- 标题、副标题和正文内容越清晰,生成的海报质量越好
- 生成的海报带有 AI 水印标识
"""
try:
api_key = settings.dashscope_api_key
if not api_key:
return "错误:未配置 DASHSCOPE_API_KEY 环境变量"
dashscope.base_http_api_url = _dashscope_http_api_base()
logger.info(f"开始生成创意海报title: {title}, sub_title: {sub_title}, body_text: {(body_text[:50] + '...') if len(body_text) > 50 else body_text}")
# 构建海报专用 prompt将标题、副标题、正文与视觉风格融合为文生图提示词
prompt_parts = ["创意海报设计,宣传海报,专业排版,醒目吸睛"]
if title:
prompt_parts.append(f"主标题:{title}")
if sub_title:
prompt_parts.append(f"副标题:{sub_title}")
if body_text:
# 正文可能较长,截取关键信息(限制约 200 字符)
body_summary = body_text.replace("\n", " ")[:200]
if len(body_text) > 200:
body_summary += "..."
prompt_parts.append(f"正文内容:{body_summary}")
# 视觉风格:优先使用用户提供的提示词
if prompt_text_zh:
prompt_parts.append(f"视觉风格:{prompt_text_zh}")
elif prompt_text_en:
prompt_parts.append(f"Visual style: {prompt_text_en}")
else:
prompt_parts.append("精美设计,高质量海报风格")
prompt = "".join(prompt_parts)
logger.info(f"海报生成 prompt: {prompt[:200]}...")
# 使用与 text_to_image 相同的文生图 API万相 wan2.2-t2i-flash
# 海报需要水印,故 watermark=True
rsp = ImageSynthesis.call(
api_key=api_key,
model="wan2.5-t2i-preview",
prompt=prompt,
n=1,
size=size,
negative_prompt="低分辨率,低画质,画面模糊,文字扭曲,构图混乱,画面过饱和",
prompt_extend=True,
watermark=True,
)
if rsp.status_code != HTTPStatus.OK:
logger.error(f"海报生成失败, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}")
return f"海报生成失败:{rsp.message or '请稍后重试'}"
image_urls = []
if rsp.output and rsp.output.results:
for result in rsp.output.results:
if hasattr(result, 'url') and result.url:
image_urls.append(result.url)
if not image_urls:
return "海报生成完成但未获取到图片URL"
image_url = image_urls[0]
logger.info(f"海报生成成功图片URL: {image_url},开始上传到 OSS")
# 复用 text_to_image 的 OSS 上传逻辑
_, oss_url, _, _ = _download_and_upload_image_to_oss(image_url, 1)
final_url = oss_url if oss_url else image_url
result_text = f"创意海报生成成功!\n\n{final_url}\n\n"
result_text += f"**标题**{title}\n"
if sub_title:
result_text += f"**副标题**{sub_title}\n"
if body_text:
result_text += f"**正文**{body_text}\n"
logger.info("✅ 海报生成完成")
return result_text
except Exception as e:
error_msg = f"❌ 生成海报异常:{str(e)}"
logger.error(error_msg, exc_info=True)
return error_msg
def create_rag_retrieve_tool(thread_id: str):
"""
创建 RAG 检索工具(用于对话文件)
Args:
thread_id: 会话线程 ID
Returns:
tool: RAG 检索工具
"""
vector_service = get_vector_service()
@tool(response_format="content_and_artifact")
def retrieve_context_from_files(query: str):
"""
从用户上传的文件中检索相关信息来帮助回答问题。
当用户的问题涉及到上传的文件内容时,使用此工具检索相关文档片段。
例如用户上传了PDF文件后询问文件中的具体内容、数据、概念等。
Args:
query: 用户的查询问题
Returns:
tuple: (检索到的文档内容字符串, 检索结果列表)
"""
try:
# 使用向量服务搜索相似文档
results = vector_service.search_similar_in_thread(
thread_id=thread_id,
query=query,
k=5 # 返回最相关的5个文档片段
)
if not results:
return "未在文件中找到相关信息。", []
# 格式化检索结果
content_parts = []
for idx, result in enumerate(results, 1):
content = result.get("content", "")
metadata = result.get("metadata", {})
score = result.get("score", 0)
# 构建来源信息
source_info = []
if metadata:
if "source" in metadata:
source_info.append(f"来源: {metadata['source']}")
if "page" in metadata:
source_info.append(f"页码: {metadata['page']}")
source_str = f" ({', '.join(source_info)})" if source_info else ""
content_parts.append(
f"[文档片段 {idx}]{source_str}\n{content}\n"
)
content = "\n".join(content_parts)
return content, results
except Exception as e:
logger.error(f"RAG 检索失败: {e}")
return f"检索文件内容时出错: {str(e)}", []
return retrieve_context_from_files
def create_kb_rag_retrieve_tool(knowledge_base_id: int):
"""
创建知识库 RAG 检索工具
Args:
knowledge_base_id: 知识库 ID
Returns:
tool: 知识库 RAG 检索工具
"""
vector_service = get_vector_service()
@tool(response_format="content_and_artifact")
def retrieve_context_from_knowledge_base(query: str):
"""
从知识库中检索相关信息来帮助回答问题。
当用户的问题涉及到知识库中的内容时,使用此工具检索相关文档片段。
知识库包含了用户预先上传和整理的文件内容。
Args:
query: 用户的查询问题
Returns:
tuple: (检索到的文档内容字符串, 检索结果列表)
"""
try:
# 使用向量服务搜索知识库中的相似文档
results = vector_service.search_similar(
knowledge_base_id=knowledge_base_id,
query=query,
k=5 # 返回最相关的5个文档片段
)
if not results:
return "未在知识库中找到相关信息。", []
# 格式化检索结果
content_parts = []
for idx, result in enumerate(results, 1):
content = result.get("content", "")
metadata = result.get("metadata", {})
score = result.get("score", 0)
# 构建来源信息
source_info = []
if metadata:
if "source" in metadata:
source_info.append(f"来源: {metadata['source']}")
if "page" in metadata:
source_info.append(f"页码: {metadata['page']}")
source_str = f" ({', '.join(source_info)})" if source_info else ""
content_parts.append(
f"[知识库文档片段 {idx}]{source_str}\n{content}\n"
)
content = "\n".join(content_parts)
return content, results
except Exception as e:
logger.error(f"知识库 RAG 检索失败: {e}")
return f"检索知识库内容时出错: {str(e)}", []
return retrieve_context_from_knowledge_base
def create_knowledge_graph_rag_retrieve_tool(knowledge_graph_pk: int):
"""
创建「知识图谱」绑定的正文向量检索工具(与 Neo4j 实体关系互补)。
"""
vector_service = get_vector_service()
@tool(response_format="content_and_artifact")
def retrieve_context_from_knowledge_graph(query: str):
"""
从用户选中的知识图谱资料正文中检索相关片段,用于回答细节、原文依据等问题。
当问题涉及资料内容、叙述、对话、描写而非仅关系网络时,应使用本工具。
Args:
query: 检索查询(可与用户问题同义改写)
Returns:
tuple: (检索到的文本片段拼接字符串, 检索结果列表)
"""
try:
results = vector_service.search_similar_knowledge_graph(
knowledge_graph_pk=knowledge_graph_pk,
query=query,
k=5,
)
if not results:
return "未在该知识图谱资料正文中找到相关片段。", []
content_parts = []
for idx, result in enumerate(results, 1):
content = result.get("content", "")
metadata = result.get("metadata", {}) or {}
chunk_i = metadata.get("chunk_index", "")
prefix = f"[资料原文片段 {idx}]"
if chunk_i != "":
prefix += f" (块 #{chunk_i})"
content_parts.append(f"{prefix}\n{content}\n")
return "\n".join(content_parts), results
except Exception as e:
logger.error(f"知识图谱 RAG 检索失败: {e}")
return f"检索资料正文时出错: {str(e)}", []
return retrieve_context_from_knowledge_graph
def _format_knowledge_graph_neo4j_result(result: dict) -> str:
"""将 Neo4j search_knowledge_graph 的返回结果转为给模型阅读的文本。"""
msg = result.get("message")
if msg:
return msg
seeds = result.get("seeds") or []
elements = result.get("elements") or []
if not seeds and not elements:
return "未在知识图谱中找到与关键词匹配的实体或关系。"
lines: list[str] = []
if seeds:
lines.append(f"关键词命中的实体: {', '.join(seeds)}")
edges: list[str] = []
for el in elements:
d = el.get("data") or {}
if "source" not in d:
continue
rel = (d.get("label") or d.get("type") or "关系").strip()
note = (d.get("note") or "").strip()
suf = f"(备注: {note}" if note else ""
edges.append(f"- {d['source']} —[{rel}]→ {d['target']}{suf}")
if edges:
lines.append("关系边来自图数据库Person/RELATION:")
lines.extend(edges[:100])
elif seeds:
lines.append("已命中实体,但未检索到相连的关系边;可尝试增大 hops 或更换关键词。")
return "\n".join(lines)
def create_knowledge_graph_neo4j_search_tool(neo4j_graph_id: str):
"""
创建基于 Neo4j 的实体/关系查询工具(与正文向量检索互补)。
"""
from services.neo4j_service import search_knowledge_graph
@tool
def query_knowledge_graph_relations(entity_keyword: str, hops: int = 2) -> str:
"""
在当前绑定的知识图谱Neo4j中按关键词查找人物/实体,并返回其关联关系。
当用户询问「某人是谁」「某人和谁的关系」「亲属/子女/上下级/合作」等**实体关系**时,应优先使用本工具。
entity_keyword 为人名或实体名(可只填部分,如姓氏或名);若无结果可换关键词再试。
hops 为关系扩展深度1 仅直接关系2 为两跳内(默认),最大 3。
Args:
entity_keyword: 要查找的实体关键词
hops: 关系跳数13
"""
try:
kw = (entity_keyword or "").strip()
if not kw:
return "请提供非空的实体关键词。"
h = max(1, min(int(hops), 3))
result = search_knowledge_graph(neo4j_graph_id, kw, hops=h)
return _format_knowledge_graph_neo4j_result(result)
except Exception as e:
logger.error(f"知识图谱 Neo4j 查询失败: {e}", exc_info=True)
return f"知识图谱关系查询失败: {e}"
return query_knowledge_graph_relations