""" OSS 文件存储服务 """ import os import time import tempfile from typing import Optional from pathlib import Path import oss2 from oss2 import SizedFileAdapter, determine_part_size from oss2.models import PartInfo from core.config import settings from logger.logging import get_logger logger = get_logger(__name__) class OSSService: """OSS 文件存储服务类""" def __init__(self): """初始化 OSS 客户端""" # 从配置读取 self.access_key_id = settings.oss_access_key_id self.access_key_secret = settings.oss_access_key_secret self.endpoint = settings.oss_endpoint self.bucket_name = settings.oss_bucket_name # 检查配置是否完整 if not all([self.access_key_id, self.access_key_secret, self.endpoint, self.bucket_name]): logger.warning("OSS 配置不完整,将使用本地存储") self.enabled = False self.external_endpoint = "" return # 初始化外网端点(用于 URL 生成) self.external_endpoint = self._get_external_endpoint(self.endpoint) if self.endpoint != self.external_endpoint: logger.info(f"端点转换: {self.endpoint} -> {self.external_endpoint}") try: # 初始化 OSS 客户端 auth = oss2.Auth(self.access_key_id, self.access_key_secret) # 配置超时时间 # 注意: oss2.Bucket 只支持 connect_timeout 参数,不支持 timeout 参数 # 如需配置读取超时,需要通过 session 参数传递自定义的 requests.Session 对象 self.bucket = oss2.Bucket( auth, self.endpoint, self.bucket_name, connect_timeout=10 # 连接超时(秒) ) self.enabled = True logger.info(f"OSS 服务初始化成功,Bucket: {self.bucket_name}") logger.info(f" 上传端点: {self.endpoint}") logger.info(f" 访问端点: {self.external_endpoint}") # 检查是否使用内网 endpoint if "internal" not in self.endpoint: logger.warning( "未使用内网 Endpoint。如果服务器在阿里云 ECS 上," f"建议使用内网 endpoint: {self.endpoint.replace('aliyuncs.com', 'internal.aliyuncs.com')}" ) except Exception as e: logger.error(f"OSS 服务初始化失败: {e}") self.enabled = False def _get_external_endpoint(self, endpoint: str) -> str: """ 将内网端点转换为外网端点 转换规则: - oss-cn-hangzhou-internal.aliyuncs.com → oss-cn-hangzhou.aliyuncs.com - https://oss-cn-hangzhou-internal.aliyuncs.com → https://oss-cn-hangzhou.aliyuncs.com - 如果不包含 "-internal",返回原端点 Args: endpoint: 原始端点 URL Returns: str: 外网端点 URL """ # 处理空值和异常情况 if not endpoint: logger.warning("端点为空,返回空字符串") return "" try: # 移除 "-internal" 字符串(包括前面的连字符) external_endpoint = endpoint.replace("-internal", "") return external_endpoint except Exception as e: logger.error(f"端点转换失败: {e},使用原端点") return endpoint def upload_file( self, local_file_path: str, oss_object_name: str, use_multipart: bool = True ) -> Optional[str]: """ 上传文件到 OSS Args: local_file_path: 本地文件路径 oss_object_name: OSS 对象名称(存储路径) use_multipart: 是否使用分片上传(大文件) Returns: Optional[str]: OSS 文件 URL,失败返回 None """ if not self.enabled: logger.warning("OSS 未启用,跳过上传") return None if not os.path.exists(local_file_path): logger.error(f"文件不存在: {local_file_path}") return None try: file_size = os.path.getsize(local_file_path) # 大于 100MB 使用分片上传 if use_multipart and file_size > 100 * 1024 * 1024: logger.info(f"文件较大 ({file_size} 字节),使用分片上传") success = self._multipart_upload(local_file_path, oss_object_name, file_size) else: logger.info(f"使用简单上传: {oss_object_name}") result = self.bucket.put_object_from_file(oss_object_name, local_file_path) success = result.status == 200 if success: # 生成文件 URL file_url = self.get_file_url(oss_object_name) logger.info(f"文件上传成功: {local_file_path} -> {oss_object_name}") logger.info(f"OSS URL: {file_url}") return file_url else: logger.error(f"文件上传失败: {oss_object_name}") return None except Exception as e: logger.error(f"上传文件到 OSS 失败: {e}") return None def upload_file_from_bytes( self, file_content: bytes, oss_object_name: str, file_name: str = None ) -> Optional[str]: """ 从字节流上传文件到 OSS Args: file_content: 文件内容(字节) oss_object_name: OSS 对象名称(存储路径) file_name: 文件名(用于日志) Returns: Optional[str]: OSS 文件 URL,失败返回 None """ if not self.enabled: logger.warning("OSS 未启用,跳过上传") return None try: file_size = len(file_content) start_time = time.time() # 大于 1MB 使用分片上传以提高性能 if file_size > 1 * 1024 * 1024: logger.info(f"文件大小 {file_size/1024/1024:.2f}MB,使用分片上传") # 写入临时文件用于分片上传 with tempfile.NamedTemporaryFile(delete=False) as tmp_file: tmp_file.write(file_content) tmp_path = tmp_file.name try: success = self._multipart_upload(tmp_path, oss_object_name, file_size) if not success: logger.error(f"分片上传失败: {oss_object_name}") return None finally: # 清理临时文件 if os.path.exists(tmp_path): os.remove(tmp_path) else: # 小文件使用简单上传 result = self.bucket.put_object(oss_object_name, file_content) if result.status != 200: logger.error(f"文件上传失败: {oss_object_name}, 状态码: {result.status}") return None # 计算上传速度 elapsed = time.time() - start_time speed_mbps = (file_size / 1024 / 1024) / elapsed if elapsed > 0 else 0 file_url = self.get_file_url(oss_object_name) logger.info( f"文件上传成功: {file_name or oss_object_name} -> {oss_object_name}, " f"大小: {file_size/1024/1024:.2f}MB, 耗时: {elapsed:.2f}s, 速度: {speed_mbps:.2f}MB/s" ) # 如果速度过慢,记录警告 if speed_mbps < 0.5 and file_size > 1024 * 1024: logger.warning( f"上传速度较慢 ({speed_mbps:.2f}MB/s),建议检查: " "1) 是否使用内网 endpoint 2) 服务器与 OSS 是否在同一区域" ) return file_url except Exception as e: logger.error(f"上传文件到 OSS 失败: {e}") return None def download_file( self, oss_object_name: str, local_file_path: str = None ) -> Optional[str]: """ 从 OSS 下载文件到本地 Args: oss_object_name: OSS 对象名称 local_file_path: 本地保存路径,如果为 None 则使用临时文件 Returns: Optional[str]: 本地文件路径,失败返回 None """ if not self.enabled: logger.warning("OSS 未启用,无法下载") return None try: # 如果没有指定本地路径,使用临时文件 if local_file_path is None: temp_dir = tempfile.gettempdir() file_name = Path(oss_object_name).name local_file_path = os.path.join(temp_dir, f"oss_download_{file_name}") # 确保目录存在 os.makedirs(os.path.dirname(local_file_path), exist_ok=True) # 下载文件 self.bucket.get_object_to_file(oss_object_name, local_file_path) logger.info(f"文件下载成功: {oss_object_name} -> {local_file_path}") return local_file_path except Exception as e: logger.error(f"从 OSS 下载文件失败: {e}") return None def delete_file(self, oss_object_name: str) -> bool: """ 删除 OSS 上的文件 Args: oss_object_name: OSS 对象名称 Returns: bool: 是否删除成功 """ if not self.enabled: logger.warning("OSS 未启用,跳过删除") return False try: self.bucket.delete_object(oss_object_name) logger.info(f"OSS 文件删除成功: {oss_object_name}") return True except Exception as e: logger.error(f"删除 OSS 文件失败: {e}") return False def get_file_url(self, oss_object_name: str) -> str: """ 获取文件的访问 URL(使用外网端点) Args: oss_object_name: OSS 对象名称 Returns: str: 文件 URL(使用外网端点,确保公网可访问) """ if not self.enabled: return "" # 构建 OSS URL # 标准格式: https://{bucket_name}.{endpoint_domain}/{object_name} # 使用外网端点确保 URL 可公网访问 # 移除 endpoint 中的协议前缀 endpoint_domain = self.external_endpoint.replace('https://', '').replace('http://', '').rstrip('/') # 构建完整的 URL base_url = f"https://{self.bucket_name}.{endpoint_domain}" return f"{base_url}/{oss_object_name}" def get_signed_url(self, oss_object_name: str, expires: int = 3600) -> Optional[str]: """ 生成带签名的临时访问 URL(用于私有 Bucket) 使用外网端点确保 URL 可公网访问 Args: oss_object_name: OSS 对象名称 expires: 签名有效期(秒),默认 3600 秒(1小时) Returns: Optional[str]: 带签名的 URL,失败返回 None """ if not self.enabled: logger.warning("OSS 未启用,无法生成签名 URL") return None try: # 创建使用外网端点的临时 bucket 对象 # 这样生成的签名 URL 使用外网端点,确保公网可访问 auth = oss2.Auth(self.access_key_id, self.access_key_secret) external_bucket = oss2.Bucket( auth, self.external_endpoint, self.bucket_name, connect_timeout=10 ) # 使用外网端点的 bucket 生成签名 URL signed_url = external_bucket.sign_url('GET', oss_object_name, expires) logger.debug(f"生成签名 URL 成功: {oss_object_name},有效期: {expires}秒") return signed_url except Exception as e: logger.error(f"生成签名 URL 失败: {e}") return None def extract_object_name_from_url(self, url: str, kb_id: int = None, thread_id: str = None) -> Optional[str]: """ 从 OSS URL 中提取对象名称 Args: url: OSS URL kb_id: 知识库 ID(可选,用于知识库文件) thread_id: 会话线程 ID(可选,用于聊天文件) Returns: Optional[str]: 对象名称,如果无法提取则返回 None """ if not self.enabled: return None try: from urllib.parse import urlparse parsed = urlparse(url) path_parts = parsed.path.strip('/').split('/') # 优先使用提供的 ID 进行精确匹配 if kb_id: kb_prefix = f"kb_{kb_id}/" if kb_prefix in url: idx = url.find(kb_prefix) if idx != -1: object_name = url[idx:] return object_name if thread_id: thread_prefix = f"thread_{thread_id}/" if thread_prefix in url: idx = url.find(thread_prefix) if idx != -1: object_name = url[idx:] return object_name # 如果上述方法失败,尝试从 URL 路径中提取 # 查找 kb_ 或 thread_ 开头的部分 for i, part in enumerate(path_parts): if part.startswith('kb_') or part.startswith('thread_'): # 提取从该部分开始的所有部分 object_name = '/'.join(path_parts[i:]) return object_name return None except Exception as e: logger.error(f"从 URL 提取对象名称失败: {e}") return None def _multipart_upload( self, local_file_path: str, oss_object_name: str, file_size: int ) -> bool: """ 分片上传大文件 Args: local_file_path: 本地文件路径 oss_object_name: OSS 对象名称 file_size: 文件大小 Returns: bool: 是否上传成功 """ try: # 确定分片大小 part_size = determine_part_size(file_size, preferred_size=100 * 1024) # 初始化分片上传 upload_id = self.bucket.init_multipart_upload(oss_object_name).upload_id parts = [] # 计算分片数量 num_parts = (file_size + part_size - 1) // part_size logger.info(f"开始分片上传,共 {num_parts} 个分片...") # 上传分片 with open(local_file_path, 'rb') as f: for i in range(num_parts): # 计算分片范围 start = i * part_size end = min(start + part_size, file_size) # 读取分片数据 f.seek(start) data = f.read(end - start) # 上传分片 result = self.bucket.upload_part( oss_object_name, upload_id, i + 1, data ) parts.append(PartInfo(i + 1, result.etag)) # 显示进度 progress = (i + 1) / num_parts * 100 logger.debug(f"上传进度: {progress:.1f}% ({i+1}/{num_parts})") # 完成分片上传 self.bucket.complete_multipart_upload(oss_object_name, upload_id, parts) logger.info(f"分片上传完成: {oss_object_name}") return True except Exception as e: logger.error(f"分片上传失败: {e}") return False def file_exists(self, oss_object_name: str) -> bool: """ 检查文件是否存在 Args: oss_object_name: OSS 对象名称 Returns: bool: 文件是否存在 """ if not self.enabled: return False try: return self.bucket.object_exists(oss_object_name) except Exception as e: logger.error(f"检查文件是否存在失败: {e}") return False # 全局 OSS 服务实例 _oss_service: Optional[OSSService] = None def get_oss_service() -> OSSService: """获取全局 OSS 服务实例""" global _oss_service if _oss_service is None: _oss_service = OSSService() return _oss_service