486 lines
17 KiB
Python
486 lines
17 KiB
Python
"""
|
||
OSS 文件存储服务
|
||
"""
|
||
import os
|
||
import time
|
||
import tempfile
|
||
from typing import Optional
|
||
from pathlib import Path
|
||
import oss2
|
||
from oss2 import SizedFileAdapter, determine_part_size
|
||
from oss2.models import PartInfo
|
||
|
||
from core.config import settings
|
||
from logger.logging import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class OSSService:
|
||
"""OSS 文件存储服务类"""
|
||
|
||
def __init__(self):
|
||
"""初始化 OSS 客户端"""
|
||
# 从配置读取
|
||
self.access_key_id = settings.oss_access_key_id
|
||
self.access_key_secret = settings.oss_access_key_secret
|
||
self.endpoint = settings.oss_endpoint
|
||
self.bucket_name = settings.oss_bucket_name
|
||
|
||
# 检查配置是否完整
|
||
if not all([self.access_key_id, self.access_key_secret, self.endpoint, self.bucket_name]):
|
||
logger.warning("OSS 配置不完整,将使用本地存储")
|
||
self.enabled = False
|
||
self.external_endpoint = ""
|
||
return
|
||
|
||
# 初始化外网端点(用于 URL 生成)
|
||
self.external_endpoint = self._get_external_endpoint(self.endpoint)
|
||
if self.endpoint != self.external_endpoint:
|
||
logger.info(f"端点转换: {self.endpoint} -> {self.external_endpoint}")
|
||
|
||
try:
|
||
# 初始化 OSS 客户端
|
||
auth = oss2.Auth(self.access_key_id, self.access_key_secret)
|
||
|
||
# 配置超时时间
|
||
# 注意: oss2.Bucket 只支持 connect_timeout 参数,不支持 timeout 参数
|
||
# 如需配置读取超时,需要通过 session 参数传递自定义的 requests.Session 对象
|
||
self.bucket = oss2.Bucket(
|
||
auth,
|
||
self.endpoint,
|
||
self.bucket_name,
|
||
connect_timeout=10 # 连接超时(秒)
|
||
)
|
||
self.enabled = True
|
||
logger.info(f"OSS 服务初始化成功,Bucket: {self.bucket_name}")
|
||
logger.info(f" 上传端点: {self.endpoint}")
|
||
logger.info(f" 访问端点: {self.external_endpoint}")
|
||
|
||
# 检查是否使用内网 endpoint
|
||
if "internal" not in self.endpoint:
|
||
logger.warning(
|
||
"未使用内网 Endpoint。如果服务器在阿里云 ECS 上,"
|
||
f"建议使用内网 endpoint: {self.endpoint.replace('aliyuncs.com', 'internal.aliyuncs.com')}"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"OSS 服务初始化失败: {e}")
|
||
self.enabled = False
|
||
|
||
def _get_external_endpoint(self, endpoint: str) -> str:
|
||
"""
|
||
将内网端点转换为外网端点
|
||
|
||
转换规则:
|
||
- oss-cn-hangzhou-internal.aliyuncs.com → oss-cn-hangzhou.aliyuncs.com
|
||
- https://oss-cn-hangzhou-internal.aliyuncs.com → https://oss-cn-hangzhou.aliyuncs.com
|
||
- 如果不包含 "-internal",返回原端点
|
||
|
||
Args:
|
||
endpoint: 原始端点 URL
|
||
|
||
Returns:
|
||
str: 外网端点 URL
|
||
"""
|
||
# 处理空值和异常情况
|
||
if not endpoint:
|
||
logger.warning("端点为空,返回空字符串")
|
||
return ""
|
||
|
||
try:
|
||
# 移除 "-internal" 字符串(包括前面的连字符)
|
||
external_endpoint = endpoint.replace("-internal", "")
|
||
return external_endpoint
|
||
except Exception as e:
|
||
logger.error(f"端点转换失败: {e},使用原端点")
|
||
return endpoint
|
||
|
||
def upload_file(
|
||
self,
|
||
local_file_path: str,
|
||
oss_object_name: str,
|
||
use_multipart: bool = True
|
||
) -> Optional[str]:
|
||
"""
|
||
上传文件到 OSS
|
||
|
||
Args:
|
||
local_file_path: 本地文件路径
|
||
oss_object_name: OSS 对象名称(存储路径)
|
||
use_multipart: 是否使用分片上传(大文件)
|
||
|
||
Returns:
|
||
Optional[str]: OSS 文件 URL,失败返回 None
|
||
"""
|
||
if not self.enabled:
|
||
logger.warning("OSS 未启用,跳过上传")
|
||
return None
|
||
|
||
if not os.path.exists(local_file_path):
|
||
logger.error(f"文件不存在: {local_file_path}")
|
||
return None
|
||
|
||
try:
|
||
file_size = os.path.getsize(local_file_path)
|
||
|
||
# 大于 100MB 使用分片上传
|
||
if use_multipart and file_size > 100 * 1024 * 1024:
|
||
logger.info(f"文件较大 ({file_size} 字节),使用分片上传")
|
||
success = self._multipart_upload(local_file_path, oss_object_name, file_size)
|
||
else:
|
||
logger.info(f"使用简单上传: {oss_object_name}")
|
||
result = self.bucket.put_object_from_file(oss_object_name, local_file_path)
|
||
success = result.status == 200
|
||
|
||
if success:
|
||
# 生成文件 URL
|
||
file_url = self.get_file_url(oss_object_name)
|
||
logger.info(f"文件上传成功: {local_file_path} -> {oss_object_name}")
|
||
logger.info(f"OSS URL: {file_url}")
|
||
return file_url
|
||
else:
|
||
logger.error(f"文件上传失败: {oss_object_name}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"上传文件到 OSS 失败: {e}")
|
||
return None
|
||
|
||
def upload_file_from_bytes(
|
||
self,
|
||
file_content: bytes,
|
||
oss_object_name: str,
|
||
file_name: str = None
|
||
) -> Optional[str]:
|
||
"""
|
||
从字节流上传文件到 OSS
|
||
|
||
Args:
|
||
file_content: 文件内容(字节)
|
||
oss_object_name: OSS 对象名称(存储路径)
|
||
file_name: 文件名(用于日志)
|
||
|
||
Returns:
|
||
Optional[str]: OSS 文件 URL,失败返回 None
|
||
"""
|
||
if not self.enabled:
|
||
logger.warning("OSS 未启用,跳过上传")
|
||
return None
|
||
|
||
try:
|
||
file_size = len(file_content)
|
||
start_time = time.time()
|
||
|
||
# 大于 1MB 使用分片上传以提高性能
|
||
if file_size > 1 * 1024 * 1024:
|
||
logger.info(f"文件大小 {file_size/1024/1024:.2f}MB,使用分片上传")
|
||
# 写入临时文件用于分片上传
|
||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||
tmp_file.write(file_content)
|
||
tmp_path = tmp_file.name
|
||
|
||
try:
|
||
success = self._multipart_upload(tmp_path, oss_object_name, file_size)
|
||
if not success:
|
||
logger.error(f"分片上传失败: {oss_object_name}")
|
||
return None
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(tmp_path):
|
||
os.remove(tmp_path)
|
||
else:
|
||
# 小文件使用简单上传
|
||
result = self.bucket.put_object(oss_object_name, file_content)
|
||
if result.status != 200:
|
||
logger.error(f"文件上传失败: {oss_object_name}, 状态码: {result.status}")
|
||
return None
|
||
|
||
# 计算上传速度
|
||
elapsed = time.time() - start_time
|
||
speed_mbps = (file_size / 1024 / 1024) / elapsed if elapsed > 0 else 0
|
||
|
||
file_url = self.get_file_url(oss_object_name)
|
||
logger.info(
|
||
f"文件上传成功: {file_name or oss_object_name} -> {oss_object_name}, "
|
||
f"大小: {file_size/1024/1024:.2f}MB, 耗时: {elapsed:.2f}s, 速度: {speed_mbps:.2f}MB/s"
|
||
)
|
||
|
||
# 如果速度过慢,记录警告
|
||
if speed_mbps < 0.5 and file_size > 1024 * 1024:
|
||
logger.warning(
|
||
f"上传速度较慢 ({speed_mbps:.2f}MB/s),建议检查: "
|
||
"1) 是否使用内网 endpoint 2) 服务器与 OSS 是否在同一区域"
|
||
)
|
||
|
||
return file_url
|
||
|
||
except Exception as e:
|
||
logger.error(f"上传文件到 OSS 失败: {e}")
|
||
return None
|
||
|
||
def download_file(
|
||
self,
|
||
oss_object_name: str,
|
||
local_file_path: str = None
|
||
) -> Optional[str]:
|
||
"""
|
||
从 OSS 下载文件到本地
|
||
|
||
Args:
|
||
oss_object_name: OSS 对象名称
|
||
local_file_path: 本地保存路径,如果为 None 则使用临时文件
|
||
|
||
Returns:
|
||
Optional[str]: 本地文件路径,失败返回 None
|
||
"""
|
||
if not self.enabled:
|
||
logger.warning("OSS 未启用,无法下载")
|
||
return None
|
||
|
||
try:
|
||
# 如果没有指定本地路径,使用临时文件
|
||
if local_file_path is None:
|
||
temp_dir = tempfile.gettempdir()
|
||
file_name = Path(oss_object_name).name
|
||
local_file_path = os.path.join(temp_dir, f"oss_download_{file_name}")
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||
|
||
# 下载文件
|
||
self.bucket.get_object_to_file(oss_object_name, local_file_path)
|
||
logger.info(f"文件下载成功: {oss_object_name} -> {local_file_path}")
|
||
return local_file_path
|
||
|
||
except Exception as e:
|
||
logger.error(f"从 OSS 下载文件失败: {e}")
|
||
return None
|
||
|
||
def delete_file(self, oss_object_name: str) -> bool:
|
||
"""
|
||
删除 OSS 上的文件
|
||
|
||
Args:
|
||
oss_object_name: OSS 对象名称
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
if not self.enabled:
|
||
logger.warning("OSS 未启用,跳过删除")
|
||
return False
|
||
|
||
try:
|
||
self.bucket.delete_object(oss_object_name)
|
||
logger.info(f"OSS 文件删除成功: {oss_object_name}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"删除 OSS 文件失败: {e}")
|
||
return False
|
||
|
||
def get_file_url(self, oss_object_name: str) -> str:
|
||
"""
|
||
获取文件的访问 URL(使用外网端点)
|
||
|
||
Args:
|
||
oss_object_name: OSS 对象名称
|
||
|
||
Returns:
|
||
str: 文件 URL(使用外网端点,确保公网可访问)
|
||
"""
|
||
if not self.enabled:
|
||
return ""
|
||
|
||
# 构建 OSS URL
|
||
# 标准格式: https://{bucket_name}.{endpoint_domain}/{object_name}
|
||
# 使用外网端点确保 URL 可公网访问
|
||
# 移除 endpoint 中的协议前缀
|
||
endpoint_domain = self.external_endpoint.replace('https://', '').replace('http://', '').rstrip('/')
|
||
|
||
# 构建完整的 URL
|
||
base_url = f"https://{self.bucket_name}.{endpoint_domain}"
|
||
|
||
return f"{base_url}/{oss_object_name}"
|
||
|
||
def get_signed_url(self, oss_object_name: str, expires: int = 3600) -> Optional[str]:
|
||
"""
|
||
生成带签名的临时访问 URL(用于私有 Bucket)
|
||
使用外网端点确保 URL 可公网访问
|
||
|
||
Args:
|
||
oss_object_name: OSS 对象名称
|
||
expires: 签名有效期(秒),默认 3600 秒(1小时)
|
||
|
||
Returns:
|
||
Optional[str]: 带签名的 URL,失败返回 None
|
||
"""
|
||
if not self.enabled:
|
||
logger.warning("OSS 未启用,无法生成签名 URL")
|
||
return None
|
||
|
||
try:
|
||
# 创建使用外网端点的临时 bucket 对象
|
||
# 这样生成的签名 URL 使用外网端点,确保公网可访问
|
||
auth = oss2.Auth(self.access_key_id, self.access_key_secret)
|
||
external_bucket = oss2.Bucket(
|
||
auth,
|
||
self.external_endpoint,
|
||
self.bucket_name,
|
||
connect_timeout=10
|
||
)
|
||
|
||
# 使用外网端点的 bucket 生成签名 URL
|
||
signed_url = external_bucket.sign_url('GET', oss_object_name, expires)
|
||
logger.debug(f"生成签名 URL 成功: {oss_object_name},有效期: {expires}秒")
|
||
return signed_url
|
||
except Exception as e:
|
||
logger.error(f"生成签名 URL 失败: {e}")
|
||
return None
|
||
|
||
def extract_object_name_from_url(self, url: str, kb_id: int = None, thread_id: str = None) -> Optional[str]:
|
||
"""
|
||
从 OSS URL 中提取对象名称
|
||
|
||
Args:
|
||
url: OSS URL
|
||
kb_id: 知识库 ID(可选,用于知识库文件)
|
||
thread_id: 会话线程 ID(可选,用于聊天文件)
|
||
|
||
Returns:
|
||
Optional[str]: 对象名称,如果无法提取则返回 None
|
||
"""
|
||
if not self.enabled:
|
||
return None
|
||
|
||
try:
|
||
from urllib.parse import urlparse
|
||
parsed = urlparse(url)
|
||
path_parts = parsed.path.strip('/').split('/')
|
||
|
||
# 优先使用提供的 ID 进行精确匹配
|
||
if kb_id:
|
||
kb_prefix = f"kb_{kb_id}/"
|
||
if kb_prefix in url:
|
||
idx = url.find(kb_prefix)
|
||
if idx != -1:
|
||
object_name = url[idx:]
|
||
return object_name
|
||
|
||
if thread_id:
|
||
thread_prefix = f"thread_{thread_id}/"
|
||
if thread_prefix in url:
|
||
idx = url.find(thread_prefix)
|
||
if idx != -1:
|
||
object_name = url[idx:]
|
||
return object_name
|
||
|
||
# 如果上述方法失败,尝试从 URL 路径中提取
|
||
# 查找 kb_ 或 thread_ 开头的部分
|
||
for i, part in enumerate(path_parts):
|
||
if part.startswith('kb_') or part.startswith('thread_'):
|
||
# 提取从该部分开始的所有部分
|
||
object_name = '/'.join(path_parts[i:])
|
||
return object_name
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"从 URL 提取对象名称失败: {e}")
|
||
return None
|
||
|
||
def _multipart_upload(
|
||
self,
|
||
local_file_path: str,
|
||
oss_object_name: str,
|
||
file_size: int
|
||
) -> bool:
|
||
"""
|
||
分片上传大文件
|
||
|
||
Args:
|
||
local_file_path: 本地文件路径
|
||
oss_object_name: OSS 对象名称
|
||
file_size: 文件大小
|
||
|
||
Returns:
|
||
bool: 是否上传成功
|
||
"""
|
||
try:
|
||
# 确定分片大小
|
||
part_size = determine_part_size(file_size, preferred_size=100 * 1024)
|
||
|
||
# 初始化分片上传
|
||
upload_id = self.bucket.init_multipart_upload(oss_object_name).upload_id
|
||
parts = []
|
||
|
||
# 计算分片数量
|
||
num_parts = (file_size + part_size - 1) // part_size
|
||
|
||
logger.info(f"开始分片上传,共 {num_parts} 个分片...")
|
||
|
||
# 上传分片
|
||
with open(local_file_path, 'rb') as f:
|
||
for i in range(num_parts):
|
||
# 计算分片范围
|
||
start = i * part_size
|
||
end = min(start + part_size, file_size)
|
||
|
||
# 读取分片数据
|
||
f.seek(start)
|
||
data = f.read(end - start)
|
||
|
||
# 上传分片
|
||
result = self.bucket.upload_part(
|
||
oss_object_name,
|
||
upload_id,
|
||
i + 1,
|
||
data
|
||
)
|
||
|
||
parts.append(PartInfo(i + 1, result.etag))
|
||
|
||
# 显示进度
|
||
progress = (i + 1) / num_parts * 100
|
||
logger.debug(f"上传进度: {progress:.1f}% ({i+1}/{num_parts})")
|
||
|
||
# 完成分片上传
|
||
self.bucket.complete_multipart_upload(oss_object_name, upload_id, parts)
|
||
logger.info(f"分片上传完成: {oss_object_name}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"分片上传失败: {e}")
|
||
return False
|
||
|
||
def file_exists(self, oss_object_name: str) -> bool:
|
||
"""
|
||
检查文件是否存在
|
||
|
||
Args:
|
||
oss_object_name: OSS 对象名称
|
||
|
||
Returns:
|
||
bool: 文件是否存在
|
||
"""
|
||
if not self.enabled:
|
||
return False
|
||
|
||
try:
|
||
return self.bucket.object_exists(oss_object_name)
|
||
except Exception as e:
|
||
logger.error(f"检查文件是否存在失败: {e}")
|
||
return False
|
||
|
||
|
||
# 全局 OSS 服务实例
|
||
_oss_service: Optional[OSSService] = None
|
||
|
||
|
||
def get_oss_service() -> OSSService:
|
||
"""获取全局 OSS 服务实例"""
|
||
global _oss_service
|
||
if _oss_service is None:
|
||
_oss_service = OSSService()
|
||
return _oss_service
|
||
|