""" 向量化处理服务 """ import os import io import tempfile import base64 import time import asyncio from typing import List, Tuple, Optional, Dict from pathlib import Path from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor, as_completed from core.config import settings from langchain_community.document_loaders import ( PyPDFLoader, WebBaseLoader, UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredPowerPointLoader, TextLoader, CSVLoader, JSONLoader, UnstructuredHTMLLoader, UnstructuredMarkdownLoader, PythonLoader, UnstructuredXMLLoader, ) # 尝试导入阿里云 OCR SDK try: from alibabacloud_ocr_api20210707.client import Client as OcrClient from alibabacloud_tea_openapi import models as open_api_models from alibabacloud_ocr_api20210707 import models as ocr_models from alibabacloud_darabonba_stream.client import Client as StreamClient from alibabacloud_tea_util import models as util_models ALIYUN_OCR_AVAILABLE = True except ImportError: ALIYUN_OCR_AVAILABLE = False OcrClient = None StreamClient = None # 尝试导入 python-docx 用于提取 DOCX 中的图片 try: from docx import Document as DocxDocument from docx.oxml.text.paragraph import CT_P from docx.oxml.table import CT_Tbl from docx.table import _Cell, Table from docx.text.paragraph import Paragraph PYTHON_DOCX_AVAILABLE = True except ImportError: PYTHON_DOCX_AVAILABLE = False DocxDocument = None # 尝试导入 Pillow 用于图片处理 try: from PIL import Image PILLOW_AVAILABLE = True except ImportError: PILLOW_AVAILABLE = False Image = None # 尝试导入 PyMuPDF 用于 PDF 处理 try: import fitz # PyMuPDF PYMUPDF_AVAILABLE = True except ImportError: PYMUPDF_AVAILABLE = False fitz = None from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_ollama import OllamaEmbeddings from langchain_openai import OpenAIEmbeddings from langchain_chroma import Chroma import bs4 from logger.logging import get_logger from core.config import settings logger = get_logger(__name__) @dataclass class ProcessResult: """文档处理结果""" success: bool chunks: List[Tuple[int, str, dict, str]] chunk_count: int error_message: Optional[str] = None extracted_image_paths: Optional[List[str]] = None # DOCX 中提取的图片路径(供视觉模型使用) class VectorService: """向量化处理服务类""" # 文件类型与加载器的映射表 LOADER_MAPPING: Dict[str, tuple] = { # 文档类型(DOCX 使用增强处理,提取图片并 OCR) ".pdf": ("PyPDFLoader", None), ".docx": ("docx_with_images", "UnstructuredWordDocumentLoader"), ".ppt": ("UnstructuredPowerPointLoader", None), ".pptx": ("UnstructuredPowerPointLoader", None), # 图片文件(使用阿里云 OCR) ".png": ("image_ocr", None), ".jpg": ("image_ocr", None), ".jpeg": ("image_ocr", None), ".bmp": ("image_ocr", None), # 其他文档类型 ".txt": ("TextLoader", None), ".md": ("UnstructuredMarkdownLoader", "TextLoader"), ".xlsx": ("UnstructuredExcelLoader", None), ".xls": ("UnstructuredExcelLoader", None), ".csv": ("CSVLoader", None), ".json": ("JSONLoader", None), ".html": ("UnstructuredHTMLLoader", None), ".htm": ("UnstructuredHTMLLoader", None), ".xml": ("UnstructuredXMLLoader", None), ".py": ("PythonLoader", None), } def __init__(self): """初始化向量服务""" # 初始化嵌入模型 # self.embedding = OllamaEmbeddings(model="nomic-embed-text") # DashScope 兼容网关只接受字符串 input;默认 check_embedding_ctx_length=True # 会用 tiktoken 转成 token id 列表再请求,导致 400:contents is neither str nor list of str print(settings.dashscope_api_key, settings.dashscope_api_base) self.embedding = OpenAIEmbeddings( model="text-embedding-v4", api_key=os.getenv("ZL_DASHSCOPE_API_KEY"), # 如果您没有配置环境变量,请在此处用您的API Key进行替换 base_url=os.getenv("ZL_DASHSCOPE_API_BASE"), check_embedding_ctx_length=False, ) # 文本分割器配置(参考 server:增大 chunk_size 保留更多上下文) self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=4096, # 从 1000 增加到 4096,保留更多完整内容 chunk_overlap=200, # 保持适度重叠,确保语义连续性 add_start_index=True ) # 向量库存储路径 self.vector_store_path = settings.chroma_persist_directory or "./chroma_db" # 初始化阿里云 OCR(图片、扫描 PDF、DOCX 内嵌图均依赖云端识别) self.ocr_engine = None if ALIYUN_OCR_AVAILABLE and settings.ocr_access_key_id and settings.ocr_access_key_secret: try: config = open_api_models.Config( access_key_id=settings.ocr_access_key_id, access_key_secret=settings.ocr_access_key_secret, endpoint=settings.ocr_endpoint ) self.ocr_engine = OcrClient(config) logger.info("✅ 阿里云 OCR 已启用,将使用云端 OCR 服务识别图片文字") except Exception as e: logger.warning(f"⚠️ 阿里云 OCR 初始化失败: {e}") elif not ALIYUN_OCR_AVAILABLE: logger.warning("⚠️ 阿里云 OCR SDK 未安装,图片与扫描件 OCR 不可用") else: logger.info("ℹ️ 未配置阿里云 OCR(需要 OCR_ACCESS_KEY_ID 和 OCR_ACCESS_KEY_SECRET),图片 OCR 将不可用") if not self.ocr_engine: logger.warning("⚠️ OCR 服务不可用,图片与扫描件内容将无法通过 OCR 提取。请配置阿里云 OCR") def _ocr_image(self, image_path: str) -> str: """ 使用阿里云 OCR 识别单张图片中的文字。 Args: image_path: 图片文件路径 Returns: 识别到的文字内容 """ if not self.ocr_engine: return "" try: return self._ocr_image_aliyun(image_path) except Exception as e: logger.warning(f"图片 OCR 处理失败: {e}") return "" def _ocr_image_aliyun(self, image_path: str) -> str: """ 使用阿里云 OCR 识别图片中的文字 Args: image_path: 图片文件路径 Returns: 识别到的文字内容 """ try: logger.debug(f"🔍 [阿里云OCR] 开始识别图片: {os.path.basename(image_path)}") # 读取图片文件为字节流 with open(image_path, 'rb') as f: image_bytes = f.read() image_size_kb = len(image_bytes) / 1024 logger.debug(f"📊 [阿里云OCR] 图片大小: {image_size_kb:.2f}KB") # 使用 StreamClient 读取字节流(阿里云 SDK 要求的格式) body_stream = StreamClient.read_from_bytes(image_bytes) logger.debug(f"📦 [阿里云OCR] 字节流已创建") # 构建请求 request = ocr_models.RecognizeGeneralRequest(body=body_stream) # 运行时选项 runtime = util_models.RuntimeOptions() logger.debug(f"☁️ [阿里云OCR] 调用 API: recognize_general_with_options") # 调用阿里云 OCR API(使用 with_options 版本) response = self.ocr_engine.recognize_general_with_options(request, runtime) logger.debug(f"✅ [阿里云OCR] API 调用成功") # 解析结果 if not response or not response.body or not response.body.data: logger.debug(f"⚠️ [阿里云OCR] 未返回数据: {os.path.basename(image_path)}") return "" # 解析 JSON 数据 import json logger.debug(f"📝 [阿里云OCR] 开始解析返回数据") data = json.loads(response.body.data) # 提取 content 字段 if not data or 'content' not in data: logger.debug(f"⚠️ [阿里云OCR] 未识别到文字: {os.path.basename(image_path)}") return "" ocr_content = data['content'] logger.debug(f"📋 [阿里云OCR] content 类型: {type(ocr_content).__name__}") # content 可能是字符串或字典列表 if isinstance(ocr_content, str): full_text = ocr_content logger.debug(f"📄 [阿里云OCR] 直接提取文本: {len(full_text)} 字符") elif isinstance(ocr_content, list): # 如果是列表,提取每个元素的 text 字段 texts = [] for idx, item in enumerate(ocr_content): if isinstance(item, dict) and 'text' in item: text = item['text'] if text and text.strip(): texts.append(text) logger.debug(f" 📌 [阿里云OCR] 行 {idx + 1}: {text[:50]}{'...' if len(text) > 50 else ''}") elif isinstance(item, str): texts.append(item) full_text = "\n".join(texts) logger.debug(f"📄 [阿里云OCR] 合并 {len(texts)} 行文本: {len(full_text)} 字符") else: full_text = str(ocr_content) logger.debug(f"📄 [阿里云OCR] 转换为字符串: {len(full_text)} 字符") if not full_text or not full_text.strip(): logger.debug(f"⚠️ [阿里云OCR] 识别结果为空: {os.path.basename(image_path)}") return "" logger.info(f"✅ [阿里云OCR] 识别成功: {os.path.basename(image_path)}, 识别到 {len(full_text)} 字符") return full_text except Exception as e: logger.warning(f"阿里云 OCR 处理失败: {e}") import traceback logger.debug(traceback.format_exc()) return "" def _process_image_ocr(self, file_path: str) -> List: """ 使用阿里云 OCR 处理图片并提取文字。 Args: file_path: 图片文件路径 Returns: Document 列表 """ if not self.ocr_engine: raise Exception("OCR 服务不可用,无法处理图片。请配置阿里云 OCR(OCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET)") try: full_text = self._ocr_image(file_path) if not full_text: logger.warning(f"图片 OCR 未识别到文字: {file_path}") return [] logger.info(f"图片 OCR 成功(阿里云 OCR),共 {len(full_text)} 字符") # 创建 Document 对象(模拟 LangChain Document 格式) from langchain_core.documents import Document doc = Document( page_content=full_text, metadata={ "source": file_path, "file_type": "image", "has_ocr": True, "ocr_provider": "aliyun" } ) return [doc] except Exception as e: logger.error(f"图片 OCR 处理失败: {e}") raise def _extract_images_from_docx(self, docx_path: str) -> List[str]: """ 从 DOCX 文件中提取所有图片并转换为标准格式(PNG/JPG) Args: docx_path: DOCX 文件路径 Returns: 临时图片文件路径列表 """ if not PYTHON_DOCX_AVAILABLE or not PILLOW_AVAILABLE: logger.warning("⚠️ python-docx 或 Pillow 不可用,无法提取 DOCX 中的图片") return [] logger.info(f"🖼️ [DOCX图片提取] 开始从 DOCX 提取图片: {os.path.basename(docx_path)}") image_paths = [] try: doc = DocxDocument(docx_path) total_rels = len(doc.part.rels.values()) logger.debug(f"📊 [DOCX图片提取] 文档关系总数: {total_rels}") # 遍历文档中的所有关系(relationships) for idx, rel in enumerate(doc.part.rels.values()): # 检查是否是图片关系 if "image" in rel.target_ref: try: # 获取图片数据 image_data = rel.target_part.blob # 使用 Pillow 打开图片并转换为标准格式 try: # 从二进制数据创建图片对象 image = Image.open(io.BytesIO(image_data)) image_size_kb = len(image_data) / 1024 logger.debug(f" 📸 [DOCX图片提取] 图片 {idx + 1}: 原始模式={image.mode}, 大小={image.size}, {image_size_kb:.2f}KB") # 转换为 RGB 模式(确保兼容性) if image.mode in ('RGBA', 'LA', 'P'): logger.debug(f" 🔄 [DOCX图片提取] 图片 {idx + 1}: 转换模式 {image.mode} -> RGB") # 如果有透明通道,创建白色背景 background = Image.new('RGB', image.size, (255, 255, 255)) if image.mode == 'P': image = image.convert('RGBA') background.paste(image, mask=image.split()[-1] if image.mode in ('RGBA', 'LA') else None) image = background elif image.mode != 'RGB': logger.debug(f" 🔄 [DOCX图片提取] 图片 {idx + 1}: 转换模式 {image.mode} -> RGB") image = image.convert('RGB') # 保存为 JPG 格式(阿里云 OCR 支持,且文件更小) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') image.save(tmp_file.name, 'JPEG', quality=95) tmp_file.close() image_paths.append(tmp_file.name) logger.info(f"✅ [DOCX图片提取] 图片 {idx + 1} 已提取并转换: {os.path.basename(tmp_file.name)}") except Exception as e: logger.warning(f"图片 {idx + 1} 格式转换失败: {e},尝试直接保存") # 如果转换失败,尝试直接保存原始数据 tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') tmp_file.write(image_data) tmp_file.close() image_paths.append(tmp_file.name) except Exception as e: logger.warning(f"提取图片 {idx + 1} 失败: {e}") continue logger.info(f"从 DOCX 提取了 {len(image_paths)} 张图片(已转换为 JPG 格式)") return image_paths except Exception as e: logger.error(f"提取 DOCX 图片失败: {e}") return [] def _ocr_images_concurrent(self, image_paths: List[str], max_workers: int = 4) -> List[Tuple[int, str]]: """ 并发处理多张图片的 OCR 识别 Args: image_paths: 图片路径列表 max_workers: 最大并发线程数(默认 4) Returns: List[Tuple[int, str]]: [(图片索引, OCR文本), ...] """ if not image_paths: return [] logger.info(f"🚀 [并发OCR] 开始并发识别 {len(image_paths)} 张图片(并发数: {max_workers}, 引擎: 阿里云OCR)") results = [] start_time = time.time() def ocr_single_image(idx: int, image_path: str) -> Tuple[int, str]: """处理单张图片""" try: logger.debug(f" 🔄 [并发OCR] 线程开始处理图片 {idx + 1}") ocr_text = self._ocr_image(image_path) if ocr_text: logger.info(f"✅ [并发OCR] 图片 {idx + 1} 识别成功: {len(ocr_text)} 字符") return (idx, ocr_text) else: logger.info(f"⚠️ [并发OCR] 图片 {idx + 1} 未识别到文字") return (idx, "") except Exception as e: logger.warning(f"❌ [并发OCR] 图片 {idx + 1} 识别失败: {e}") return (idx, "") # 使用线程池并发处理 logger.debug(f"📦 [并发OCR] 创建线程池,max_workers={max_workers}") with ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交所有任务 future_to_idx = { executor.submit(ocr_single_image, idx, img_path): idx for idx, img_path in enumerate(image_paths) } # 收集结果 completed_count = 0 for future in as_completed(future_to_idx): try: result = future.result() results.append(result) completed_count += 1 logger.debug(f" ✓ [并发OCR] 已完成 {completed_count}/{len(image_paths)} 张图片") except Exception as e: idx = future_to_idx[future] logger.error(f"❌ [并发OCR] 图片 {idx + 1} OCR 任务执行失败: {e}") # 按索引排序,保持原始顺序 results.sort(key=lambda x: x[0]) # 统计结果 elapsed_time = time.time() - start_time success_count = sum(1 for _, text in results if text) total_chars = sum(len(text) for _, text in results) logger.info(f"🎉 [并发OCR] 完成!总计 {len(image_paths)} 张图片,成功 {success_count} 张,识别 {total_chars} 字符,耗时 {elapsed_time:.2f}秒") logger.info(f"📊 [并发OCR] 平均速度: {elapsed_time/len(image_paths):.2f}秒/张") return results def _process_docx_with_images(self, file_path: str) -> Tuple[List, List[str]]: """ 处理 DOCX 文件,提取文字和图片中的文字(使用多线程并发 OCR) Args: file_path: DOCX 文件路径 Returns: Tuple[List, List[str]]: (Document 列表, 提取的图片路径列表) """ from langchain_core.documents import Document try: # 1. 使用标准加载器提取文字内容 loader = UnstructuredWordDocumentLoader(file_path) text_docs = loader.load() text_content = "" if text_docs: text_content = "\n\n".join([doc.page_content for doc in text_docs]) logger.info(f"DOCX 文字内容提取完成,共 {len(text_content)} 字符") # 2. 提取并识别图片中的文字(使用并发处理) image_texts = [] extracted_image_paths = [] # 保存图片路径,稍后用于视觉模型分析 if self.ocr_engine and PYTHON_DOCX_AVAILABLE: image_paths = self._extract_images_from_docx(file_path) extracted_image_paths = image_paths.copy() # 保存副本供后续使用 if image_paths: # 使用多线程并发处理(最多并发 4 张图片) logger.info(f"开始并发 OCR 识别 {len(image_paths)} 张图片(并发数: 4)") ocr_results = self._ocr_images_concurrent(image_paths, max_workers=4) # 整理结果 for idx, ocr_text in ocr_results: if ocr_text: image_texts.append(f"\n\n[图片 {idx + 1} 内容]\n{ocr_text}") logger.info(f"OCR 识别完成,共识别到 {len(image_texts)} 张图片的文字") # 3. 合并文字内容和图片内容 full_content = text_content if image_texts: full_content += "\n\n" + "\n\n".join(image_texts) logger.info(f"DOCX 总内容(文字+图片):{len(full_content)} 字符") # 4. 创建 Document 对象(metadata 中不包含列表类型,避免 ChromaDB 错误) doc = Document( page_content=full_content, metadata={ "source": file_path, "file_type": "docx", "has_images": len(image_texts) > 0, "image_count": len(image_texts), "has_ocr": len(image_texts) > 0 # 注意:不在这里保存 extracted_image_paths,因为 ChromaDB 不支持列表类型 } ) return [doc], extracted_image_paths except Exception as e: logger.error(f"处理 DOCX 文件失败: {e}") # 降级到标准处理方式 loader = UnstructuredWordDocumentLoader(file_path) return loader.load(), [] # 返回元组格式 def _extract_images_from_pdf(self, pdf_path: str) -> List[str]: """ 从 PDF 文件中提取所有页面为图片 Args: pdf_path: PDF 文件路径 Returns: 临时图片文件路径列表 """ if not PYMUPDF_AVAILABLE or not PILLOW_AVAILABLE: logger.warning("⚠️ PyMuPDF 或 Pillow 不可用,无法提取 PDF 页面为图片") return [] logger.info(f"📄 [PDF页面提取] 开始从 PDF 提取页面为图片: {os.path.basename(pdf_path)}") image_paths = [] try: # 打开 PDF 文件 pdf_document = fitz.open(pdf_path) total_pages = len(pdf_document) logger.info(f"📊 [PDF页面提取] PDF 总页数: {total_pages}") # 遍历每一页 for page_num in range(len(pdf_document)): try: logger.debug(f" 🔄 [PDF页面提取] 处理第 {page_num + 1}/{total_pages} 页") page = pdf_document[page_num] # 将页面转换为图片(提高分辨率以提升 OCR 效果) # zoom=2 表示 2 倍分辨率(DPI 约 144) mat = fitz.Matrix(2, 2) pix = page.get_pixmap(matrix=mat) # 保存为临时图片文件 tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') pix.save(tmp_file.name) tmp_file.close() file_size_kb = os.path.getsize(tmp_file.name) / 1024 image_paths.append(tmp_file.name) logger.info(f"✅ [PDF页面提取] 第 {page_num + 1} 页已转换: {os.path.basename(tmp_file.name)} ({file_size_kb:.2f}KB)") except Exception as e: logger.warning(f"PDF 页面 {page_num + 1} 转换失败: {e}") continue pdf_document.close() logger.info(f"从 PDF 提取了 {len(image_paths)} 页图片(已转换为 JPG 格式)") return image_paths except Exception as e: logger.error(f"提取 PDF 页面失败: {e}") return [] def _is_image_pdf(self, docs: List) -> bool: """ 判断 PDF 是否为图片型(扫描版) Args: docs: PyPDFLoader 加载的文档列表 Returns: bool: 如果文本内容很少或为空,则认为是图片型 PDF """ if not docs: return True # 计算总文本长度 total_text = "".join([doc.page_content for doc in docs]) total_chars = len(total_text.strip()) # 如果文本内容少于 100 个字符,认为是图片型 PDF # (排除空格和换行后仍然很少) if total_chars < 100: logger.info(f"检测到图片型 PDF(文本内容少于 100 字符:{total_chars} 字符)") return True return False def _process_pdf_with_ocr(self, file_path: str) -> List: """ 处理图片型 PDF(扫描版),使用 OCR 识别内容 Args: file_path: PDF 文件路径 Returns: Document 列表 """ from langchain_core.documents import Document try: if not self.ocr_engine: raise Exception("OCR 服务不可用,无法处理图片型 PDF") # 1. 提取 PDF 每一页为图片 image_paths = self._extract_images_from_pdf(file_path) if not image_paths: logger.warning("未能从 PDF 提取任何页面") return [] # 2. 使用多线程并发 OCR 识别 logger.info(f"开始并发 OCR 识别 {len(image_paths)} 页 PDF(并发数: 4)") ocr_results = self._ocr_images_concurrent(image_paths, max_workers=4) # 3. 整理每页内容 page_texts = [] for idx, ocr_text in ocr_results: if ocr_text: page_texts.append(f"[第 {idx + 1} 页]\n{ocr_text}") logger.info(f"第 {idx + 1} 页 OCR 识别到 {len(ocr_text)} 字符") else: logger.warning(f"第 {idx + 1} 页 OCR 未识别到文字") # 4. 清理临时图片文件 for img_path in image_paths: try: if os.path.exists(img_path): os.remove(img_path) except: pass if not page_texts: logger.warning("PDF OCR 未识别到任何文字内容") return [] # 5. 合并所有页面内容 full_content = "\n\n".join(page_texts) logger.info(f"PDF OCR 完成,总计 {len(page_texts)} 页,共 {len(full_content)} 字符") # 6. 创建 Document 对象 doc = Document( page_content=full_content, metadata={ "source": file_path, "file_type": "pdf", "is_image_pdf": True, "page_count": len(page_texts), "has_ocr": True, "ocr_provider": "aliyun" } ) return [doc] except Exception as e: logger.error(f"处理图片型 PDF 失败: {e}") import traceback logger.error(traceback.format_exc()) return [] def _get_loader_for_file(self, file_path: str, file_type: str = None): """ 根据文件类型获取合适的文档加载器 Args: file_path: 文件路径 file_type: 文件类型(可选,如果不提供则从文件路径推断) Returns: 加载器实例或 None,或者返回 "image_ocr" 字符串表示需要 OCR 处理 """ # 确定文件扩展名 if file_type: ext = f".{file_type.lower()}" else: ext = Path(file_path).suffix.lower() # 获取加载器配置 loader_config = self.LOADER_MAPPING.get(ext) if not loader_config: logger.warning(f"不支持的文件类型: {ext}") return None primary_loader, fallback_loader = loader_config # 特殊处理图片 OCR if primary_loader == "image_ocr": if not self.ocr_engine: raise Exception("阿里云 OCR 不可用,无法处理图片文件。请配置 OCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET") logger.info("使用阿里云 OCR 处理图片文件") return "image_ocr" # 返回特殊标记 # 特殊处理 DOCX(提取图片并 OCR) if primary_loader == "docx_with_images": if self.ocr_engine and PYTHON_DOCX_AVAILABLE: logger.info(f"使用增强模式处理 DOCX(提取图片并 OCR)") return "docx_with_images" # 返回特殊标记 else: logger.info(f"OCR 或 python-docx 不可用,使用标准 DOCX 加载器") # 降级到标准加载器 # 使用备选加载器 if fallback_loader: loader_class = globals().get(fallback_loader) if loader_class: logger.info(f"使用 {fallback_loader} 加载文件") # 特殊处理 TextLoader(需要编码参数) if fallback_loader == "TextLoader": return self._load_text_with_encoding(file_path, loader_class) return loader_class(file_path) # 如果没有备选,使用主加载器 if primary_loader: loader_class = globals().get(primary_loader) if loader_class: logger.info(f"使用 {primary_loader} 加载文件") # 特殊处理 TextLoader if primary_loader == "TextLoader": return self._load_text_with_encoding(file_path, loader_class) return loader_class(file_path) return None def _load_text_with_encoding(self, file_path: str, loader_class): """ 尝试多种编码加载文本文件 Args: file_path: 文件路径 loader_class: TextLoader 类 Returns: 加载器实例或 None """ encodings = ["utf-8", "gbk", "gb2312", "latin-1"] for encoding in encodings: try: loader = loader_class(file_path, encoding=encoding) # 尝试加载以验证编码是否正确 docs = loader.load() logger.info(f"成功使用编码 {encoding} 加载文本文件") return loader except (UnicodeDecodeError, Exception) as e: continue logger.warning(f"无法使用任何编码加载文本文件: {file_path}") return None def get_vector_store(self, collection_name: str) -> Chroma: """ 获取向量库实例 Args: collection_name: 集合名称(使用知识库 ID) Returns: Chroma: 向量库实例 """ persist_directory = os.path.join(self.vector_store_path, collection_name) vector_store = Chroma( host=settings.chroma_host, port=settings.chroma_port, collection_name=collection_name, embedding_function=self.embedding, # persist_directory=persist_directory ) return vector_store async def process_document( self, file_path: str, knowledge_base_id: int, file_type: str = "pdf", file_id: Optional[int] = None, source_url: Optional[str] = None ) -> ProcessResult: """ 处理文档文件:加载、分割、向量化(支持多种文档格式,包括图片 OCR) 支持的文件类型: - PDF、DOCX、PPT/PPTX(支持 OCR 提取图片文字) - 图片:PNG、JPG、JPEG、BMP(需要配置阿里云 OCR) - 其他:TXT、MD、Excel、CSV、JSON、HTML、XML、Python 等 Args: file_path: 文件路径 knowledge_base_id: 知识库 ID file_type: 文件类型(如 pdf、docx、xlsx、png 等) Returns: ProcessResult: 处理结果对象 """ try: logger.info(f"开始处理文件: {file_path}, 类型: {file_type}") # 1. 获取合适的加载器 loader = self._get_loader_for_file(file_path, file_type) if not loader: error_msg = f"不支持的文件类型: {file_type}" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 2. 加载文档(特殊处理图片 OCR 和 DOCX,放到线程池执行) if loader == "image_ocr": logger.info("🔄 在线程池中执行图片 OCR...") docs = await asyncio.to_thread(self._process_image_ocr, file_path) elif loader == "docx_with_images": logger.info("🔄 在线程池中处理 DOCX 文件(提取图片并 OCR)...") docs, _ = await asyncio.to_thread(self._process_docx_with_images, file_path) # 忽略图片路径(知识库暂不使用视觉模型) else: # 文档加载也可能是 CPU 密集型的,放到线程池中 logger.info(f"🔄 在线程池中加载文档(类型: {file_type})...") docs = await asyncio.to_thread(loader.load) # 特殊处理:检测 PDF 是否为图片型(扫描版) if file_type.lower() == "pdf" and self._is_image_pdf(docs): logger.info("检测到图片型 PDF(扫描版),切换到 OCR 模式") if self.ocr_engine: logger.info("🔄 在线程池中执行 PDF OCR...") docs = await asyncio.to_thread(self._process_pdf_with_ocr, file_path) if not docs: error_msg = "图片型 PDF OCR 识别失败" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) else: error_msg = "检测到图片型 PDF(扫描版),但 OCR 服务不可用" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) logger.info(f"文档加载完成,共 {len(docs)} 个文档片段") if not docs: error_msg = "未能从文件加载到任何内容" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 2. 分割文本 all_splits = self.text_splitter.split_documents(docs) logger.info(f"文本分割完成,共 {len(all_splits)} 个块") # 检查是否有内容 if not all_splits: error_msg = "文档分割后没有内容,可能是空白文档" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 3. 向量化并存储 collection_name = f"kb_{knowledge_base_id}" vector_store = self.get_vector_store(collection_name) # 🔑 关键:在向量化前,将 file_id、chunk_index 和 source_url 添加到 metadata if file_id is not None or source_url is not None: for idx, doc in enumerate(all_splits): if not doc.metadata: doc.metadata = {} if file_id is not None: doc.metadata['file_id'] = file_id doc.metadata['chunk_index'] = idx if source_url is not None: doc.metadata['source'] = source_url # 🔑 替换为 OSS URL if file_id is not None: logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 file_id={file_id}") if source_url is not None: logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 source={source_url}") # 添加文档到向量库 vector_ids = vector_store.add_documents(documents=all_splits) logger.info(f"向量化完成,共 {len(vector_ids)} 个向量") # 4. 准备返回数据 chunks = [] for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)): chunks.append(( idx, # chunk_index doc.page_content, # content doc.metadata, # metadata vector_id # vector_id )) return ProcessResult(success=True, chunks=chunks, chunk_count=len(chunks)) except Exception as e: error_msg = f"处理文件失败 ({file_type}): {str(e)}" logger.error(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) async def process_pdf( self, file_path: str, knowledge_base_id: int ) -> ProcessResult: """ 处理 PDF 文件:加载、分割、向量化(兼容旧接口) Args: file_path: PDF 文件路径 knowledge_base_id: 知识库 ID Returns: ProcessResult: 处理结果对象 """ return await self.process_document(file_path, knowledge_base_id, "pdf") async def process_url( self, url: str, knowledge_base_id: int ) -> ProcessResult: """ 处理 URL:加载网页内容、分割、向量化 Args: url: 网页 URL knowledge_base_id: 知识库 ID Returns: ProcessResult: 处理结果对象 """ try: logger.info(f"开始处理 URL: {url}") # 1. 加载网页内容(放到线程池执行,避免阻塞事件循环) # 使用 bs4 过滤,只保留主要内容(可以根据需要调整) bs4_strainer = bs4.SoupStrainer() loader = WebBaseLoader( web_paths=(url,), bs_kwargs={"parse_only": bs4_strainer} ) logger.info("🔄 在线程池中加载网页内容...") docs = await asyncio.to_thread(loader.load) logger.info(f"网页加载完成,共 {len(docs)} 个文档") if not docs: error_msg = "未能从 URL 加载到任何内容" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 2. 分割文本 all_splits = self.text_splitter.split_documents(docs) logger.info(f"文本分割完成,共 {len(all_splits)} 个块") # 检查是否有内容 if not all_splits: error_msg = "网页分割后没有内容,可能是空白页面或无法提取文本" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 3. 向量化并存储 collection_name = f"kb_{knowledge_base_id}" vector_store = self.get_vector_store(collection_name) # 添加文档到向量库 vector_ids = vector_store.add_documents(documents=all_splits) logger.info(f"向量化完成,共 {len(vector_ids)} 个向量") # 4. 准备返回数据 chunks = [] for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)): chunks.append(( idx, # chunk_index doc.page_content, # content doc.metadata, # metadata vector_id # vector_id )) return ProcessResult(success=True, chunks=chunks, chunk_count=len(chunks)) except Exception as e: error_msg = f"处理 URL 失败: {str(e)}" logger.error(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) async def process_chat_thread_file( self, file_path: str, thread_id: str, file_type: str = "pdf", file_id: Optional[int] = None, source_url: Optional[str] = None ) -> ProcessResult: """ 处理聊天对话文件:加载、分割、向量化(支持多种格式,包括 URL 和图片 OCR) 支持的文件类型: - PDF、DOCX、PPT/PPTX(支持 OCR 提取图片文字) - 图片:PNG、JPG、JPEG、BMP(需要配置阿里云 OCR) - 其他:TXT、MD、Excel、CSV、JSON、HTML、XML 等 - URL:网页链接 Args: file_path: 文件路径或 URL thread_id: 会话线程 ID file_type: 文件类型(pdf、docx、xlsx、png、url 等) Returns: ProcessResult: 处理结果对象 """ try: logger.info(f"开始处理聊天文件: {file_path}, thread_id: {thread_id}, 类型: {file_type}") docs = [] extracted_image_paths = [] # 用于保存 DOCX 中提取的图片路径 # 特殊处理 URL(放到线程池执行) if file_type == "url": bs4_strainer = bs4.SoupStrainer() loader = WebBaseLoader( web_paths=(file_path,), bs_kwargs={"parse_only": bs4_strainer} ) logger.info("🔄 在线程池中加载网页内容...") docs = await asyncio.to_thread(loader.load) logger.info(f"网页加载完成,共 {len(docs)} 个文档") else: # 使用统一的加载器选择逻辑 loader = self._get_loader_for_file(file_path, file_type) print("loader", loader) if not loader: error_msg = f"不支持的文件类型: {file_type}" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 特殊处理图片 OCR 和 DOCX(放到线程池执行,避免阻塞事件循环) if loader == "image_ocr": logger.info("🔄 在线程池中执行图片 OCR...") docs = await asyncio.to_thread(self._process_image_ocr, file_path) elif loader == "docx_with_images": logger.info("🔄 在线程池中处理 DOCX 文件(提取图片并 OCR)...") docs, extracted_image_paths = await asyncio.to_thread(self._process_docx_with_images, file_path) else: # 文档加载也可能是 CPU 密集型的,放到线程池中 logger.info(f"🔄 在线程池中加载文档(类型: {file_type})...") docs = await asyncio.to_thread(loader.load) # 特殊处理:检测 PDF 是否为图片型(扫描版) if file_type.lower() == "pdf" and self._is_image_pdf(docs): logger.info("检测到图片型 PDF(扫描版),切换到 OCR 模式") if self.ocr_engine: logger.info("🔄 在线程池中执行 PDF OCR...") docs = await asyncio.to_thread(self._process_pdf_with_ocr, file_path) if not docs: error_msg = "图片型 PDF OCR 识别失败" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) else: error_msg = "检测到图片型 PDF(扫描版),但 OCR 服务不可用" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) logger.info(f"文档加载完成,共 {len(docs)} 个文档片段") if not docs: error_msg = "未能加载到任何内容" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 分割文本 all_splits = self.text_splitter.split_documents(docs) logger.info(f"文本分割完成,共 {len(all_splits)} 个块") # 检查是否有内容 if not all_splits: error_msg = "文档分割后没有内容,可能是空白文档或无法提取文本" logger.warning(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) # 向量化并存储(使用 thread_id 作为集合名) collection_name = f"thread_{thread_id}" vector_store = self.get_vector_store(collection_name) # 🔑 关键:在向量化前,将 file_id、chunk_index 和 source_url 添加到 metadata if file_id is not None or source_url is not None: for idx, doc in enumerate(all_splits): if not doc.metadata: doc.metadata = {} if file_id is not None: doc.metadata['file_id'] = file_id doc.metadata['chunk_index'] = idx if source_url is not None: doc.metadata['source'] = source_url # 🔑 替换为 OSS URL if file_id is not None: logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 file_id={file_id}") if source_url is not None: logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 source={source_url}") # 添加文档到向量库 vector_ids = vector_store.add_documents(documents=all_splits) logger.info(f"向量化完成,共 {len(vector_ids)} 个向量") # 准备返回数据 chunks = [] for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)): chunks.append(( idx, # chunk_index doc.page_content, # content doc.metadata, # metadata vector_id # vector_id )) return ProcessResult( success=True, chunks=chunks, chunk_count=len(chunks), extracted_image_paths=extracted_image_paths if extracted_image_paths else None ) except Exception as e: error_msg = f"处理聊天文件失败: {str(e)}" logger.error(error_msg) return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg) def search_similar_in_thread( self, thread_id: str, query: str, k: int = 5, file_id: Optional[int] = None, score_threshold: float = 0.0 ) -> List[dict]: """ 在聊天对话中搜索相似文档(增强版:支持过滤和阈值) Args: thread_id: 会话线程 ID query: 查询文本 k: 返回结果数量 file_id: 可选,仅搜索指定文件的内容 score_threshold: 相似度阈值(0-1,越小越相似) Returns: List[dict]: 相似文档列表,按相关性排序 """ try: collection_name = f"thread_{thread_id}" vector_store = self.get_vector_store(collection_name) # 构建过滤条件(如果指定了 file_id) filter_dict = None if file_id is not None: filter_dict = {"file_id": file_id} # 相似度搜索(增加k值以便过滤后仍有足够结果) search_k = k * 3 if file_id else k * 2 results = vector_store.similarity_search_with_score( query, k=search_k, filter=filter_dict ) # 格式化结果并应用阈值过滤 formatted_results = [] for doc, score in results: # ChromaDB 使用距离(越小越相似),过滤掉不相关的结果 if score <= score_threshold or score_threshold == 0.0: formatted_results.append({ "content": doc.page_content, "metadata": doc.metadata, "score": float(score), "file_summary": doc.metadata.get("file_summary", "") # 包含 summary }) # 达到所需数量即可停止 if len(formatted_results) >= k: break logger.info(f"向量检索完成: 查询='{query[:50]}...', 结果数={len(formatted_results)}, file_id={file_id}, 阈值={score_threshold}") return formatted_results except Exception as e: logger.error(f"搜索相似文档失败: {e}") return [] def get_all_file_chunks( self, thread_id: str, file_id: int ) -> List[dict]: """ 获取指定文件的所有chunks(完整内容)- 参考 server 实现 Args: thread_id: 会话线程 ID file_id: 文件 ID Returns: List[dict]: 所有 chunk 列表(按 chunk_index 排序) """ try: collection_name = f"thread_{thread_id}" logger.info(f"🔍 开始获取文件chunks: collection={collection_name}, file_id={file_id}") vector_store = self.get_vector_store(collection_name) # 获取该文件的所有 chunks(使用 filter) # 注意:ChromaDB 的 get 方法需要使用 where 参数 all_docs = vector_store.get( where={"file_id": file_id}, include=["documents", "metadatas"] ) logger.info(f"📦 ChromaDB返回结果: documents数量={len(all_docs.get('documents', []))}, metadatas数量={len(all_docs.get('metadatas', []))}") # 格式化结果并按 chunk_index 排序 chunks = [] if all_docs and 'documents' in all_docs and all_docs['documents']: logger.info(f"✅ 检测到 {len(all_docs['documents'])} 个文档") for idx, (doc_content, metadata) in enumerate(zip(all_docs['documents'], all_docs['metadatas'])): chunk_index = metadata.get("chunk_index", idx) has_summary = "file_summary" in metadata logger.info(f" - Chunk {idx}: chunk_index={chunk_index}, 内容长度={len(doc_content)}, 有摘要={has_summary}, metadata_keys={list(metadata.keys())}") chunks.append({ "content": doc_content, "metadata": metadata, "chunk_index": chunk_index, "file_summary": metadata.get("file_summary", "") }) # 按 chunk_index 排序 chunks.sort(key=lambda x: x['chunk_index']) logger.info(f"✅ 排序完成,共 {len(chunks)} 个chunks,chunk_index范围: {chunks[0]['chunk_index']} ~ {chunks[-1]['chunk_index']}") else: logger.warning(f"⚠️ ChromaDB未返回任何文档: all_docs keys={list(all_docs.keys()) if all_docs else 'None'}") logger.info(f"📊 最终返回: file_id={file_id}, 总chunks数量={len(chunks)}") return chunks except Exception as e: logger.error(f"❌ 获取文件所有chunks失败: {e}") import traceback logger.error(traceback.format_exc()) return [] def update_file_summary_in_vectors( self, thread_id: str, file_id: int, summary: str ) -> bool: """ 更新 ChromaDB 中指定文件所有向量的 summary metadata Args: thread_id: 会话线程 ID file_id: 文件 ID summary: 文件摘要 Returns: bool: 是否更新成功 """ try: collection_name = f"thread_{thread_id}" vector_store = self.get_vector_store(collection_name) # 获取该文件的所有向量 all_docs = vector_store.get( where={"file_id": file_id}, include=["metadatas"] ) if not all_docs or 'ids' not in all_docs or not all_docs['ids']: logger.warning(f"未找到 file_id={file_id} 的向量") return False # 更新每个向量的 metadata(添加 file_summary) vector_ids = all_docs['ids'] updated_metadatas = [] for metadata in all_docs['metadatas']: updated_metadata = metadata.copy() updated_metadata['file_summary'] = summary updated_metadatas.append(updated_metadata) # ChromaDB 更新 metadata vector_store._collection.update( ids=vector_ids, metadatas=updated_metadatas ) logger.info(f"✅ 已更新 ChromaDB 中 {len(vector_ids)} 个向量的 summary (file_id={file_id})") return True except Exception as e: logger.error(f"更新 ChromaDB summary 失败: {e}") import traceback logger.error(traceback.format_exc()) return False def update_kb_file_summary_in_vectors( self, knowledge_base_id: int, file_id: int, summary: str ) -> bool: """ 更新 ChromaDB 中指定知识库文件所有向量的 summary metadata Args: knowledge_base_id: 知识库 ID file_id: 文件 ID summary: 文件摘要 Returns: bool: 是否更新成功 """ try: collection_name = f"kb_{knowledge_base_id}" vector_store = self.get_vector_store(collection_name) # 获取该文件的所有向量 all_docs = vector_store.get( where={"file_id": file_id}, include=["metadatas"] ) if not all_docs or 'ids' not in all_docs or not all_docs['ids']: logger.warning(f"未找到 file_id={file_id} 的向量 (kb_id={knowledge_base_id})") return False # 更新每个向量的 metadata(添加 file_summary) vector_ids = all_docs['ids'] updated_metadatas = [] for metadata in all_docs['metadatas']: updated_metadata = metadata.copy() updated_metadata['file_summary'] = summary updated_metadatas.append(updated_metadata) # ChromaDB 更新 metadata vector_store._collection.update( ids=vector_ids, metadatas=updated_metadatas ) logger.info(f"✅ 已更新 ChromaDB 中 {len(vector_ids)} 个向量的 summary (kb_id={knowledge_base_id}, file_id={file_id})") return True except Exception as e: logger.error(f"更新 ChromaDB summary 失败 (kb_id={knowledge_base_id}, file_id={file_id}): {e}") import traceback logger.error(traceback.format_exc()) return False def delete_thread_vectors( self, thread_id: str, vector_ids: List[str] ) -> bool: """ 删除聊天对话中的向量 Args: thread_id: 会话线程 ID vector_ids: 向量 ID 列表 Returns: bool: 是否删除成功 """ try: if not vector_ids: return True collection_name = f"thread_{thread_id}" vector_store = self.get_vector_store(collection_name) # 删除指定的向量 vector_store.delete(ids=vector_ids) logger.info(f"从会话 {thread_id} 删除 {len(vector_ids)} 个向量") return True except Exception as e: logger.error(f"删除向量失败: {e}") return False def delete_thread_collection( self, thread_id: str ) -> bool: """ 删除聊天对话的整个向量集合 Args: thread_id: 会话线程 ID Returns: bool: 是否删除成功 """ try: collection_name = f"thread_{thread_id}" persist_directory = os.path.join(self.vector_store_path, collection_name) # 删除向量库目录 import shutil if os.path.exists(persist_directory): shutil.rmtree(persist_directory) logger.info(f"删除向量库: {collection_name}") return True return False except Exception as e: logger.error(f"删除向量库失败: {e}") return False def search_similar( self, knowledge_base_id: int, query: str, k: int = 5 ) -> List[dict]: """ 在知识库中搜索相似文档 Args: knowledge_base_id: 知识库 ID query: 查询文本 k: 返回结果数量 Returns: List[dict]: 相似文档列表 """ try: collection_name = f"kb_{knowledge_base_id}" vector_store = self.get_vector_store(collection_name) # 相似度搜索 results = vector_store.similarity_search_with_score(query, k=k) # 格式化结果 formatted_results = [] for doc, score in results: formatted_results.append({ "content": doc.page_content, "metadata": doc.metadata, "score": float(score) }) return formatted_results except Exception as e: logger.error(f"搜索相似文档失败: {e}") return [] def delete_vectors_by_ids( self, knowledge_base_id: int, vector_ids: List[str] ) -> bool: """ 根据向量 ID 列表删除向量库中的向量 Args: knowledge_base_id: 知识库 ID vector_ids: 向量 ID 列表 Returns: bool: 是否删除成功 """ try: if not vector_ids: return True collection_name = f"kb_{knowledge_base_id}" vector_store = self.get_vector_store(collection_name) # 删除指定的向量 vector_store.delete(ids=vector_ids) logger.info(f"从知识库 {knowledge_base_id} 删除 {len(vector_ids)} 个向量") return True except Exception as e: logger.error(f"删除向量失败: {e}") return False def delete_collection(self, knowledge_base_id: int) -> bool: """ 删除知识库的整个向量集合 Args: knowledge_base_id: 知识库 ID Returns: bool: 是否删除成功 """ try: collection_name = f"kb_{knowledge_base_id}" persist_directory = os.path.join(self.vector_store_path, collection_name) # 删除向量库目录 import shutil if os.path.exists(persist_directory): shutil.rmtree(persist_directory) logger.info(f"删除向量库: {collection_name}") return True return False except Exception as e: logger.error(f"删除向量库失败: {e}") return False # ----- 知识图谱 RAG(Chroma 集合名 knowledge_graph_{graphs.id};兼容旧名 novel_kg_*) ----- def delete_knowledge_graph_collection(self, knowledge_graph_pk: int) -> bool: """删除知识图谱对应的 Chroma 持久化目录(含旧集合 novel_kg_*)。""" import shutil ok = True for name in (f"knowledge_graph_{knowledge_graph_pk}", f"novel_kg_{knowledge_graph_pk}"): try: persist_directory = os.path.join(self.vector_store_path, name) if os.path.exists(persist_directory): shutil.rmtree(persist_directory) logger.info(f"删除知识图谱向量库: {name}") except Exception as e: logger.error(f"删除知识图谱向量库失败 {name}: {e}") ok = False return ok def index_knowledge_graph_text(self, knowledge_graph_pk: int, text: str) -> int: """ 将资料全文分块写入 Chroma。 Returns: 写入的分块数量 """ from langchain_core.documents import Document if not text or not text.strip(): return 0 self.delete_knowledge_graph_collection(knowledge_graph_pk) doc = Document( page_content=text.strip(), metadata={"source": f"knowledge_graph_{knowledge_graph_pk}", "kind": "knowledge_graph"}, ) splits = self.text_splitter.split_documents([doc]) if not splits: return 0 for idx, d in enumerate(splits): if not d.metadata: d.metadata = {} d.metadata["chunk_index"] = idx collection_name = f"knowledge_graph_{knowledge_graph_pk}" vector_store = self.get_vector_store(collection_name) vector_store.add_documents(splits) logger.info(f"知识图谱向量化完成 knowledge_graph_pk={knowledge_graph_pk} chunks={len(splits)}") return len(splits) def search_similar_knowledge_graph( self, knowledge_graph_pk: int, query: str, k: int = 5, ) -> List[dict]: """在知识图谱向量集合中检索(优先新集合名,兼容 novel_kg_*)。""" for name in (f"knowledge_graph_{knowledge_graph_pk}", f"novel_kg_{knowledge_graph_pk}"): persist_directory = os.path.join(self.vector_store_path, name) if not os.path.exists(persist_directory): continue try: vector_store = self.get_vector_store(name) results = vector_store.similarity_search_with_score(query, k=k) formatted: List[dict] = [] for doc, score in results: formatted.append({ "content": doc.page_content, "metadata": doc.metadata, "score": float(score), }) return formatted except Exception as e: logger.error(f"知识图谱向量检索失败 ({name}): {e}") return [] # ==================== 增强的 RAG 功能 ==================== def add_summary_chunk( self, collection_name: str, file_id: int, file_name: str, summary_text: str, metadata: Optional[dict] = None ) -> Optional[str]: """ 添加文件摘要 chunk Args: collection_name: 集合名称 file_id: 文件 ID file_name: 文件名 summary_text: 摘要文本 metadata: 额外的元数据 Returns: Optional[str]: 摘要 chunk 的 ID """ try: vector_store = self.get_vector_store(collection_name) # 构建摘要 metadata summary_metadata = { "file_id": file_id, "file_name": file_name, "chunk_type": "summary", # 标记为摘要类型 "chunk_index": -1, # 摘要没有索引 } if metadata: summary_metadata.update(metadata) # 添加到向量库 ids = vector_store.add_texts( texts=[summary_text], metadatas=[summary_metadata] ) logger.info(f"添加摘要 chunk 成功: {file_name}") return ids[0] if ids else None except Exception as e: logger.error(f"添加摘要 chunk 失败: {e}") return None def search_by_chunk_type( self, collection_name: str, query: str, chunk_type: str = "text", top_k: int = 5, filter_metadata: Optional[dict] = None ) -> List[Tuple[str, dict, float]]: """ 基于 chunk 类型检索 Args: collection_name: 集合名称 query: 查询文本 chunk_type: chunk 类型 (summary, text) top_k: 返回结果数量 filter_metadata: 额外的过滤条件 Returns: List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表 """ try: vector_store = self.get_vector_store(collection_name) # 构建过滤条件 where_filter = {"chunk_type": chunk_type} if filter_metadata: where_filter.update(filter_metadata) # 执行检索 results = vector_store.similarity_search_with_score( query=query, k=top_k, filter=where_filter ) return [ (doc.page_content, doc.metadata, score) for doc, score in results ] except Exception as e: logger.error(f"基于类型检索失败: {e}") return [] def get_file_all_chunks( self, collection_name: str, file_id: int, chunk_type: str = "text" ) -> List[Tuple[str, dict]]: """ 获取文件的所有 chunks(用于全文检索) Args: collection_name: 集合名称 file_id: 文件 ID chunk_type: chunk 类型 Returns: List[Tuple[str, dict]]: (内容, 元数据) 的列表 """ try: vector_store = self.get_vector_store(collection_name) # 使用 where 过滤获取所有匹配的文档 results = vector_store.get( where={ "file_id": file_id, "chunk_type": chunk_type } ) if not results or "documents" not in results: return [] documents = results["documents"] metadatas = results["metadatas"] # 按 chunk_index 排序 chunks = list(zip(documents, metadatas)) chunks.sort(key=lambda x: x[1].get("chunk_index", 0)) return chunks except Exception as e: logger.error(f"获取文件所有 chunks 失败: {e}") return [] def hybrid_search( self, collection_name: str, query: str, top_k: int = 5, filter_metadata: Optional[dict] = None, dense_weight: float = 0.7, sparse_weight: float = 0.3 ) -> List[Tuple[str, dict, float]]: """ 混合检索(dense + sparse) 注意:这是一个简化版本,真正的混合检索需要支持 BM25 等稀疏检索算法。 ChromaDB 目前主要支持 dense 向量检索。 Args: collection_name: 集合名称 query: 查询文本 top_k: 返回结果数量 filter_metadata: 过滤条件 dense_weight: dense 检索权重 sparse_weight: sparse 检索权重(当前未实现) Returns: List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表 """ try: # TODO: 实现真正的混合检索 # 当前仅使用 dense 检索 logger.warning("混合检索功能当前仅支持 dense 向量检索,需要额外集成 BM25") vector_store = self.get_vector_store(collection_name) # 执行 dense 检索 results = vector_store.similarity_search_with_score( query=query, k=top_k, filter=filter_metadata ) return [ (doc.page_content, doc.metadata, score) for doc, score in results ] except Exception as e: logger.error(f"混合检索失败: {e}") return [] def search_with_rerank( self, collection_name: str, query: str, top_k: int = 5, rerank_top_k: int = 3, filter_metadata: Optional[dict] = None ) -> List[Tuple[str, dict, float]]: """ 带重排序的检索 Args: collection_name: 集合名称 query: 查询文本 top_k: 初次检索数量 rerank_top_k: 重排序后返回数量 filter_metadata: 过滤条件 Returns: List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表 """ try: # 先执行初次检索,获取较多结果 results = self.search_similar( collection_name=collection_name, query=query, top_k=top_k, filter_metadata=filter_metadata ) # TODO: 使用 rerank 模型对结果进行重排序 # 当前仅返回 top_k 结果 return results[:rerank_top_k] except Exception as e: logger.error(f"重排序检索失败: {e}") return [] def get_summary_by_file_ids( self, collection_name: str, file_ids: List[int] ) -> Dict[int, str]: """ 批量获取文件摘要 Args: collection_name: 集合名称 file_ids: 文件 ID 列表 Returns: Dict[int, str]: 文件 ID 到摘要的映射 """ try: vector_store = self.get_vector_store(collection_name) summaries = {} for file_id in file_ids: results = vector_store.get( where={ "file_id": file_id, "chunk_type": "summary" } ) if results and "documents" in results and results["documents"]: summaries[file_id] = results["documents"][0] return summaries except Exception as e: logger.error(f"批量获取摘要失败: {e}") return {} # 全局向量服务实例 _vector_service = None def get_vector_service() -> VectorService: """获取全局向量服务实例""" global _vector_service if _vector_service is None: _vector_service = VectorService() return _vector_service