821 lines
34 KiB
Python
821 lines
34 KiB
Python
"""
|
||
工具模块
|
||
|
||
定义各种 AI 工具函数,包括网络搜索、文生图、文生视频、RAG 检索等。
|
||
"""
|
||
import time
|
||
import uuid
|
||
import requests
|
||
from typing import Literal, Optional
|
||
from http import HTTPStatus
|
||
from langchain.tools import tool
|
||
from dashscope import ImageSynthesis, VideoSynthesis
|
||
import dashscope
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import json
|
||
|
||
from tavily import TavilyClient
|
||
from core.config import settings
|
||
from core import llm_env
|
||
from utils.datetime_utils import format_beijing_time_for_agent
|
||
from logger.logging import get_logger
|
||
from services.vector_service import get_vector_service
|
||
from services.oss_service import get_oss_service
|
||
|
||
# 获取日志记录器
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _dashscope_http_api_base() -> str:
|
||
"""``dashscope`` 原生 SDK 使用的 HTTP 根路径(与 OpenAI 兼容 ``DASHSCOPE_API_BASE`` 可能不同)。"""
|
||
return llm_env.dashscope_native_http_api_base().strip().rstrip("/")
|
||
|
||
|
||
# 初始化 Tavily 客户端
|
||
tavily_client = TavilyClient(api_key=settings.tavily_api_key)
|
||
|
||
|
||
@tool
|
||
def get_current_time() -> str:
|
||
"""
|
||
获取当前中国北京时间(东八区)。
|
||
|
||
当用户询问现在几点、今天日期、星期几、或需要当前时间作为参考时调用此工具。
|
||
"""
|
||
return format_beijing_time_for_agent()
|
||
|
||
|
||
def internet_search(
|
||
query: str,
|
||
max_results: int = 5,
|
||
topic: Literal["general", "news", "finance"] = "general",
|
||
include_raw_content: bool = False,
|
||
):
|
||
"""Run a web search"""
|
||
return tavily_client.search(
|
||
query,
|
||
max_results=max_results,
|
||
include_raw_content=include_raw_content,
|
||
topic=topic,
|
||
)
|
||
|
||
|
||
def _download_and_upload_image_to_oss(image_url: str, image_index: int) -> tuple[str, Optional[str], float, float]:
|
||
"""
|
||
下载单张图片并上传到 OSS
|
||
|
||
Args:
|
||
image_url: 原始图片 URL
|
||
image_index: 图片索引(用于日志)
|
||
|
||
Returns:
|
||
tuple: (原始URL, OSS URL, 下载耗时, 上传耗时)
|
||
"""
|
||
upload_start_time = time.time()
|
||
try:
|
||
logger.info(f"开始下载图片 {image_index}:{image_url}")
|
||
|
||
# 下载图片内容
|
||
download_start = time.time()
|
||
response = requests.get(image_url, timeout=300) # 5分钟超时
|
||
response.raise_for_status()
|
||
image_content = response.content
|
||
download_time = time.time() - download_start
|
||
logger.info(f"图片 {image_index} 下载完成,耗时:{download_time:.2f} 秒,大小:{len(image_content) / 1024 / 1024:.2f} MB")
|
||
|
||
# 生成 OSS 对象名称
|
||
timestamp = int(time.time())
|
||
unique_id = str(uuid.uuid4())[:8]
|
||
# 根据图片内容判断文件扩展名
|
||
content_type = response.headers.get('Content-Type', 'image/png')
|
||
if 'jpeg' in content_type or 'jpg' in content_type:
|
||
ext = 'jpg'
|
||
elif 'png' in content_type:
|
||
ext = 'png'
|
||
elif 'webp' in content_type:
|
||
ext = 'webp'
|
||
else:
|
||
ext = 'png' # 默认使用 png
|
||
oss_object_name = f"images/{timestamp}_{unique_id}_{image_index}.{ext}"
|
||
|
||
# 上传到 OSS
|
||
upload_start = time.time()
|
||
oss_service = get_oss_service()
|
||
oss_url = oss_service.upload_file_from_bytes(
|
||
file_content=image_content,
|
||
oss_object_name=oss_object_name,
|
||
file_name=f"generated_image_{image_index}.{ext}"
|
||
)
|
||
upload_time = time.time() - upload_start
|
||
|
||
total_time = time.time() - upload_start_time
|
||
|
||
if oss_url:
|
||
logger.info(f"✅ 图片 {image_index} 已上传到 OSS:{oss_url}")
|
||
logger.info(f"📊 图片 {image_index} 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒")
|
||
return image_url, oss_url, download_time, upload_time
|
||
else:
|
||
logger.warning(f"⚠️ 图片 {image_index} OSS 上传失败,使用原始 URL")
|
||
logger.warning(f"📊 图片 {image_index} 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒(上传失败)")
|
||
return image_url, None, download_time, upload_time
|
||
|
||
except Exception as e:
|
||
error_msg = f"❌ 图片 {image_index} 上传到 OSS 失败:{str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return image_url, None, 0.0, 0.0
|
||
|
||
|
||
@tool
|
||
def text_to_image(
|
||
prompt: str,
|
||
negative_prompt: str = "",
|
||
size: str = "1280*720",
|
||
n: int = 1,
|
||
)->str:
|
||
"""
|
||
文生图工具:根据文本描述生成高质量图片。
|
||
|
||
当用户需要生成图片、创建图像、制作插图、设计视觉内容时,使用此工具。
|
||
该工具使用阿里云百炼平台的 AI 图像生成模型,可以根据文字描述生成相应的图片。
|
||
生成的图片会自动上传到 OSS 存储,返回永久可访问的 URL。
|
||
|
||
使用场景:
|
||
- 用户说"生成一张...的图片"、"画一个..."、"创建...的图像"
|
||
- 需要为文章、演示文稿、社交媒体创建配图
|
||
- 用户想要可视化某个概念、场景、物体或人物
|
||
- 需要生成多个不同风格的图片供选择
|
||
|
||
参数说明:
|
||
prompt (必需): 详细描述想要生成的图片内容。应该包含:
|
||
- 主体对象(人物、动物、物品等)
|
||
- 场景和环境(背景、地点、氛围)
|
||
- 风格和艺术效果(写实、卡通、油画、水彩等)
|
||
- 颜色和光线(明亮、昏暗、暖色调等)
|
||
- 构图和视角(正面、侧面、俯视、特写等)
|
||
示例:"一只可爱的橘色小猫坐在窗台上,阳光透过窗户洒在它身上,背景是温馨的客厅,写实风格"
|
||
|
||
negative_prompt (可选): 描述不希望在图片中出现的内容,用于排除不想要的元素。
|
||
示例:"模糊,低质量,文字,水印,变形,多余的手指"
|
||
|
||
size (可选): 图片尺寸,格式为 "宽*高"。支持的官方尺寸:
|
||
- "1280*1280" - 1:1 正方形(适合头像、图标、社交媒体头像)
|
||
- "800*1200" - 2:3 竖向(适合手机壁纸、竖版海报)
|
||
- "1200*800" - 3:2 横向(适合横向展示、横幅)
|
||
- "960*1280" - 3:4 竖向(适合手机屏幕、竖版内容)
|
||
- "1280*960" - 4:3 横向(适合传统显示器比例、横版内容)
|
||
- "720*1280" - 9:16 竖向(适合手机竖屏、短视频封面)
|
||
- "1280*720" - 16:9 横向(默认,适合宽屏显示器、视频封面、网页横幅)
|
||
- "1344*576" - 21:9 超宽屏(适合电影比例、超宽屏展示)
|
||
默认值:"1280*720"
|
||
|
||
n (可选): 生成图片的数量,范围 1-4。生成多张图片时会并行处理以提高效率。
|
||
默认值:1
|
||
|
||
返回值:
|
||
返回包含图片的 Markdown 格式字符串,图片会自动显示在对话中。
|
||
如果生成多张图片,会按顺序展示所有图片。
|
||
|
||
注意事项:
|
||
- 生成图片需要一定时间,请耐心等待
|
||
- 提示词越详细,生成的图片质量越好
|
||
- 生成多张图片时,总耗时会更长,但会并行处理以提高效率
|
||
- 如果用户没有明确指定尺寸,使用默认尺寸即可
|
||
"""
|
||
try:
|
||
api_key = settings.dashscope_api_key
|
||
if not api_key:
|
||
return "错误:未配置 DASHSCOPE_API_KEY 环境变量"
|
||
|
||
dashscope.base_http_api_url = _dashscope_http_api_base()
|
||
|
||
logger.info(f"开始生成图片,prompt: {prompt}, n={n}")
|
||
|
||
# 创建异步任务
|
||
rsp = ImageSynthesis.call(api_key=api_key,
|
||
model="wan2.2-t2i-flash",
|
||
prompt=prompt,
|
||
n=n,
|
||
size=size,
|
||
negative_prompt=negative_prompt,
|
||
prompt_extend=True,
|
||
watermark=True)
|
||
print(f'response: {rsp}')
|
||
if rsp.status_code != HTTPStatus.OK:
|
||
print(f'同步调用失败, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}')
|
||
return "图片生成失败"
|
||
|
||
# 提取图片 URL
|
||
image_urls = []
|
||
if rsp.output and rsp.output.results:
|
||
for result in rsp.output.results:
|
||
if hasattr(result, 'url') and result.url:
|
||
image_urls.append(result.url)
|
||
|
||
if not image_urls:
|
||
return "图片生成完成,但未获取到图片URL"
|
||
|
||
logger.info(f"图片生成成功,共 {len(image_urls)} 张图片,开始上传到 OSS")
|
||
|
||
# 使用多线程下载并上传图片到 OSS
|
||
oss_urls = []
|
||
total_start_time = time.time()
|
||
|
||
if len(image_urls) == 1:
|
||
# 单张图片,直接处理
|
||
_, oss_url, _, _ = _download_and_upload_image_to_oss(image_urls[0], 1)
|
||
oss_urls.append(oss_url if oss_url else image_urls[0])
|
||
else:
|
||
# 多张图片,使用多线程并行处理
|
||
with ThreadPoolExecutor(max_workers=min(len(image_urls), 5)) as executor:
|
||
# 提交所有任务
|
||
future_to_index = {
|
||
executor.submit(_download_and_upload_image_to_oss, url, idx + 1): idx
|
||
for idx, url in enumerate(image_urls)
|
||
}
|
||
|
||
# 收集结果(保持顺序)
|
||
results = [None] * len(image_urls)
|
||
for future in as_completed(future_to_index):
|
||
idx = future_to_index[future]
|
||
try:
|
||
results[idx] = future.result()
|
||
except Exception as e:
|
||
logger.error(f"图片 {idx + 1} 处理异常:{e}", exc_info=True)
|
||
results[idx] = (image_urls[idx], None, 0.0, 0.0)
|
||
|
||
# 按顺序提取 OSS URL
|
||
for original_url, oss_url, _, _ in results:
|
||
oss_urls.append(oss_url if oss_url else original_url)
|
||
|
||
total_time = time.time() - total_start_time
|
||
logger.info(f"✅ 所有图片处理完成,总耗时:{total_time:.2f} 秒")
|
||
|
||
# 构建返回信息(使用 markdown 格式以便前端正确显示图片)
|
||
result_text = f"图片生成成功!共生成 {len(oss_urls)} 张图片,以下是图片连接,请使用 markdown 格式渲染这些图片。\n\n"
|
||
for idx, url in enumerate(oss_urls, 1):
|
||
# 使用 markdown 图片语法,这样前端可以正确渲染
|
||
result_text += f"{url}\n\n"
|
||
|
||
return result_text
|
||
|
||
except Exception as e:
|
||
error_msg = f"生成图片时发生错误: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return error_msg
|
||
|
||
|
||
from typing import Optional
|
||
from langchain_core.tools import tool
|
||
import os
|
||
import logging
|
||
import uuid
|
||
import requests
|
||
from dashscope import VideoSynthesis
|
||
from http import HTTPStatus
|
||
from services.oss_service import get_oss_service
|
||
|
||
|
||
@tool
|
||
def text_to_video(
|
||
prompt: str,
|
||
negative_prompt: str = "",
|
||
size: str = "832*480",
|
||
duration: int = 5,
|
||
) -> str:
|
||
"""
|
||
文生视频工具:根据文本描述生成动态视频。
|
||
|
||
当用户需要生成视频、创建动画、制作短视频、需要动态视觉内容时,使用此工具。
|
||
该工具使用阿里云百炼平台的 AI 视频生成模型,可以根据文字描述生成相应的视频。
|
||
生成的视频会自动上传到 OSS 存储,返回永久可访问的 URL。
|
||
|
||
使用场景:
|
||
- 用户说"生成一个...的视频"、"创建一个...的动画"、"制作...的短视频"
|
||
- 需要为产品演示、营销推广创建动态视频内容
|
||
- 社交媒体短视频内容生成(抖音、快手、小红书等)
|
||
- 需要展示动态场景、运动过程、变化效果
|
||
- 用户想要可视化动态概念或过程
|
||
|
||
参数说明:
|
||
prompt (必需): 详细描述想要生成的视频内容。应该包含:
|
||
- 主体对象和动作(什么在做什么)
|
||
- 场景和环境(背景、地点、氛围)
|
||
- 运动方式和动态效果(移动、旋转、变化等)
|
||
- 风格和视觉效果(写实、动画、电影感等)
|
||
- 颜色和光线(明亮、昏暗、暖色调等)
|
||
示例:"一只橘色小猫在窗台上玩耍,阳光透过窗户洒在它身上,它好奇地看向窗外,背景是温馨的客厅,写实风格,画面流畅自然"
|
||
|
||
negative_prompt (可选): 描述不希望在视频中出现的内容,用于排除不想要的元素。
|
||
示例:"模糊,低质量,文字,水印,画面抖动,不自然的运动,变形"
|
||
|
||
size (可选): 视频尺寸,格式为 "宽*高"。支持的尺寸:
|
||
- "832*480" - 标准横向(默认,适合通用视频)
|
||
- "1280*720" - 高清横向(适合高质量视频)
|
||
- "720*1280" - 竖向(适合手机竖屏视频、短视频平台)
|
||
默认值:"832*480"
|
||
|
||
duration (可选): 视频时长,单位为秒。支持的时长:
|
||
- 5 秒(默认,适合短视频)
|
||
- 10 秒(适合中等长度视频)
|
||
- 15 秒(适合较长视频)
|
||
默认值:5
|
||
|
||
返回值:
|
||
返回包含视频的 HTML 格式字符串,视频会自动显示在对话中,用户可以直接播放。
|
||
视频已上传到 OSS,返回的是永久可访问的 URL。
|
||
|
||
注意事项:
|
||
- 视频生成需要较长时间(通常需要几十秒到几分钟),请耐心等待
|
||
- 提示词越详细,生成的视频质量越好
|
||
- 视频生成后会自动下载并上传到 OSS,这个过程可能需要额外时间
|
||
- 如果用户没有明确指定尺寸和时长,使用默认值即可
|
||
- 视频生成是异步过程,完成后会返回可播放的视频链接
|
||
"""
|
||
try:
|
||
# 地域与 ``DASHSCOPE_API_BASE`` 一致;新加坡等请改环境变量(例:https://dashscope-intl.aliyuncs.com/api/v1)
|
||
dashscope.base_http_api_url = _dashscope_http_api_base()
|
||
|
||
api_key = settings.dashscope_api_key
|
||
|
||
# call sync api, will return the result
|
||
start = time.time()
|
||
print('开始时间-->',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||
rsp = VideoSynthesis.call(api_key=api_key,
|
||
model='wan2.2-t2v-plus',
|
||
prompt=prompt,
|
||
size="832*480",
|
||
duration=5,
|
||
negative_prompt=negative_prompt,
|
||
# audio=True,
|
||
prompt_extend=True,
|
||
watermark=True)
|
||
print("请求结果:",rsp)
|
||
video_url = ""
|
||
result = ""
|
||
if rsp.status_code == HTTPStatus.OK:
|
||
print("请求链接地址:",rsp.output.video_url)
|
||
print("结束时间-->",time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||
print("耗时-->",time.time()-start,"秒")
|
||
video_url = rsp.output.video_url
|
||
|
||
# 下载视频并上传到 OSS
|
||
try:
|
||
upload_start_time = time.time()
|
||
logger.info(f"开始下载视频:{video_url}")
|
||
|
||
# 下载视频内容
|
||
download_start = time.time()
|
||
response = requests.get(video_url, timeout=300) # 5分钟超时
|
||
response.raise_for_status()
|
||
video_content = response.content
|
||
download_time = time.time() - download_start
|
||
logger.info(f"视频下载完成,耗时:{download_time:.2f} 秒,大小:{len(video_content) / 1024 / 1024:.2f} MB")
|
||
|
||
# 生成 OSS 对象名称
|
||
timestamp = int(time.time())
|
||
unique_id = str(uuid.uuid4())[:8]
|
||
oss_object_name = f"videos/{timestamp}_{unique_id}.mp4"
|
||
|
||
# 上传到 OSS
|
||
upload_start = time.time()
|
||
oss_service = get_oss_service()
|
||
oss_url = oss_service.upload_file_from_bytes(
|
||
file_content=video_content,
|
||
oss_object_name=oss_object_name,
|
||
file_name="generated_video.mp4"
|
||
)
|
||
upload_time = time.time() - upload_start
|
||
|
||
# 计算总耗时
|
||
total_time = time.time() - upload_start_time
|
||
|
||
if oss_url:
|
||
logger.info(f"✅ 视频已上传到 OSS:{oss_url}")
|
||
logger.info(f"📊 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒")
|
||
# 使用 OSS URL 替换临时 URL
|
||
video_url = oss_url
|
||
else:
|
||
logger.warning("⚠️ OSS 上传失败,使用原始临时 URL")
|
||
logger.warning(f"📊 上传统计 - 下载耗时:{download_time:.2f} 秒,上传耗时:{upload_time:.2f} 秒,总耗时:{total_time:.2f} 秒(上传失败)")
|
||
# 如果 OSS 上传失败,继续使用原始 URL
|
||
|
||
except Exception as upload_error:
|
||
logger.error(f"❌ 上传视频到 OSS 失败:{upload_error}", exc_info=True)
|
||
# 如果上传失败,继续使用原始临时 URL
|
||
logger.warning("⚠️ 使用原始临时视频 URL")
|
||
|
||
result = f"""<video controls width="100%" style="max-width: 600px;">
|
||
<source src="{video_url}" type="video/mp4">
|
||
您的浏览器不支持视频播放
|
||
</video>"""
|
||
else:
|
||
print('视频请求失败, status_code: %s, code: %s, message: %s' %
|
||
(rsp.status_code, rsp.code, rsp.message))
|
||
result = "视频生成失败"
|
||
logger.info(f"✅ 视频生成完成:{video_url}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
error_msg = f"❌ 生成视频异常:{str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return error_msg
|
||
|
||
|
||
@tool
|
||
def text_to_poster(
|
||
title: str,
|
||
sub_title: str = "",
|
||
body_text: str = "",
|
||
prompt_text_zh: str = "",
|
||
prompt_text_en: str = "",
|
||
size: str = "1280*1280",
|
||
) -> str:
|
||
"""
|
||
创意海报生成工具:根据标题、副标题和正文内容生成创意海报。
|
||
|
||
当用户需要生成海报、创建宣传图、制作营销图片、需要创意设计时,使用此工具。
|
||
该工具使用阿里云百炼平台的文生图(万相)模型,根据海报文案生成相应的创意海报图片。
|
||
生成的海报会自动上传到 OSS 存储,返回永久可访问的 URL。海报右下角带有「AI生成」水印。
|
||
|
||
使用场景:
|
||
- 用户说"生成一张...的海报"、"创建一个...的宣传图"、"制作...的创意海报"
|
||
- 需要为活动、产品、品牌创建宣传海报
|
||
- 社交媒体营销图片生成(微信、微博、小红书等)
|
||
- 需要展示标题、副标题和正文内容的创意设计
|
||
- 用户想要可视化某个主题或概念的海报
|
||
|
||
参数说明:
|
||
title (必需): 海报的主标题。应该简洁有力,能够吸引注意力。
|
||
示例:"春季新品发布"、"限时优惠活动"、"品牌宣传"
|
||
|
||
sub_title (可选): 海报的副标题。用于补充说明主标题或提供更多信息。
|
||
示例:"全场8折起"、"限时3天"、"专业团队打造"
|
||
|
||
body_text (可选): 海报的正文内容。可以包含详细说明、活动规则、联系方式等。
|
||
示例:"活动时间:2024年3月1日-3月31日\n活动地点:全国门店\n咨询热线:400-xxx-xxxx"
|
||
|
||
prompt_text_zh (可选): 中文提示文本,用于描述海报的视觉风格和设计元素。
|
||
示例:"小朋友画的可爱的龙,白色背景"、"温馨的节日氛围,红色和金色主题"
|
||
如果未提供,将根据标题和副标题自动生成。
|
||
|
||
prompt_text_en (可选): 英文提示文本,用于描述海报的视觉风格和设计元素。
|
||
示例:"Children draw a lovely dragon, white background"、"Warm festive atmosphere, red and gold theme"
|
||
如果未提供,将根据标题和副标题自动生成。
|
||
注意:prompt_text_zh 和 prompt_text_en 至少需要设置其中一个。
|
||
|
||
size string (可选)
|
||
|
||
输出图像的分辨率,格式为宽*高。
|
||
|
||
默认值为 1280*1280。
|
||
|
||
总像素在 [1280*1280, 1440*1440] 之间且宽高比范围为 [1:4, 4:1]。例如,768*2700符合要求。
|
||
|
||
示例值:1280*1280。
|
||
|
||
常见比例推荐的分辨率
|
||
|
||
1:1:1280*1280
|
||
|
||
3:4:1104*1472
|
||
|
||
4:3:1472*1104
|
||
|
||
9:16:960*1696
|
||
|
||
16:9:1696*960
|
||
|
||
返回值:
|
||
返回包含海报图片的 Markdown 格式字符串,海报会自动显示在对话中。
|
||
|
||
注意事项:
|
||
- 生成海报需要一定时间,请耐心等待
|
||
- 标题、副标题和正文内容越清晰,生成的海报质量越好
|
||
- 生成的海报带有 AI 水印标识
|
||
"""
|
||
try:
|
||
api_key = settings.dashscope_api_key
|
||
if not api_key:
|
||
return "错误:未配置 DASHSCOPE_API_KEY 环境变量"
|
||
|
||
dashscope.base_http_api_url = _dashscope_http_api_base()
|
||
|
||
logger.info(f"开始生成创意海报,title: {title}, sub_title: {sub_title}, body_text: {(body_text[:50] + '...') if len(body_text) > 50 else body_text}")
|
||
|
||
# 构建海报专用 prompt:将标题、副标题、正文与视觉风格融合为文生图提示词
|
||
prompt_parts = ["创意海报设计,宣传海报,专业排版,醒目吸睛"]
|
||
|
||
if title:
|
||
prompt_parts.append(f"主标题:{title}")
|
||
if sub_title:
|
||
prompt_parts.append(f"副标题:{sub_title}")
|
||
if body_text:
|
||
# 正文可能较长,截取关键信息(限制约 200 字符)
|
||
body_summary = body_text.replace("\n", " ")[:200]
|
||
if len(body_text) > 200:
|
||
body_summary += "..."
|
||
prompt_parts.append(f"正文内容:{body_summary}")
|
||
|
||
# 视觉风格:优先使用用户提供的提示词
|
||
if prompt_text_zh:
|
||
prompt_parts.append(f"视觉风格:{prompt_text_zh}")
|
||
elif prompt_text_en:
|
||
prompt_parts.append(f"Visual style: {prompt_text_en}")
|
||
else:
|
||
prompt_parts.append("精美设计,高质量海报风格")
|
||
|
||
prompt = ",".join(prompt_parts)
|
||
logger.info(f"海报生成 prompt: {prompt[:200]}...")
|
||
|
||
# 使用与 text_to_image 相同的文生图 API(万相 wan2.2-t2i-flash)
|
||
# 海报需要水印,故 watermark=True
|
||
rsp = ImageSynthesis.call(
|
||
api_key=api_key,
|
||
model="wan2.5-t2i-preview",
|
||
prompt=prompt,
|
||
n=1,
|
||
size=size,
|
||
negative_prompt="低分辨率,低画质,画面模糊,文字扭曲,构图混乱,画面过饱和",
|
||
prompt_extend=True,
|
||
watermark=True,
|
||
)
|
||
|
||
if rsp.status_code != HTTPStatus.OK:
|
||
logger.error(f"海报生成失败, status_code: {rsp.status_code}, code: {rsp.code}, message: {rsp.message}")
|
||
return f"海报生成失败:{rsp.message or '请稍后重试'}"
|
||
|
||
image_urls = []
|
||
if rsp.output and rsp.output.results:
|
||
for result in rsp.output.results:
|
||
if hasattr(result, 'url') and result.url:
|
||
image_urls.append(result.url)
|
||
|
||
if not image_urls:
|
||
return "海报生成完成,但未获取到图片URL"
|
||
|
||
image_url = image_urls[0]
|
||
logger.info(f"海报生成成功,图片URL: {image_url},开始上传到 OSS")
|
||
|
||
# 复用 text_to_image 的 OSS 上传逻辑
|
||
_, oss_url, _, _ = _download_and_upload_image_to_oss(image_url, 1)
|
||
final_url = oss_url if oss_url else image_url
|
||
|
||
result_text = f"创意海报生成成功!\n\n{final_url}\n\n"
|
||
result_text += f"**标题**:{title}\n"
|
||
if sub_title:
|
||
result_text += f"**副标题**:{sub_title}\n"
|
||
if body_text:
|
||
result_text += f"**正文**:{body_text}\n"
|
||
|
||
logger.info("✅ 海报生成完成")
|
||
return result_text
|
||
|
||
except Exception as e:
|
||
error_msg = f"❌ 生成海报异常:{str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
return error_msg
|
||
|
||
|
||
def create_rag_retrieve_tool(thread_id: str):
|
||
"""
|
||
创建 RAG 检索工具(用于对话文件)
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
tool: RAG 检索工具
|
||
"""
|
||
vector_service = get_vector_service()
|
||
|
||
@tool(response_format="content_and_artifact")
|
||
def retrieve_context_from_files(query: str):
|
||
"""
|
||
从用户上传的文件中检索相关信息来帮助回答问题。
|
||
|
||
当用户的问题涉及到上传的文件内容时,使用此工具检索相关文档片段。
|
||
例如:用户上传了PDF文件后,询问文件中的具体内容、数据、概念等。
|
||
|
||
Args:
|
||
query: 用户的查询问题
|
||
|
||
Returns:
|
||
tuple: (检索到的文档内容字符串, 检索结果列表)
|
||
"""
|
||
try:
|
||
# 使用向量服务搜索相似文档
|
||
results = vector_service.search_similar_in_thread(
|
||
thread_id=thread_id,
|
||
query=query,
|
||
k=5 # 返回最相关的5个文档片段
|
||
)
|
||
|
||
if not results:
|
||
return "未在文件中找到相关信息。", []
|
||
|
||
# 格式化检索结果
|
||
content_parts = []
|
||
for idx, result in enumerate(results, 1):
|
||
content = result.get("content", "")
|
||
metadata = result.get("metadata", {})
|
||
score = result.get("score", 0)
|
||
|
||
# 构建来源信息
|
||
source_info = []
|
||
if metadata:
|
||
if "source" in metadata:
|
||
source_info.append(f"来源: {metadata['source']}")
|
||
if "page" in metadata:
|
||
source_info.append(f"页码: {metadata['page']}")
|
||
|
||
source_str = f" ({', '.join(source_info)})" if source_info else ""
|
||
|
||
content_parts.append(
|
||
f"[文档片段 {idx}]{source_str}\n{content}\n"
|
||
)
|
||
|
||
content = "\n".join(content_parts)
|
||
return content, results
|
||
|
||
except Exception as e:
|
||
logger.error(f"RAG 检索失败: {e}")
|
||
return f"检索文件内容时出错: {str(e)}", []
|
||
|
||
return retrieve_context_from_files
|
||
|
||
|
||
def create_kb_rag_retrieve_tool(knowledge_base_id: int):
|
||
"""
|
||
创建知识库 RAG 检索工具
|
||
|
||
Args:
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
tool: 知识库 RAG 检索工具
|
||
"""
|
||
vector_service = get_vector_service()
|
||
|
||
@tool(response_format="content_and_artifact")
|
||
def retrieve_context_from_knowledge_base(query: str):
|
||
"""
|
||
从知识库中检索相关信息来帮助回答问题。
|
||
|
||
当用户的问题涉及到知识库中的内容时,使用此工具检索相关文档片段。
|
||
知识库包含了用户预先上传和整理的文件内容。
|
||
|
||
Args:
|
||
query: 用户的查询问题
|
||
|
||
Returns:
|
||
tuple: (检索到的文档内容字符串, 检索结果列表)
|
||
"""
|
||
try:
|
||
# 使用向量服务搜索知识库中的相似文档
|
||
results = vector_service.search_similar(
|
||
knowledge_base_id=knowledge_base_id,
|
||
query=query,
|
||
k=5 # 返回最相关的5个文档片段
|
||
)
|
||
|
||
if not results:
|
||
return "未在知识库中找到相关信息。", []
|
||
|
||
# 格式化检索结果
|
||
content_parts = []
|
||
for idx, result in enumerate(results, 1):
|
||
content = result.get("content", "")
|
||
metadata = result.get("metadata", {})
|
||
score = result.get("score", 0)
|
||
|
||
# 构建来源信息
|
||
source_info = []
|
||
if metadata:
|
||
if "source" in metadata:
|
||
source_info.append(f"来源: {metadata['source']}")
|
||
if "page" in metadata:
|
||
source_info.append(f"页码: {metadata['page']}")
|
||
|
||
source_str = f" ({', '.join(source_info)})" if source_info else ""
|
||
|
||
content_parts.append(
|
||
f"[知识库文档片段 {idx}]{source_str}\n{content}\n"
|
||
)
|
||
|
||
content = "\n".join(content_parts)
|
||
return content, results
|
||
|
||
except Exception as e:
|
||
logger.error(f"知识库 RAG 检索失败: {e}")
|
||
return f"检索知识库内容时出错: {str(e)}", []
|
||
|
||
return retrieve_context_from_knowledge_base
|
||
|
||
|
||
def create_knowledge_graph_rag_retrieve_tool(knowledge_graph_pk: int):
|
||
"""
|
||
创建「知识图谱」绑定的正文向量检索工具(与 Neo4j 实体关系互补)。
|
||
"""
|
||
vector_service = get_vector_service()
|
||
|
||
@tool(response_format="content_and_artifact")
|
||
def retrieve_context_from_knowledge_graph(query: str):
|
||
"""
|
||
从用户选中的知识图谱资料正文中检索相关片段,用于回答细节、原文依据等问题。
|
||
|
||
当问题涉及资料内容、叙述、对话、描写而非仅关系网络时,应使用本工具。
|
||
|
||
Args:
|
||
query: 检索查询(可与用户问题同义改写)
|
||
|
||
Returns:
|
||
tuple: (检索到的文本片段拼接字符串, 检索结果列表)
|
||
"""
|
||
try:
|
||
results = vector_service.search_similar_knowledge_graph(
|
||
knowledge_graph_pk=knowledge_graph_pk,
|
||
query=query,
|
||
k=5,
|
||
)
|
||
if not results:
|
||
return "未在该知识图谱资料正文中找到相关片段。", []
|
||
|
||
content_parts = []
|
||
for idx, result in enumerate(results, 1):
|
||
content = result.get("content", "")
|
||
metadata = result.get("metadata", {}) or {}
|
||
chunk_i = metadata.get("chunk_index", "")
|
||
prefix = f"[资料原文片段 {idx}]"
|
||
if chunk_i != "":
|
||
prefix += f" (块 #{chunk_i})"
|
||
content_parts.append(f"{prefix}\n{content}\n")
|
||
|
||
return "\n".join(content_parts), results
|
||
except Exception as e:
|
||
logger.error(f"知识图谱 RAG 检索失败: {e}")
|
||
return f"检索资料正文时出错: {str(e)}", []
|
||
|
||
return retrieve_context_from_knowledge_graph
|
||
|
||
|
||
def _format_knowledge_graph_neo4j_result(result: dict) -> str:
|
||
"""将 Neo4j search_knowledge_graph 的返回结果转为给模型阅读的文本。"""
|
||
msg = result.get("message")
|
||
if msg:
|
||
return msg
|
||
seeds = result.get("seeds") or []
|
||
elements = result.get("elements") or []
|
||
if not seeds and not elements:
|
||
return "未在知识图谱中找到与关键词匹配的实体或关系。"
|
||
lines: list[str] = []
|
||
if seeds:
|
||
lines.append(f"关键词命中的实体: {', '.join(seeds)}")
|
||
edges: list[str] = []
|
||
for el in elements:
|
||
d = el.get("data") or {}
|
||
if "source" not in d:
|
||
continue
|
||
rel = (d.get("label") or d.get("type") or "关系").strip()
|
||
note = (d.get("note") or "").strip()
|
||
suf = f"(备注: {note})" if note else ""
|
||
edges.append(f"- {d['source']} —[{rel}]→ {d['target']}{suf}")
|
||
if edges:
|
||
lines.append("关系边(来自图数据库,Person/RELATION):")
|
||
lines.extend(edges[:100])
|
||
elif seeds:
|
||
lines.append("已命中实体,但未检索到相连的关系边;可尝试增大 hops 或更换关键词。")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def create_knowledge_graph_neo4j_search_tool(neo4j_graph_id: str):
|
||
"""
|
||
创建基于 Neo4j 的实体/关系查询工具(与正文向量检索互补)。
|
||
"""
|
||
from services.neo4j_service import search_knowledge_graph
|
||
|
||
@tool
|
||
def query_knowledge_graph_relations(entity_keyword: str, hops: int = 2) -> str:
|
||
"""
|
||
在当前绑定的知识图谱(Neo4j)中按关键词查找人物/实体,并返回其关联关系。
|
||
|
||
当用户询问「某人是谁」「某人和谁的关系」「亲属/子女/上下级/合作」等**实体关系**时,应优先使用本工具。
|
||
entity_keyword 为人名或实体名(可只填部分,如姓氏或名);若无结果可换关键词再试。
|
||
hops 为关系扩展深度:1 仅直接关系,2 为两跳内(默认),最大 3。
|
||
|
||
Args:
|
||
entity_keyword: 要查找的实体关键词
|
||
hops: 关系跳数,1–3
|
||
"""
|
||
try:
|
||
kw = (entity_keyword or "").strip()
|
||
if not kw:
|
||
return "请提供非空的实体关键词。"
|
||
h = max(1, min(int(hops), 3))
|
||
result = search_knowledge_graph(neo4j_graph_id, kw, hops=h)
|
||
return _format_knowledge_graph_neo4j_result(result)
|
||
except Exception as e:
|
||
logger.error(f"知识图谱 Neo4j 查询失败: {e}", exc_info=True)
|
||
return f"知识图谱关系查询失败: {e}"
|
||
|
||
return query_knowledge_graph_relations
|