1923 lines
75 KiB
Python
1923 lines
75 KiB
Python
"""
|
||
向量化处理服务
|
||
"""
|
||
import os
|
||
import io
|
||
import tempfile
|
||
import base64
|
||
import time
|
||
import asyncio
|
||
from typing import List, Tuple, Optional, Dict
|
||
from pathlib import Path
|
||
from dataclasses import dataclass
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
from core.config import settings
|
||
|
||
from langchain_community.document_loaders import (
|
||
PyPDFLoader,
|
||
WebBaseLoader,
|
||
UnstructuredWordDocumentLoader,
|
||
UnstructuredExcelLoader,
|
||
UnstructuredPowerPointLoader,
|
||
TextLoader,
|
||
CSVLoader,
|
||
JSONLoader,
|
||
UnstructuredHTMLLoader,
|
||
UnstructuredMarkdownLoader,
|
||
PythonLoader,
|
||
UnstructuredXMLLoader,
|
||
)
|
||
|
||
# 尝试导入阿里云 OCR SDK
|
||
try:
|
||
from alibabacloud_ocr_api20210707.client import Client as OcrClient
|
||
from alibabacloud_tea_openapi import models as open_api_models
|
||
from alibabacloud_ocr_api20210707 import models as ocr_models
|
||
from alibabacloud_darabonba_stream.client import Client as StreamClient
|
||
from alibabacloud_tea_util import models as util_models
|
||
ALIYUN_OCR_AVAILABLE = True
|
||
except ImportError:
|
||
ALIYUN_OCR_AVAILABLE = False
|
||
OcrClient = None
|
||
StreamClient = None
|
||
|
||
# 尝试导入 python-docx 用于提取 DOCX 中的图片
|
||
try:
|
||
from docx import Document as DocxDocument
|
||
from docx.oxml.text.paragraph import CT_P
|
||
from docx.oxml.table import CT_Tbl
|
||
from docx.table import _Cell, Table
|
||
from docx.text.paragraph import Paragraph
|
||
PYTHON_DOCX_AVAILABLE = True
|
||
except ImportError:
|
||
PYTHON_DOCX_AVAILABLE = False
|
||
DocxDocument = None
|
||
|
||
# 尝试导入 Pillow 用于图片处理
|
||
try:
|
||
from PIL import Image
|
||
PILLOW_AVAILABLE = True
|
||
except ImportError:
|
||
PILLOW_AVAILABLE = False
|
||
Image = None
|
||
|
||
# 尝试导入 PyMuPDF 用于 PDF 处理
|
||
try:
|
||
import fitz # PyMuPDF
|
||
PYMUPDF_AVAILABLE = True
|
||
except ImportError:
|
||
PYMUPDF_AVAILABLE = False
|
||
fitz = None
|
||
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from langchain_ollama import OllamaEmbeddings
|
||
from langchain_openai import OpenAIEmbeddings
|
||
from langchain_chroma import Chroma
|
||
import bs4
|
||
|
||
from logger.logging import get_logger
|
||
from core.config import settings
|
||
from services.kb_text_limits import validate_kb_text_length, validate_chat_file_text_length
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ProcessResult:
|
||
"""文档处理结果"""
|
||
success: bool
|
||
chunks: List[Tuple[int, str, dict, str]]
|
||
chunk_count: int
|
||
error_message: Optional[str] = None
|
||
extracted_image_paths: Optional[List[str]] = None # DOCX 中提取的图片路径(供视觉模型使用)
|
||
|
||
|
||
class VectorService:
|
||
"""向量化处理服务类"""
|
||
|
||
# 文件类型与加载器的映射表
|
||
LOADER_MAPPING: Dict[str, tuple] = {
|
||
# 文档类型(DOCX 使用增强处理,提取图片并 OCR)
|
||
".pdf": ("PyPDFLoader", None),
|
||
".docx": ("docx_with_images", "UnstructuredWordDocumentLoader"),
|
||
".ppt": ("UnstructuredPowerPointLoader", None),
|
||
".pptx": ("UnstructuredPowerPointLoader", None),
|
||
|
||
# 图片文件(使用阿里云 OCR)
|
||
".png": ("image_ocr", None),
|
||
".jpg": ("image_ocr", None),
|
||
".jpeg": ("image_ocr", None),
|
||
".bmp": ("image_ocr", None),
|
||
|
||
# 其他文档类型
|
||
".txt": ("TextLoader", None),
|
||
".md": ("UnstructuredMarkdownLoader", "TextLoader"),
|
||
".xlsx": ("UnstructuredExcelLoader", None),
|
||
".xls": ("UnstructuredExcelLoader", None),
|
||
".csv": ("CSVLoader", None),
|
||
".json": ("JSONLoader", None),
|
||
".html": ("UnstructuredHTMLLoader", None),
|
||
".htm": ("UnstructuredHTMLLoader", None),
|
||
".xml": ("UnstructuredXMLLoader", None),
|
||
".py": ("PythonLoader", None),
|
||
}
|
||
|
||
def __init__(self):
|
||
"""初始化向量服务"""
|
||
# 初始化嵌入模型
|
||
# self.embedding = OllamaEmbeddings(model="nomic-embed-text")
|
||
# DashScope 兼容网关只接受字符串 input;默认 check_embedding_ctx_length=True
|
||
# 会用 tiktoken 转成 token id 列表再请求,导致 400:contents is neither str nor list of str
|
||
print(settings.dashscope_api_key, settings.dashscope_api_base)
|
||
self.embedding = OpenAIEmbeddings(
|
||
model="text-embedding-v4",
|
||
api_key=os.getenv("ZL_DASHSCOPE_API_KEY"), # 如果您没有配置环境变量,请在此处用您的API Key进行替换
|
||
base_url=os.getenv("ZL_DASHSCOPE_API_BASE"),
|
||
check_embedding_ctx_length=False,
|
||
)
|
||
|
||
# 文本分割器配置(参考 server:增大 chunk_size 保留更多上下文)
|
||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=4096, # 从 1000 增加到 4096,保留更多完整内容
|
||
chunk_overlap=200, # 保持适度重叠,确保语义连续性
|
||
add_start_index=True
|
||
)
|
||
|
||
# 向量库存储路径
|
||
self.vector_store_path = settings.chroma_persist_directory or "./chroma_db"
|
||
|
||
# 初始化阿里云 OCR(图片、扫描 PDF、DOCX 内嵌图均依赖云端识别)
|
||
self.ocr_engine = None
|
||
if ALIYUN_OCR_AVAILABLE and settings.ocr_access_key_id and settings.ocr_access_key_secret:
|
||
try:
|
||
config = open_api_models.Config(
|
||
access_key_id=settings.ocr_access_key_id,
|
||
access_key_secret=settings.ocr_access_key_secret,
|
||
endpoint=settings.ocr_endpoint
|
||
)
|
||
self.ocr_engine = OcrClient(config)
|
||
logger.info("✅ 阿里云 OCR 已启用,将使用云端 OCR 服务识别图片文字")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ 阿里云 OCR 初始化失败: {e}")
|
||
elif not ALIYUN_OCR_AVAILABLE:
|
||
logger.warning("⚠️ 阿里云 OCR SDK 未安装,图片与扫描件 OCR 不可用")
|
||
else:
|
||
logger.info("ℹ️ 未配置阿里云 OCR(需要 OCR_ACCESS_KEY_ID 和 OCR_ACCESS_KEY_SECRET),图片 OCR 将不可用")
|
||
|
||
if not self.ocr_engine:
|
||
logger.warning("⚠️ OCR 服务不可用,图片与扫描件内容将无法通过 OCR 提取。请配置阿里云 OCR")
|
||
|
||
def _ocr_image(self, image_path: str) -> str:
|
||
"""
|
||
使用阿里云 OCR 识别单张图片中的文字。
|
||
|
||
Args:
|
||
image_path: 图片文件路径
|
||
|
||
Returns:
|
||
识别到的文字内容
|
||
"""
|
||
if not self.ocr_engine:
|
||
return ""
|
||
|
||
try:
|
||
return self._ocr_image_aliyun(image_path)
|
||
except Exception as e:
|
||
logger.warning(f"图片 OCR 处理失败: {e}")
|
||
return ""
|
||
|
||
def _ocr_image_aliyun(self, image_path: str) -> str:
|
||
"""
|
||
使用阿里云 OCR 识别图片中的文字
|
||
|
||
Args:
|
||
image_path: 图片文件路径
|
||
|
||
Returns:
|
||
识别到的文字内容
|
||
"""
|
||
try:
|
||
logger.debug(f"🔍 [阿里云OCR] 开始识别图片: {os.path.basename(image_path)}")
|
||
|
||
# 读取图片文件为字节流
|
||
with open(image_path, 'rb') as f:
|
||
image_bytes = f.read()
|
||
|
||
image_size_kb = len(image_bytes) / 1024
|
||
logger.debug(f"📊 [阿里云OCR] 图片大小: {image_size_kb:.2f}KB")
|
||
|
||
# 使用 StreamClient 读取字节流(阿里云 SDK 要求的格式)
|
||
body_stream = StreamClient.read_from_bytes(image_bytes)
|
||
logger.debug(f"📦 [阿里云OCR] 字节流已创建")
|
||
|
||
# 构建请求
|
||
request = ocr_models.RecognizeGeneralRequest(body=body_stream)
|
||
|
||
# 运行时选项
|
||
runtime = util_models.RuntimeOptions()
|
||
|
||
logger.debug(f"☁️ [阿里云OCR] 调用 API: recognize_general_with_options")
|
||
# 调用阿里云 OCR API(使用 with_options 版本)
|
||
response = self.ocr_engine.recognize_general_with_options(request, runtime)
|
||
logger.debug(f"✅ [阿里云OCR] API 调用成功")
|
||
|
||
# 解析结果
|
||
if not response or not response.body or not response.body.data:
|
||
logger.debug(f"⚠️ [阿里云OCR] 未返回数据: {os.path.basename(image_path)}")
|
||
return ""
|
||
|
||
# 解析 JSON 数据
|
||
import json
|
||
logger.debug(f"📝 [阿里云OCR] 开始解析返回数据")
|
||
data = json.loads(response.body.data)
|
||
|
||
# 提取 content 字段
|
||
if not data or 'content' not in data:
|
||
logger.debug(f"⚠️ [阿里云OCR] 未识别到文字: {os.path.basename(image_path)}")
|
||
return ""
|
||
|
||
ocr_content = data['content']
|
||
logger.debug(f"📋 [阿里云OCR] content 类型: {type(ocr_content).__name__}")
|
||
|
||
# content 可能是字符串或字典列表
|
||
if isinstance(ocr_content, str):
|
||
full_text = ocr_content
|
||
logger.debug(f"📄 [阿里云OCR] 直接提取文本: {len(full_text)} 字符")
|
||
elif isinstance(ocr_content, list):
|
||
# 如果是列表,提取每个元素的 text 字段
|
||
texts = []
|
||
for idx, item in enumerate(ocr_content):
|
||
if isinstance(item, dict) and 'text' in item:
|
||
text = item['text']
|
||
if text and text.strip():
|
||
texts.append(text)
|
||
logger.debug(f" 📌 [阿里云OCR] 行 {idx + 1}: {text[:50]}{'...' if len(text) > 50 else ''}")
|
||
elif isinstance(item, str):
|
||
texts.append(item)
|
||
full_text = "\n".join(texts)
|
||
logger.debug(f"📄 [阿里云OCR] 合并 {len(texts)} 行文本: {len(full_text)} 字符")
|
||
else:
|
||
full_text = str(ocr_content)
|
||
logger.debug(f"📄 [阿里云OCR] 转换为字符串: {len(full_text)} 字符")
|
||
|
||
if not full_text or not full_text.strip():
|
||
logger.debug(f"⚠️ [阿里云OCR] 识别结果为空: {os.path.basename(image_path)}")
|
||
return ""
|
||
|
||
logger.info(f"✅ [阿里云OCR] 识别成功: {os.path.basename(image_path)}, 识别到 {len(full_text)} 字符")
|
||
return full_text
|
||
|
||
except Exception as e:
|
||
logger.warning(f"阿里云 OCR 处理失败: {e}")
|
||
import traceback
|
||
logger.debug(traceback.format_exc())
|
||
return ""
|
||
|
||
def _process_image_ocr(self, file_path: str) -> List:
|
||
"""
|
||
使用阿里云 OCR 处理图片并提取文字。
|
||
|
||
Args:
|
||
file_path: 图片文件路径
|
||
|
||
Returns:
|
||
Document 列表
|
||
"""
|
||
if not self.ocr_engine:
|
||
raise Exception("OCR 服务不可用,无法处理图片。请配置阿里云 OCR(OCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET)")
|
||
|
||
try:
|
||
full_text = self._ocr_image(file_path)
|
||
|
||
if not full_text:
|
||
logger.warning(f"图片 OCR 未识别到文字: {file_path}")
|
||
return []
|
||
|
||
logger.info(f"图片 OCR 成功(阿里云 OCR),共 {len(full_text)} 字符")
|
||
|
||
# 创建 Document 对象(模拟 LangChain Document 格式)
|
||
from langchain_core.documents import Document
|
||
|
||
doc = Document(
|
||
page_content=full_text,
|
||
metadata={
|
||
"source": file_path,
|
||
"file_type": "image",
|
||
"has_ocr": True,
|
||
"ocr_provider": "aliyun"
|
||
}
|
||
)
|
||
|
||
return [doc]
|
||
|
||
except Exception as e:
|
||
logger.error(f"图片 OCR 处理失败: {e}")
|
||
raise
|
||
|
||
def _extract_images_from_docx(self, docx_path: str) -> List[str]:
|
||
"""
|
||
从 DOCX 文件中提取所有图片并转换为标准格式(PNG/JPG)
|
||
|
||
Args:
|
||
docx_path: DOCX 文件路径
|
||
|
||
Returns:
|
||
临时图片文件路径列表
|
||
"""
|
||
if not PYTHON_DOCX_AVAILABLE or not PILLOW_AVAILABLE:
|
||
logger.warning("⚠️ python-docx 或 Pillow 不可用,无法提取 DOCX 中的图片")
|
||
return []
|
||
|
||
logger.info(f"🖼️ [DOCX图片提取] 开始从 DOCX 提取图片: {os.path.basename(docx_path)}")
|
||
image_paths = []
|
||
|
||
try:
|
||
doc = DocxDocument(docx_path)
|
||
total_rels = len(doc.part.rels.values())
|
||
logger.debug(f"📊 [DOCX图片提取] 文档关系总数: {total_rels}")
|
||
|
||
# 遍历文档中的所有关系(relationships)
|
||
for idx, rel in enumerate(doc.part.rels.values()):
|
||
# 检查是否是图片关系
|
||
if "image" in rel.target_ref:
|
||
try:
|
||
# 获取图片数据
|
||
image_data = rel.target_part.blob
|
||
|
||
# 使用 Pillow 打开图片并转换为标准格式
|
||
try:
|
||
# 从二进制数据创建图片对象
|
||
image = Image.open(io.BytesIO(image_data))
|
||
image_size_kb = len(image_data) / 1024
|
||
logger.debug(f" 📸 [DOCX图片提取] 图片 {idx + 1}: 原始模式={image.mode}, 大小={image.size}, {image_size_kb:.2f}KB")
|
||
|
||
# 转换为 RGB 模式(确保兼容性)
|
||
if image.mode in ('RGBA', 'LA', 'P'):
|
||
logger.debug(f" 🔄 [DOCX图片提取] 图片 {idx + 1}: 转换模式 {image.mode} -> RGB")
|
||
# 如果有透明通道,创建白色背景
|
||
background = Image.new('RGB', image.size, (255, 255, 255))
|
||
if image.mode == 'P':
|
||
image = image.convert('RGBA')
|
||
background.paste(image, mask=image.split()[-1] if image.mode in ('RGBA', 'LA') else None)
|
||
image = background
|
||
elif image.mode != 'RGB':
|
||
logger.debug(f" 🔄 [DOCX图片提取] 图片 {idx + 1}: 转换模式 {image.mode} -> RGB")
|
||
image = image.convert('RGB')
|
||
|
||
# 保存为 JPG 格式(阿里云 OCR 支持,且文件更小)
|
||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
|
||
image.save(tmp_file.name, 'JPEG', quality=95)
|
||
tmp_file.close()
|
||
|
||
image_paths.append(tmp_file.name)
|
||
logger.info(f"✅ [DOCX图片提取] 图片 {idx + 1} 已提取并转换: {os.path.basename(tmp_file.name)}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"图片 {idx + 1} 格式转换失败: {e},尝试直接保存")
|
||
# 如果转换失败,尝试直接保存原始数据
|
||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
|
||
tmp_file.write(image_data)
|
||
tmp_file.close()
|
||
image_paths.append(tmp_file.name)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"提取图片 {idx + 1} 失败: {e}")
|
||
continue
|
||
|
||
logger.info(f"从 DOCX 提取了 {len(image_paths)} 张图片(已转换为 JPG 格式)")
|
||
return image_paths
|
||
|
||
except Exception as e:
|
||
logger.error(f"提取 DOCX 图片失败: {e}")
|
||
return []
|
||
|
||
def _ocr_images_concurrent(self, image_paths: List[str], max_workers: int = 4) -> List[Tuple[int, str]]:
|
||
"""
|
||
并发处理多张图片的 OCR 识别
|
||
|
||
Args:
|
||
image_paths: 图片路径列表
|
||
max_workers: 最大并发线程数(默认 4)
|
||
|
||
Returns:
|
||
List[Tuple[int, str]]: [(图片索引, OCR文本), ...]
|
||
"""
|
||
if not image_paths:
|
||
return []
|
||
|
||
logger.info(f"🚀 [并发OCR] 开始并发识别 {len(image_paths)} 张图片(并发数: {max_workers}, 引擎: 阿里云OCR)")
|
||
|
||
results = []
|
||
start_time = time.time()
|
||
|
||
def ocr_single_image(idx: int, image_path: str) -> Tuple[int, str]:
|
||
"""处理单张图片"""
|
||
try:
|
||
logger.debug(f" 🔄 [并发OCR] 线程开始处理图片 {idx + 1}")
|
||
ocr_text = self._ocr_image(image_path)
|
||
if ocr_text:
|
||
logger.info(f"✅ [并发OCR] 图片 {idx + 1} 识别成功: {len(ocr_text)} 字符")
|
||
return (idx, ocr_text)
|
||
else:
|
||
logger.info(f"⚠️ [并发OCR] 图片 {idx + 1} 未识别到文字")
|
||
return (idx, "")
|
||
except Exception as e:
|
||
logger.warning(f"❌ [并发OCR] 图片 {idx + 1} 识别失败: {e}")
|
||
return (idx, "")
|
||
|
||
# 使用线程池并发处理
|
||
logger.debug(f"📦 [并发OCR] 创建线程池,max_workers={max_workers}")
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# 提交所有任务
|
||
future_to_idx = {
|
||
executor.submit(ocr_single_image, idx, img_path): idx
|
||
for idx, img_path in enumerate(image_paths)
|
||
}
|
||
|
||
# 收集结果
|
||
completed_count = 0
|
||
for future in as_completed(future_to_idx):
|
||
try:
|
||
result = future.result()
|
||
results.append(result)
|
||
completed_count += 1
|
||
logger.debug(f" ✓ [并发OCR] 已完成 {completed_count}/{len(image_paths)} 张图片")
|
||
except Exception as e:
|
||
idx = future_to_idx[future]
|
||
logger.error(f"❌ [并发OCR] 图片 {idx + 1} OCR 任务执行失败: {e}")
|
||
|
||
# 按索引排序,保持原始顺序
|
||
results.sort(key=lambda x: x[0])
|
||
|
||
# 统计结果
|
||
elapsed_time = time.time() - start_time
|
||
success_count = sum(1 for _, text in results if text)
|
||
total_chars = sum(len(text) for _, text in results)
|
||
|
||
logger.info(f"🎉 [并发OCR] 完成!总计 {len(image_paths)} 张图片,成功 {success_count} 张,识别 {total_chars} 字符,耗时 {elapsed_time:.2f}秒")
|
||
logger.info(f"📊 [并发OCR] 平均速度: {elapsed_time/len(image_paths):.2f}秒/张")
|
||
|
||
return results
|
||
|
||
def _process_docx_with_images(self, file_path: str) -> Tuple[List, List[str]]:
|
||
"""
|
||
处理 DOCX 文件,提取文字和图片中的文字(使用多线程并发 OCR)
|
||
|
||
Args:
|
||
file_path: DOCX 文件路径
|
||
|
||
Returns:
|
||
Tuple[List, List[str]]: (Document 列表, 提取的图片路径列表)
|
||
"""
|
||
from langchain_core.documents import Document
|
||
|
||
try:
|
||
# 1. 使用标准加载器提取文字内容
|
||
loader = UnstructuredWordDocumentLoader(file_path)
|
||
text_docs = loader.load()
|
||
|
||
text_content = ""
|
||
if text_docs:
|
||
text_content = "\n\n".join([doc.page_content for doc in text_docs])
|
||
logger.info(f"DOCX 文字内容提取完成,共 {len(text_content)} 字符")
|
||
|
||
# 2. 提取并识别图片中的文字(使用并发处理)
|
||
image_texts = []
|
||
extracted_image_paths = [] # 保存图片路径,稍后用于视觉模型分析
|
||
|
||
if self.ocr_engine and PYTHON_DOCX_AVAILABLE:
|
||
image_paths = self._extract_images_from_docx(file_path)
|
||
extracted_image_paths = image_paths.copy() # 保存副本供后续使用
|
||
|
||
if image_paths:
|
||
# 使用多线程并发处理(最多并发 4 张图片)
|
||
logger.info(f"开始并发 OCR 识别 {len(image_paths)} 张图片(并发数: 4)")
|
||
ocr_results = self._ocr_images_concurrent(image_paths, max_workers=4)
|
||
|
||
# 整理结果
|
||
for idx, ocr_text in ocr_results:
|
||
if ocr_text:
|
||
image_texts.append(f"\n\n[图片 {idx + 1} 内容]\n{ocr_text}")
|
||
|
||
logger.info(f"OCR 识别完成,共识别到 {len(image_texts)} 张图片的文字")
|
||
|
||
# 3. 合并文字内容和图片内容
|
||
full_content = text_content
|
||
if image_texts:
|
||
full_content += "\n\n" + "\n\n".join(image_texts)
|
||
logger.info(f"DOCX 总内容(文字+图片):{len(full_content)} 字符")
|
||
|
||
# 4. 创建 Document 对象(metadata 中不包含列表类型,避免 ChromaDB 错误)
|
||
doc = Document(
|
||
page_content=full_content,
|
||
metadata={
|
||
"source": file_path,
|
||
"file_type": "docx",
|
||
"has_images": len(image_texts) > 0,
|
||
"image_count": len(image_texts),
|
||
"has_ocr": len(image_texts) > 0
|
||
# 注意:不在这里保存 extracted_image_paths,因为 ChromaDB 不支持列表类型
|
||
}
|
||
)
|
||
|
||
return [doc], extracted_image_paths
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理 DOCX 文件失败: {e}")
|
||
# 降级到标准处理方式
|
||
loader = UnstructuredWordDocumentLoader(file_path)
|
||
return loader.load(), [] # 返回元组格式
|
||
|
||
def _extract_images_from_pdf(self, pdf_path: str) -> List[str]:
|
||
"""
|
||
从 PDF 文件中提取所有页面为图片
|
||
|
||
Args:
|
||
pdf_path: PDF 文件路径
|
||
|
||
Returns:
|
||
临时图片文件路径列表
|
||
"""
|
||
if not PYMUPDF_AVAILABLE or not PILLOW_AVAILABLE:
|
||
logger.warning("⚠️ PyMuPDF 或 Pillow 不可用,无法提取 PDF 页面为图片")
|
||
return []
|
||
|
||
logger.info(f"📄 [PDF页面提取] 开始从 PDF 提取页面为图片: {os.path.basename(pdf_path)}")
|
||
image_paths = []
|
||
|
||
try:
|
||
# 打开 PDF 文件
|
||
pdf_document = fitz.open(pdf_path)
|
||
total_pages = len(pdf_document)
|
||
logger.info(f"📊 [PDF页面提取] PDF 总页数: {total_pages}")
|
||
|
||
# 遍历每一页
|
||
for page_num in range(len(pdf_document)):
|
||
try:
|
||
logger.debug(f" 🔄 [PDF页面提取] 处理第 {page_num + 1}/{total_pages} 页")
|
||
page = pdf_document[page_num]
|
||
|
||
# 将页面转换为图片(提高分辨率以提升 OCR 效果)
|
||
# zoom=2 表示 2 倍分辨率(DPI 约 144)
|
||
mat = fitz.Matrix(2, 2)
|
||
pix = page.get_pixmap(matrix=mat)
|
||
|
||
# 保存为临时图片文件
|
||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
|
||
pix.save(tmp_file.name)
|
||
tmp_file.close()
|
||
|
||
file_size_kb = os.path.getsize(tmp_file.name) / 1024
|
||
image_paths.append(tmp_file.name)
|
||
logger.info(f"✅ [PDF页面提取] 第 {page_num + 1} 页已转换: {os.path.basename(tmp_file.name)} ({file_size_kb:.2f}KB)")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"PDF 页面 {page_num + 1} 转换失败: {e}")
|
||
continue
|
||
|
||
pdf_document.close()
|
||
logger.info(f"从 PDF 提取了 {len(image_paths)} 页图片(已转换为 JPG 格式)")
|
||
return image_paths
|
||
|
||
except Exception as e:
|
||
logger.error(f"提取 PDF 页面失败: {e}")
|
||
return []
|
||
|
||
def _is_image_pdf(self, docs: List) -> bool:
|
||
"""
|
||
判断 PDF 是否为图片型(扫描版)
|
||
|
||
Args:
|
||
docs: PyPDFLoader 加载的文档列表
|
||
|
||
Returns:
|
||
bool: 如果文本内容很少或为空,则认为是图片型 PDF
|
||
"""
|
||
if not docs:
|
||
return True
|
||
|
||
# 计算总文本长度
|
||
total_text = "".join([doc.page_content for doc in docs])
|
||
total_chars = len(total_text.strip())
|
||
|
||
# 如果文本内容少于 100 个字符,认为是图片型 PDF
|
||
# (排除空格和换行后仍然很少)
|
||
if total_chars < 100:
|
||
logger.info(f"检测到图片型 PDF(文本内容少于 100 字符:{total_chars} 字符)")
|
||
return True
|
||
|
||
return False
|
||
|
||
def _process_pdf_with_ocr(self, file_path: str) -> List:
|
||
"""
|
||
处理图片型 PDF(扫描版),使用 OCR 识别内容
|
||
|
||
Args:
|
||
file_path: PDF 文件路径
|
||
|
||
Returns:
|
||
Document 列表
|
||
"""
|
||
from langchain_core.documents import Document
|
||
|
||
try:
|
||
if not self.ocr_engine:
|
||
raise Exception("OCR 服务不可用,无法处理图片型 PDF")
|
||
|
||
# 1. 提取 PDF 每一页为图片
|
||
image_paths = self._extract_images_from_pdf(file_path)
|
||
|
||
if not image_paths:
|
||
logger.warning("未能从 PDF 提取任何页面")
|
||
return []
|
||
|
||
# 2. 使用多线程并发 OCR 识别
|
||
logger.info(f"开始并发 OCR 识别 {len(image_paths)} 页 PDF(并发数: 4)")
|
||
ocr_results = self._ocr_images_concurrent(image_paths, max_workers=4)
|
||
|
||
# 3. 整理每页内容
|
||
page_texts = []
|
||
for idx, ocr_text in ocr_results:
|
||
if ocr_text:
|
||
page_texts.append(f"[第 {idx + 1} 页]\n{ocr_text}")
|
||
logger.info(f"第 {idx + 1} 页 OCR 识别到 {len(ocr_text)} 字符")
|
||
else:
|
||
logger.warning(f"第 {idx + 1} 页 OCR 未识别到文字")
|
||
|
||
# 4. 清理临时图片文件
|
||
for img_path in image_paths:
|
||
try:
|
||
if os.path.exists(img_path):
|
||
os.remove(img_path)
|
||
except:
|
||
pass
|
||
|
||
if not page_texts:
|
||
logger.warning("PDF OCR 未识别到任何文字内容")
|
||
return []
|
||
|
||
# 5. 合并所有页面内容
|
||
full_content = "\n\n".join(page_texts)
|
||
logger.info(f"PDF OCR 完成,总计 {len(page_texts)} 页,共 {len(full_content)} 字符")
|
||
|
||
# 6. 创建 Document 对象
|
||
doc = Document(
|
||
page_content=full_content,
|
||
metadata={
|
||
"source": file_path,
|
||
"file_type": "pdf",
|
||
"is_image_pdf": True,
|
||
"page_count": len(page_texts),
|
||
"has_ocr": True,
|
||
"ocr_provider": "aliyun"
|
||
}
|
||
)
|
||
|
||
return [doc]
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理图片型 PDF 失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return []
|
||
|
||
def _get_loader_for_file(self, file_path: str, file_type: str = None):
|
||
"""
|
||
根据文件类型获取合适的文档加载器
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
file_type: 文件类型(可选,如果不提供则从文件路径推断)
|
||
|
||
Returns:
|
||
加载器实例或 None,或者返回 "image_ocr" 字符串表示需要 OCR 处理
|
||
"""
|
||
# 确定文件扩展名
|
||
if file_type:
|
||
ext = f".{file_type.lower()}"
|
||
else:
|
||
ext = Path(file_path).suffix.lower()
|
||
|
||
# 获取加载器配置
|
||
loader_config = self.LOADER_MAPPING.get(ext)
|
||
if not loader_config:
|
||
logger.warning(f"不支持的文件类型: {ext}")
|
||
return None
|
||
|
||
primary_loader, fallback_loader = loader_config
|
||
|
||
# 特殊处理图片 OCR
|
||
if primary_loader == "image_ocr":
|
||
if not self.ocr_engine:
|
||
raise Exception("阿里云 OCR 不可用,无法处理图片文件。请配置 OCR_ACCESS_KEY_ID 与 OCR_ACCESS_KEY_SECRET")
|
||
logger.info("使用阿里云 OCR 处理图片文件")
|
||
return "image_ocr" # 返回特殊标记
|
||
|
||
# 特殊处理 DOCX(提取图片并 OCR)
|
||
if primary_loader == "docx_with_images":
|
||
if self.ocr_engine and PYTHON_DOCX_AVAILABLE:
|
||
logger.info(f"使用增强模式处理 DOCX(提取图片并 OCR)")
|
||
return "docx_with_images" # 返回特殊标记
|
||
else:
|
||
logger.info(f"OCR 或 python-docx 不可用,使用标准 DOCX 加载器")
|
||
# 降级到标准加载器
|
||
|
||
# 使用备选加载器
|
||
if fallback_loader:
|
||
loader_class = globals().get(fallback_loader)
|
||
if loader_class:
|
||
logger.info(f"使用 {fallback_loader} 加载文件")
|
||
# 特殊处理 TextLoader(需要编码参数)
|
||
if fallback_loader == "TextLoader":
|
||
return self._load_text_with_encoding(file_path, loader_class)
|
||
return loader_class(file_path)
|
||
|
||
# 如果没有备选,使用主加载器
|
||
if primary_loader:
|
||
loader_class = globals().get(primary_loader)
|
||
if loader_class:
|
||
logger.info(f"使用 {primary_loader} 加载文件")
|
||
# 特殊处理 TextLoader
|
||
if primary_loader == "TextLoader":
|
||
return self._load_text_with_encoding(file_path, loader_class)
|
||
return loader_class(file_path)
|
||
|
||
return None
|
||
|
||
def _load_text_with_encoding(self, file_path: str, loader_class):
|
||
"""
|
||
尝试多种编码加载文本文件
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
loader_class: TextLoader 类
|
||
|
||
Returns:
|
||
加载器实例或 None
|
||
"""
|
||
encodings = ["utf-8", "gbk", "gb2312", "latin-1"]
|
||
for encoding in encodings:
|
||
try:
|
||
loader = loader_class(file_path, encoding=encoding)
|
||
# 尝试加载以验证编码是否正确
|
||
docs = loader.load()
|
||
logger.info(f"成功使用编码 {encoding} 加载文本文件")
|
||
return loader
|
||
except (UnicodeDecodeError, Exception) as e:
|
||
continue
|
||
|
||
logger.warning(f"无法使用任何编码加载文本文件: {file_path}")
|
||
return None
|
||
|
||
def get_vector_store(self, collection_name: str) -> Chroma:
|
||
"""
|
||
获取向量库实例
|
||
|
||
Args:
|
||
collection_name: 集合名称(使用知识库 ID)
|
||
|
||
Returns:
|
||
Chroma: 向量库实例
|
||
"""
|
||
persist_directory = os.path.join(self.vector_store_path, collection_name)
|
||
|
||
vector_store = Chroma(
|
||
host=settings.chroma_host,
|
||
port=settings.chroma_port,
|
||
collection_name=collection_name,
|
||
embedding_function=self.embedding,
|
||
# persist_directory=persist_directory
|
||
)
|
||
|
||
return vector_store
|
||
|
||
async def process_document(
|
||
self,
|
||
file_path: str,
|
||
knowledge_base_id: int,
|
||
file_type: str = "pdf",
|
||
file_id: Optional[int] = None,
|
||
source_url: Optional[str] = None
|
||
) -> ProcessResult:
|
||
"""
|
||
处理文档文件:加载、分割、向量化(支持多种文档格式,包括图片 OCR)
|
||
|
||
支持的文件类型:
|
||
- PDF、DOCX、PPT/PPTX(支持 OCR 提取图片文字)
|
||
- 图片:PNG、JPG、JPEG、BMP(需要配置阿里云 OCR)
|
||
- 其他:TXT、MD、Excel、CSV、JSON、HTML、XML、Python 等
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
knowledge_base_id: 知识库 ID
|
||
file_type: 文件类型(如 pdf、docx、xlsx、png 等)
|
||
|
||
Returns:
|
||
ProcessResult: 处理结果对象
|
||
"""
|
||
try:
|
||
logger.info(f"开始处理文件: {file_path}, 类型: {file_type}")
|
||
|
||
# 1. 获取合适的加载器
|
||
loader = self._get_loader_for_file(file_path, file_type)
|
||
|
||
if not loader:
|
||
error_msg = f"不支持的文件类型: {file_type}"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 2. 加载文档(特殊处理图片 OCR 和 DOCX,放到线程池执行)
|
||
if loader == "image_ocr":
|
||
logger.info("🔄 在线程池中执行图片 OCR...")
|
||
docs = await asyncio.to_thread(self._process_image_ocr, file_path)
|
||
elif loader == "docx_with_images":
|
||
logger.info("🔄 在线程池中处理 DOCX 文件(提取图片并 OCR)...")
|
||
docs, _ = await asyncio.to_thread(self._process_docx_with_images, file_path) # 忽略图片路径(知识库暂不使用视觉模型)
|
||
else:
|
||
# 文档加载也可能是 CPU 密集型的,放到线程池中
|
||
logger.info(f"🔄 在线程池中加载文档(类型: {file_type})...")
|
||
docs = await asyncio.to_thread(loader.load)
|
||
|
||
# 特殊处理:检测 PDF 是否为图片型(扫描版)
|
||
if file_type.lower() == "pdf" and self._is_image_pdf(docs):
|
||
logger.info("检测到图片型 PDF(扫描版),切换到 OCR 模式")
|
||
if self.ocr_engine:
|
||
logger.info("🔄 在线程池中执行 PDF OCR...")
|
||
docs = await asyncio.to_thread(self._process_pdf_with_ocr, file_path)
|
||
if not docs:
|
||
error_msg = "图片型 PDF OCR 识别失败"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
else:
|
||
error_msg = "检测到图片型 PDF(扫描版),但 OCR 服务不可用"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
logger.info(f"文档加载完成,共 {len(docs)} 个文档片段")
|
||
|
||
if not docs:
|
||
error_msg = "未能从文件加载到任何内容"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
total_chars = sum(len(d.page_content or "") for d in docs)
|
||
try:
|
||
validate_kb_text_length(total_chars)
|
||
except ValueError as e:
|
||
error_msg = str(e)
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 2. 分割文本
|
||
all_splits = self.text_splitter.split_documents(docs)
|
||
logger.info(f"文本分割完成,共 {len(all_splits)} 个块")
|
||
|
||
# 检查是否有内容
|
||
if not all_splits:
|
||
error_msg = "文档分割后没有内容,可能是空白文档"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
try:
|
||
validate_kb_text_length(total_chars, chunk_count=len(all_splits))
|
||
except ValueError as e:
|
||
error_msg = str(e)
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 3. 向量化并存储
|
||
collection_name = f"kb_{knowledge_base_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 🔑 关键:在向量化前,将 file_id、chunk_index 和 source_url 添加到 metadata
|
||
if file_id is not None or source_url is not None:
|
||
for idx, doc in enumerate(all_splits):
|
||
if not doc.metadata:
|
||
doc.metadata = {}
|
||
if file_id is not None:
|
||
doc.metadata['file_id'] = file_id
|
||
doc.metadata['chunk_index'] = idx
|
||
if source_url is not None:
|
||
doc.metadata['source'] = source_url # 🔑 替换为 OSS URL
|
||
|
||
if file_id is not None:
|
||
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 file_id={file_id}")
|
||
if source_url is not None:
|
||
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 source={source_url}")
|
||
|
||
# 添加文档到向量库
|
||
vector_ids = vector_store.add_documents(documents=all_splits)
|
||
logger.info(f"向量化完成,共 {len(vector_ids)} 个向量")
|
||
|
||
# 4. 准备返回数据
|
||
chunks = []
|
||
for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)):
|
||
chunks.append((
|
||
idx, # chunk_index
|
||
doc.page_content, # content
|
||
doc.metadata, # metadata
|
||
vector_id # vector_id
|
||
))
|
||
|
||
return ProcessResult(success=True, chunks=chunks, chunk_count=len(chunks))
|
||
|
||
except Exception as e:
|
||
error_msg = f"处理文件失败 ({file_type}): {str(e)}"
|
||
logger.error(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
async def process_pdf(
|
||
self,
|
||
file_path: str,
|
||
knowledge_base_id: int
|
||
) -> ProcessResult:
|
||
"""
|
||
处理 PDF 文件:加载、分割、向量化(兼容旧接口)
|
||
|
||
Args:
|
||
file_path: PDF 文件路径
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
ProcessResult: 处理结果对象
|
||
"""
|
||
return await self.process_document(file_path, knowledge_base_id, "pdf")
|
||
|
||
async def process_url(
|
||
self,
|
||
url: str,
|
||
knowledge_base_id: int
|
||
) -> ProcessResult:
|
||
"""
|
||
处理 URL:加载网页内容、分割、向量化
|
||
|
||
Args:
|
||
url: 网页 URL
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
ProcessResult: 处理结果对象
|
||
"""
|
||
try:
|
||
logger.info(f"开始处理 URL: {url}")
|
||
|
||
# 1. 加载网页内容(放到线程池执行,避免阻塞事件循环)
|
||
# 使用 bs4 过滤,只保留主要内容(可以根据需要调整)
|
||
bs4_strainer = bs4.SoupStrainer()
|
||
loader = WebBaseLoader(
|
||
web_paths=(url,),
|
||
bs_kwargs={"parse_only": bs4_strainer}
|
||
)
|
||
logger.info("🔄 在线程池中加载网页内容...")
|
||
docs = await asyncio.to_thread(loader.load)
|
||
logger.info(f"网页加载完成,共 {len(docs)} 个文档")
|
||
|
||
if not docs:
|
||
error_msg = "未能从 URL 加载到任何内容"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
total_chars = sum(len(d.page_content or "") for d in docs)
|
||
try:
|
||
validate_kb_text_length(total_chars)
|
||
except ValueError as e:
|
||
error_msg = str(e)
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 2. 分割文本
|
||
all_splits = self.text_splitter.split_documents(docs)
|
||
logger.info(f"文本分割完成,共 {len(all_splits)} 个块")
|
||
|
||
# 检查是否有内容
|
||
if not all_splits:
|
||
error_msg = "网页分割后没有内容,可能是空白页面或无法提取文本"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
try:
|
||
validate_kb_text_length(total_chars, chunk_count=len(all_splits))
|
||
except ValueError as e:
|
||
error_msg = str(e)
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 3. 向量化并存储
|
||
collection_name = f"kb_{knowledge_base_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 添加文档到向量库
|
||
vector_ids = vector_store.add_documents(documents=all_splits)
|
||
logger.info(f"向量化完成,共 {len(vector_ids)} 个向量")
|
||
|
||
# 4. 准备返回数据
|
||
chunks = []
|
||
for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)):
|
||
chunks.append((
|
||
idx, # chunk_index
|
||
doc.page_content, # content
|
||
doc.metadata, # metadata
|
||
vector_id # vector_id
|
||
))
|
||
|
||
return ProcessResult(success=True, chunks=chunks, chunk_count=len(chunks))
|
||
|
||
except Exception as e:
|
||
error_msg = f"处理 URL 失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
async def process_chat_thread_file(
|
||
self,
|
||
file_path: str,
|
||
thread_id: str,
|
||
file_type: str = "pdf",
|
||
file_id: Optional[int] = None,
|
||
source_url: Optional[str] = None
|
||
) -> ProcessResult:
|
||
"""
|
||
处理聊天对话文件:加载、分割、向量化(支持多种格式,包括 URL 和图片 OCR)
|
||
|
||
支持的文件类型:
|
||
- PDF、DOCX、PPT/PPTX(支持 OCR 提取图片文字)
|
||
- 图片:PNG、JPG、JPEG、BMP(需要配置阿里云 OCR)
|
||
- 其他:TXT、MD、Excel、CSV、JSON、HTML、XML 等
|
||
- URL:网页链接
|
||
|
||
Args:
|
||
file_path: 文件路径或 URL
|
||
thread_id: 会话线程 ID
|
||
file_type: 文件类型(pdf、docx、xlsx、png、url 等)
|
||
|
||
Returns:
|
||
ProcessResult: 处理结果对象
|
||
"""
|
||
try:
|
||
logger.info(f"开始处理聊天文件: {file_path}, thread_id: {thread_id}, 类型: {file_type}")
|
||
|
||
docs = []
|
||
extracted_image_paths = [] # 用于保存 DOCX 中提取的图片路径
|
||
|
||
# 特殊处理 URL(放到线程池执行)
|
||
if file_type == "url":
|
||
bs4_strainer = bs4.SoupStrainer()
|
||
loader = WebBaseLoader(
|
||
web_paths=(file_path,),
|
||
bs_kwargs={"parse_only": bs4_strainer}
|
||
)
|
||
logger.info("🔄 在线程池中加载网页内容...")
|
||
docs = await asyncio.to_thread(loader.load)
|
||
logger.info(f"网页加载完成,共 {len(docs)} 个文档")
|
||
else:
|
||
# 使用统一的加载器选择逻辑
|
||
loader = self._get_loader_for_file(file_path, file_type)
|
||
print("loader", loader)
|
||
|
||
if not loader:
|
||
error_msg = f"不支持的文件类型: {file_type}"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 特殊处理图片 OCR 和 DOCX(放到线程池执行,避免阻塞事件循环)
|
||
if loader == "image_ocr":
|
||
logger.info("🔄 在线程池中执行图片 OCR...")
|
||
docs = await asyncio.to_thread(self._process_image_ocr, file_path)
|
||
elif loader == "docx_with_images":
|
||
logger.info("🔄 在线程池中处理 DOCX 文件(提取图片并 OCR)...")
|
||
docs, extracted_image_paths = await asyncio.to_thread(self._process_docx_with_images, file_path)
|
||
else:
|
||
# 文档加载也可能是 CPU 密集型的,放到线程池中
|
||
logger.info(f"🔄 在线程池中加载文档(类型: {file_type})...")
|
||
docs = await asyncio.to_thread(loader.load)
|
||
|
||
# 特殊处理:检测 PDF 是否为图片型(扫描版)
|
||
if file_type.lower() == "pdf" and self._is_image_pdf(docs):
|
||
logger.info("检测到图片型 PDF(扫描版),切换到 OCR 模式")
|
||
if self.ocr_engine:
|
||
logger.info("🔄 在线程池中执行 PDF OCR...")
|
||
docs = await asyncio.to_thread(self._process_pdf_with_ocr, file_path)
|
||
if not docs:
|
||
error_msg = "图片型 PDF OCR 识别失败"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
else:
|
||
error_msg = "检测到图片型 PDF(扫描版),但 OCR 服务不可用"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
logger.info(f"文档加载完成,共 {len(docs)} 个文档片段")
|
||
|
||
if not docs:
|
||
error_msg = "未能加载到任何内容"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
total_chars = sum(len(d.page_content or "") for d in docs)
|
||
try:
|
||
validate_chat_file_text_length(total_chars)
|
||
except ValueError as e:
|
||
error_msg = str(e)
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 分割文本
|
||
all_splits = self.text_splitter.split_documents(docs)
|
||
logger.info(f"文本分割完成,共 {len(all_splits)} 个块")
|
||
|
||
# 检查是否有内容
|
||
if not all_splits:
|
||
error_msg = "文档分割后没有内容,可能是空白文档或无法提取文本"
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
try:
|
||
validate_chat_file_text_length(total_chars, chunk_count=len(all_splits))
|
||
except ValueError as e:
|
||
error_msg = str(e)
|
||
logger.warning(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
# 向量化并存储(使用 thread_id 作为集合名)
|
||
collection_name = f"thread_{thread_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 🔑 关键:在向量化前,将 file_id、chunk_index 和 source_url 添加到 metadata
|
||
if file_id is not None or source_url is not None:
|
||
for idx, doc in enumerate(all_splits):
|
||
if not doc.metadata:
|
||
doc.metadata = {}
|
||
if file_id is not None:
|
||
doc.metadata['file_id'] = file_id
|
||
doc.metadata['chunk_index'] = idx
|
||
if source_url is not None:
|
||
doc.metadata['source'] = source_url # 🔑 替换为 OSS URL
|
||
|
||
if file_id is not None:
|
||
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 file_id={file_id}")
|
||
if source_url is not None:
|
||
logger.info(f"✅ 已为 {len(all_splits)} 个chunks设置 source={source_url}")
|
||
# 添加文档到向量库
|
||
vector_ids = vector_store.add_documents(documents=all_splits)
|
||
logger.info(f"向量化完成,共 {len(vector_ids)} 个向量")
|
||
|
||
# 准备返回数据
|
||
chunks = []
|
||
for idx, (doc, vector_id) in enumerate(zip(all_splits, vector_ids)):
|
||
chunks.append((
|
||
idx, # chunk_index
|
||
doc.page_content, # content
|
||
doc.metadata, # metadata
|
||
vector_id # vector_id
|
||
))
|
||
|
||
return ProcessResult(
|
||
success=True,
|
||
chunks=chunks,
|
||
chunk_count=len(chunks),
|
||
extracted_image_paths=extracted_image_paths if extracted_image_paths else None
|
||
)
|
||
|
||
except Exception as e:
|
||
error_msg = f"处理聊天文件失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
return ProcessResult(success=False, chunks=[], chunk_count=0, error_message=error_msg)
|
||
|
||
def search_similar_in_thread(
|
||
self,
|
||
thread_id: str,
|
||
query: str,
|
||
k: int = 5,
|
||
file_id: Optional[int] = None,
|
||
score_threshold: float = 0.0
|
||
) -> List[dict]:
|
||
"""
|
||
在聊天对话中搜索相似文档(增强版:支持过滤和阈值)
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
query: 查询文本
|
||
k: 返回结果数量
|
||
file_id: 可选,仅搜索指定文件的内容
|
||
score_threshold: 相似度阈值(0-1,越小越相似)
|
||
|
||
Returns:
|
||
List[dict]: 相似文档列表,按相关性排序
|
||
"""
|
||
try:
|
||
collection_name = f"thread_{thread_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 构建过滤条件(如果指定了 file_id)
|
||
filter_dict = None
|
||
if file_id is not None:
|
||
filter_dict = {"file_id": file_id}
|
||
|
||
# 相似度搜索(增加k值以便过滤后仍有足够结果)
|
||
search_k = k * 3 if file_id else k * 2
|
||
results = vector_store.similarity_search_with_score(
|
||
query,
|
||
k=search_k,
|
||
filter=filter_dict
|
||
)
|
||
|
||
# 格式化结果并应用阈值过滤
|
||
formatted_results = []
|
||
for doc, score in results:
|
||
# ChromaDB 使用距离(越小越相似),过滤掉不相关的结果
|
||
if score <= score_threshold or score_threshold == 0.0:
|
||
formatted_results.append({
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"score": float(score),
|
||
"file_summary": doc.metadata.get("file_summary", "") # 包含 summary
|
||
})
|
||
|
||
# 达到所需数量即可停止
|
||
if len(formatted_results) >= k:
|
||
break
|
||
|
||
logger.info(f"向量检索完成: 查询='{query[:50]}...', 结果数={len(formatted_results)}, file_id={file_id}, 阈值={score_threshold}")
|
||
return formatted_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索相似文档失败: {e}")
|
||
return []
|
||
|
||
def get_all_file_chunks(
|
||
self,
|
||
thread_id: str,
|
||
file_id: int
|
||
) -> List[dict]:
|
||
"""
|
||
获取指定文件的所有chunks(完整内容)- 参考 server 实现
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
file_id: 文件 ID
|
||
|
||
Returns:
|
||
List[dict]: 所有 chunk 列表(按 chunk_index 排序)
|
||
"""
|
||
try:
|
||
collection_name = f"thread_{thread_id}"
|
||
logger.info(f"🔍 开始获取文件chunks: collection={collection_name}, file_id={file_id}")
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 获取该文件的所有 chunks(使用 filter)
|
||
# 注意:ChromaDB 的 get 方法需要使用 where 参数
|
||
all_docs = vector_store.get(
|
||
where={"file_id": file_id},
|
||
include=["documents", "metadatas"]
|
||
)
|
||
|
||
logger.info(f"📦 ChromaDB返回结果: documents数量={len(all_docs.get('documents', []))}, metadatas数量={len(all_docs.get('metadatas', []))}")
|
||
|
||
# 格式化结果并按 chunk_index 排序
|
||
chunks = []
|
||
if all_docs and 'documents' in all_docs and all_docs['documents']:
|
||
logger.info(f"✅ 检测到 {len(all_docs['documents'])} 个文档")
|
||
for idx, (doc_content, metadata) in enumerate(zip(all_docs['documents'], all_docs['metadatas'])):
|
||
chunk_index = metadata.get("chunk_index", idx)
|
||
has_summary = "file_summary" in metadata
|
||
logger.info(f" - Chunk {idx}: chunk_index={chunk_index}, 内容长度={len(doc_content)}, 有摘要={has_summary}, metadata_keys={list(metadata.keys())}")
|
||
chunks.append({
|
||
"content": doc_content,
|
||
"metadata": metadata,
|
||
"chunk_index": chunk_index,
|
||
"file_summary": metadata.get("file_summary", "")
|
||
})
|
||
|
||
# 按 chunk_index 排序
|
||
chunks.sort(key=lambda x: x['chunk_index'])
|
||
logger.info(f"✅ 排序完成,共 {len(chunks)} 个chunks,chunk_index范围: {chunks[0]['chunk_index']} ~ {chunks[-1]['chunk_index']}")
|
||
else:
|
||
logger.warning(f"⚠️ ChromaDB未返回任何文档: all_docs keys={list(all_docs.keys()) if all_docs else 'None'}")
|
||
|
||
logger.info(f"📊 最终返回: file_id={file_id}, 总chunks数量={len(chunks)}")
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取文件所有chunks失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return []
|
||
|
||
def update_file_summary_in_vectors(
|
||
self,
|
||
thread_id: str,
|
||
file_id: int,
|
||
summary: str
|
||
) -> bool:
|
||
"""
|
||
更新 ChromaDB 中指定文件所有向量的 summary metadata
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
file_id: 文件 ID
|
||
summary: 文件摘要
|
||
|
||
Returns:
|
||
bool: 是否更新成功
|
||
"""
|
||
try:
|
||
collection_name = f"thread_{thread_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 获取该文件的所有向量
|
||
all_docs = vector_store.get(
|
||
where={"file_id": file_id},
|
||
include=["metadatas"]
|
||
)
|
||
|
||
if not all_docs or 'ids' not in all_docs or not all_docs['ids']:
|
||
logger.warning(f"未找到 file_id={file_id} 的向量")
|
||
return False
|
||
|
||
# 更新每个向量的 metadata(添加 file_summary)
|
||
vector_ids = all_docs['ids']
|
||
updated_metadatas = []
|
||
for metadata in all_docs['metadatas']:
|
||
updated_metadata = metadata.copy()
|
||
updated_metadata['file_summary'] = summary
|
||
updated_metadatas.append(updated_metadata)
|
||
|
||
# ChromaDB 更新 metadata
|
||
vector_store._collection.update(
|
||
ids=vector_ids,
|
||
metadatas=updated_metadatas
|
||
)
|
||
|
||
logger.info(f"✅ 已更新 ChromaDB 中 {len(vector_ids)} 个向量的 summary (file_id={file_id})")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新 ChromaDB summary 失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def update_kb_file_summary_in_vectors(
|
||
self,
|
||
knowledge_base_id: int,
|
||
file_id: int,
|
||
summary: str
|
||
) -> bool:
|
||
"""
|
||
更新 ChromaDB 中指定知识库文件所有向量的 summary metadata
|
||
|
||
Args:
|
||
knowledge_base_id: 知识库 ID
|
||
file_id: 文件 ID
|
||
summary: 文件摘要
|
||
|
||
Returns:
|
||
bool: 是否更新成功
|
||
"""
|
||
try:
|
||
collection_name = f"kb_{knowledge_base_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 获取该文件的所有向量
|
||
all_docs = vector_store.get(
|
||
where={"file_id": file_id},
|
||
include=["metadatas"]
|
||
)
|
||
|
||
if not all_docs or 'ids' not in all_docs or not all_docs['ids']:
|
||
logger.warning(f"未找到 file_id={file_id} 的向量 (kb_id={knowledge_base_id})")
|
||
return False
|
||
|
||
# 更新每个向量的 metadata(添加 file_summary)
|
||
vector_ids = all_docs['ids']
|
||
updated_metadatas = []
|
||
for metadata in all_docs['metadatas']:
|
||
updated_metadata = metadata.copy()
|
||
updated_metadata['file_summary'] = summary
|
||
updated_metadatas.append(updated_metadata)
|
||
|
||
# ChromaDB 更新 metadata
|
||
vector_store._collection.update(
|
||
ids=vector_ids,
|
||
metadatas=updated_metadatas
|
||
)
|
||
|
||
logger.info(f"✅ 已更新 ChromaDB 中 {len(vector_ids)} 个向量的 summary (kb_id={knowledge_base_id}, file_id={file_id})")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新 ChromaDB summary 失败 (kb_id={knowledge_base_id}, file_id={file_id}): {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def delete_thread_vectors(
|
||
self,
|
||
thread_id: str,
|
||
vector_ids: List[str]
|
||
) -> bool:
|
||
"""
|
||
删除聊天对话中的向量
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
vector_ids: 向量 ID 列表
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
try:
|
||
if not vector_ids:
|
||
return True
|
||
|
||
collection_name = f"thread_{thread_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 删除指定的向量
|
||
vector_store.delete(ids=vector_ids)
|
||
logger.info(f"从会话 {thread_id} 删除 {len(vector_ids)} 个向量")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除向量失败: {e}")
|
||
return False
|
||
|
||
def delete_thread_collection(
|
||
self,
|
||
thread_id: str
|
||
) -> bool:
|
||
"""
|
||
删除聊天对话的整个向量集合
|
||
|
||
Args:
|
||
thread_id: 会话线程 ID
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
try:
|
||
collection_name = f"thread_{thread_id}"
|
||
persist_directory = os.path.join(self.vector_store_path, collection_name)
|
||
|
||
# 删除向量库目录
|
||
import shutil
|
||
if os.path.exists(persist_directory):
|
||
shutil.rmtree(persist_directory)
|
||
logger.info(f"删除向量库: {collection_name}")
|
||
return True
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除向量库失败: {e}")
|
||
return False
|
||
|
||
def search_similar(
|
||
self,
|
||
knowledge_base_id: int,
|
||
query: str,
|
||
k: int = 5
|
||
) -> List[dict]:
|
||
"""
|
||
在知识库中搜索相似文档
|
||
|
||
Args:
|
||
knowledge_base_id: 知识库 ID
|
||
query: 查询文本
|
||
k: 返回结果数量
|
||
|
||
Returns:
|
||
List[dict]: 相似文档列表
|
||
"""
|
||
try:
|
||
collection_name = f"kb_{knowledge_base_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 相似度搜索
|
||
results = vector_store.similarity_search_with_score(query, k=k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc, score in results:
|
||
formatted_results.append({
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"score": float(score)
|
||
})
|
||
|
||
return formatted_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索相似文档失败: {e}")
|
||
return []
|
||
|
||
def delete_vectors_by_ids(
|
||
self,
|
||
knowledge_base_id: int,
|
||
vector_ids: List[str]
|
||
) -> bool:
|
||
"""
|
||
根据向量 ID 列表删除向量库中的向量
|
||
|
||
Args:
|
||
knowledge_base_id: 知识库 ID
|
||
vector_ids: 向量 ID 列表
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
try:
|
||
if not vector_ids:
|
||
return True
|
||
|
||
collection_name = f"kb_{knowledge_base_id}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 删除指定的向量
|
||
vector_store.delete(ids=vector_ids)
|
||
logger.info(f"从知识库 {knowledge_base_id} 删除 {len(vector_ids)} 个向量")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除向量失败: {e}")
|
||
return False
|
||
|
||
def delete_collection(self, knowledge_base_id: int) -> bool:
|
||
"""
|
||
删除知识库的整个向量集合
|
||
|
||
Args:
|
||
knowledge_base_id: 知识库 ID
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
try:
|
||
collection_name = f"kb_{knowledge_base_id}"
|
||
persist_directory = os.path.join(self.vector_store_path, collection_name)
|
||
|
||
# 删除向量库目录
|
||
import shutil
|
||
if os.path.exists(persist_directory):
|
||
shutil.rmtree(persist_directory)
|
||
logger.info(f"删除向量库: {collection_name}")
|
||
return True
|
||
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除向量库失败: {e}")
|
||
return False
|
||
|
||
# ----- 知识图谱 RAG(Chroma 集合名 knowledge_graph_{graphs.id};兼容旧名 novel_kg_*) -----
|
||
|
||
def delete_knowledge_graph_collection(self, knowledge_graph_pk: int) -> bool:
|
||
"""删除知识图谱对应的 Chroma 持久化目录(含旧集合 novel_kg_*)。"""
|
||
import shutil
|
||
|
||
ok = True
|
||
for name in (f"knowledge_graph_{knowledge_graph_pk}", f"novel_kg_{knowledge_graph_pk}"):
|
||
try:
|
||
persist_directory = os.path.join(self.vector_store_path, name)
|
||
if os.path.exists(persist_directory):
|
||
shutil.rmtree(persist_directory)
|
||
logger.info(f"删除知识图谱向量库: {name}")
|
||
except Exception as e:
|
||
logger.error(f"删除知识图谱向量库失败 {name}: {e}")
|
||
ok = False
|
||
return ok
|
||
|
||
def index_knowledge_graph_text(self, knowledge_graph_pk: int, text: str) -> int:
|
||
"""
|
||
将资料全文分块写入 Chroma。
|
||
|
||
Returns:
|
||
写入的分块数量
|
||
"""
|
||
from langchain_core.documents import Document
|
||
|
||
if not text or not text.strip():
|
||
return 0
|
||
|
||
self.delete_knowledge_graph_collection(knowledge_graph_pk)
|
||
doc = Document(
|
||
page_content=text.strip(),
|
||
metadata={"source": f"knowledge_graph_{knowledge_graph_pk}", "kind": "knowledge_graph"},
|
||
)
|
||
splits = self.text_splitter.split_documents([doc])
|
||
if not splits:
|
||
return 0
|
||
for idx, d in enumerate(splits):
|
||
if not d.metadata:
|
||
d.metadata = {}
|
||
d.metadata["chunk_index"] = idx
|
||
|
||
collection_name = f"knowledge_graph_{knowledge_graph_pk}"
|
||
vector_store = self.get_vector_store(collection_name)
|
||
vector_store.add_documents(splits)
|
||
logger.info(f"知识图谱向量化完成 knowledge_graph_pk={knowledge_graph_pk} chunks={len(splits)}")
|
||
return len(splits)
|
||
|
||
def search_similar_knowledge_graph(
|
||
self,
|
||
knowledge_graph_pk: int,
|
||
query: str,
|
||
k: int = 5,
|
||
) -> List[dict]:
|
||
"""在知识图谱向量集合中检索(优先新集合名,兼容 novel_kg_*)。"""
|
||
for name in (f"knowledge_graph_{knowledge_graph_pk}", f"novel_kg_{knowledge_graph_pk}"):
|
||
persist_directory = os.path.join(self.vector_store_path, name)
|
||
if not os.path.exists(persist_directory):
|
||
continue
|
||
try:
|
||
vector_store = self.get_vector_store(name)
|
||
results = vector_store.similarity_search_with_score(query, k=k)
|
||
formatted: List[dict] = []
|
||
for doc, score in results:
|
||
formatted.append({
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"score": float(score),
|
||
})
|
||
return formatted
|
||
except Exception as e:
|
||
logger.error(f"知识图谱向量检索失败 ({name}): {e}")
|
||
return []
|
||
|
||
# ==================== 增强的 RAG 功能 ====================
|
||
|
||
def add_summary_chunk(
|
||
self,
|
||
collection_name: str,
|
||
file_id: int,
|
||
file_name: str,
|
||
summary_text: str,
|
||
metadata: Optional[dict] = None
|
||
) -> Optional[str]:
|
||
"""
|
||
添加文件摘要 chunk
|
||
|
||
Args:
|
||
collection_name: 集合名称
|
||
file_id: 文件 ID
|
||
file_name: 文件名
|
||
summary_text: 摘要文本
|
||
metadata: 额外的元数据
|
||
|
||
Returns:
|
||
Optional[str]: 摘要 chunk 的 ID
|
||
"""
|
||
try:
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 构建摘要 metadata
|
||
summary_metadata = {
|
||
"file_id": file_id,
|
||
"file_name": file_name,
|
||
"chunk_type": "summary", # 标记为摘要类型
|
||
"chunk_index": -1, # 摘要没有索引
|
||
}
|
||
|
||
if metadata:
|
||
summary_metadata.update(metadata)
|
||
|
||
# 添加到向量库
|
||
ids = vector_store.add_texts(
|
||
texts=[summary_text],
|
||
metadatas=[summary_metadata]
|
||
)
|
||
|
||
logger.info(f"添加摘要 chunk 成功: {file_name}")
|
||
return ids[0] if ids else None
|
||
|
||
except Exception as e:
|
||
logger.error(f"添加摘要 chunk 失败: {e}")
|
||
return None
|
||
|
||
def search_by_chunk_type(
|
||
self,
|
||
collection_name: str,
|
||
query: str,
|
||
chunk_type: str = "text",
|
||
top_k: int = 5,
|
||
filter_metadata: Optional[dict] = None
|
||
) -> List[Tuple[str, dict, float]]:
|
||
"""
|
||
基于 chunk 类型检索
|
||
|
||
Args:
|
||
collection_name: 集合名称
|
||
query: 查询文本
|
||
chunk_type: chunk 类型 (summary, text)
|
||
top_k: 返回结果数量
|
||
filter_metadata: 额外的过滤条件
|
||
|
||
Returns:
|
||
List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表
|
||
"""
|
||
try:
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 构建过滤条件
|
||
where_filter = {"chunk_type": chunk_type}
|
||
if filter_metadata:
|
||
where_filter.update(filter_metadata)
|
||
|
||
# 执行检索
|
||
results = vector_store.similarity_search_with_score(
|
||
query=query,
|
||
k=top_k,
|
||
filter=where_filter
|
||
)
|
||
|
||
return [
|
||
(doc.page_content, doc.metadata, score)
|
||
for doc, score in results
|
||
]
|
||
|
||
except Exception as e:
|
||
logger.error(f"基于类型检索失败: {e}")
|
||
return []
|
||
|
||
def get_file_all_chunks(
|
||
self,
|
||
collection_name: str,
|
||
file_id: int,
|
||
chunk_type: str = "text"
|
||
) -> List[Tuple[str, dict]]:
|
||
"""
|
||
获取文件的所有 chunks(用于全文检索)
|
||
|
||
Args:
|
||
collection_name: 集合名称
|
||
file_id: 文件 ID
|
||
chunk_type: chunk 类型
|
||
|
||
Returns:
|
||
List[Tuple[str, dict]]: (内容, 元数据) 的列表
|
||
"""
|
||
try:
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 使用 where 过滤获取所有匹配的文档
|
||
results = vector_store.get(
|
||
where={
|
||
"file_id": file_id,
|
||
"chunk_type": chunk_type
|
||
}
|
||
)
|
||
|
||
if not results or "documents" not in results:
|
||
return []
|
||
|
||
documents = results["documents"]
|
||
metadatas = results["metadatas"]
|
||
|
||
# 按 chunk_index 排序
|
||
chunks = list(zip(documents, metadatas))
|
||
chunks.sort(key=lambda x: x[1].get("chunk_index", 0))
|
||
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文件所有 chunks 失败: {e}")
|
||
return []
|
||
|
||
def hybrid_search(
|
||
self,
|
||
collection_name: str,
|
||
query: str,
|
||
top_k: int = 5,
|
||
filter_metadata: Optional[dict] = None,
|
||
dense_weight: float = 0.7,
|
||
sparse_weight: float = 0.3
|
||
) -> List[Tuple[str, dict, float]]:
|
||
"""
|
||
混合检索(dense + sparse)
|
||
|
||
注意:这是一个简化版本,真正的混合检索需要支持 BM25 等稀疏检索算法。
|
||
ChromaDB 目前主要支持 dense 向量检索。
|
||
|
||
Args:
|
||
collection_name: 集合名称
|
||
query: 查询文本
|
||
top_k: 返回结果数量
|
||
filter_metadata: 过滤条件
|
||
dense_weight: dense 检索权重
|
||
sparse_weight: sparse 检索权重(当前未实现)
|
||
|
||
Returns:
|
||
List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表
|
||
"""
|
||
try:
|
||
# TODO: 实现真正的混合检索
|
||
# 当前仅使用 dense 检索
|
||
logger.warning("混合检索功能当前仅支持 dense 向量检索,需要额外集成 BM25")
|
||
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
# 执行 dense 检索
|
||
results = vector_store.similarity_search_with_score(
|
||
query=query,
|
||
k=top_k,
|
||
filter=filter_metadata
|
||
)
|
||
|
||
return [
|
||
(doc.page_content, doc.metadata, score)
|
||
for doc, score in results
|
||
]
|
||
|
||
except Exception as e:
|
||
logger.error(f"混合检索失败: {e}")
|
||
return []
|
||
|
||
def search_with_rerank(
|
||
self,
|
||
collection_name: str,
|
||
query: str,
|
||
top_k: int = 5,
|
||
rerank_top_k: int = 3,
|
||
filter_metadata: Optional[dict] = None
|
||
) -> List[Tuple[str, dict, float]]:
|
||
"""
|
||
带重排序的检索
|
||
|
||
Args:
|
||
collection_name: 集合名称
|
||
query: 查询文本
|
||
top_k: 初次检索数量
|
||
rerank_top_k: 重排序后返回数量
|
||
filter_metadata: 过滤条件
|
||
|
||
Returns:
|
||
List[Tuple[str, dict, float]]: (内容, 元数据, 分数) 的列表
|
||
"""
|
||
try:
|
||
# 先执行初次检索,获取较多结果
|
||
results = self.search_similar(
|
||
collection_name=collection_name,
|
||
query=query,
|
||
top_k=top_k,
|
||
filter_metadata=filter_metadata
|
||
)
|
||
|
||
# TODO: 使用 rerank 模型对结果进行重排序
|
||
# 当前仅返回 top_k 结果
|
||
return results[:rerank_top_k]
|
||
|
||
except Exception as e:
|
||
logger.error(f"重排序检索失败: {e}")
|
||
return []
|
||
|
||
def get_summary_by_file_ids(
|
||
self,
|
||
collection_name: str,
|
||
file_ids: List[int]
|
||
) -> Dict[int, str]:
|
||
"""
|
||
批量获取文件摘要
|
||
|
||
Args:
|
||
collection_name: 集合名称
|
||
file_ids: 文件 ID 列表
|
||
|
||
Returns:
|
||
Dict[int, str]: 文件 ID 到摘要的映射
|
||
"""
|
||
try:
|
||
vector_store = self.get_vector_store(collection_name)
|
||
|
||
summaries = {}
|
||
for file_id in file_ids:
|
||
results = vector_store.get(
|
||
where={
|
||
"file_id": file_id,
|
||
"chunk_type": "summary"
|
||
}
|
||
)
|
||
|
||
if results and "documents" in results and results["documents"]:
|
||
summaries[file_id] = results["documents"][0]
|
||
|
||
return summaries
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量获取摘要失败: {e}")
|
||
return {}
|
||
|
||
|
||
# 全局向量服务实例
|
||
_vector_service = None
|
||
|
||
|
||
def get_vector_service() -> VectorService:
|
||
"""获取全局向量服务实例"""
|
||
global _vector_service
|
||
if _vector_service is None:
|
||
_vector_service = VectorService()
|
||
return _vector_service
|
||
|