259 lines
7.2 KiB
Python
259 lines
7.2 KiB
Python
"""
|
||
通用工具函数模块
|
||
|
||
提供 API 响应模型、HTTP 配置、线程池等通用工具。
|
||
"""
|
||
import os
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from pathlib import Path
|
||
from typing import Any, Callable, Dict, Generator, List, Optional, Union
|
||
from urllib.parse import urlparse
|
||
|
||
import httpx
|
||
from fastapi import FastAPI
|
||
from pydantic import BaseModel, Field
|
||
|
||
from core.config import settings
|
||
|
||
|
||
def get_base_url(url: str) -> str:
|
||
"""
|
||
从 URL 中提取基础 URL
|
||
|
||
Args:
|
||
url: 完整 URL
|
||
|
||
Returns:
|
||
基础 URL(scheme + netloc)
|
||
"""
|
||
parsed_url = urlparse(url)
|
||
base_url = '{uri.scheme}://{uri.netloc}/'.format(uri=parsed_url)
|
||
return base_url.rstrip('/')
|
||
|
||
|
||
class MsgType:
|
||
"""消息类型常量"""
|
||
TEXT = 1
|
||
IMAGE = 2
|
||
AUDIO = 3
|
||
VIDEO = 4
|
||
|
||
|
||
class BaseResponse(BaseModel):
|
||
"""API 基础响应模型"""
|
||
code: int = Field(200, description="API status code")
|
||
msg: str = Field("success", description="API status message")
|
||
data: Any = Field(None, description="API data")
|
||
|
||
class Config:
|
||
json_schema_extra = {
|
||
"example": {
|
||
"code": 200,
|
||
"msg": "success",
|
||
}
|
||
}
|
||
|
||
|
||
class ListResponse(BaseResponse):
|
||
"""列表响应模型"""
|
||
data: List[Any] = Field(..., description="List of data")
|
||
|
||
class Config:
|
||
json_schema_extra = {
|
||
"example": {
|
||
"code": 200,
|
||
"msg": "success",
|
||
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
|
||
}
|
||
}
|
||
|
||
|
||
def api_address(is_public: bool = False) -> str:
|
||
"""
|
||
获取 API 服务器地址
|
||
|
||
Args:
|
||
is_public: 是否返回公网地址
|
||
|
||
Returns:
|
||
API 服务器地址
|
||
"""
|
||
return settings.api_address
|
||
|
||
|
||
def set_httpx_config(
|
||
timeout: Optional[float] = None,
|
||
proxy: Union[str, Dict, None] = None,
|
||
unused_proxies: List[str] = [],
|
||
):
|
||
"""
|
||
设置 httpx 默认配置
|
||
|
||
设置 httpx 默认 timeout,将本项目相关服务加入无代理列表。
|
||
|
||
Args:
|
||
timeout: 超时时间(秒)
|
||
proxy: 代理配置
|
||
unused_proxies: 不使用代理的地址列表
|
||
"""
|
||
if timeout is None:
|
||
timeout = settings.httpx_default_timeout
|
||
|
||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||
|
||
# 设置系统级代理
|
||
proxies = {}
|
||
if isinstance(proxy, str):
|
||
for n in ["http", "https", "all"]:
|
||
proxies[n + "_proxy"] = proxy
|
||
elif isinstance(proxy, dict):
|
||
for n in ["http", "https", "all"]:
|
||
if p := proxy.get(n):
|
||
proxies[n + "_proxy"] = p
|
||
elif p := proxy.get(n + "_proxy"):
|
||
proxies[n + "_proxy"] = p
|
||
|
||
for k, v in proxies.items():
|
||
os.environ[k] = v
|
||
|
||
# 设置不使用代理的地址
|
||
no_proxy = [
|
||
x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()
|
||
]
|
||
no_proxy += [
|
||
"http://127.0.0.1",
|
||
"http://localhost",
|
||
]
|
||
for x in unused_proxies:
|
||
host = ":".join(x.split(":")[:2])
|
||
if host not in no_proxy:
|
||
no_proxy.append(host)
|
||
os.environ["NO_PROXY"] = ",".join(no_proxy)
|
||
|
||
def _get_proxies():
|
||
return proxies
|
||
|
||
import urllib.request
|
||
urllib.request.getproxies = _get_proxies
|
||
|
||
|
||
def run_in_thread_pool(
|
||
func: Callable,
|
||
params: List[Dict] = [],
|
||
) -> Generator:
|
||
"""
|
||
在线程池中批量运行任务
|
||
|
||
Args:
|
||
func: 要执行的函数
|
||
params: 参数列表,每个元素是一个关键字参数字典
|
||
|
||
Yields:
|
||
任务执行结果
|
||
"""
|
||
tasks = []
|
||
with ThreadPoolExecutor() as pool:
|
||
for kwargs in params:
|
||
tasks.append(pool.submit(func, **kwargs))
|
||
|
||
for obj in as_completed(tasks):
|
||
try:
|
||
yield obj.result()
|
||
except Exception as e:
|
||
print(f"error in sub thread: {e}\n")
|
||
|
||
|
||
def get_server_configs() -> Dict:
|
||
"""获取服务器配置,供前端使用"""
|
||
return {
|
||
"api_address": api_address(),
|
||
}
|
||
|
||
|
||
def make_fastapi_offline(
|
||
app: FastAPI,
|
||
static_dir: Path = Path(__file__).resolve().parent.parent / "static" / "api_server",
|
||
static_url: str = "/static-offline-docs",
|
||
docs_url: Optional[str] = "/docs",
|
||
redoc_url: Optional[str] = "/redoc",
|
||
) -> None:
|
||
"""
|
||
配置 FastAPI 离线文档
|
||
|
||
使用本地静态文件替代 CDN,支持离线访问 API 文档。
|
||
|
||
Args:
|
||
app: FastAPI 应用实例
|
||
static_dir: 静态文件目录
|
||
static_url: 静态文件 URL 前缀
|
||
docs_url: Swagger UI 文档地址
|
||
redoc_url: ReDoc 文档地址
|
||
"""
|
||
from fastapi import Request
|
||
from fastapi.openapi.docs import (
|
||
get_redoc_html,
|
||
get_swagger_ui_html,
|
||
get_swagger_ui_oauth2_redirect_html,
|
||
)
|
||
from fastapi.staticfiles import StaticFiles
|
||
from starlette.responses import HTMLResponse
|
||
|
||
openapi_url = app.openapi_url
|
||
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
|
||
|
||
def remove_route(url: str) -> None:
|
||
"""移除原有路由"""
|
||
index = None
|
||
for i, r in enumerate(app.routes):
|
||
if r.path.lower() == url.lower():
|
||
index = i
|
||
break
|
||
if isinstance(index, int):
|
||
app.routes.pop(index)
|
||
|
||
# 挂载静态文件
|
||
if static_dir.exists():
|
||
app.mount(
|
||
static_url,
|
||
StaticFiles(directory=str(static_dir)),
|
||
name="static-offline-docs",
|
||
)
|
||
|
||
if docs_url is not None:
|
||
remove_route(docs_url)
|
||
remove_route(swagger_ui_oauth2_redirect_url)
|
||
|
||
@app.get(docs_url, include_in_schema=False)
|
||
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
|
||
root = request.scope.get("root_path")
|
||
favicon = f"{root}{static_url}/favicon.png"
|
||
return get_swagger_ui_html(
|
||
openapi_url=f"{root}{openapi_url}",
|
||
title=app.title + " - Swagger UI",
|
||
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
|
||
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
|
||
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
|
||
swagger_favicon_url=favicon,
|
||
)
|
||
|
||
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
||
async def swagger_ui_redirect() -> HTMLResponse:
|
||
return get_swagger_ui_oauth2_redirect_html()
|
||
|
||
if redoc_url is not None:
|
||
remove_route(redoc_url)
|
||
|
||
@app.get(redoc_url, include_in_schema=False)
|
||
async def redoc_html(request: Request) -> HTMLResponse:
|
||
root = request.scope.get("root_path")
|
||
favicon = f"{root}{static_url}/favicon.png"
|
||
return get_redoc_html(
|
||
openapi_url=f"{root}{openapi_url}",
|
||
title=app.title + " - ReDoc",
|
||
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
|
||
with_google_fonts=False,
|
||
redoc_favicon_url=favicon,
|
||
)
|