287 lines
9.4 KiB
Python
287 lines
9.4 KiB
Python
"""
|
||
视觉模型服务
|
||
|
||
基于阿里云通义千问视觉模型 (qwen-vl-max-latest) 提供图片理解能力。
|
||
参考 server/aaa/jenius_attachment_knowledge_base/jenius_rag_util.py 的实现。
|
||
"""
|
||
import asyncio
|
||
import base64
|
||
from typing import Optional
|
||
|
||
from openai import OpenAI, AsyncOpenAI
|
||
from core.llm_env import tongyi_openai_compatible_base_url
|
||
from core.config import settings
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _is_vision_image_url(url: str) -> bool:
|
||
if not url:
|
||
return False
|
||
if url.startswith(("http://", "https://")):
|
||
return True
|
||
if url.startswith("data:image/") and "base64," in url:
|
||
return True
|
||
return False
|
||
|
||
|
||
def image_bytes_to_data_url(image_bytes: bytes, mime_hint: Optional[str] = None) -> str:
|
||
"""将本地图片字节转为 OpenAI/DashScope 兼容的 data URL。"""
|
||
mime = mime_hint or "image/jpeg"
|
||
if mime_hint is None:
|
||
if len(image_bytes) >= 8 and image_bytes[:8] == b"\x89PNG\r\n\x1a\n":
|
||
mime = "image/png"
|
||
elif len(image_bytes) >= 3 and image_bytes[:3] == b"\xff\xd8\xff":
|
||
mime = "image/jpeg"
|
||
elif len(image_bytes) >= 6 and image_bytes[:6] in (b"GIF87a", b"GIF89a"):
|
||
mime = "image/gif"
|
||
elif len(image_bytes) >= 2 and image_bytes[:2] == b"BM":
|
||
mime = "image/bmp"
|
||
elif len(image_bytes) >= 12 and image_bytes[:4] == b"RIFF" and image_bytes[8:12] == b"WEBP":
|
||
mime = "image/webp"
|
||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||
return f"data:{mime};base64,{b64}"
|
||
|
||
|
||
class VisionService:
|
||
"""视觉模型服务类
|
||
|
||
使用阿里云通义千问视觉模型进行图片理解和描述。
|
||
"""
|
||
|
||
_client_cache: Optional[AsyncOpenAI] = None
|
||
_sync_client_cache: Optional[OpenAI] = None
|
||
_lock = asyncio.Lock()
|
||
|
||
@classmethod
|
||
async def _get_async_client(cls) -> AsyncOpenAI:
|
||
"""获取或创建异步客户端(单例模式)"""
|
||
if cls._client_cache is not None:
|
||
return cls._client_cache
|
||
|
||
async with cls._lock:
|
||
if cls._client_cache is None:
|
||
cls._client_cache = AsyncOpenAI(
|
||
api_key=settings.dashscope_api_key,
|
||
base_url=tongyi_openai_compatible_base_url(),
|
||
)
|
||
return cls._client_cache
|
||
|
||
@classmethod
|
||
def _get_sync_client(cls) -> OpenAI:
|
||
"""获取或创建同步客户端(单例模式)"""
|
||
if cls._sync_client_cache is not None:
|
||
return cls._sync_client_cache
|
||
|
||
cls._sync_client_cache = OpenAI(
|
||
api_key=settings.dashscope_api_key,
|
||
base_url=tongyi_openai_compatible_base_url(),
|
||
)
|
||
return cls._sync_client_cache
|
||
|
||
@classmethod
|
||
async def get_image_description(
|
||
cls,
|
||
image_url: str,
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内"
|
||
) -> str:
|
||
"""
|
||
获取图片的描述(异步)
|
||
|
||
Args:
|
||
image_url: 图片的 URL 地址(必须是 http/https 开头)
|
||
prompt: 提示词,用于引导模型生成描述
|
||
|
||
Returns:
|
||
str: 图片描述文本
|
||
"""
|
||
if not _is_vision_image_url(image_url):
|
||
logger.warning(f"无效的图片 URL: {image_url[:80] if image_url else ''}")
|
||
return ""
|
||
|
||
try:
|
||
client = await cls._get_async_client()
|
||
|
||
completion = await client.chat.completions.create(
|
||
model="qwen-vl-max-latest",
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": image_url}
|
||
},
|
||
{"type": "text", "text": prompt}
|
||
]
|
||
}
|
||
]
|
||
)
|
||
|
||
description = completion.choices[0].message.content
|
||
logger.info(f"成功获取图片描述: {description[:50]}...")
|
||
return description
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取图片描述失败: {e}")
|
||
return ""
|
||
|
||
@classmethod
|
||
async def get_image_description_from_bytes(
|
||
cls,
|
||
image_bytes: bytes,
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内",
|
||
mime_hint: Optional[str] = None,
|
||
) -> str:
|
||
"""
|
||
从内存中的图片字节获取描述(异步),使用 data URL 调用通义 VL。
|
||
用于知识图谱上传等无公网 URL 的场景。
|
||
"""
|
||
if not settings.dashscope_api_key:
|
||
logger.warning("未配置 DASHSCOPE_API_KEY,无法进行视觉理解")
|
||
return ""
|
||
if not image_bytes:
|
||
return ""
|
||
data_url = image_bytes_to_data_url(image_bytes, mime_hint)
|
||
return await cls.get_image_description(data_url, prompt=prompt)
|
||
|
||
@classmethod
|
||
def get_image_description_sync(
|
||
cls,
|
||
image_url: str,
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内"
|
||
) -> str:
|
||
"""
|
||
获取图片的描述(同步)
|
||
|
||
Args:
|
||
image_url: 图片的 URL 地址(必须是 http/https 开头)
|
||
prompt: 提示词,用于引导模型生成描述
|
||
|
||
Returns:
|
||
str: 图片描述文本
|
||
"""
|
||
if not _is_vision_image_url(image_url):
|
||
logger.warning(f"无效的图片 URL: {image_url[:80] if image_url else ''}")
|
||
return ""
|
||
|
||
try:
|
||
client = cls._get_sync_client()
|
||
|
||
completion = client.chat.completions.create(
|
||
model="qwen-vl-max-latest",
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": image_url}
|
||
},
|
||
{"type": "text", "text": prompt}
|
||
]
|
||
}
|
||
]
|
||
)
|
||
|
||
description = completion.choices[0].message.content
|
||
logger.info(f"成功获取图片描述: {description[:50]}...")
|
||
return description
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取图片描述失败: {e}")
|
||
return ""
|
||
|
||
@classmethod
|
||
async def analyze_image_with_question(
|
||
cls,
|
||
image_url: str,
|
||
question: str
|
||
) -> str:
|
||
"""
|
||
基于问题分析图片
|
||
|
||
Args:
|
||
image_url: 图片的 URL 地址
|
||
question: 用户的问题
|
||
|
||
Returns:
|
||
str: 分析结果
|
||
"""
|
||
if not _is_vision_image_url(image_url):
|
||
logger.warning(f"无效的图片 URL: {image_url[:80] if image_url else ''}")
|
||
return ""
|
||
|
||
try:
|
||
client = await cls._get_async_client()
|
||
|
||
completion = await client.chat.completions.create(
|
||
model="qwen-vl-max-latest",
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": [{"type": "text", "text": "You are a helpful assistant that can analyze images and answer questions about them."}]
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": image_url}
|
||
},
|
||
{"type": "text", "text": question}
|
||
]
|
||
}
|
||
]
|
||
)
|
||
|
||
answer = completion.choices[0].message.content
|
||
logger.info(f"成功分析图片并回答问题")
|
||
return answer
|
||
|
||
except Exception as e:
|
||
logger.error(f"分析图片失败: {e}")
|
||
return ""
|
||
|
||
|
||
# 批量处理辅助函数
|
||
async def batch_get_image_descriptions(
|
||
image_urls: list[str],
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内"
|
||
) -> dict[str, str]:
|
||
"""
|
||
批量获取图片描述
|
||
|
||
Args:
|
||
image_urls: 图片 URL 列表
|
||
prompt: 提示词
|
||
|
||
Returns:
|
||
dict: URL 到描述的映射
|
||
"""
|
||
tasks = [
|
||
VisionService.get_image_description(url, prompt)
|
||
for url in image_urls
|
||
]
|
||
|
||
descriptions = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
result = {}
|
||
for url, desc in zip(image_urls, descriptions):
|
||
if isinstance(desc, Exception):
|
||
logger.error(f"获取图片描述失败 {url}: {desc}")
|
||
result[url] = ""
|
||
else:
|
||
result[url] = desc
|
||
|
||
return result
|