huoyan-enterprise/backend/utils/helpers.py

259 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
通用工具函数模块
提供 API 响应模型、HTTP 配置、线程池等通用工具。
"""
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Union
from urllib.parse import urlparse
import httpx
from fastapi import FastAPI
from pydantic import BaseModel, Field
from core.config import settings
def get_base_url(url: str) -> str:
"""
从 URL 中提取基础 URL
Args:
url: 完整 URL
Returns:
基础 URLscheme + netloc
"""
parsed_url = urlparse(url)
base_url = '{uri.scheme}://{uri.netloc}/'.format(uri=parsed_url)
return base_url.rstrip('/')
class MsgType:
"""消息类型常量"""
TEXT = 1
IMAGE = 2
AUDIO = 3
VIDEO = 4
class BaseResponse(BaseModel):
"""API 基础响应模型"""
code: int = Field(200, description="API status code")
msg: str = Field("success", description="API status message")
data: Any = Field(None, description="API data")
class Config:
json_schema_extra = {
"example": {
"code": 200,
"msg": "success",
}
}
class ListResponse(BaseResponse):
"""列表响应模型"""
data: List[Any] = Field(..., description="List of data")
class Config:
json_schema_extra = {
"example": {
"code": 200,
"msg": "success",
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
}
}
def api_address(is_public: bool = False) -> str:
"""
获取 API 服务器地址
Args:
is_public: 是否返回公网地址
Returns:
API 服务器地址
"""
return settings.api_address
def set_httpx_config(
timeout: Optional[float] = None,
proxy: Union[str, Dict, None] = None,
unused_proxies: List[str] = [],
):
"""
设置 httpx 默认配置
设置 httpx 默认 timeout将本项目相关服务加入无代理列表。
Args:
timeout: 超时时间(秒)
proxy: 代理配置
unused_proxies: 不使用代理的地址列表
"""
if timeout is None:
timeout = settings.httpx_default_timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 设置系统级代理
proxies = {}
if isinstance(proxy, str):
for n in ["http", "https", "all"]:
proxies[n + "_proxy"] = proxy
elif isinstance(proxy, dict):
for n in ["http", "https", "all"]:
if p := proxy.get(n):
proxies[n + "_proxy"] = p
elif p := proxy.get(n + "_proxy"):
proxies[n + "_proxy"] = p
for k, v in proxies.items():
os.environ[k] = v
# 设置不使用代理的地址
no_proxy = [
x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()
]
no_proxy += [
"http://127.0.0.1",
"http://localhost",
]
for x in unused_proxies:
host = ":".join(x.split(":")[:2])
if host not in no_proxy:
no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy)
def _get_proxies():
return proxies
import urllib.request
urllib.request.getproxies = _get_proxies
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
) -> Generator:
"""
在线程池中批量运行任务
Args:
func: 要执行的函数
params: 参数列表,每个元素是一个关键字参数字典
Yields:
任务执行结果
"""
tasks = []
with ThreadPoolExecutor() as pool:
for kwargs in params:
tasks.append(pool.submit(func, **kwargs))
for obj in as_completed(tasks):
try:
yield obj.result()
except Exception as e:
print(f"error in sub thread: {e}\n")
def get_server_configs() -> Dict:
"""获取服务器配置,供前端使用"""
return {
"api_address": api_address(),
}
def make_fastapi_offline(
app: FastAPI,
static_dir: Path = Path(__file__).resolve().parent.parent / "static" / "api_server",
static_url: str = "/static-offline-docs",
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
) -> None:
"""
配置 FastAPI 离线文档
使用本地静态文件替代 CDN支持离线访问 API 文档。
Args:
app: FastAPI 应用实例
static_dir: 静态文件目录
static_url: 静态文件 URL 前缀
docs_url: Swagger UI 文档地址
redoc_url: ReDoc 文档地址
"""
from fastapi import Request
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
from fastapi.staticfiles import StaticFiles
from starlette.responses import HTMLResponse
openapi_url = app.openapi_url
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
def remove_route(url: str) -> None:
"""移除原有路由"""
index = None
for i, r in enumerate(app.routes):
if r.path.lower() == url.lower():
index = i
break
if isinstance(index, int):
app.routes.pop(index)
# 挂载静态文件
if static_dir.exists():
app.mount(
static_url,
StaticFiles(directory=str(static_dir)),
name="static-offline-docs",
)
if docs_url is not None:
remove_route(docs_url)
remove_route(swagger_ui_oauth2_redirect_url)
@app.get(docs_url, include_in_schema=False)
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_swagger_ui_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - Swagger UI",
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
swagger_favicon_url=favicon,
)
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect() -> HTMLResponse:
return get_swagger_ui_oauth2_redirect_html()
if redoc_url is not None:
remove_route(redoc_url)
@app.get(redoc_url, include_in_schema=False)
async def redoc_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_redoc_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - ReDoc",
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
with_google_fonts=False,
redoc_favicon_url=favicon,
)