285 lines
9.6 KiB
Python
285 lines
9.6 KiB
Python
"""
|
||
视觉模型服务
|
||
|
||
基于 OpenAI 兼容 ``chat.completions`` 接口(参见阿里云百炼视觉文档),使用支持图像输入的多模态模型。
|
||
默认模型 ``qwen3-vl-plus``;网关与密钥优先读 ``ZL_OPENAI_*``,未配置时回退 ``DASHSCOPE_*`` / ``settings``。
|
||
"""
|
||
import asyncio
|
||
import base64
|
||
import os
|
||
from typing import Optional
|
||
|
||
from openai import OpenAI, AsyncOpenAI
|
||
from core.llm_env import tongyi_openai_compatible_base_url
|
||
from core.config import settings
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 与 OpenAI 兼容视觉示例一致;可通过环境变量覆盖
|
||
VISION_CHAT_MODEL = (os.getenv("ZL_OPENAI_VISION_MODEL") or "qwen3-vl-plus").strip()
|
||
|
||
|
||
def _vision_openai_api_key() -> str:
|
||
zl = (os.getenv("ZL_OPENAI_API_KEY") or "").strip()
|
||
if zl:
|
||
return zl
|
||
return (settings.dashscope_api_key or "").strip()
|
||
|
||
|
||
def _vision_openai_base_url() -> str:
|
||
zl = (os.getenv("ZL_OPENAI_BASE_URL") or "").strip().rstrip("/")
|
||
if zl:
|
||
return zl
|
||
return tongyi_openai_compatible_base_url().strip().rstrip("/")
|
||
|
||
|
||
def _vision_client_signature() -> tuple[str, str]:
|
||
return (_vision_openai_api_key(), _vision_openai_base_url())
|
||
|
||
|
||
def _is_vision_image_url(url: str) -> bool:
|
||
if not url:
|
||
return False
|
||
if url.startswith(("http://", "https://")):
|
||
return True
|
||
if url.startswith("data:image/") and "base64," in url:
|
||
return True
|
||
return False
|
||
|
||
|
||
def image_bytes_to_data_url(image_bytes: bytes, mime_hint: Optional[str] = None) -> str:
|
||
"""将本地图片字节转为 OpenAI/DashScope 兼容的 data URL。"""
|
||
mime = mime_hint or "image/jpeg"
|
||
if mime_hint is None:
|
||
if len(image_bytes) >= 8 and image_bytes[:8] == b"\x89PNG\r\n\x1a\n":
|
||
mime = "image/png"
|
||
elif len(image_bytes) >= 3 and image_bytes[:3] == b"\xff\xd8\xff":
|
||
mime = "image/jpeg"
|
||
elif len(image_bytes) >= 6 and image_bytes[:6] in (b"GIF87a", b"GIF89a"):
|
||
mime = "image/gif"
|
||
elif len(image_bytes) >= 2 and image_bytes[:2] == b"BM":
|
||
mime = "image/bmp"
|
||
elif len(image_bytes) >= 12 and image_bytes[:4] == b"RIFF" and image_bytes[8:12] == b"WEBP":
|
||
mime = "image/webp"
|
||
b64 = base64.standard_b64encode(image_bytes).decode("ascii")
|
||
return f"data:{mime};base64,{b64}"
|
||
|
||
|
||
class VisionService:
|
||
"""视觉模型服务:OpenAI SDK ``chat.completions`` + ``image_url``(URL 或 base64 data URL)。"""
|
||
|
||
_client_cache: Optional[AsyncOpenAI] = None
|
||
_sync_client_cache: Optional[OpenAI] = None
|
||
_async_client_sig: Optional[tuple[str, str]] = None
|
||
_sync_client_sig: Optional[tuple[str, str]] = None
|
||
_lock = asyncio.Lock()
|
||
|
||
@classmethod
|
||
async def _get_async_client(cls) -> AsyncOpenAI:
|
||
"""获取或创建异步客户端(凭证变更时重建)。"""
|
||
sig = _vision_client_signature()
|
||
async with cls._lock:
|
||
if cls._client_cache is not None and cls._async_client_sig == sig:
|
||
return cls._client_cache
|
||
cls._async_client_sig = sig
|
||
cls._client_cache = AsyncOpenAI(api_key=sig[0], base_url=sig[1])
|
||
return cls._client_cache
|
||
|
||
@classmethod
|
||
def _get_sync_client(cls) -> OpenAI:
|
||
"""获取或创建同步客户端(凭证变更时重建)。"""
|
||
sig = _vision_client_signature()
|
||
if cls._sync_client_cache is not None and cls._sync_client_sig == sig:
|
||
return cls._sync_client_cache
|
||
cls._sync_client_sig = sig
|
||
cls._sync_client_cache = OpenAI(api_key=sig[0], base_url=sig[1])
|
||
return cls._sync_client_cache
|
||
|
||
@classmethod
|
||
async def get_image_description(
|
||
cls,
|
||
image_url: str,
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内"
|
||
) -> str:
|
||
"""
|
||
获取图片的描述(异步)
|
||
|
||
Args:
|
||
image_url: 图片的 URL 地址(必须是 http/https 开头)
|
||
prompt: 提示词,用于引导模型生成描述
|
||
|
||
Returns:
|
||
str: 图片描述文本
|
||
"""
|
||
if not _is_vision_image_url(image_url):
|
||
logger.warning(f"无效的图片 URL: {image_url[:80] if image_url else ''}")
|
||
return ""
|
||
|
||
try:
|
||
client = await cls._get_async_client()
|
||
|
||
completion = await client.chat.completions.create(
|
||
model=VISION_CHAT_MODEL,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image_url", "image_url": {"url": image_url}},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
],
|
||
)
|
||
|
||
description = completion.choices[0].message.content or ""
|
||
if description:
|
||
logger.info(f"成功获取图片描述: {description[:50]}...")
|
||
return description
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取图片描述失败: {e}")
|
||
return ""
|
||
|
||
@classmethod
|
||
async def get_image_description_from_bytes(
|
||
cls,
|
||
image_bytes: bytes,
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内",
|
||
mime_hint: Optional[str] = None,
|
||
) -> str:
|
||
"""
|
||
从内存中的图片字节获取描述(异步),使用 data URL 调用通义 VL。
|
||
用于知识图谱上传等无公网 URL 的场景。
|
||
"""
|
||
if not _vision_openai_api_key():
|
||
logger.warning("未配置 ZL_OPENAI_API_KEY 或 DASHSCOPE_API_KEY,无法进行视觉理解")
|
||
return ""
|
||
if not image_bytes:
|
||
return ""
|
||
data_url = image_bytes_to_data_url(image_bytes, mime_hint)
|
||
return await cls.get_image_description(data_url, prompt=prompt)
|
||
|
||
@classmethod
|
||
def get_image_description_sync(
|
||
cls,
|
||
image_url: str,
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内"
|
||
) -> str:
|
||
"""
|
||
获取图片的描述(同步)
|
||
|
||
Args:
|
||
image_url: 图片的 URL 地址(必须是 http/https 开头)
|
||
prompt: 提示词,用于引导模型生成描述
|
||
|
||
Returns:
|
||
str: 图片描述文本
|
||
"""
|
||
if not _is_vision_image_url(image_url):
|
||
logger.warning(f"无效的图片 URL: {image_url[:80] if image_url else ''}")
|
||
return ""
|
||
|
||
try:
|
||
client = cls._get_sync_client()
|
||
|
||
completion = client.chat.completions.create(
|
||
model=VISION_CHAT_MODEL,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image_url", "image_url": {"url": image_url}},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
],
|
||
)
|
||
|
||
description = completion.choices[0].message.content or ""
|
||
if description:
|
||
logger.info(f"成功获取图片描述: {description[:50]}...")
|
||
return description
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取图片描述失败: {e}")
|
||
return ""
|
||
|
||
@classmethod
|
||
async def analyze_image_with_question(
|
||
cls,
|
||
image_url: str,
|
||
question: str
|
||
) -> str:
|
||
"""
|
||
基于问题分析图片
|
||
|
||
Args:
|
||
image_url: 图片的 URL 地址
|
||
question: 用户的问题
|
||
|
||
Returns:
|
||
str: 分析结果
|
||
"""
|
||
if not _is_vision_image_url(image_url):
|
||
logger.warning(f"无效的图片 URL: {image_url[:80] if image_url else ''}")
|
||
return ""
|
||
|
||
try:
|
||
client = await cls._get_async_client()
|
||
|
||
completion = await client.chat.completions.create(
|
||
model=VISION_CHAT_MODEL,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image_url", "image_url": {"url": image_url}},
|
||
{"type": "text", "text": question},
|
||
],
|
||
}
|
||
],
|
||
)
|
||
|
||
answer = completion.choices[0].message.content or ""
|
||
if answer:
|
||
logger.info("成功分析图片并回答问题")
|
||
return answer
|
||
|
||
except Exception as e:
|
||
logger.error(f"分析图片失败: {e}")
|
||
return ""
|
||
|
||
|
||
# 批量处理辅助函数
|
||
async def batch_get_image_descriptions(
|
||
image_urls: list[str],
|
||
prompt: str = "图中的主要内容是什么?回答以'图片'开头, 500字以内"
|
||
) -> dict[str, str]:
|
||
"""
|
||
批量获取图片描述
|
||
|
||
Args:
|
||
image_urls: 图片 URL 列表
|
||
prompt: 提示词
|
||
|
||
Returns:
|
||
dict: URL 到描述的映射
|
||
"""
|
||
tasks = [
|
||
VisionService.get_image_description(url, prompt)
|
||
for url in image_urls
|
||
]
|
||
|
||
descriptions = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
result = {}
|
||
for url, desc in zip(image_urls, descriptions):
|
||
if isinstance(desc, Exception):
|
||
logger.error(f"获取图片描述失败 {url}: {desc}")
|
||
result[url] = ""
|
||
else:
|
||
result[url] = desc
|
||
|
||
return result
|