huoyan-enterprise/backend/services/oss_service.py

486 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OSS 文件存储服务
"""
import os
import time
import tempfile
from typing import Optional
from pathlib import Path
import oss2
from oss2 import SizedFileAdapter, determine_part_size
from oss2.models import PartInfo
from core.config import settings
from logger.logging import get_logger
logger = get_logger(__name__)
class OSSService:
"""OSS 文件存储服务类"""
def __init__(self):
"""初始化 OSS 客户端"""
# 从配置读取
self.access_key_id = settings.oss_access_key_id
self.access_key_secret = settings.oss_access_key_secret
self.endpoint = settings.oss_endpoint
self.bucket_name = settings.oss_bucket_name
# 检查配置是否完整
if not all([self.access_key_id, self.access_key_secret, self.endpoint, self.bucket_name]):
logger.warning("OSS 配置不完整,将使用本地存储")
self.enabled = False
self.external_endpoint = ""
return
# 初始化外网端点(用于 URL 生成)
self.external_endpoint = self._get_external_endpoint(self.endpoint)
if self.endpoint != self.external_endpoint:
logger.info(f"端点转换: {self.endpoint} -> {self.external_endpoint}")
try:
# 初始化 OSS 客户端
auth = oss2.Auth(self.access_key_id, self.access_key_secret)
# 配置超时时间
# 注意: oss2.Bucket 只支持 connect_timeout 参数,不支持 timeout 参数
# 如需配置读取超时,需要通过 session 参数传递自定义的 requests.Session 对象
self.bucket = oss2.Bucket(
auth,
self.endpoint,
self.bucket_name,
connect_timeout=10 # 连接超时(秒)
)
self.enabled = True
logger.info(f"OSS 服务初始化成功Bucket: {self.bucket_name}")
logger.info(f" 上传端点: {self.endpoint}")
logger.info(f" 访问端点: {self.external_endpoint}")
# 检查是否使用内网 endpoint
if "internal" not in self.endpoint:
logger.warning(
"未使用内网 Endpoint。如果服务器在阿里云 ECS 上,"
f"建议使用内网 endpoint: {self.endpoint.replace('aliyuncs.com', 'internal.aliyuncs.com')}"
)
except Exception as e:
logger.error(f"OSS 服务初始化失败: {e}")
self.enabled = False
def _get_external_endpoint(self, endpoint: str) -> str:
"""
将内网端点转换为外网端点
转换规则:
- oss-cn-hangzhou-internal.aliyuncs.com → oss-cn-hangzhou.aliyuncs.com
- https://oss-cn-hangzhou-internal.aliyuncs.com → https://oss-cn-hangzhou.aliyuncs.com
- 如果不包含 "-internal",返回原端点
Args:
endpoint: 原始端点 URL
Returns:
str: 外网端点 URL
"""
# 处理空值和异常情况
if not endpoint:
logger.warning("端点为空,返回空字符串")
return ""
try:
# 移除 "-internal" 字符串(包括前面的连字符)
external_endpoint = endpoint.replace("-internal", "")
return external_endpoint
except Exception as e:
logger.error(f"端点转换失败: {e},使用原端点")
return endpoint
def upload_file(
self,
local_file_path: str,
oss_object_name: str,
use_multipart: bool = True
) -> Optional[str]:
"""
上传文件到 OSS
Args:
local_file_path: 本地文件路径
oss_object_name: OSS 对象名称(存储路径)
use_multipart: 是否使用分片上传(大文件)
Returns:
Optional[str]: OSS 文件 URL失败返回 None
"""
if not self.enabled:
logger.warning("OSS 未启用,跳过上传")
return None
if not os.path.exists(local_file_path):
logger.error(f"文件不存在: {local_file_path}")
return None
try:
file_size = os.path.getsize(local_file_path)
# 大于 100MB 使用分片上传
if use_multipart and file_size > 100 * 1024 * 1024:
logger.info(f"文件较大 ({file_size} 字节),使用分片上传")
success = self._multipart_upload(local_file_path, oss_object_name, file_size)
else:
logger.info(f"使用简单上传: {oss_object_name}")
result = self.bucket.put_object_from_file(oss_object_name, local_file_path)
success = result.status == 200
if success:
# 生成文件 URL
file_url = self.get_file_url(oss_object_name)
logger.info(f"文件上传成功: {local_file_path} -> {oss_object_name}")
logger.info(f"OSS URL: {file_url}")
return file_url
else:
logger.error(f"文件上传失败: {oss_object_name}")
return None
except Exception as e:
logger.error(f"上传文件到 OSS 失败: {e}")
return None
def upload_file_from_bytes(
self,
file_content: bytes,
oss_object_name: str,
file_name: str = None
) -> Optional[str]:
"""
从字节流上传文件到 OSS
Args:
file_content: 文件内容(字节)
oss_object_name: OSS 对象名称(存储路径)
file_name: 文件名(用于日志)
Returns:
Optional[str]: OSS 文件 URL失败返回 None
"""
if not self.enabled:
logger.warning("OSS 未启用,跳过上传")
return None
try:
file_size = len(file_content)
start_time = time.time()
# 大于 1MB 使用分片上传以提高性能
if file_size > 1 * 1024 * 1024:
logger.info(f"文件大小 {file_size/1024/1024:.2f}MB使用分片上传")
# 写入临时文件用于分片上传
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(file_content)
tmp_path = tmp_file.name
try:
success = self._multipart_upload(tmp_path, oss_object_name, file_size)
if not success:
logger.error(f"分片上传失败: {oss_object_name}")
return None
finally:
# 清理临时文件
if os.path.exists(tmp_path):
os.remove(tmp_path)
else:
# 小文件使用简单上传
result = self.bucket.put_object(oss_object_name, file_content)
if result.status != 200:
logger.error(f"文件上传失败: {oss_object_name}, 状态码: {result.status}")
return None
# 计算上传速度
elapsed = time.time() - start_time
speed_mbps = (file_size / 1024 / 1024) / elapsed if elapsed > 0 else 0
file_url = self.get_file_url(oss_object_name)
logger.info(
f"文件上传成功: {file_name or oss_object_name} -> {oss_object_name}, "
f"大小: {file_size/1024/1024:.2f}MB, 耗时: {elapsed:.2f}s, 速度: {speed_mbps:.2f}MB/s"
)
# 如果速度过慢,记录警告
if speed_mbps < 0.5 and file_size > 1024 * 1024:
logger.warning(
f"上传速度较慢 ({speed_mbps:.2f}MB/s),建议检查: "
"1) 是否使用内网 endpoint 2) 服务器与 OSS 是否在同一区域"
)
return file_url
except Exception as e:
logger.error(f"上传文件到 OSS 失败: {e}")
return None
def download_file(
self,
oss_object_name: str,
local_file_path: str = None
) -> Optional[str]:
"""
从 OSS 下载文件到本地
Args:
oss_object_name: OSS 对象名称
local_file_path: 本地保存路径,如果为 None 则使用临时文件
Returns:
Optional[str]: 本地文件路径,失败返回 None
"""
if not self.enabled:
logger.warning("OSS 未启用,无法下载")
return None
try:
# 如果没有指定本地路径,使用临时文件
if local_file_path is None:
temp_dir = tempfile.gettempdir()
file_name = Path(oss_object_name).name
local_file_path = os.path.join(temp_dir, f"oss_download_{file_name}")
# 确保目录存在
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# 下载文件
self.bucket.get_object_to_file(oss_object_name, local_file_path)
logger.info(f"文件下载成功: {oss_object_name} -> {local_file_path}")
return local_file_path
except Exception as e:
logger.error(f"从 OSS 下载文件失败: {e}")
return None
def delete_file(self, oss_object_name: str) -> bool:
"""
删除 OSS 上的文件
Args:
oss_object_name: OSS 对象名称
Returns:
bool: 是否删除成功
"""
if not self.enabled:
logger.warning("OSS 未启用,跳过删除")
return False
try:
self.bucket.delete_object(oss_object_name)
logger.info(f"OSS 文件删除成功: {oss_object_name}")
return True
except Exception as e:
logger.error(f"删除 OSS 文件失败: {e}")
return False
def get_file_url(self, oss_object_name: str) -> str:
"""
获取文件的访问 URL使用外网端点
Args:
oss_object_name: OSS 对象名称
Returns:
str: 文件 URL使用外网端点确保公网可访问
"""
if not self.enabled:
return ""
# 构建 OSS URL
# 标准格式: https://{bucket_name}.{endpoint_domain}/{object_name}
# 使用外网端点确保 URL 可公网访问
# 移除 endpoint 中的协议前缀
endpoint_domain = self.external_endpoint.replace('https://', '').replace('http://', '').rstrip('/')
# 构建完整的 URL
base_url = f"https://{self.bucket_name}.{endpoint_domain}"
return f"{base_url}/{oss_object_name}"
def get_signed_url(self, oss_object_name: str, expires: int = 3600) -> Optional[str]:
"""
生成带签名的临时访问 URL用于私有 Bucket
使用外网端点确保 URL 可公网访问
Args:
oss_object_name: OSS 对象名称
expires: 签名有效期(秒),默认 3600 秒1小时
Returns:
Optional[str]: 带签名的 URL失败返回 None
"""
if not self.enabled:
logger.warning("OSS 未启用,无法生成签名 URL")
return None
try:
# 创建使用外网端点的临时 bucket 对象
# 这样生成的签名 URL 使用外网端点,确保公网可访问
auth = oss2.Auth(self.access_key_id, self.access_key_secret)
external_bucket = oss2.Bucket(
auth,
self.external_endpoint,
self.bucket_name,
connect_timeout=10
)
# 使用外网端点的 bucket 生成签名 URL
signed_url = external_bucket.sign_url('GET', oss_object_name, expires)
logger.debug(f"生成签名 URL 成功: {oss_object_name},有效期: {expires}")
return signed_url
except Exception as e:
logger.error(f"生成签名 URL 失败: {e}")
return None
def extract_object_name_from_url(self, url: str, kb_id: int = None, thread_id: str = None) -> Optional[str]:
"""
从 OSS URL 中提取对象名称
Args:
url: OSS URL
kb_id: 知识库 ID可选用于知识库文件
thread_id: 会话线程 ID可选用于聊天文件
Returns:
Optional[str]: 对象名称,如果无法提取则返回 None
"""
if not self.enabled:
return None
try:
from urllib.parse import urlparse
parsed = urlparse(url)
path_parts = parsed.path.strip('/').split('/')
# 优先使用提供的 ID 进行精确匹配
if kb_id:
kb_prefix = f"kb_{kb_id}/"
if kb_prefix in url:
idx = url.find(kb_prefix)
if idx != -1:
object_name = url[idx:]
return object_name
if thread_id:
thread_prefix = f"thread_{thread_id}/"
if thread_prefix in url:
idx = url.find(thread_prefix)
if idx != -1:
object_name = url[idx:]
return object_name
# 如果上述方法失败,尝试从 URL 路径中提取
# 查找 kb_ 或 thread_ 开头的部分
for i, part in enumerate(path_parts):
if part.startswith('kb_') or part.startswith('thread_'):
# 提取从该部分开始的所有部分
object_name = '/'.join(path_parts[i:])
return object_name
return None
except Exception as e:
logger.error(f"从 URL 提取对象名称失败: {e}")
return None
def _multipart_upload(
self,
local_file_path: str,
oss_object_name: str,
file_size: int
) -> bool:
"""
分片上传大文件
Args:
local_file_path: 本地文件路径
oss_object_name: OSS 对象名称
file_size: 文件大小
Returns:
bool: 是否上传成功
"""
try:
# 确定分片大小
part_size = determine_part_size(file_size, preferred_size=100 * 1024)
# 初始化分片上传
upload_id = self.bucket.init_multipart_upload(oss_object_name).upload_id
parts = []
# 计算分片数量
num_parts = (file_size + part_size - 1) // part_size
logger.info(f"开始分片上传,共 {num_parts} 个分片...")
# 上传分片
with open(local_file_path, 'rb') as f:
for i in range(num_parts):
# 计算分片范围
start = i * part_size
end = min(start + part_size, file_size)
# 读取分片数据
f.seek(start)
data = f.read(end - start)
# 上传分片
result = self.bucket.upload_part(
oss_object_name,
upload_id,
i + 1,
data
)
parts.append(PartInfo(i + 1, result.etag))
# 显示进度
progress = (i + 1) / num_parts * 100
logger.debug(f"上传进度: {progress:.1f}% ({i+1}/{num_parts})")
# 完成分片上传
self.bucket.complete_multipart_upload(oss_object_name, upload_id, parts)
logger.info(f"分片上传完成: {oss_object_name}")
return True
except Exception as e:
logger.error(f"分片上传失败: {e}")
return False
def file_exists(self, oss_object_name: str) -> bool:
"""
检查文件是否存在
Args:
oss_object_name: OSS 对象名称
Returns:
bool: 文件是否存在
"""
if not self.enabled:
return False
try:
return self.bucket.object_exists(oss_object_name)
except Exception as e:
logger.error(f"检查文件是否存在失败: {e}")
return False
# 全局 OSS 服务实例
_oss_service: Optional[OSSService] = None
def get_oss_service() -> OSSService:
"""获取全局 OSS 服务实例"""
global _oss_service
if _oss_service is None:
_oss_service = OSSService()
return _oss_service